transparent_edge.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694
  1. from __future__ import annotations
  2. from pathlib import Path
  3. import asyncio
  4. import contextlib
  5. import itertools
  6. import os
  7. import socket
  8. import struct
  9. from dataclasses import dataclass, field
  10. from typing import Awaitable, Callable
  11. from .config import Config
  12. from .protocol import STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, encode_json
  13. from .relay_client import RelayConnection, RelayManager
  14. SO_ORIGINAL_DST = 80
  15. IP6T_SO_ORIGINAL_DST = 80
  16. IP_RECVORIGDSTADDR = 20
  17. IPV6_RECVORIGDSTADDR = 74
  18. @dataclass(frozen=True)
  19. class TargetAddress:
  20. host: str
  21. port: int
  22. family: int
  23. @dataclass(frozen=True)
  24. class PeerAddress:
  25. host: str
  26. port: int
  27. family: int
  28. def parse_sockaddr(raw: bytes) -> TargetAddress:
  29. if len(raw) < 8:
  30. raise ValueError("invalid transparent destination payload")
  31. family = struct.unpack_from("=H", raw, 0)[0]
  32. port = struct.unpack_from("!H", raw, 2)[0]
  33. if family == socket.AF_INET:
  34. host = socket.inet_ntoa(raw[4:8])
  35. return TargetAddress(host=host, port=port, family=family)
  36. if family == socket.AF_INET6:
  37. if len(raw) < 28:
  38. raise ValueError("invalid IPv6 transparent destination payload")
  39. host = socket.inet_ntop(socket.AF_INET6, raw[8:24])
  40. return TargetAddress(host=host, port=port, family=family)
  41. raise ValueError(f"unsupported family={family}")
  42. def winner_group(name: str) -> str:
  43. return "direct" if name.startswith("direct") else name
  44. def grouped_total(stats: dict[str, int], group: str) -> int:
  45. return sum(count for name, count in stats.items() if winner_group(name) == group)
  46. class BasePath:
  47. def __init__(self, name: str, on_frame: Callable[["BasePath", str, bytes | None], Awaitable[None]]) -> None:
  48. self.name = name
  49. self.on_frame = on_frame
  50. self.opened = False
  51. self.closed = False
  52. async def open(self, target: TargetAddress) -> None:
  53. raise NotImplementedError
  54. async def send(self, data: bytes) -> None:
  55. raise NotImplementedError
  56. async def close(self) -> None:
  57. raise NotImplementedError
  58. class DirectTcpPath(BasePath):
  59. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], open_timeout: float, happy_eyeballs_delay: float | None, tcp_nodelay: bool = True) -> None:
  60. super().__init__(name, on_frame)
  61. self.reader: asyncio.StreamReader | None = None
  62. self.writer: asyncio.StreamWriter | None = None
  63. self.pump_task: asyncio.Task | None = None
  64. self.open_timeout = open_timeout
  65. self.happy_eyeballs_delay = happy_eyeballs_delay
  66. self.tcp_nodelay = tcp_nodelay
  67. async def open(self, target: TargetAddress) -> None:
  68. try:
  69. family = socket.AF_INET6 if target.family == socket.AF_INET6 else socket.AF_INET
  70. kwargs = {"host": target.host, "port": target.port, "family": family}
  71. if self.happy_eyeballs_delay is not None:
  72. kwargs["happy_eyeballs_delay"] = self.happy_eyeballs_delay
  73. self.reader, self.writer = await asyncio.wait_for(asyncio.open_connection(**kwargs), timeout=self.open_timeout)
  74. sock = self.writer.get_extra_info("socket")
  75. if sock is not None and self.tcp_nodelay:
  76. with contextlib.suppress(OSError):
  77. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  78. self.opened = True
  79. self.pump_task = asyncio.create_task(self._pump())
  80. await self.on_frame(self, "status", b"ok")
  81. except Exception as exc:
  82. await self.on_frame(self, "status", str(exc).encode())
  83. async def _pump(self) -> None:
  84. assert self.reader is not None
  85. try:
  86. while True:
  87. chunk = await self.reader.read(65536)
  88. if not chunk:
  89. break
  90. await self.on_frame(self, "data", chunk)
  91. except Exception:
  92. pass
  93. finally:
  94. await self.on_frame(self, "close", None)
  95. async def send(self, data: bytes) -> None:
  96. if self.closed or self.writer is None:
  97. return
  98. self.writer.write(data)
  99. await self.writer.drain()
  100. async def close(self) -> None:
  101. if self.closed:
  102. return
  103. self.closed = True
  104. if self.pump_task and self.pump_task is not asyncio.current_task():
  105. self.pump_task.cancel()
  106. with contextlib.suppress(Exception):
  107. await self.pump_task
  108. if self.writer:
  109. self.writer.close()
  110. with contextlib.suppress(Exception):
  111. await self.writer.wait_closed()
  112. class RelayTcpPath(BasePath):
  113. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], connection: RelayConnection, session_id: int, stream_id: int) -> None:
  114. super().__init__(name, on_frame)
  115. self.connection = connection
  116. self.session_id = session_id
  117. self.stream_id = stream_id
  118. async def open(self, target: TargetAddress) -> None:
  119. if self.connection.closed:
  120. await self.on_frame(self, "status", b"relay unavailable")
  121. return
  122. self.connection.bind(self.session_id, self.stream_id, self._handle_frame)
  123. try:
  124. await self.connection.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, encode_json({"host": target.host, "port": target.port, "family": target.family})))
  125. except Exception as exc:
  126. self.connection.unbind(self.session_id, self.stream_id)
  127. await self.on_frame(self, "status", str(exc).encode())
  128. async def _handle_frame(self, _conn: RelayConnection, frame: Frame) -> None:
  129. if frame.kind == TCP_STATUS:
  130. if frame.packet_id == STATUS_OK:
  131. self.opened = True
  132. await self.on_frame(self, "status", b"ok")
  133. else:
  134. await self.on_frame(self, "status", frame.payload)
  135. return
  136. if frame.kind == TCP_DATA:
  137. await self.on_frame(self, "data", frame.payload)
  138. return
  139. if frame.kind == TCP_CLOSE:
  140. await self.on_frame(self, "close", None)
  141. async def send(self, data: bytes) -> None:
  142. if self.closed or self.connection.closed:
  143. return
  144. await self.connection.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, data))
  145. async def close(self) -> None:
  146. if self.closed:
  147. return
  148. self.closed = True
  149. self.connection.unbind(self.session_id, self.stream_id)
  150. if not self.connection.closed:
  151. with contextlib.suppress(Exception):
  152. await self.connection.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b""))
  153. @dataclass
  154. class TransparentSession:
  155. session_id: int
  156. target: TargetAddress
  157. reader: asyncio.StreamReader
  158. writer: asyncio.StreamWriter
  159. paths: list[BasePath]
  160. warmup_bytes: int
  161. loser_grace_ms: int
  162. stats: dict[str, int]
  163. target_stats: dict[tuple[str, int], dict[str, int]]
  164. family_stats: dict[str, dict[str, int]]
  165. opened_count: int = 0
  166. status_count: int = 0
  167. errors: list[str] = field(default_factory=list)
  168. winner: BasePath | None = None
  169. uplink_bytes: int = 0
  170. open_event: asyncio.Event = field(default_factory=asyncio.Event)
  171. winner_event: asyncio.Event = field(default_factory=asyncio.Event)
  172. closed: bool = False
  173. pump_task: asyncio.Task | None = None
  174. loser_close_task: asyncio.Task | None = None
  175. def _record_win(self, winner: BasePath) -> None:
  176. self.stats[winner.name] = self.stats.get(winner.name, 0) + 1
  177. key = (self.target.host, self.target.port)
  178. target_stats = self.target_stats.setdefault(key, {})
  179. target_stats[winner.name] = target_stats.get(winner.name, 0) + 1
  180. family_key = "ipv6" if self.target.family == socket.AF_INET6 else "ipv4"
  181. family_stats = self.family_stats.setdefault(family_key, {})
  182. family_stats[winner.name] = family_stats.get(winner.name, 0) + 1
  183. direct_wins = grouped_total(self.stats, "direct")
  184. relay_wins = sum(count for name, count in self.stats.items() if winner_group(name) != "direct")
  185. target_direct = grouped_total(target_stats, "direct")
  186. target_relay = sum(count for name, count in target_stats.items() if winner_group(name) != "direct")
  187. family_direct = grouped_total(family_stats, "direct")
  188. family_relay = sum(count for name, count in family_stats.items() if winner_group(name) != "direct")
  189. relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.stats.items()) if winner_group(name) != "direct") or "none"
  190. target_detail = ", ".join(f"{name}={count}" for name, count in sorted(target_stats.items()) if winner_group(name) != "direct") or "none"
  191. target_pref = "relay" if target_relay > target_direct else "direct"
  192. family_pref = "relay" if family_relay > family_direct else "direct"
  193. print(f"[edge] tcp win session={self.session_id} target={self.target.host}:{self.target.port} winner={winner.name} direct={direct_wins} relay={relay_wins} relay_breakdown={relay_detail} target_pref={target_pref} target_direct={target_direct} target_relay={target_relay} target_breakdown={target_detail} family_pref={family_pref} family={family_key} family_direct={family_direct} family_relay={family_relay}")
  194. async def start(self) -> None:
  195. await asyncio.gather(*(path.open(self.target) for path in self.paths), return_exceptions=True)
  196. await asyncio.wait_for(self.open_event.wait(), timeout=15)
  197. if self.opened_count == 0:
  198. raise ConnectionError(self.errors[0] if self.errors else "all paths failed")
  199. self.pump_task = asyncio.create_task(self._pump_local())
  200. async def _pump_local(self) -> None:
  201. try:
  202. while True:
  203. chunk = await self.reader.read(65536)
  204. if not chunk:
  205. break
  206. self.uplink_bytes += len(chunk)
  207. active = [path for path in self.paths if path.opened and not path.closed]
  208. if not active:
  209. break
  210. if self.uplink_bytes <= self.warmup_bytes:
  211. await asyncio.gather(*(path.send(chunk) for path in active), return_exceptions=True)
  212. else:
  213. if self.winner is None:
  214. await self.winner_event.wait()
  215. if self.winner:
  216. await self.winner.send(chunk)
  217. except Exception:
  218. pass
  219. finally:
  220. await self.close()
  221. async def handle_path(self, path: BasePath, event: str, payload: bytes | None) -> None:
  222. if self.closed:
  223. return
  224. if event == "status":
  225. self.status_count += 1
  226. if payload == b"ok":
  227. self.opened_count += 1
  228. elif payload is not None:
  229. self.errors.append(payload.decode("utf-8", errors="replace"))
  230. if self.opened_count > 0 or self.status_count == len(self.paths):
  231. self.open_event.set()
  232. return
  233. if event == "data":
  234. if self.winner is None:
  235. self.winner = path
  236. self._record_win(path)
  237. self.winner_event.set()
  238. if self.loser_grace_ms > 0:
  239. self.loser_close_task = asyncio.create_task(self._close_losers_after_grace(path))
  240. else:
  241. await self._close_losers(path)
  242. if path is self.winner and payload is not None:
  243. self.writer.write(payload)
  244. await self.writer.drain()
  245. return
  246. if event == "close":
  247. path.closed = True
  248. if self.winner is None:
  249. remaining = [candidate for candidate in self.paths if candidate.opened and not candidate.closed]
  250. if not remaining:
  251. await self.close()
  252. elif path is self.winner:
  253. await self.close()
  254. async def _close_losers(self, winner: BasePath) -> None:
  255. await asyncio.gather(*(path.close() for path in self.paths if path is not winner), return_exceptions=True)
  256. async def _close_losers_after_grace(self, winner: BasePath) -> None:
  257. await asyncio.sleep(self.loser_grace_ms / 1000)
  258. if not self.closed:
  259. await self._close_losers(winner)
  260. async def close(self) -> None:
  261. if self.closed:
  262. return
  263. self.closed = True
  264. print(f"[edge] session={self.session_id} closed target={self.target.host}:{self.target.port}")
  265. if self.pump_task and self.pump_task is not asyncio.current_task():
  266. self.pump_task.cancel()
  267. with contextlib.suppress(Exception):
  268. await self.pump_task
  269. if self.loser_close_task and self.loser_close_task is not asyncio.current_task():
  270. self.loser_close_task.cancel()
  271. with contextlib.suppress(Exception):
  272. await self.loser_close_task
  273. await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True)
  274. self.writer.close()
  275. with contextlib.suppress(Exception):
  276. await self.writer.wait_closed()
  277. class DirectUdpPath(BasePath):
  278. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], target: TargetAddress) -> None:
  279. super().__init__(name, on_frame)
  280. self.target = target
  281. self.socket: socket.socket | None = None
  282. self.read_task: asyncio.Task | None = None
  283. async def open(self, _target: TargetAddress) -> None:
  284. try:
  285. family = socket.AF_INET6 if self.target.family == socket.AF_INET6 else socket.AF_INET
  286. self.socket = socket.socket(family, socket.SOCK_DGRAM)
  287. self.socket.setblocking(False)
  288. await asyncio.get_running_loop().sock_connect(self.socket, (self.target.host, self.target.port))
  289. self.opened = True
  290. self.read_task = asyncio.create_task(self._pump())
  291. await self.on_frame(self, "status", b"ok")
  292. except Exception as exc:
  293. await self.on_frame(self, "status", str(exc).encode())
  294. async def _pump(self) -> None:
  295. assert self.socket is not None
  296. loop = asyncio.get_running_loop()
  297. try:
  298. while True:
  299. data = await loop.sock_recv(self.socket, 65535)
  300. if not data:
  301. break
  302. await self.on_frame(self, "data", data)
  303. except Exception:
  304. pass
  305. finally:
  306. await self.on_frame(self, "close", None)
  307. async def send(self, data: bytes) -> None:
  308. if self.closed or self.socket is None:
  309. return
  310. await asyncio.get_running_loop().sock_sendall(self.socket, data)
  311. async def close(self) -> None:
  312. if self.closed:
  313. return
  314. self.closed = True
  315. if self.read_task and self.read_task is not asyncio.current_task():
  316. self.read_task.cancel()
  317. with contextlib.suppress(Exception):
  318. await self.read_task
  319. if self.socket:
  320. self.socket.close()
  321. class RelayUdpPath(BasePath):
  322. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], connection: RelayConnection, session_id: int, stream_id: int, target: TargetAddress) -> None:
  323. super().__init__(name, on_frame)
  324. self.connection = connection
  325. self.session_id = session_id
  326. self.stream_id = stream_id
  327. self.target = target
  328. async def open(self, _target: TargetAddress) -> None:
  329. if self.connection.closed:
  330. await self.on_frame(self, "status", b"relay unavailable")
  331. return
  332. self.connection.bind(self.session_id, self.stream_id, self._handle_frame)
  333. try:
  334. self.opened = True
  335. await self.on_frame(self, "status", b"ok")
  336. except Exception:
  337. self.connection.unbind(self.session_id, self.stream_id)
  338. self.closed = True
  339. raise
  340. async def _handle_frame(self, _conn: RelayConnection, frame: Frame) -> None:
  341. if frame.kind == UDP_RECV:
  342. await self.on_frame(self, "data", frame.payload)
  343. async def send(self, data: bytes) -> None:
  344. if self.closed or self.connection.closed:
  345. return
  346. meta = encode_json({"host": self.target.host, "port": self.target.port, "family": self.target.family})
  347. payload = meta + data
  348. await self.connection.send(Frame(UDP_SEND, self.session_id, self.stream_id, 0, len(meta), payload))
  349. async def close(self) -> None:
  350. if self.closed:
  351. return
  352. self.closed = True
  353. self.connection.unbind(self.session_id, self.stream_id)
  354. @dataclass
  355. class UdpFlow:
  356. flow_id: int
  357. source: PeerAddress
  358. target: TargetAddress
  359. send_response: Callable[[PeerAddress, bytes], Awaitable[None]]
  360. paths: list[BasePath]
  361. winner: BasePath | None = None
  362. closed: bool = False
  363. last_activity: float = 0.0
  364. async def start(self) -> None:
  365. await asyncio.gather(*(path.open(self.target) for path in self.paths), return_exceptions=True)
  366. async def send(self, payload: bytes) -> None:
  367. self.last_activity = asyncio.get_running_loop().time()
  368. active = [path for path in self.paths if path.opened and not path.closed]
  369. if self.winner is None:
  370. await asyncio.gather(*(path.send(payload) for path in active), return_exceptions=True)
  371. elif not self.winner.closed:
  372. await self.winner.send(payload)
  373. async def handle_path(self, path: BasePath, event: str, payload: bytes | None) -> None:
  374. self.last_activity = asyncio.get_running_loop().time()
  375. if event == "data" and payload is not None:
  376. if self.winner is None:
  377. self.winner = path
  378. print(f"[edge] udp flow={self.flow_id} winner={path.name} target={self.target.host}:{self.target.port}")
  379. if path is self.winner:
  380. await self.send_response(self.source, payload)
  381. if event == "close":
  382. path.closed = True
  383. if path is self.winner:
  384. self.winner = None
  385. async def close(self) -> None:
  386. if self.closed:
  387. return
  388. self.closed = True
  389. await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True)
  390. class TransparentUdpListener:
  391. def __init__(self, edge: "TransparentEdge", family: int, bind_host: str, port: int) -> None:
  392. self.edge = edge
  393. self.family = family
  394. self.bind_host = bind_host
  395. self.port = port
  396. self.socket: socket.socket | None = None
  397. def start(self) -> None:
  398. sock = socket.socket(self.family, socket.SOCK_DGRAM)
  399. sock.setblocking(False)
  400. if self.family == socket.AF_INET:
  401. sock.setsockopt(socket.SOL_IP, IP_RECVORIGDSTADDR, 1)
  402. sock.bind((self.bind_host, self.port))
  403. else:
  404. sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
  405. sock.setsockopt(socket.IPPROTO_IPV6, IPV6_RECVORIGDSTADDR, 1)
  406. sock.bind((self.bind_host, self.port, 0, 0))
  407. self.socket = sock
  408. asyncio.get_running_loop().add_reader(sock.fileno(), self._on_readable)
  409. print(f"[edge] transparent udp listening on {sock.getsockname()}")
  410. def _on_readable(self) -> None:
  411. assert self.socket is not None
  412. try:
  413. data, ancdata, _flags, src = self.socket.recvmsg(65535, 512)
  414. except BlockingIOError:
  415. return
  416. except Exception:
  417. return
  418. original = None
  419. for level, ctype, cdata in ancdata:
  420. if self.family == socket.AF_INET and level == socket.SOL_IP and ctype == IP_RECVORIGDSTADDR:
  421. original = parse_sockaddr(cdata)
  422. break
  423. if self.family == socket.AF_INET6 and level == socket.IPPROTO_IPV6 and ctype == IPV6_RECVORIGDSTADDR:
  424. original = parse_sockaddr(cdata)
  425. break
  426. if original is None:
  427. return
  428. if self.family == socket.AF_INET:
  429. source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET)
  430. else:
  431. source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET6)
  432. if original.port == self.port and (original.host in ("127.0.0.1", "::1") or original.host == self.bind_host):
  433. return
  434. asyncio.create_task(self.edge.handle_udp_datagram(source, original, data, self))
  435. async def send_response(self, source: PeerAddress, payload: bytes) -> None:
  436. assert self.socket is not None
  437. if source.family == socket.AF_INET:
  438. self.socket.sendto(payload, (source.host, source.port))
  439. else:
  440. self.socket.sendto(payload, (source.host, source.port, 0, 0))
  441. async def close(self) -> None:
  442. if self.socket is None:
  443. return
  444. asyncio.get_running_loop().remove_reader(self.socket.fileno())
  445. self.socket.close()
  446. self.socket = None
  447. class TransparentEdge:
  448. def __init__(self, listen_host: str, listen_port: int, config: Config, enable_udp: bool = False, kernel_mode: str = "auto") -> None:
  449. self.listen_host = listen_host
  450. self.listen_port = listen_port
  451. self.config = config
  452. self.enable_udp = enable_udp
  453. self.kernel_mode = self._resolve_kernel_mode(kernel_mode, config.kernel_mode)
  454. self.manager = RelayManager(config)
  455. self.session_ids = itertools.count(1)
  456. self.stream_ids = itertools.count(1)
  457. self.udp_listeners: list[TransparentUdpListener] = []
  458. self.udp_flows: dict[tuple[PeerAddress, TargetAddress], UdpFlow] = {}
  459. self.udp_flow_ids = itertools.count(1)
  460. self.udp_gc_task: asyncio.Task | None = None
  461. self.tcp_win_counts: dict[str, int] = {}
  462. self.tcp_target_wins: dict[tuple[str, int], dict[str, int]] = {}
  463. self.tcp_family_wins: dict[str, dict[str, int]] = {"ipv4": {}, "ipv6": {}}
  464. def _resolve_kernel_mode(self, cli_kernel_mode: str, config_kernel_mode: str) -> str:
  465. mode = cli_kernel_mode if cli_kernel_mode != "auto" else config_kernel_mode
  466. if mode != "auto":
  467. return mode
  468. try:
  469. if Path("/etc/os-release").exists() and 'VERSION_ID="24' in Path("/etc/os-release").read_text(errors="ignore"):
  470. return "24"
  471. except Exception:
  472. pass
  473. try:
  474. release = os.uname().release
  475. if release.startswith("6."):
  476. return "24"
  477. except Exception:
  478. pass
  479. return "20"
  480. async def start(self) -> None:
  481. if self.kernel_mode == "24":
  482. if self.config.direct_open_timeout == 10.0:
  483. self.config.direct_open_timeout = 6.0
  484. if self.config.relay_open_timeout == 10.0:
  485. self.config.relay_open_timeout = 6.0
  486. if self.config.tcp_connect_happy_eyeballs_delay is None:
  487. self.config.tcp_connect_happy_eyeballs_delay = 0.25
  488. await self.manager.start()
  489. print(f"[edge] kernel_mode={self.kernel_mode} relay snapshot: {self.manager.snapshot()}")
  490. server4 = await asyncio.start_server(self._accept, self.listen_host, self.listen_port, family=socket.AF_INET)
  491. sockets = [str(sock.getsockname()) for sock in server4.sockets or []]
  492. server6 = None
  493. if self.listen_host in ("::", "::1", "0.0.0.0", "127.0.0.1"):
  494. host6 = "::1" if self.listen_host == "127.0.0.1" else "::"
  495. try:
  496. server6 = await asyncio.start_server(self._accept, host6, self.listen_port, family=socket.AF_INET6)
  497. sockets.extend(str(sock.getsockname()) for sock in server6.sockets or [])
  498. except Exception as exc:
  499. print(f"[edge] ipv6 tcp listener skipped: {exc!r}")
  500. if self.enable_udp:
  501. self._start_udp_listeners()
  502. self.udp_gc_task = asyncio.create_task(self._gc_udp_flows())
  503. print(f"[edge] transparent tcp listening on {', '.join(sockets)}")
  504. if server6 is None:
  505. async with server4:
  506. await server4.serve_forever()
  507. else:
  508. async with server4, server6:
  509. await asyncio.gather(server4.serve_forever(), server6.serve_forever())
  510. def _direct_redundancy_for_target(self, target: TargetAddress) -> int:
  511. base = self.config.direct_redundancy
  512. if target.family == socket.AF_INET6 and self.config.direct_redundancy_v6 is not None:
  513. base = self.config.direct_redundancy_v6
  514. elif target.family == socket.AF_INET and self.config.direct_redundancy_v4 is not None:
  515. base = self.config.direct_redundancy_v4
  516. base = max(1, min(base, self.config.direct_max_redundancy))
  517. target_stats = self.tcp_target_wins.get((target.host, target.port), {})
  518. family_key = "ipv6" if target.family == socket.AF_INET6 else "ipv4"
  519. family_stats = self.tcp_family_wins.get(family_key, {})
  520. target_prefers_relay = sum(count for name, count in target_stats.items() if winner_group(name) != "direct") > grouped_total(target_stats, "direct")
  521. family_prefers_relay = sum(count for name, count in family_stats.items() if winner_group(name) != "direct") > grouped_total(family_stats, "direct")
  522. if target_prefers_relay or family_prefers_relay:
  523. return min(self.config.direct_max_redundancy, base + 1)
  524. return base
  525. def _build_direct_paths(self, session: TransparentSession) -> list[BasePath]:
  526. count = self._direct_redundancy_for_target(session.target)
  527. return [
  528. DirectTcpPath(
  529. name=f"direct-{index + 1}" if count > 1 else "direct",
  530. on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload),
  531. open_timeout=self.config.direct_open_timeout,
  532. happy_eyeballs_delay=self.config.tcp_connect_happy_eyeballs_delay,
  533. tcp_nodelay=self.config.relay_tcp_nodelay,
  534. )
  535. for index in range(count)
  536. ]
  537. def _start_udp_listeners(self) -> None:
  538. binds = []
  539. if self.listen_host == "127.0.0.1":
  540. binds = [(socket.AF_INET, "127.0.0.1"), (socket.AF_INET6, "::1")]
  541. elif self.listen_host == "0.0.0.0":
  542. binds = [(socket.AF_INET, "0.0.0.0"), (socket.AF_INET6, "::")]
  543. else:
  544. family = socket.AF_INET6 if ":" in self.listen_host else socket.AF_INET
  545. binds = [(family, self.listen_host)]
  546. for family, host in binds:
  547. try:
  548. listener = TransparentUdpListener(self, family, host, self.listen_port)
  549. listener.start()
  550. self.udp_listeners.append(listener)
  551. except Exception as exc:
  552. print(f"[edge] udp listener skipped family={family} host={host} error={exc!r}")
  553. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  554. peer = writer.get_extra_info("peername")
  555. try:
  556. target = self._get_original_dst(writer)
  557. session_id = next(self.session_ids)
  558. session = TransparentSession(session_id=session_id, target=target, reader=reader, writer=writer, paths=[], warmup_bytes=self.config.tcp_warmup_bytes, loser_grace_ms=self.config.tcp_loser_grace_ms, stats=self.tcp_win_counts, target_stats=self.tcp_target_wins, family_stats=self.tcp_family_wins)
  559. paths: list[BasePath] = self._build_direct_paths(session)
  560. for connection in self.manager.available():
  561. stream_id = next(self.stream_ids)
  562. paths.append(RelayTcpPath(name=connection.node.name, on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload), connection=connection, session_id=session_id, stream_id=stream_id))
  563. session.paths = paths
  564. print(f"[edge] accept peer={peer} session={session_id} target={target.host}:{target.port} candidates={[path.name for path in paths]}")
  565. await session.start()
  566. except Exception as exc:
  567. print(f"[edge] accept failed peer={peer} error={exc!r}")
  568. writer.close()
  569. with contextlib.suppress(Exception):
  570. await writer.wait_closed()
  571. async def _handle_tcp_session(self, session: TransparentSession, path: BasePath, event: str, payload: bytes | None) -> None:
  572. await session.handle_path(path, event, payload)
  573. def _get_original_dst(self, writer: asyncio.StreamWriter) -> TargetAddress:
  574. sock = writer.get_extra_info("socket")
  575. if sock is None:
  576. raise RuntimeError("socket unavailable")
  577. family = sock.family
  578. if family == socket.AF_INET:
  579. raw = sock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, 16)
  580. return parse_sockaddr(raw)
  581. if family == socket.AF_INET6:
  582. raw = sock.getsockopt(socket.IPPROTO_IPV6, IP6T_SO_ORIGINAL_DST, 128)
  583. return parse_sockaddr(raw)
  584. raise RuntimeError(f"unsupported socket family={family}")
  585. async def handle_udp_datagram(self, source: PeerAddress, target: TargetAddress, payload: bytes, listener: TransparentUdpListener) -> None:
  586. if not self.enable_udp:
  587. return
  588. if target.port == self.listen_port and target.host in ("127.0.0.1", "::1", self.listen_host):
  589. return
  590. key = (source, target)
  591. flow = self.udp_flows.get(key)
  592. if flow is None:
  593. flow_id = next(self.udp_flow_ids)
  594. paths: list[BasePath] = [DirectUdpPath(name="direct", on_frame=lambda path, event, data, fid=flow_id: self._handle_udp_path(fid, path, event, data), target=target)]
  595. for connection in self.manager.available():
  596. stream_id = next(self.stream_ids)
  597. paths.append(RelayUdpPath(name=connection.node.name, on_frame=lambda path, event, data, fid=flow_id: self._handle_udp_path(fid, path, event, data), connection=connection, session_id=flow_id, stream_id=stream_id, target=target))
  598. flow = UdpFlow(flow_id=flow_id, source=source, target=target, send_response=listener.send_response, paths=paths)
  599. self.udp_flows[key] = flow
  600. print(f"[edge] udp flow={flow_id} target={target.host}:{target.port} candidates={[path.name for path in paths]}")
  601. await flow.start()
  602. await flow.send(payload)
  603. async def _handle_udp_path(self, flow_id: int, path: BasePath, event: str, payload: bytes | None) -> None:
  604. for flow in list(self.udp_flows.values()):
  605. if flow.flow_id == flow_id:
  606. await flow.handle_path(path, event, payload)
  607. break
  608. async def _gc_udp_flows(self) -> None:
  609. loop = asyncio.get_running_loop()
  610. while True:
  611. await asyncio.sleep(30)
  612. now = loop.time()
  613. stale = [key for key, flow in self.udp_flows.items() if flow.last_activity and now - flow.last_activity > 120]
  614. for key in stale:
  615. flow = self.udp_flows.pop(key, None)
  616. if flow:
  617. await flow.close()