transparent_edge.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. from __future__ import annotations
  2. from pathlib import Path
  3. import asyncio
  4. import contextlib
  5. import itertools
  6. import os
  7. import socket
  8. import struct
  9. from dataclasses import dataclass, field
  10. from typing import Awaitable, Callable
  11. from .config import Config
  12. from .protocol import STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, encode_json
  13. from .relay_client import RelayConnection, RelayManager
  14. SO_ORIGINAL_DST = 80
  15. IP6T_SO_ORIGINAL_DST = 80
  16. IP_RECVORIGDSTADDR = 20
  17. IPV6_RECVORIGDSTADDR = 74
  18. @dataclass(frozen=True)
  19. class TargetAddress:
  20. host: str
  21. port: int
  22. family: int
  23. @dataclass(frozen=True)
  24. class PeerAddress:
  25. host: str
  26. port: int
  27. family: int
  28. def parse_sockaddr(raw: bytes) -> TargetAddress:
  29. if len(raw) < 8:
  30. raise ValueError("invalid transparent destination payload")
  31. family = struct.unpack_from("=H", raw, 0)[0]
  32. port = struct.unpack_from("!H", raw, 2)[0]
  33. if family == socket.AF_INET:
  34. host = socket.inet_ntoa(raw[4:8])
  35. return TargetAddress(host=host, port=port, family=family)
  36. if family == socket.AF_INET6:
  37. if len(raw) < 28:
  38. raise ValueError("invalid IPv6 transparent destination payload")
  39. host = socket.inet_ntop(socket.AF_INET6, raw[8:24])
  40. return TargetAddress(host=host, port=port, family=family)
  41. raise ValueError(f"unsupported family={family}")
  42. class BasePath:
  43. def __init__(self, name: str, on_frame: Callable[["BasePath", str, bytes | None], Awaitable[None]]) -> None:
  44. self.name = name
  45. self.on_frame = on_frame
  46. self.opened = False
  47. self.closed = False
  48. async def open(self, target: TargetAddress) -> None:
  49. raise NotImplementedError
  50. async def send(self, data: bytes) -> None:
  51. raise NotImplementedError
  52. async def close(self) -> None:
  53. raise NotImplementedError
  54. class DirectTcpPath(BasePath):
  55. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], open_timeout: float, happy_eyeballs_delay: float | None, tcp_nodelay: bool = True) -> None:
  56. super().__init__(name, on_frame)
  57. self.reader: asyncio.StreamReader | None = None
  58. self.writer: asyncio.StreamWriter | None = None
  59. self.pump_task: asyncio.Task | None = None
  60. self.open_timeout = open_timeout
  61. self.happy_eyeballs_delay = happy_eyeballs_delay
  62. self.tcp_nodelay = tcp_nodelay
  63. async def open(self, target: TargetAddress) -> None:
  64. try:
  65. family = socket.AF_INET6 if target.family == socket.AF_INET6 else socket.AF_INET
  66. kwargs = {"host": target.host, "port": target.port, "family": family}
  67. if self.happy_eyeballs_delay is not None:
  68. kwargs["happy_eyeballs_delay"] = self.happy_eyeballs_delay
  69. self.reader, self.writer = await asyncio.wait_for(asyncio.open_connection(**kwargs), timeout=self.open_timeout)
  70. sock = self.writer.get_extra_info("socket")
  71. if sock is not None and self.tcp_nodelay:
  72. with contextlib.suppress(OSError):
  73. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  74. self.opened = True
  75. self.pump_task = asyncio.create_task(self._pump())
  76. await self.on_frame(self, "status", b"ok")
  77. except Exception as exc:
  78. await self.on_frame(self, "status", str(exc).encode())
  79. async def _pump(self) -> None:
  80. assert self.reader is not None
  81. try:
  82. while True:
  83. chunk = await self.reader.read(65536)
  84. if not chunk:
  85. break
  86. await self.on_frame(self, "data", chunk)
  87. except Exception:
  88. pass
  89. finally:
  90. await self.on_frame(self, "close", None)
  91. async def send(self, data: bytes) -> None:
  92. if self.closed or self.writer is None:
  93. return
  94. self.writer.write(data)
  95. await self.writer.drain()
  96. async def close(self) -> None:
  97. if self.closed:
  98. return
  99. self.closed = True
  100. if self.pump_task and self.pump_task is not asyncio.current_task():
  101. self.pump_task.cancel()
  102. with contextlib.suppress(Exception):
  103. await self.pump_task
  104. if self.writer:
  105. self.writer.close()
  106. with contextlib.suppress(Exception):
  107. await self.writer.wait_closed()
  108. class RelayTcpPath(BasePath):
  109. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], connection: RelayConnection, session_id: int, stream_id: int) -> None:
  110. super().__init__(name, on_frame)
  111. self.connection = connection
  112. self.session_id = session_id
  113. self.stream_id = stream_id
  114. async def open(self, target: TargetAddress) -> None:
  115. if self.connection.closed:
  116. await self.on_frame(self, "status", b"relay unavailable")
  117. return
  118. self.connection.bind(self.session_id, self.stream_id, self._handle_frame)
  119. try:
  120. await self.connection.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, encode_json({"host": target.host, "port": target.port, "family": target.family})))
  121. except Exception as exc:
  122. await self.on_frame(self, "status", str(exc).encode())
  123. async def _handle_frame(self, _conn: RelayConnection, frame: Frame) -> None:
  124. if frame.kind == TCP_STATUS:
  125. if frame.packet_id == STATUS_OK:
  126. self.opened = True
  127. await self.on_frame(self, "status", b"ok")
  128. else:
  129. await self.on_frame(self, "status", frame.payload)
  130. return
  131. if frame.kind == TCP_DATA:
  132. await self.on_frame(self, "data", frame.payload)
  133. return
  134. if frame.kind == TCP_CLOSE:
  135. await self.on_frame(self, "close", None)
  136. async def send(self, data: bytes) -> None:
  137. if self.closed or self.connection.closed:
  138. return
  139. await self.connection.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, data))
  140. async def close(self) -> None:
  141. if self.closed:
  142. return
  143. self.closed = True
  144. self.connection.unbind(self.session_id, self.stream_id)
  145. if not self.connection.closed:
  146. with contextlib.suppress(Exception):
  147. await self.connection.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b""))
  148. @dataclass
  149. class TransparentSession:
  150. session_id: int
  151. target: TargetAddress
  152. reader: asyncio.StreamReader
  153. writer: asyncio.StreamWriter
  154. paths: list[BasePath]
  155. warmup_bytes: int
  156. loser_grace_ms: int
  157. stats: dict[str, int]
  158. target_stats: dict[tuple[str, int], dict[str, int]]
  159. opened_count: int = 0
  160. status_count: int = 0
  161. errors: list[str] = field(default_factory=list)
  162. winner: BasePath | None = None
  163. uplink_bytes: int = 0
  164. open_event: asyncio.Event = field(default_factory=asyncio.Event)
  165. winner_event: asyncio.Event = field(default_factory=asyncio.Event)
  166. closed: bool = False
  167. pump_task: asyncio.Task | None = None
  168. loser_close_task: asyncio.Task | None = None
  169. def _record_win(self, winner: BasePath) -> None:
  170. self.stats[winner.name] = self.stats.get(winner.name, 0) + 1
  171. key = (self.target.host, self.target.port)
  172. target_stats = self.target_stats.setdefault(key, {})
  173. target_stats[winner.name] = target_stats.get(winner.name, 0) + 1
  174. direct_wins = self.stats.get("direct", 0)
  175. relay_wins = sum(count for name, count in self.stats.items() if name != "direct")
  176. target_direct = target_stats.get("direct", 0)
  177. target_relay = sum(count for name, count in target_stats.items() if name != "direct")
  178. relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.stats.items()) if name != "direct") or "none"
  179. target_detail = ", ".join(f"{name}={count}" for name, count in sorted(target_stats.items()) if name != "direct") or "none"
  180. target_pref = "relay" if target_relay > target_direct else "direct"
  181. print(f"[edge] tcp win session={self.session_id} target={self.target.host}:{self.target.port} winner={winner.name} direct={direct_wins} relay={relay_wins} relay_breakdown={relay_detail} target_pref={target_pref} target_direct={target_direct} target_relay={target_relay} target_breakdown={target_detail}")
  182. async def start(self) -> None:
  183. await asyncio.gather(*(path.open(self.target) for path in self.paths), return_exceptions=True)
  184. await asyncio.wait_for(self.open_event.wait(), timeout=15)
  185. if self.opened_count == 0:
  186. raise ConnectionError(self.errors[0] if self.errors else "all paths failed")
  187. self.pump_task = asyncio.create_task(self._pump_local())
  188. async def _pump_local(self) -> None:
  189. try:
  190. while True:
  191. chunk = await self.reader.read(65536)
  192. if not chunk:
  193. break
  194. self.uplink_bytes += len(chunk)
  195. active = [path for path in self.paths if path.opened and not path.closed]
  196. if not active:
  197. break
  198. if self.uplink_bytes <= self.warmup_bytes:
  199. await asyncio.gather(*(path.send(chunk) for path in active), return_exceptions=True)
  200. else:
  201. if self.winner is None:
  202. await self.winner_event.wait()
  203. if self.winner:
  204. await self.winner.send(chunk)
  205. except Exception:
  206. pass
  207. finally:
  208. await self.close()
  209. async def handle_path(self, path: BasePath, event: str, payload: bytes | None) -> None:
  210. if self.closed:
  211. return
  212. if event == "status":
  213. self.status_count += 1
  214. if payload == b"ok":
  215. self.opened_count += 1
  216. elif payload is not None:
  217. self.errors.append(payload.decode("utf-8", errors="replace"))
  218. if self.opened_count > 0 or self.status_count == len(self.paths):
  219. self.open_event.set()
  220. return
  221. if event == "data":
  222. if self.winner is None:
  223. self.winner = path
  224. self._record_win(path)
  225. self.winner_event.set()
  226. if self.loser_grace_ms > 0:
  227. self.loser_close_task = asyncio.create_task(self._close_losers_after_grace(path))
  228. else:
  229. await self._close_losers(path)
  230. if path is self.winner and payload is not None:
  231. self.writer.write(payload)
  232. await self.writer.drain()
  233. return
  234. if event == "close":
  235. path.closed = True
  236. if self.winner is None:
  237. remaining = [candidate for candidate in self.paths if candidate.opened and not candidate.closed]
  238. if not remaining:
  239. await self.close()
  240. elif path is self.winner:
  241. await self.close()
  242. async def _close_losers(self, winner: BasePath) -> None:
  243. await asyncio.gather(*(path.close() for path in self.paths if path is not winner), return_exceptions=True)
  244. async def _close_losers_after_grace(self, winner: BasePath) -> None:
  245. await asyncio.sleep(self.loser_grace_ms / 1000)
  246. if not self.closed:
  247. await self._close_losers(winner)
  248. async def close(self) -> None:
  249. if self.closed:
  250. return
  251. self.closed = True
  252. print(f"[edge] session={self.session_id} closed target={self.target.host}:{self.target.port}")
  253. if self.pump_task and self.pump_task is not asyncio.current_task():
  254. self.pump_task.cancel()
  255. with contextlib.suppress(Exception):
  256. await self.pump_task
  257. if self.loser_close_task and self.loser_close_task is not asyncio.current_task():
  258. self.loser_close_task.cancel()
  259. with contextlib.suppress(Exception):
  260. await self.loser_close_task
  261. await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True)
  262. self.writer.close()
  263. with contextlib.suppress(Exception):
  264. await self.writer.wait_closed()
  265. class DirectUdpPath(BasePath):
  266. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], target: TargetAddress) -> None:
  267. super().__init__(name, on_frame)
  268. self.target = target
  269. self.socket: socket.socket | None = None
  270. self.read_task: asyncio.Task | None = None
  271. async def open(self, _target: TargetAddress) -> None:
  272. try:
  273. family = socket.AF_INET6 if self.target.family == socket.AF_INET6 else socket.AF_INET
  274. self.socket = socket.socket(family, socket.SOCK_DGRAM)
  275. self.socket.setblocking(False)
  276. await asyncio.get_running_loop().sock_connect(self.socket, (self.target.host, self.target.port))
  277. self.opened = True
  278. self.read_task = asyncio.create_task(self._pump())
  279. await self.on_frame(self, "status", b"ok")
  280. except Exception as exc:
  281. await self.on_frame(self, "status", str(exc).encode())
  282. async def _pump(self) -> None:
  283. assert self.socket is not None
  284. loop = asyncio.get_running_loop()
  285. try:
  286. while True:
  287. data = await loop.sock_recv(self.socket, 65535)
  288. if not data:
  289. break
  290. await self.on_frame(self, "data", data)
  291. except Exception:
  292. pass
  293. finally:
  294. await self.on_frame(self, "close", None)
  295. async def send(self, data: bytes) -> None:
  296. if self.closed or self.socket is None:
  297. return
  298. await asyncio.get_running_loop().sock_sendall(self.socket, data)
  299. async def close(self) -> None:
  300. if self.closed:
  301. return
  302. self.closed = True
  303. if self.read_task and self.read_task is not asyncio.current_task():
  304. self.read_task.cancel()
  305. with contextlib.suppress(Exception):
  306. await self.read_task
  307. if self.socket:
  308. self.socket.close()
  309. class RelayUdpPath(BasePath):
  310. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], connection: RelayConnection, session_id: int, stream_id: int, target: TargetAddress) -> None:
  311. super().__init__(name, on_frame)
  312. self.connection = connection
  313. self.session_id = session_id
  314. self.stream_id = stream_id
  315. self.target = target
  316. async def open(self, _target: TargetAddress) -> None:
  317. if self.connection.closed:
  318. await self.on_frame(self, "status", b"relay unavailable")
  319. return
  320. self.connection.bind(self.session_id, self.stream_id, self._handle_frame)
  321. self.opened = True
  322. await self.on_frame(self, "status", b"ok")
  323. async def _handle_frame(self, _conn: RelayConnection, frame: Frame) -> None:
  324. if frame.kind == UDP_RECV:
  325. await self.on_frame(self, "data", frame.payload)
  326. async def send(self, data: bytes) -> None:
  327. if self.closed or self.connection.closed:
  328. return
  329. meta = encode_json({"host": self.target.host, "port": self.target.port, "family": self.target.family})
  330. payload = meta + data
  331. await self.connection.send(Frame(UDP_SEND, self.session_id, self.stream_id, 0, len(meta), payload))
  332. async def close(self) -> None:
  333. if self.closed:
  334. return
  335. self.closed = True
  336. self.connection.unbind(self.session_id, self.stream_id)
  337. @dataclass
  338. class UdpFlow:
  339. flow_id: int
  340. source: PeerAddress
  341. target: TargetAddress
  342. send_response: Callable[[PeerAddress, bytes], Awaitable[None]]
  343. paths: list[BasePath]
  344. winner: BasePath | None = None
  345. closed: bool = False
  346. last_activity: float = 0.0
  347. async def start(self) -> None:
  348. await asyncio.gather(*(path.open(self.target) for path in self.paths), return_exceptions=True)
  349. async def send(self, payload: bytes) -> None:
  350. self.last_activity = asyncio.get_running_loop().time()
  351. active = [path for path in self.paths if path.opened and not path.closed]
  352. if self.winner is None:
  353. await asyncio.gather(*(path.send(payload) for path in active), return_exceptions=True)
  354. elif not self.winner.closed:
  355. await self.winner.send(payload)
  356. async def handle_path(self, path: BasePath, event: str, payload: bytes | None) -> None:
  357. self.last_activity = asyncio.get_running_loop().time()
  358. if event == "data" and payload is not None:
  359. if self.winner is None:
  360. self.winner = path
  361. print(f"[edge] udp flow={self.flow_id} winner={path.name} target={self.target.host}:{self.target.port}")
  362. if path is self.winner:
  363. await self.send_response(self.source, payload)
  364. if event == "close":
  365. path.closed = True
  366. async def close(self) -> None:
  367. if self.closed:
  368. return
  369. self.closed = True
  370. await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True)
  371. class TransparentUdpListener:
  372. def __init__(self, edge: "TransparentEdge", family: int, bind_host: str, port: int) -> None:
  373. self.edge = edge
  374. self.family = family
  375. self.bind_host = bind_host
  376. self.port = port
  377. self.socket: socket.socket | None = None
  378. def start(self) -> None:
  379. sock = socket.socket(self.family, socket.SOCK_DGRAM)
  380. sock.setblocking(False)
  381. if self.family == socket.AF_INET:
  382. sock.setsockopt(socket.SOL_IP, IP_RECVORIGDSTADDR, 1)
  383. sock.bind((self.bind_host, self.port))
  384. else:
  385. sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
  386. sock.setsockopt(socket.IPPROTO_IPV6, IPV6_RECVORIGDSTADDR, 1)
  387. sock.bind((self.bind_host, self.port, 0, 0))
  388. self.socket = sock
  389. asyncio.get_running_loop().add_reader(sock.fileno(), self._on_readable)
  390. print(f"[edge] transparent udp listening on {sock.getsockname()}")
  391. def _on_readable(self) -> None:
  392. assert self.socket is not None
  393. try:
  394. data, ancdata, _flags, src = self.socket.recvmsg(65535, 512)
  395. except BlockingIOError:
  396. return
  397. except Exception:
  398. return
  399. original = None
  400. for level, ctype, cdata in ancdata:
  401. if self.family == socket.AF_INET and level == socket.SOL_IP and ctype == IP_RECVORIGDSTADDR:
  402. original = parse_sockaddr(cdata)
  403. break
  404. if self.family == socket.AF_INET6 and level == socket.IPPROTO_IPV6 and ctype == IPV6_RECVORIGDSTADDR:
  405. original = parse_sockaddr(cdata)
  406. break
  407. if original is None:
  408. return
  409. if self.family == socket.AF_INET:
  410. source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET)
  411. else:
  412. source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET6)
  413. if original.port == self.port and (original.host in ("127.0.0.1", "::1") or original.host == self.bind_host):
  414. return
  415. asyncio.create_task(self.edge.handle_udp_datagram(source, original, data, self))
  416. async def send_response(self, source: PeerAddress, payload: bytes) -> None:
  417. assert self.socket is not None
  418. if source.family == socket.AF_INET:
  419. self.socket.sendto(payload, (source.host, source.port))
  420. else:
  421. self.socket.sendto(payload, (source.host, source.port, 0, 0))
  422. async def close(self) -> None:
  423. if self.socket is None:
  424. return
  425. asyncio.get_running_loop().remove_reader(self.socket.fileno())
  426. self.socket.close()
  427. self.socket = None
  428. class TransparentEdge:
  429. def __init__(self, listen_host: str, listen_port: int, config: Config, enable_udp: bool = False, kernel_mode: str = "auto") -> None:
  430. self.listen_host = listen_host
  431. self.listen_port = listen_port
  432. self.config = config
  433. self.enable_udp = enable_udp
  434. self.kernel_mode = self._resolve_kernel_mode(kernel_mode, config.kernel_mode)
  435. self.manager = RelayManager(config)
  436. self.session_ids = itertools.count(1)
  437. self.stream_ids = itertools.count(1)
  438. self.udp_listeners: list[TransparentUdpListener] = []
  439. self.udp_flows: dict[tuple[PeerAddress, TargetAddress], UdpFlow] = {}
  440. self.udp_flow_ids = itertools.count(1)
  441. self.udp_gc_task: asyncio.Task | None = None
  442. self.tcp_win_counts: dict[str, int] = {}
  443. self.tcp_target_wins: dict[tuple[str, int], dict[str, int]] = {}
  444. def _resolve_kernel_mode(self, cli_kernel_mode: str, config_kernel_mode: str) -> str:
  445. mode = cli_kernel_mode if cli_kernel_mode != "auto" else config_kernel_mode
  446. if mode != "auto":
  447. return mode
  448. try:
  449. if Path("/etc/os-release").exists() and 'VERSION_ID="24' in Path("/etc/os-release").read_text(errors="ignore"):
  450. return "24"
  451. except Exception:
  452. pass
  453. try:
  454. release = os.uname().release
  455. if release.startswith("6."):
  456. return "24"
  457. except Exception:
  458. pass
  459. return "20"
  460. async def start(self) -> None:
  461. if self.kernel_mode == "24":
  462. if self.config.direct_open_timeout == 10.0:
  463. self.config.direct_open_timeout = 6.0
  464. if self.config.relay_open_timeout == 10.0:
  465. self.config.relay_open_timeout = 6.0
  466. if self.config.tcp_connect_happy_eyeballs_delay is None:
  467. self.config.tcp_connect_happy_eyeballs_delay = 0.25
  468. await self.manager.start()
  469. print(f"[edge] kernel_mode={self.kernel_mode} relay snapshot: {self.manager.snapshot()}")
  470. server4 = await asyncio.start_server(self._accept, self.listen_host, self.listen_port, family=socket.AF_INET)
  471. sockets = [str(sock.getsockname()) for sock in server4.sockets or []]
  472. server6 = None
  473. if self.listen_host in ("::", "::1", "0.0.0.0", "127.0.0.1"):
  474. host6 = "::1" if self.listen_host == "127.0.0.1" else "::"
  475. try:
  476. server6 = await asyncio.start_server(self._accept, host6, self.listen_port, family=socket.AF_INET6)
  477. sockets.extend(str(sock.getsockname()) for sock in server6.sockets or [])
  478. except Exception as exc:
  479. print(f"[edge] ipv6 tcp listener skipped: {exc!r}")
  480. if self.enable_udp:
  481. self._start_udp_listeners()
  482. self.udp_gc_task = asyncio.create_task(self._gc_udp_flows())
  483. print(f"[edge] transparent tcp listening on {', '.join(sockets)}")
  484. if server6 is None:
  485. async with server4:
  486. await server4.serve_forever()
  487. else:
  488. async with server4, server6:
  489. await asyncio.gather(server4.serve_forever(), server6.serve_forever())
  490. def _start_udp_listeners(self) -> None:
  491. binds = []
  492. if self.listen_host == "127.0.0.1":
  493. binds = [(socket.AF_INET, "127.0.0.1"), (socket.AF_INET6, "::1")]
  494. elif self.listen_host == "0.0.0.0":
  495. binds = [(socket.AF_INET, "0.0.0.0"), (socket.AF_INET6, "::")]
  496. else:
  497. family = socket.AF_INET6 if ":" in self.listen_host else socket.AF_INET
  498. binds = [(family, self.listen_host)]
  499. for family, host in binds:
  500. try:
  501. listener = TransparentUdpListener(self, family, host, self.listen_port)
  502. listener.start()
  503. self.udp_listeners.append(listener)
  504. except Exception as exc:
  505. print(f"[edge] udp listener skipped family={family} host={host} error={exc!r}")
  506. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  507. peer = writer.get_extra_info("peername")
  508. try:
  509. target = self._get_original_dst(writer)
  510. session_id = next(self.session_ids)
  511. session = TransparentSession(session_id=session_id, target=target, reader=reader, writer=writer, paths=[], warmup_bytes=self.config.tcp_warmup_bytes, loser_grace_ms=self.config.tcp_loser_grace_ms, stats=self.tcp_win_counts, target_stats=self.tcp_target_wins)
  512. paths: list[BasePath] = [DirectTcpPath(name="direct", on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload), open_timeout=self.config.direct_open_timeout, happy_eyeballs_delay=self.config.tcp_connect_happy_eyeballs_delay, tcp_nodelay=self.config.relay_tcp_nodelay)]
  513. for connection in self.manager.available():
  514. stream_id = next(self.stream_ids)
  515. paths.append(RelayTcpPath(name=connection.node.name, on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload), connection=connection, session_id=session_id, stream_id=stream_id))
  516. session.paths = paths
  517. print(f"[edge] accept peer={peer} session={session_id} target={target.host}:{target.port} candidates={[path.name for path in paths]}")
  518. await session.start()
  519. except Exception as exc:
  520. print(f"[edge] accept failed peer={peer} error={exc!r}")
  521. writer.close()
  522. with contextlib.suppress(Exception):
  523. await writer.wait_closed()
  524. async def _handle_tcp_session(self, session: TransparentSession, path: BasePath, event: str, payload: bytes | None) -> None:
  525. await session.handle_path(path, event, payload)
  526. def _get_original_dst(self, writer: asyncio.StreamWriter) -> TargetAddress:
  527. sock = writer.get_extra_info("socket")
  528. if sock is None:
  529. raise RuntimeError("socket unavailable")
  530. family = sock.family
  531. if family == socket.AF_INET:
  532. raw = sock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, 16)
  533. return parse_sockaddr(raw)
  534. if family == socket.AF_INET6:
  535. raw = sock.getsockopt(socket.IPPROTO_IPV6, IP6T_SO_ORIGINAL_DST, 128)
  536. return parse_sockaddr(raw)
  537. raise RuntimeError(f"unsupported socket family={family}")
  538. async def handle_udp_datagram(self, source: PeerAddress, target: TargetAddress, payload: bytes, listener: TransparentUdpListener) -> None:
  539. if not self.enable_udp:
  540. return
  541. if target.port == self.listen_port and target.host in ("127.0.0.1", "::1", self.listen_host):
  542. return
  543. key = (source, target)
  544. flow = self.udp_flows.get(key)
  545. if flow is None:
  546. flow_id = next(self.udp_flow_ids)
  547. paths: list[BasePath] = [DirectUdpPath(name="direct", on_frame=lambda path, event, data, fid=flow_id: self._handle_udp_path(fid, path, event, data), target=target)]
  548. for connection in self.manager.available():
  549. stream_id = next(self.stream_ids)
  550. paths.append(RelayUdpPath(name=connection.node.name, on_frame=lambda path, event, data, fid=flow_id: self._handle_udp_path(fid, path, event, data), connection=connection, session_id=flow_id, stream_id=stream_id, target=target))
  551. flow = UdpFlow(flow_id=flow_id, source=source, target=target, send_response=listener.send_response, paths=paths)
  552. self.udp_flows[key] = flow
  553. print(f"[edge] udp flow={flow_id} target={target.host}:{target.port} candidates={[path.name for path in paths]}")
  554. await flow.start()
  555. await flow.send(payload)
  556. async def _handle_udp_path(self, flow_id: int, path: BasePath, event: str, payload: bytes | None) -> None:
  557. for flow in list(self.udp_flows.values()):
  558. if flow.flow_id == flow_id:
  559. await flow.handle_path(path, event, payload)
  560. break
  561. async def _gc_udp_flows(self) -> None:
  562. loop = asyncio.get_running_loop()
  563. while True:
  564. await asyncio.sleep(30)
  565. now = loop.time()
  566. stale = [key for key, flow in self.udp_flows.items() if flow.last_activity and now - flow.last_activity > 120]
  567. for key in stale:
  568. flow = self.udp_flows.pop(key, None)
  569. if flow:
  570. await flow.close()