transparent_edge.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import itertools
  5. import socket
  6. import struct
  7. from dataclasses import dataclass, field
  8. from typing import Awaitable, Callable
  9. from .config import Config
  10. from .protocol import STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, encode_json
  11. from .relay_client import RelayConnection, RelayManager
  12. SO_ORIGINAL_DST = 80
  13. IP6T_SO_ORIGINAL_DST = 80
  14. IP_RECVORIGDSTADDR = 20
  15. IPV6_RECVORIGDSTADDR = 74
  16. @dataclass(frozen=True)
  17. class TargetAddress:
  18. host: str
  19. port: int
  20. family: int
  21. @dataclass(frozen=True)
  22. class PeerAddress:
  23. host: str
  24. port: int
  25. family: int
  26. def parse_sockaddr(raw: bytes) -> TargetAddress:
  27. if len(raw) < 8:
  28. raise ValueError("invalid transparent destination payload")
  29. family = struct.unpack_from("=H", raw, 0)[0]
  30. port = struct.unpack_from("!H", raw, 2)[0]
  31. if family == socket.AF_INET:
  32. host = socket.inet_ntoa(raw[4:8])
  33. return TargetAddress(host=host, port=port, family=family)
  34. if family == socket.AF_INET6:
  35. if len(raw) < 28:
  36. raise ValueError("invalid IPv6 transparent destination payload")
  37. host = socket.inet_ntop(socket.AF_INET6, raw[8:24])
  38. return TargetAddress(host=host, port=port, family=family)
  39. raise ValueError(f"unsupported family={family}")
  40. class BasePath:
  41. def __init__(self, name: str, on_frame: Callable[["BasePath", str, bytes | None], Awaitable[None]]) -> None:
  42. self.name = name
  43. self.on_frame = on_frame
  44. self.opened = False
  45. self.closed = False
  46. async def open(self, target: TargetAddress) -> None:
  47. raise NotImplementedError
  48. async def send(self, data: bytes) -> None:
  49. raise NotImplementedError
  50. async def close(self) -> None:
  51. raise NotImplementedError
  52. class DirectTcpPath(BasePath):
  53. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]]) -> None:
  54. super().__init__(name, on_frame)
  55. self.reader: asyncio.StreamReader | None = None
  56. self.writer: asyncio.StreamWriter | None = None
  57. self.pump_task: asyncio.Task | None = None
  58. async def open(self, target: TargetAddress) -> None:
  59. try:
  60. family = socket.AF_INET6 if target.family == socket.AF_INET6 else socket.AF_INET
  61. self.reader, self.writer = await asyncio.open_connection(host=target.host, port=target.port, family=family)
  62. self.opened = True
  63. self.pump_task = asyncio.create_task(self._pump())
  64. await self.on_frame(self, "status", b"ok")
  65. except Exception as exc:
  66. await self.on_frame(self, "status", str(exc).encode())
  67. async def _pump(self) -> None:
  68. assert self.reader is not None
  69. try:
  70. while True:
  71. chunk = await self.reader.read(65536)
  72. if not chunk:
  73. break
  74. await self.on_frame(self, "data", chunk)
  75. except Exception:
  76. pass
  77. finally:
  78. await self.on_frame(self, "close", None)
  79. async def send(self, data: bytes) -> None:
  80. if self.closed or self.writer is None:
  81. return
  82. self.writer.write(data)
  83. await self.writer.drain()
  84. async def close(self) -> None:
  85. if self.closed:
  86. return
  87. self.closed = True
  88. if self.pump_task and self.pump_task is not asyncio.current_task():
  89. self.pump_task.cancel()
  90. with contextlib.suppress(Exception):
  91. await self.pump_task
  92. if self.writer:
  93. self.writer.close()
  94. with contextlib.suppress(Exception):
  95. await self.writer.wait_closed()
  96. class RelayTcpPath(BasePath):
  97. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], connection: RelayConnection, session_id: int, stream_id: int) -> None:
  98. super().__init__(name, on_frame)
  99. self.connection = connection
  100. self.session_id = session_id
  101. self.stream_id = stream_id
  102. async def open(self, target: TargetAddress) -> None:
  103. if self.connection.closed:
  104. await self.on_frame(self, "status", b"relay unavailable")
  105. return
  106. self.connection.bind(self.session_id, self.stream_id, self._handle_frame)
  107. try:
  108. await self.connection.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, encode_json({"host": target.host, "port": target.port, "family": target.family})))
  109. except Exception as exc:
  110. await self.on_frame(self, "status", str(exc).encode())
  111. async def _handle_frame(self, _conn: RelayConnection, frame: Frame) -> None:
  112. if frame.kind == TCP_STATUS:
  113. if frame.packet_id == STATUS_OK:
  114. self.opened = True
  115. await self.on_frame(self, "status", b"ok")
  116. else:
  117. await self.on_frame(self, "status", frame.payload)
  118. return
  119. if frame.kind == TCP_DATA:
  120. await self.on_frame(self, "data", frame.payload)
  121. return
  122. if frame.kind == TCP_CLOSE:
  123. await self.on_frame(self, "close", None)
  124. async def send(self, data: bytes) -> None:
  125. if self.closed or self.connection.closed:
  126. return
  127. await self.connection.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, data))
  128. async def close(self) -> None:
  129. if self.closed:
  130. return
  131. self.closed = True
  132. self.connection.unbind(self.session_id, self.stream_id)
  133. if not self.connection.closed:
  134. with contextlib.suppress(Exception):
  135. await self.connection.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b""))
  136. @dataclass
  137. class TransparentSession:
  138. session_id: int
  139. target: TargetAddress
  140. reader: asyncio.StreamReader
  141. writer: asyncio.StreamWriter
  142. paths: list[BasePath]
  143. warmup_bytes: int
  144. loser_grace_ms: int
  145. stats: dict[str, int]
  146. target_stats: dict[tuple[str, int], dict[str, int]]
  147. opened_count: int = 0
  148. status_count: int = 0
  149. errors: list[str] = field(default_factory=list)
  150. winner: BasePath | None = None
  151. uplink_bytes: int = 0
  152. open_event: asyncio.Event = field(default_factory=asyncio.Event)
  153. winner_event: asyncio.Event = field(default_factory=asyncio.Event)
  154. closed: bool = False
  155. pump_task: asyncio.Task | None = None
  156. loser_close_task: asyncio.Task | None = None
  157. def _record_win(self, winner: BasePath) -> None:
  158. self.stats[winner.name] = self.stats.get(winner.name, 0) + 1
  159. key = (self.target.host, self.target.port)
  160. target_stats = self.target_stats.setdefault(key, {})
  161. target_stats[winner.name] = target_stats.get(winner.name, 0) + 1
  162. direct_wins = self.stats.get("direct", 0)
  163. relay_wins = sum(count for name, count in self.stats.items() if name != "direct")
  164. target_direct = target_stats.get("direct", 0)
  165. target_relay = sum(count for name, count in target_stats.items() if name != "direct")
  166. relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.stats.items()) if name != "direct") or "none"
  167. target_detail = ", ".join(f"{name}={count}" for name, count in sorted(target_stats.items()) if name != "direct") or "none"
  168. target_pref = "relay" if target_relay > target_direct else "direct"
  169. print(f"[edge] tcp win session={self.session_id} target={self.target.host}:{self.target.port} winner={winner.name} direct={direct_wins} relay={relay_wins} relay_breakdown={relay_detail} target_pref={target_pref} target_direct={target_direct} target_relay={target_relay} target_breakdown={target_detail}")
  170. async def start(self) -> None:
  171. await asyncio.gather(*(path.open(self.target) for path in self.paths), return_exceptions=True)
  172. await asyncio.wait_for(self.open_event.wait(), timeout=10)
  173. if self.opened_count == 0:
  174. raise ConnectionError(self.errors[0] if self.errors else "all paths failed")
  175. self.pump_task = asyncio.create_task(self._pump_local())
  176. async def _pump_local(self) -> None:
  177. try:
  178. while True:
  179. chunk = await self.reader.read(65536)
  180. if not chunk:
  181. break
  182. self.uplink_bytes += len(chunk)
  183. active = [path for path in self.paths if path.opened and not path.closed]
  184. if not active:
  185. break
  186. if self.uplink_bytes <= self.warmup_bytes:
  187. await asyncio.gather(*(path.send(chunk) for path in active), return_exceptions=True)
  188. else:
  189. if self.winner is None:
  190. await self.winner_event.wait()
  191. if self.winner:
  192. await self.winner.send(chunk)
  193. except Exception:
  194. pass
  195. finally:
  196. await self.close()
  197. async def handle_path(self, path: BasePath, event: str, payload: bytes | None) -> None:
  198. if self.closed:
  199. return
  200. if event == "status":
  201. self.status_count += 1
  202. if payload == b"ok":
  203. self.opened_count += 1
  204. elif payload is not None:
  205. self.errors.append(payload.decode("utf-8", errors="replace"))
  206. if self.opened_count > 0 or self.status_count == len(self.paths):
  207. self.open_event.set()
  208. return
  209. if event == "data":
  210. if self.winner is None:
  211. self.winner = path
  212. self._record_win(path)
  213. self.winner_event.set()
  214. if self.loser_grace_ms > 0:
  215. self.loser_close_task = asyncio.create_task(self._close_losers_after_grace(path))
  216. else:
  217. await self._close_losers(path)
  218. if path is self.winner and payload is not None:
  219. self.writer.write(payload)
  220. await self.writer.drain()
  221. return
  222. if event == "close":
  223. path.closed = True
  224. if self.winner is None:
  225. remaining = [candidate for candidate in self.paths if candidate.opened and not candidate.closed]
  226. if not remaining:
  227. await self.close()
  228. elif path is self.winner:
  229. await self.close()
  230. async def _close_losers(self, winner: BasePath) -> None:
  231. await asyncio.gather(*(path.close() for path in self.paths if path is not winner), return_exceptions=True)
  232. async def _close_losers_after_grace(self, winner: BasePath) -> None:
  233. await asyncio.sleep(self.loser_grace_ms / 1000)
  234. if not self.closed:
  235. await self._close_losers(winner)
  236. async def close(self) -> None:
  237. if self.closed:
  238. return
  239. self.closed = True
  240. print(f"[edge] session={self.session_id} closed target={self.target.host}:{self.target.port}")
  241. if self.pump_task and self.pump_task is not asyncio.current_task():
  242. self.pump_task.cancel()
  243. with contextlib.suppress(Exception):
  244. await self.pump_task
  245. if self.loser_close_task and self.loser_close_task is not asyncio.current_task():
  246. self.loser_close_task.cancel()
  247. with contextlib.suppress(Exception):
  248. await self.loser_close_task
  249. await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True)
  250. self.writer.close()
  251. with contextlib.suppress(Exception):
  252. await self.writer.wait_closed()
  253. class DirectUdpPath(BasePath):
  254. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], target: TargetAddress) -> None:
  255. super().__init__(name, on_frame)
  256. self.target = target
  257. self.socket: socket.socket | None = None
  258. self.read_task: asyncio.Task | None = None
  259. async def open(self, _target: TargetAddress) -> None:
  260. try:
  261. family = socket.AF_INET6 if self.target.family == socket.AF_INET6 else socket.AF_INET
  262. self.socket = socket.socket(family, socket.SOCK_DGRAM)
  263. self.socket.setblocking(False)
  264. await asyncio.get_running_loop().sock_connect(self.socket, (self.target.host, self.target.port))
  265. self.opened = True
  266. self.read_task = asyncio.create_task(self._pump())
  267. await self.on_frame(self, "status", b"ok")
  268. except Exception as exc:
  269. await self.on_frame(self, "status", str(exc).encode())
  270. async def _pump(self) -> None:
  271. assert self.socket is not None
  272. loop = asyncio.get_running_loop()
  273. try:
  274. while True:
  275. data = await loop.sock_recv(self.socket, 65535)
  276. if not data:
  277. break
  278. await self.on_frame(self, "data", data)
  279. except Exception:
  280. pass
  281. finally:
  282. await self.on_frame(self, "close", None)
  283. async def send(self, data: bytes) -> None:
  284. if self.closed or self.socket is None:
  285. return
  286. await asyncio.get_running_loop().sock_sendall(self.socket, data)
  287. async def close(self) -> None:
  288. if self.closed:
  289. return
  290. self.closed = True
  291. if self.read_task and self.read_task is not asyncio.current_task():
  292. self.read_task.cancel()
  293. with contextlib.suppress(Exception):
  294. await self.read_task
  295. if self.socket:
  296. self.socket.close()
  297. class RelayUdpPath(BasePath):
  298. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], connection: RelayConnection, session_id: int, stream_id: int, target: TargetAddress) -> None:
  299. super().__init__(name, on_frame)
  300. self.connection = connection
  301. self.session_id = session_id
  302. self.stream_id = stream_id
  303. self.target = target
  304. async def open(self, _target: TargetAddress) -> None:
  305. if self.connection.closed:
  306. await self.on_frame(self, "status", b"relay unavailable")
  307. return
  308. self.connection.bind(self.session_id, self.stream_id, self._handle_frame)
  309. self.opened = True
  310. await self.on_frame(self, "status", b"ok")
  311. async def _handle_frame(self, _conn: RelayConnection, frame: Frame) -> None:
  312. if frame.kind == UDP_RECV:
  313. await self.on_frame(self, "data", frame.payload)
  314. async def send(self, data: bytes) -> None:
  315. if self.closed or self.connection.closed:
  316. return
  317. meta = encode_json({"host": self.target.host, "port": self.target.port, "family": self.target.family})
  318. payload = meta + data
  319. await self.connection.send(Frame(UDP_SEND, self.session_id, self.stream_id, 0, len(meta), payload))
  320. async def close(self) -> None:
  321. if self.closed:
  322. return
  323. self.closed = True
  324. self.connection.unbind(self.session_id, self.stream_id)
  325. @dataclass
  326. class UdpFlow:
  327. flow_id: int
  328. source: PeerAddress
  329. target: TargetAddress
  330. send_response: Callable[[PeerAddress, bytes], Awaitable[None]]
  331. paths: list[BasePath]
  332. winner: BasePath | None = None
  333. closed: bool = False
  334. last_activity: float = 0.0
  335. async def start(self) -> None:
  336. await asyncio.gather(*(path.open(self.target) for path in self.paths), return_exceptions=True)
  337. async def send(self, payload: bytes) -> None:
  338. self.last_activity = asyncio.get_running_loop().time()
  339. active = [path for path in self.paths if path.opened and not path.closed]
  340. if self.winner is None:
  341. await asyncio.gather(*(path.send(payload) for path in active), return_exceptions=True)
  342. elif not self.winner.closed:
  343. await self.winner.send(payload)
  344. async def handle_path(self, path: BasePath, event: str, payload: bytes | None) -> None:
  345. self.last_activity = asyncio.get_running_loop().time()
  346. if event == "data" and payload is not None:
  347. if self.winner is None:
  348. self.winner = path
  349. print(f"[edge] udp flow={self.flow_id} winner={path.name} target={self.target.host}:{self.target.port}")
  350. if path is self.winner:
  351. await self.send_response(self.source, payload)
  352. if event == "close":
  353. path.closed = True
  354. async def close(self) -> None:
  355. if self.closed:
  356. return
  357. self.closed = True
  358. await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True)
  359. class TransparentUdpListener:
  360. def __init__(self, edge: "TransparentEdge", family: int, bind_host: str, port: int) -> None:
  361. self.edge = edge
  362. self.family = family
  363. self.bind_host = bind_host
  364. self.port = port
  365. self.socket: socket.socket | None = None
  366. def start(self) -> None:
  367. sock = socket.socket(self.family, socket.SOCK_DGRAM)
  368. sock.setblocking(False)
  369. if self.family == socket.AF_INET:
  370. sock.setsockopt(socket.SOL_IP, IP_RECVORIGDSTADDR, 1)
  371. sock.bind((self.bind_host, self.port))
  372. else:
  373. sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
  374. sock.setsockopt(socket.IPPROTO_IPV6, IPV6_RECVORIGDSTADDR, 1)
  375. sock.bind((self.bind_host, self.port, 0, 0))
  376. self.socket = sock
  377. asyncio.get_running_loop().add_reader(sock.fileno(), self._on_readable)
  378. print(f"[edge] transparent udp listening on {sock.getsockname()}")
  379. def _on_readable(self) -> None:
  380. assert self.socket is not None
  381. try:
  382. data, ancdata, _flags, src = self.socket.recvmsg(65535, 512)
  383. except BlockingIOError:
  384. return
  385. except Exception:
  386. return
  387. original = None
  388. for level, ctype, cdata in ancdata:
  389. if self.family == socket.AF_INET and level == socket.SOL_IP and ctype == IP_RECVORIGDSTADDR:
  390. original = parse_sockaddr(cdata)
  391. break
  392. if self.family == socket.AF_INET6 and level == socket.IPPROTO_IPV6 and ctype == IPV6_RECVORIGDSTADDR:
  393. original = parse_sockaddr(cdata)
  394. break
  395. if original is None:
  396. return
  397. if self.family == socket.AF_INET:
  398. source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET)
  399. else:
  400. source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET6)
  401. if original.port == self.port and (original.host in ("127.0.0.1", "::1") or original.host == self.bind_host):
  402. return
  403. asyncio.create_task(self.edge.handle_udp_datagram(source, original, data, self))
  404. async def send_response(self, source: PeerAddress, payload: bytes) -> None:
  405. assert self.socket is not None
  406. if source.family == socket.AF_INET:
  407. self.socket.sendto(payload, (source.host, source.port))
  408. else:
  409. self.socket.sendto(payload, (source.host, source.port, 0, 0))
  410. async def close(self) -> None:
  411. if self.socket is None:
  412. return
  413. asyncio.get_running_loop().remove_reader(self.socket.fileno())
  414. self.socket.close()
  415. self.socket = None
  416. class TransparentEdge:
  417. def __init__(self, listen_host: str, listen_port: int, config: Config, enable_udp: bool = False) -> None:
  418. self.listen_host = listen_host
  419. self.listen_port = listen_port
  420. self.config = config
  421. self.enable_udp = enable_udp
  422. self.manager = RelayManager(config)
  423. self.session_ids = itertools.count(1)
  424. self.stream_ids = itertools.count(1)
  425. self.udp_listeners: list[TransparentUdpListener] = []
  426. self.udp_flows: dict[tuple[PeerAddress, TargetAddress], UdpFlow] = {}
  427. self.udp_flow_ids = itertools.count(1)
  428. self.udp_gc_task: asyncio.Task | None = None
  429. self.tcp_win_counts: dict[str, int] = {}
  430. self.tcp_target_wins: dict[tuple[str, int], dict[str, int]] = {}
  431. async def start(self) -> None:
  432. await self.manager.start()
  433. print(f"[edge] relay snapshot: {self.manager.snapshot()}")
  434. server4 = await asyncio.start_server(self._accept, self.listen_host, self.listen_port, family=socket.AF_INET)
  435. sockets = [str(sock.getsockname()) for sock in server4.sockets or []]
  436. server6 = None
  437. if self.listen_host in ("::", "::1", "0.0.0.0", "127.0.0.1"):
  438. host6 = "::1" if self.listen_host == "127.0.0.1" else "::"
  439. try:
  440. server6 = await asyncio.start_server(self._accept, host6, self.listen_port, family=socket.AF_INET6)
  441. sockets.extend(str(sock.getsockname()) for sock in server6.sockets or [])
  442. except Exception as exc:
  443. print(f"[edge] ipv6 tcp listener skipped: {exc!r}")
  444. if self.enable_udp:
  445. self._start_udp_listeners()
  446. self.udp_gc_task = asyncio.create_task(self._gc_udp_flows())
  447. print(f"[edge] transparent tcp listening on {', '.join(sockets)}")
  448. if server6 is None:
  449. async with server4:
  450. await server4.serve_forever()
  451. else:
  452. async with server4, server6:
  453. await asyncio.gather(server4.serve_forever(), server6.serve_forever())
  454. def _start_udp_listeners(self) -> None:
  455. binds = []
  456. if self.listen_host == "127.0.0.1":
  457. binds = [(socket.AF_INET, "127.0.0.1"), (socket.AF_INET6, "::1")]
  458. elif self.listen_host == "0.0.0.0":
  459. binds = [(socket.AF_INET, "0.0.0.0"), (socket.AF_INET6, "::")]
  460. else:
  461. family = socket.AF_INET6 if ":" in self.listen_host else socket.AF_INET
  462. binds = [(family, self.listen_host)]
  463. for family, host in binds:
  464. try:
  465. listener = TransparentUdpListener(self, family, host, self.listen_port)
  466. listener.start()
  467. self.udp_listeners.append(listener)
  468. except Exception as exc:
  469. print(f"[edge] udp listener skipped family={family} host={host} error={exc!r}")
  470. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  471. peer = writer.get_extra_info("peername")
  472. try:
  473. target = self._get_original_dst(writer)
  474. session_id = next(self.session_ids)
  475. session = TransparentSession(session_id=session_id, target=target, reader=reader, writer=writer, paths=[], warmup_bytes=self.config.tcp_warmup_bytes, loser_grace_ms=self.config.tcp_loser_grace_ms, stats=self.tcp_win_counts, target_stats=self.tcp_target_wins)
  476. paths: list[BasePath] = [DirectTcpPath(name="direct", on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload))]
  477. for connection in self.manager.available():
  478. stream_id = next(self.stream_ids)
  479. paths.append(RelayTcpPath(name=connection.node.name, on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload), connection=connection, session_id=session_id, stream_id=stream_id))
  480. session.paths = paths
  481. print(f"[edge] accept peer={peer} session={session_id} target={target.host}:{target.port} candidates={[path.name for path in paths]}")
  482. await session.start()
  483. except Exception as exc:
  484. print(f"[edge] accept failed peer={peer} error={exc!r}")
  485. writer.close()
  486. with contextlib.suppress(Exception):
  487. await writer.wait_closed()
  488. async def _handle_tcp_session(self, session: TransparentSession, path: BasePath, event: str, payload: bytes | None) -> None:
  489. await session.handle_path(path, event, payload)
  490. def _get_original_dst(self, writer: asyncio.StreamWriter) -> TargetAddress:
  491. sock = writer.get_extra_info("socket")
  492. if sock is None:
  493. raise RuntimeError("socket unavailable")
  494. family = sock.family
  495. if family == socket.AF_INET:
  496. raw = sock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, 16)
  497. return parse_sockaddr(raw)
  498. if family == socket.AF_INET6:
  499. raw = sock.getsockopt(socket.IPPROTO_IPV6, IP6T_SO_ORIGINAL_DST, 128)
  500. return parse_sockaddr(raw)
  501. raise RuntimeError(f"unsupported socket family={family}")
  502. async def handle_udp_datagram(self, source: PeerAddress, target: TargetAddress, payload: bytes, listener: TransparentUdpListener) -> None:
  503. if not self.enable_udp:
  504. return
  505. if target.port == self.listen_port and target.host in ("127.0.0.1", "::1", self.listen_host):
  506. return
  507. key = (source, target)
  508. flow = self.udp_flows.get(key)
  509. if flow is None:
  510. flow_id = next(self.udp_flow_ids)
  511. paths: list[BasePath] = [DirectUdpPath(name="direct", on_frame=lambda path, event, data, fid=flow_id: self._handle_udp_path(fid, path, event, data), target=target)]
  512. for connection in self.manager.available():
  513. stream_id = next(self.stream_ids)
  514. paths.append(RelayUdpPath(name=connection.node.name, on_frame=lambda path, event, data, fid=flow_id: self._handle_udp_path(fid, path, event, data), connection=connection, session_id=flow_id, stream_id=stream_id, target=target))
  515. flow = UdpFlow(flow_id=flow_id, source=source, target=target, send_response=listener.send_response, paths=paths)
  516. self.udp_flows[key] = flow
  517. print(f"[edge] udp flow={flow_id} target={target.host}:{target.port} candidates={[path.name for path in paths]}")
  518. await flow.start()
  519. await flow.send(payload)
  520. async def _handle_udp_path(self, flow_id: int, path: BasePath, event: str, payload: bytes | None) -> None:
  521. for flow in list(self.udp_flows.values()):
  522. if flow.flow_id == flow_id:
  523. await flow.handle_path(path, event, payload)
  524. break
  525. async def _gc_udp_flows(self) -> None:
  526. loop = asyncio.get_running_loop()
  527. while True:
  528. await asyncio.sleep(30)
  529. now = loop.time()
  530. stale = [key for key, flow in self.udp_flows.items() if flow.last_activity and now - flow.last_activity > 120]
  531. for key in stale:
  532. flow = self.udp_flows.pop(key, None)
  533. if flow:
  534. await flow.close()