transparent_edge.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899
  1. from __future__ import annotations
  2. from pathlib import Path
  3. import asyncio
  4. import contextlib
  5. import itertools
  6. import os
  7. import socket
  8. import struct
  9. from dataclasses import dataclass, field
  10. from typing import Awaitable, Callable
  11. from .config import Config
  12. from .protocol import STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, encode_json
  13. from .relay_client import RelayConnection, RelayManager
  14. SO_ORIGINAL_DST = 80
  15. IP6T_SO_ORIGINAL_DST = 80
  16. IP_RECVORIGDSTADDR = 20
  17. IPV6_RECVORIGDSTADDR = 74
  18. @dataclass(frozen=True)
  19. class TargetAddress:
  20. host: str
  21. port: int
  22. family: int
  23. @dataclass(frozen=True)
  24. class PeerAddress:
  25. host: str
  26. port: int
  27. family: int
  28. def parse_sockaddr(raw: bytes) -> TargetAddress:
  29. if len(raw) < 8:
  30. raise ValueError("invalid transparent destination payload")
  31. family = struct.unpack_from("=H", raw, 0)[0]
  32. port = struct.unpack_from("!H", raw, 2)[0]
  33. if family == socket.AF_INET:
  34. host = socket.inet_ntoa(raw[4:8])
  35. return TargetAddress(host=host, port=port, family=family)
  36. if family == socket.AF_INET6:
  37. if len(raw) < 28:
  38. raise ValueError("invalid IPv6 transparent destination payload")
  39. host = socket.inet_ntop(socket.AF_INET6, raw[8:24])
  40. return TargetAddress(host=host, port=port, family=family)
  41. raise ValueError(f"unsupported family={family}")
  42. def winner_group(name: str) -> str:
  43. return "direct" if name.startswith("direct") else name
  44. def grouped_total(stats: dict[str, int], group: str) -> int:
  45. return sum(count for name, count in stats.items() if winner_group(name) == group)
  46. class BasePath:
  47. def __init__(self, name: str, on_frame: Callable[["BasePath", str, bytes | None], Awaitable[None]]) -> None:
  48. self.name = name
  49. self.on_frame = on_frame
  50. self.opened = False
  51. self.closed = False
  52. async def open(self, target: TargetAddress) -> None:
  53. raise NotImplementedError
  54. async def send(self, data: bytes) -> None:
  55. raise NotImplementedError
  56. async def close(self) -> None:
  57. raise NotImplementedError
  58. class DirectTcpPath(BasePath):
  59. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], open_timeout: float, happy_eyeballs_delay: float | None, tcp_nodelay: bool = True) -> None:
  60. super().__init__(name, on_frame)
  61. self.reader: asyncio.StreamReader | None = None
  62. self.writer: asyncio.StreamWriter | None = None
  63. self.pump_task: asyncio.Task | None = None
  64. self.open_timeout = open_timeout
  65. self.happy_eyeballs_delay = happy_eyeballs_delay
  66. self.tcp_nodelay = tcp_nodelay
  67. async def open(self, target: TargetAddress) -> None:
  68. try:
  69. family = socket.AF_INET6 if target.family == socket.AF_INET6 else socket.AF_INET
  70. kwargs = {"host": target.host, "port": target.port, "family": family}
  71. if self.happy_eyeballs_delay is not None:
  72. kwargs["happy_eyeballs_delay"] = self.happy_eyeballs_delay
  73. self.reader, self.writer = await asyncio.wait_for(asyncio.open_connection(**kwargs), timeout=self.open_timeout)
  74. sock = self.writer.get_extra_info("socket")
  75. if sock is not None and self.tcp_nodelay:
  76. with contextlib.suppress(OSError):
  77. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  78. self.opened = True
  79. self.pump_task = asyncio.create_task(self._pump())
  80. await self.on_frame(self, "status", b"ok")
  81. except Exception as exc:
  82. await self.on_frame(self, "status", str(exc).encode())
  83. async def _pump(self) -> None:
  84. assert self.reader is not None
  85. try:
  86. while True:
  87. chunk = await self.reader.read(65536)
  88. if not chunk:
  89. break
  90. await self.on_frame(self, "data", chunk)
  91. except Exception:
  92. pass
  93. finally:
  94. await self.on_frame(self, "close", None)
  95. async def send(self, data: bytes) -> None:
  96. if self.closed or self.writer is None:
  97. return
  98. try:
  99. self.writer.write(data)
  100. await self.writer.drain()
  101. except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError) as exc:
  102. await self.close()
  103. raise ConnectionError("relay closed") from exc
  104. async def close(self) -> None:
  105. if self.closed:
  106. return
  107. self.closed = True
  108. if self.pump_task and self.pump_task is not asyncio.current_task():
  109. self.pump_task.cancel()
  110. with contextlib.suppress(Exception):
  111. await self.pump_task
  112. if self.writer:
  113. self.writer.close()
  114. with contextlib.suppress(Exception):
  115. await self.writer.wait_closed()
  116. class RelayTcpPath(BasePath):
  117. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], connection: RelayConnection, session_id: int, stream_id: int) -> None:
  118. super().__init__(name, on_frame)
  119. self.connection = connection
  120. self.session_id = session_id
  121. self.stream_id = stream_id
  122. self.unbind_task: asyncio.Task | None = None
  123. async def open(self, target: TargetAddress) -> None:
  124. if self.connection.closed:
  125. await self.on_frame(self, "status", b"relay unavailable")
  126. return
  127. self.connection.bind(self.session_id, self.stream_id, self._handle_frame)
  128. try:
  129. await self.connection.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, encode_json({"host": target.host, "port": target.port, "family": target.family})))
  130. except Exception as exc:
  131. self.connection.unbind(self.session_id, self.stream_id)
  132. await self.on_frame(self, "status", str(exc).encode())
  133. async def _handle_frame(self, _conn: RelayConnection, frame: Frame) -> None:
  134. if frame.kind == TCP_STATUS:
  135. if frame.packet_id == STATUS_OK:
  136. self.opened = True
  137. await self.on_frame(self, "status", b"ok")
  138. else:
  139. await self.on_frame(self, "status", frame.payload)
  140. return
  141. if frame.kind == TCP_DATA:
  142. await self.on_frame(self, "data", frame.payload)
  143. return
  144. if frame.kind == TCP_CLOSE:
  145. await self.on_frame(self, "close", None)
  146. async def send(self, data: bytes) -> None:
  147. if self.closed or self.connection.closed:
  148. return
  149. await self.connection.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, data))
  150. async def close(self) -> None:
  151. if self.closed:
  152. return
  153. self.closed = True
  154. if self.unbind_task is None or self.unbind_task.done():
  155. self.unbind_task = asyncio.create_task(self._delayed_unbind())
  156. if not self.connection.closed:
  157. with contextlib.suppress(Exception):
  158. await self.connection.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b""))
  159. async def _delayed_unbind(self) -> None:
  160. await asyncio.sleep(0.5)
  161. self.connection.unbind(self.session_id, self.stream_id)
  162. @dataclass
  163. class TransparentSession:
  164. session_id: int
  165. target: TargetAddress
  166. reader: asyncio.StreamReader
  167. writer: asyncio.StreamWriter
  168. paths: list[BasePath]
  169. warmup_bytes: int
  170. loser_grace_ms: int
  171. tcp_failover_idle_ms: int
  172. stats: dict[str, int]
  173. target_stats: dict[tuple[str, int], dict[str, int]]
  174. family_stats: dict[str, dict[str, int]]
  175. opened_count: int = 0
  176. status_count: int = 0
  177. errors: list[str] = field(default_factory=list)
  178. winner: BasePath | None = None
  179. uplink_bytes: int = 0
  180. open_event: asyncio.Event = field(default_factory=asyncio.Event)
  181. winner_event: asyncio.Event = field(default_factory=asyncio.Event)
  182. closed: bool = False
  183. pump_task: asyncio.Task | None = None
  184. loser_close_task: asyncio.Task | None = None
  185. open_tasks: list[asyncio.Task] = field(default_factory=list)
  186. backup_path: BasePath | None = None
  187. last_winner_data_at: float = 0.0
  188. failover_task: asyncio.Task | None = None
  189. def _select_backup_path(self, winner: BasePath) -> BasePath | None:
  190. candidates = [path for path in self.paths if path is not winner and path.opened and not path.closed]
  191. if not candidates:
  192. return None
  193. winner_is_direct = winner_group(winner.name) == "direct"
  194. # Prefer the opposite group to increase failover diversity.
  195. opposite = [path for path in candidates if (winner_group(path.name) == "direct") != winner_is_direct]
  196. pool = opposite or candidates
  197. # Keep the first eligible path as a synchronized backup.
  198. return pool[0]
  199. def _record_win(self, winner: BasePath) -> None:
  200. self.stats[winner.name] = self.stats.get(winner.name, 0) + 1
  201. key = (self.target.host, self.target.port)
  202. target_stats = self.target_stats.setdefault(key, {})
  203. target_stats[winner.name] = target_stats.get(winner.name, 0) + 1
  204. family_key = "ipv6" if self.target.family == socket.AF_INET6 else "ipv4"
  205. family_stats = self.family_stats.setdefault(family_key, {})
  206. family_stats[winner.name] = family_stats.get(winner.name, 0) + 1
  207. direct_wins = grouped_total(self.stats, "direct")
  208. relay_wins = sum(count for name, count in self.stats.items() if winner_group(name) != "direct")
  209. target_direct = grouped_total(target_stats, "direct")
  210. target_relay = sum(count for name, count in target_stats.items() if winner_group(name) != "direct")
  211. family_direct = grouped_total(family_stats, "direct")
  212. family_relay = sum(count for name, count in family_stats.items() if winner_group(name) != "direct")
  213. relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.stats.items()) if winner_group(name) != "direct") or "none"
  214. target_detail = ", ".join(f"{name}={count}" for name, count in sorted(target_stats.items()) if winner_group(name) != "direct") or "none"
  215. target_pref = "relay" if target_relay > target_direct else "direct"
  216. family_pref = "relay" if family_relay > family_direct else "direct"
  217. print(f"[edge] tcp win session={self.session_id} target={self.target.host}:{self.target.port} winner={winner.name} direct={direct_wins} relay={relay_wins} relay_breakdown={relay_detail} target_pref={target_pref} target_direct={target_direct} target_relay={target_relay} target_breakdown={target_detail} family_pref={family_pref} family={family_key} family_direct={family_direct} family_relay={family_relay}")
  218. async def start(self) -> None:
  219. self.open_tasks = [asyncio.create_task(path.open(self.target)) for path in self.paths]
  220. await asyncio.wait_for(self.open_event.wait(), timeout=8)
  221. if self.opened_count == 0:
  222. raise ConnectionError(self.errors[0] if self.errors else "all paths failed")
  223. self.pump_task = asyncio.create_task(self._pump_local())
  224. self.failover_task = asyncio.create_task(self._watch_failover())
  225. async def _pump_local(self) -> None:
  226. try:
  227. while True:
  228. chunk = await self.reader.read(65536)
  229. if not chunk:
  230. break
  231. self.uplink_bytes += len(chunk)
  232. active = [path for path in self.paths if path.opened and not path.closed]
  233. if not active:
  234. break
  235. if self.uplink_bytes <= self.warmup_bytes:
  236. await asyncio.gather(*(path.send(chunk) for path in active), return_exceptions=True)
  237. else:
  238. if self.winner is None:
  239. await self.winner_event.wait()
  240. if self.winner:
  241. send_targets = [self.winner]
  242. if self.backup_path and self.backup_path.opened and not self.backup_path.closed and self.backup_path is not self.winner:
  243. send_targets.append(self.backup_path)
  244. await asyncio.gather(*(path.send(chunk) for path in send_targets), return_exceptions=True)
  245. except Exception:
  246. pass
  247. finally:
  248. await self.close()
  249. async def _watch_failover(self) -> None:
  250. try:
  251. while not self.closed:
  252. await asyncio.sleep(0.2)
  253. if self.winner is None:
  254. continue
  255. if self.last_winner_data_at <= 0:
  256. continue
  257. idle_ms = (asyncio.get_running_loop().time() - self.last_winner_data_at) * 1000
  258. if idle_ms < 0:
  259. continue
  260. if idle_ms >= self.tcp_failover_idle_ms and self.backup_path and self.backup_path.opened and not self.backup_path.closed:
  261. old = self.winner
  262. self.winner = self.backup_path
  263. self.backup_path = self._select_backup_path(self.winner)
  264. self.last_winner_data_at = asyncio.get_running_loop().time()
  265. self._record_win(self.winner)
  266. print(
  267. f"[edge] tcp failover session={self.session_id} target={self.target.host}:{self.target.port} "
  268. f"old_winner={old.name if old else 'none'} new_winner={self.winner.name} idle_ms={int(idle_ms)}"
  269. )
  270. except Exception:
  271. pass
  272. async def handle_path(self, path: BasePath, event: str, payload: bytes | None) -> None:
  273. if self.closed:
  274. return
  275. if event == "status":
  276. self.status_count += 1
  277. if payload == b"ok":
  278. self.opened_count += 1
  279. elif payload is not None:
  280. self.errors.append(payload.decode("utf-8", errors="replace"))
  281. if self.opened_count > 0 or self.status_count == len(self.paths):
  282. self.open_event.set()
  283. return
  284. if event == "data":
  285. if self.winner is None:
  286. self.winner = path
  287. self._record_win(path)
  288. self.backup_path = self._select_backup_path(path)
  289. self.winner_event.set()
  290. if self.loser_grace_ms > 0:
  291. self.loser_close_task = asyncio.create_task(self._close_losers_after_grace(path))
  292. else:
  293. await self._close_losers(path)
  294. self.last_winner_data_at = asyncio.get_running_loop().time()
  295. if path is self.winner and payload is not None:
  296. self.writer.write(payload)
  297. await self.writer.drain()
  298. return
  299. if event == "close":
  300. path.closed = True
  301. if self.winner is None:
  302. remaining = [candidate for candidate in self.paths if candidate.opened and not candidate.closed]
  303. if not remaining:
  304. await self.close()
  305. elif path is self.winner:
  306. await self.close()
  307. async def _close_losers(self, winner: BasePath) -> None:
  308. await asyncio.gather(*(path.close() for path in self.paths if path is not winner and path is not self.backup_path), return_exceptions=True)
  309. async def _close_losers_after_grace(self, winner: BasePath) -> None:
  310. await asyncio.sleep(self.loser_grace_ms / 1000)
  311. if not self.closed:
  312. await self._close_losers(winner)
  313. async def close(self) -> None:
  314. if self.closed:
  315. return
  316. self.closed = True
  317. if self.errors:
  318. detail = ", ".join(self.errors[:3])
  319. print(
  320. f"[edge] session={self.session_id} closed target={self.target.host}:{self.target.port} "
  321. f"errors={len(self.errors)} detail={detail}"
  322. )
  323. if self.pump_task and self.pump_task is not asyncio.current_task():
  324. self.pump_task.cancel()
  325. with contextlib.suppress(Exception):
  326. await self.pump_task
  327. if self.loser_close_task and self.loser_close_task is not asyncio.current_task():
  328. self.loser_close_task.cancel()
  329. with contextlib.suppress(Exception):
  330. await self.loser_close_task
  331. if self.failover_task and self.failover_task is not asyncio.current_task():
  332. self.failover_task.cancel()
  333. with contextlib.suppress(Exception):
  334. await self.failover_task
  335. for task in self.open_tasks:
  336. if task is not asyncio.current_task():
  337. task.cancel()
  338. for task in self.open_tasks:
  339. if task is not asyncio.current_task():
  340. with contextlib.suppress(Exception):
  341. await task
  342. await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True)
  343. self.writer.close()
  344. with contextlib.suppress(Exception):
  345. await self.writer.wait_closed()
  346. class DirectUdpPath(BasePath):
  347. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], target: TargetAddress) -> None:
  348. super().__init__(name, on_frame)
  349. self.target = target
  350. self.socket: socket.socket | None = None
  351. self.read_task: asyncio.Task | None = None
  352. async def open(self, _target: TargetAddress) -> None:
  353. try:
  354. family = socket.AF_INET6 if self.target.family == socket.AF_INET6 else socket.AF_INET
  355. self.socket = socket.socket(family, socket.SOCK_DGRAM)
  356. self.socket.setblocking(False)
  357. await asyncio.get_running_loop().sock_connect(self.socket, (self.target.host, self.target.port))
  358. self.opened = True
  359. self.read_task = asyncio.create_task(self._pump())
  360. await self.on_frame(self, "status", b"ok")
  361. except Exception as exc:
  362. await self.on_frame(self, "status", str(exc).encode())
  363. async def _pump(self) -> None:
  364. assert self.socket is not None
  365. loop = asyncio.get_running_loop()
  366. try:
  367. while True:
  368. data = await loop.sock_recv(self.socket, 65535)
  369. if not data:
  370. break
  371. await self.on_frame(self, "data", data)
  372. except Exception:
  373. pass
  374. finally:
  375. await self.on_frame(self, "close", None)
  376. async def send(self, data: bytes) -> None:
  377. if self.closed or self.socket is None:
  378. return
  379. await asyncio.get_running_loop().sock_sendall(self.socket, data)
  380. async def close(self) -> None:
  381. if self.closed:
  382. return
  383. self.closed = True
  384. if self.read_task and self.read_task is not asyncio.current_task():
  385. self.read_task.cancel()
  386. with contextlib.suppress(Exception):
  387. await self.read_task
  388. if self.socket:
  389. self.socket.close()
  390. class RelayUdpPath(BasePath):
  391. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], connection: RelayConnection, session_id: int, stream_id: int, target: TargetAddress) -> None:
  392. super().__init__(name, on_frame)
  393. self.connection = connection
  394. self.session_id = session_id
  395. self.stream_id = stream_id
  396. self.target = target
  397. self.unbind_task: asyncio.Task | None = None
  398. async def open(self, _target: TargetAddress) -> None:
  399. if self.connection.closed:
  400. await self.on_frame(self, "status", b"relay unavailable")
  401. return
  402. self.connection.bind(self.session_id, self.stream_id, self._handle_frame)
  403. try:
  404. self.opened = True
  405. await self.on_frame(self, "status", b"ok")
  406. except Exception:
  407. self.connection.unbind(self.session_id, self.stream_id)
  408. self.closed = True
  409. raise
  410. async def _handle_frame(self, _conn: RelayConnection, frame: Frame) -> None:
  411. if frame.kind == UDP_RECV:
  412. await self.on_frame(self, "data", frame.payload)
  413. async def send(self, data: bytes) -> None:
  414. if self.closed or self.connection.closed:
  415. return
  416. meta = encode_json({"host": self.target.host, "port": self.target.port, "family": self.target.family})
  417. payload = meta + data
  418. try:
  419. await self.connection.send(Frame(UDP_SEND, self.session_id, self.stream_id, 0, len(meta), payload))
  420. except Exception:
  421. self.closed = True
  422. raise
  423. async def close(self) -> None:
  424. if self.closed:
  425. return
  426. self.closed = True
  427. if self.unbind_task is None or self.unbind_task.done():
  428. self.unbind_task = asyncio.create_task(self._delayed_unbind())
  429. async def _delayed_unbind(self) -> None:
  430. await asyncio.sleep(0.5)
  431. self.connection.unbind(self.session_id, self.stream_id)
  432. @dataclass
  433. class UdpFlow:
  434. flow_id: int
  435. source: PeerAddress
  436. target: TargetAddress
  437. send_response: Callable[[PeerAddress, bytes], Awaitable[None]]
  438. paths: list[BasePath]
  439. redundancy: int = 0
  440. always_broadcast: bool = True
  441. copy_interval_ms: int = 0
  442. winner: BasePath | None = None
  443. closed: bool = False
  444. last_activity: float = 0.0
  445. packets_sent: int = 0
  446. packets_received: int = 0
  447. duplicate_responses: int = 0
  448. send_task: asyncio.Task | None = None
  449. winner_burst_sent: int = 0
  450. converged: bool = False
  451. async def start(self) -> None:
  452. await asyncio.gather(*(path.open(self.target) for path in self.paths), return_exceptions=True)
  453. async def send(self, payload: bytes) -> None:
  454. self.last_activity = asyncio.get_running_loop().time()
  455. self.packets_sent += 1
  456. active = [path for path in self.paths if path.opened and not path.closed]
  457. if not active:
  458. return
  459. copies = max(1, self.redundancy + 1)
  460. if self.winner is None or self.winner.closed:
  461. self.converged = False
  462. self.winner_burst_sent = 0
  463. targets = active
  464. elif not self.converged:
  465. # 先并发、后收敛:winner 刚出现时保留短暂重叠,随后快速收敛到单路径。
  466. self.winner_burst_sent += 1
  467. backup = [path for path in active if path is not self.winner][:1]
  468. targets = [self.winner, *backup] if self.winner_burst_sent <= 2 else [self.winner]
  469. if self.winner_burst_sent > 2:
  470. self.converged = True
  471. else:
  472. targets = [self.winner]
  473. for attempt in range(copies):
  474. await asyncio.gather(*(path.send(payload) for path in targets), return_exceptions=True)
  475. if attempt + 1 < copies and self.copy_interval_ms > 0:
  476. await asyncio.sleep(self.copy_interval_ms / 1000)
  477. async def handle_path(self, path: BasePath, event: str, payload: bytes | None) -> None:
  478. self.last_activity = asyncio.get_running_loop().time()
  479. if event == "data" and payload is not None:
  480. self.packets_received += 1
  481. if self.winner is None:
  482. self.winner = path
  483. self.converged = False
  484. self.winner_burst_sent = 0
  485. mode = "redundant" if self.redundancy > 0 else "single"
  486. print(f"[edge] udp flow={self.flow_id} winner={path.name} target={self.target.host}:{self.target.port} mode={mode} candidates={len(self.paths)}")
  487. elif path is not self.winner:
  488. self.duplicate_responses += 1
  489. if path is self.winner:
  490. await self.send_response(self.source, payload)
  491. if event == "close":
  492. path.closed = True
  493. if path is self.winner:
  494. remaining = [candidate for candidate in self.paths if candidate.opened and not candidate.closed]
  495. self.winner = remaining[0] if remaining else None
  496. self.converged = False
  497. self.winner_burst_sent = 0
  498. async def close(self) -> None:
  499. if self.closed:
  500. return
  501. self.closed = True
  502. if self.send_task and self.send_task is not asyncio.current_task():
  503. self.send_task.cancel()
  504. with contextlib.suppress(Exception):
  505. await self.send_task
  506. await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True)
  507. class TransparentUdpListener:
  508. def __init__(self, edge: "TransparentEdge", family: int, bind_host: str, port: int) -> None:
  509. self.edge = edge
  510. self.family = family
  511. self.bind_host = bind_host
  512. self.port = port
  513. self.socket: socket.socket | None = None
  514. self.udp_packets_received = 0
  515. self.udp_recv_errors = 0
  516. self.udp_parse_errors = 0
  517. self.udp_missing_original = 0
  518. self.udp_self_loop_skipped = 0
  519. self.udp_flows_created = 0
  520. self.last_summary_at = 0.0
  521. def start(self) -> None:
  522. sock = socket.socket(self.family, socket.SOCK_DGRAM)
  523. sock.setblocking(False)
  524. if self.family == socket.AF_INET:
  525. sock.setsockopt(socket.SOL_IP, IP_RECVORIGDSTADDR, 1)
  526. sock.bind((self.bind_host, self.port))
  527. else:
  528. sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
  529. sock.setsockopt(socket.IPPROTO_IPV6, IPV6_RECVORIGDSTADDR, 1)
  530. sock.bind((self.bind_host, self.port, 0, 0))
  531. self.socket = sock
  532. asyncio.get_running_loop().add_reader(sock.fileno(), self._on_readable)
  533. print(f"[edge] transparent udp listening on {sock.getsockname()}")
  534. def _log_udp_summary(self, force: bool = False) -> None:
  535. now = asyncio.get_running_loop().time()
  536. if not force and now - self.last_summary_at < 10:
  537. return
  538. self.last_summary_at = now
  539. print(
  540. f"[edge] udp summary family={self.family} bind={self.bind_host}:{self.port} "
  541. f"received={self.udp_packets_received} flows={self.udp_flows_created} "
  542. f"self_loop={self.udp_self_loop_skipped} missing_original={self.udp_missing_original} "
  543. f"parse_error={self.udp_parse_errors} recv_error={self.udp_recv_errors}"
  544. )
  545. def _on_readable(self) -> None:
  546. assert self.socket is not None
  547. try:
  548. data, ancdata, _flags, src = self.socket.recvmsg(65535, 512)
  549. except BlockingIOError:
  550. return
  551. except Exception as exc:
  552. self.udp_recv_errors += 1
  553. print(f"[edge] udp recvmsg error family={self.family} error={exc!r}")
  554. self._log_udp_summary(force=True)
  555. return
  556. self.udp_packets_received += 1
  557. original = None
  558. for level, ctype, cdata in ancdata:
  559. if self.family == socket.AF_INET and level == socket.SOL_IP and ctype == IP_RECVORIGDSTADDR:
  560. try:
  561. original = parse_sockaddr(cdata)
  562. except Exception as exc:
  563. self.udp_parse_errors += 1
  564. print(f"[edge] udp parse original dst error family={self.family} src={src} error={exc!r} raw_len={len(cdata)}")
  565. self._log_udp_summary(force=True)
  566. return
  567. break
  568. if self.family == socket.AF_INET6 and level == socket.IPPROTO_IPV6 and ctype == IPV6_RECVORIGDSTADDR:
  569. try:
  570. original = parse_sockaddr(cdata)
  571. except Exception as exc:
  572. self.udp_parse_errors += 1
  573. print(f"[edge] udp parse original dst error family={self.family} src={src} error={exc!r} raw_len={len(cdata)}")
  574. self._log_udp_summary(force=True)
  575. return
  576. break
  577. if original is None:
  578. self.udp_missing_original += 1
  579. self._log_udp_summary()
  580. return
  581. if self.family == socket.AF_INET:
  582. source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET)
  583. else:
  584. source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET6)
  585. if original.port == self.port and (original.host in ("127.0.0.1", "::1") or original.host == self.bind_host):
  586. self.udp_self_loop_skipped += 1
  587. print(
  588. f"[edge] udp self_loop family={self.family} src={source.host}:{source.port} "
  589. f"original={original.host}:{original.port} size={len(data)}"
  590. )
  591. self._log_udp_summary()
  592. return
  593. asyncio.create_task(self.edge.handle_udp_datagram(source, original, data, self))
  594. async def send_response(self, source: PeerAddress, payload: bytes) -> None:
  595. assert self.socket is not None
  596. if source.family == socket.AF_INET:
  597. self.socket.sendto(payload, (source.host, source.port))
  598. else:
  599. self.socket.sendto(payload, (source.host, source.port, 0, 0))
  600. async def close(self) -> None:
  601. if self.socket is None:
  602. return
  603. asyncio.get_running_loop().remove_reader(self.socket.fileno())
  604. self.socket.close()
  605. self.socket = None
  606. class TransparentEdge:
  607. def __init__(self, listen_host: str, listen_port: int, config: Config, enable_udp: bool = False, kernel_mode: str = "auto") -> None:
  608. self.listen_host = listen_host
  609. self.listen_port = listen_port
  610. self.config = config
  611. self.enable_udp = enable_udp
  612. self.kernel_mode = self._resolve_kernel_mode(kernel_mode, config.kernel_mode)
  613. self.manager = RelayManager(config)
  614. self.session_ids = itertools.count(1)
  615. self.stream_ids = itertools.count(1)
  616. self.udp_listeners: list[TransparentUdpListener] = []
  617. self.udp_flows: dict[tuple[PeerAddress, TargetAddress], UdpFlow] = {}
  618. self.udp_flow_ids = itertools.count(1)
  619. self.udp_gc_task: asyncio.Task | None = None
  620. self.tcp_win_counts: dict[str, int] = {}
  621. self.tcp_target_wins: dict[tuple[str, int], dict[str, int]] = {}
  622. self.tcp_family_wins: dict[str, dict[str, int]] = {"ipv4": {}, "ipv6": {}}
  623. def _resolve_kernel_mode(self, cli_kernel_mode: str, config_kernel_mode: str) -> str:
  624. mode = cli_kernel_mode if cli_kernel_mode != "auto" else config_kernel_mode
  625. if mode != "auto":
  626. return mode
  627. try:
  628. if Path("/etc/os-release").exists() and 'VERSION_ID="24' in Path("/etc/os-release").read_text(errors="ignore"):
  629. return "24"
  630. except Exception:
  631. pass
  632. try:
  633. release = os.uname().release
  634. if release.startswith("6."):
  635. return "24"
  636. except Exception:
  637. pass
  638. return "20"
  639. async def start(self) -> None:
  640. if self.kernel_mode == "24":
  641. if self.config.direct_open_timeout == 10.0:
  642. self.config.direct_open_timeout = 6.0
  643. if self.config.relay_open_timeout == 10.0:
  644. self.config.relay_open_timeout = 6.0
  645. if self.config.tcp_connect_happy_eyeballs_delay is None:
  646. self.config.tcp_connect_happy_eyeballs_delay = 0.25
  647. await self.manager.start()
  648. print(f"[edge] kernel_mode={self.kernel_mode} relay snapshot: {self.manager.snapshot()}")
  649. server4 = await asyncio.start_server(self._accept, self.listen_host, self.listen_port, family=socket.AF_INET)
  650. sockets = [str(sock.getsockname()) for sock in server4.sockets or []]
  651. server6 = None
  652. if self.listen_host in ("::", "::1", "0.0.0.0", "127.0.0.1"):
  653. host6 = "::1" if self.listen_host == "127.0.0.1" else "::"
  654. try:
  655. server6 = await asyncio.start_server(self._accept, host6, self.listen_port, family=socket.AF_INET6)
  656. sockets.extend(str(sock.getsockname()) for sock in server6.sockets or [])
  657. except Exception as exc:
  658. print(f"[edge] ipv6 tcp listener skipped: {exc!r}")
  659. if self.enable_udp:
  660. self._start_udp_listeners()
  661. self.udp_gc_task = asyncio.create_task(self._gc_udp_flows())
  662. print(f"[edge] transparent tcp listening on {', '.join(sockets)}")
  663. if server6 is None:
  664. async with server4:
  665. await server4.serve_forever()
  666. else:
  667. async with server4, server6:
  668. await asyncio.gather(server4.serve_forever(), server6.serve_forever())
  669. def _direct_redundancy_for_target(self, target: TargetAddress) -> int:
  670. if target.family == socket.AF_INET6 and not self.config.direct_ipv6_enabled:
  671. return 0
  672. base = self.config.direct_redundancy
  673. if target.family == socket.AF_INET6 and self.config.direct_redundancy_v6 is not None:
  674. base = self.config.direct_redundancy_v6
  675. elif target.family == socket.AF_INET and self.config.direct_redundancy_v4 is not None:
  676. base = self.config.direct_redundancy_v4
  677. base = max(1, min(base, self.config.direct_max_redundancy))
  678. target_stats = self.tcp_target_wins.get((target.host, target.port), {})
  679. family_key = "ipv6" if target.family == socket.AF_INET6 else "ipv4"
  680. family_stats = self.tcp_family_wins.get(family_key, {})
  681. target_total = sum(target_stats.values())
  682. family_total = sum(family_stats.values())
  683. target_relay = sum(count for name, count in target_stats.items() if winner_group(name) != "direct")
  684. family_relay = sum(count for name, count in family_stats.items() if winner_group(name) != "direct")
  685. target_prefers_relay = target_total >= 4 and target_relay > grouped_total(target_stats, "direct")
  686. family_prefers_relay = family_total >= 8 and family_relay > grouped_total(family_stats, "direct")
  687. target_prefers_direct = target_total >= 4 and grouped_total(target_stats, "direct") > target_relay
  688. family_prefers_direct = family_total >= 8 and grouped_total(family_stats, "direct") > family_relay
  689. if target_prefers_relay or family_prefers_relay:
  690. return max(1, base - 1)
  691. if target_prefers_direct or family_prefers_direct:
  692. if base > 2:
  693. return base - 1
  694. return base
  695. def _build_direct_paths(self, session: TransparentSession) -> list[BasePath]:
  696. count = self._direct_redundancy_for_target(session.target)
  697. if count <= 0:
  698. return []
  699. return [
  700. DirectTcpPath(
  701. name=f"direct-{index + 1}" if count > 1 else "direct",
  702. on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload),
  703. open_timeout=self.config.direct_open_timeout,
  704. happy_eyeballs_delay=self.config.tcp_connect_happy_eyeballs_delay,
  705. tcp_nodelay=self.config.relay_tcp_nodelay,
  706. )
  707. for index in range(count)
  708. ]
  709. def _build_udp_direct_paths(self, target: TargetAddress, flow_id: int) -> list[BasePath]:
  710. if target.family == socket.AF_INET6 and not self.config.direct_ipv6_enabled:
  711. return []
  712. count = max(1, self.config.udp_direct_redundancy)
  713. if target.family == socket.AF_INET6 and self.config.udp_direct_redundancy_v6 is not None:
  714. count = max(1, self.config.udp_direct_redundancy_v6)
  715. elif target.family == socket.AF_INET and self.config.udp_direct_redundancy_v4 is not None:
  716. count = max(1, self.config.udp_direct_redundancy_v4)
  717. return [
  718. DirectUdpPath(
  719. name=f"direct-{index + 1}" if count > 1 else "direct",
  720. on_frame=lambda path, event, data, fid=flow_id: self._handle_udp_path(fid, path, event, data),
  721. target=target,
  722. )
  723. for index in range(count)
  724. ]
  725. def _start_udp_listeners(self) -> None:
  726. binds = []
  727. if self.listen_host == "127.0.0.1":
  728. binds = [(socket.AF_INET, "127.0.0.1"), (socket.AF_INET6, "::1")]
  729. elif self.listen_host == "0.0.0.0":
  730. binds = [(socket.AF_INET, "0.0.0.0"), (socket.AF_INET6, "::")]
  731. else:
  732. family = socket.AF_INET6 if ":" in self.listen_host else socket.AF_INET
  733. binds = [(family, self.listen_host)]
  734. for family, host in binds:
  735. try:
  736. listener = TransparentUdpListener(self, family, host, self.listen_port)
  737. listener.start()
  738. self.udp_listeners.append(listener)
  739. except Exception as exc:
  740. print(f"[edge] udp listener skipped family={family} host={host} error={exc!r}")
  741. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  742. peer = writer.get_extra_info("peername")
  743. try:
  744. target = self._get_original_dst(writer)
  745. session_id = next(self.session_ids)
  746. session = TransparentSession(session_id=session_id, target=target, reader=reader, writer=writer, paths=[], warmup_bytes=self.config.tcp_warmup_bytes, loser_grace_ms=self.config.tcp_loser_grace_ms, tcp_failover_idle_ms=self.config.tcp_failover_idle_ms, stats=self.tcp_win_counts, target_stats=self.tcp_target_wins, family_stats=self.tcp_family_wins)
  747. paths: list[BasePath] = self._build_direct_paths(session)
  748. for connection in self.manager.available():
  749. stream_id = next(self.stream_ids)
  750. paths.append(RelayTcpPath(name=connection.node.name, on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload), connection=connection, session_id=session_id, stream_id=stream_id))
  751. session.paths = paths
  752. print(f"[edge] accept peer={peer} session={session_id} target={target.host}:{target.port} candidates={[path.name for path in paths]}")
  753. await session.start()
  754. except Exception as exc:
  755. print(f"[edge] accept failed peer={peer} error={exc!r}")
  756. writer.close()
  757. with contextlib.suppress(Exception):
  758. await writer.wait_closed()
  759. async def _handle_tcp_session(self, session: TransparentSession, path: BasePath, event: str, payload: bytes | None) -> None:
  760. await session.handle_path(path, event, payload)
  761. def _get_original_dst(self, writer: asyncio.StreamWriter) -> TargetAddress:
  762. sock = writer.get_extra_info("socket")
  763. if sock is None:
  764. raise RuntimeError("socket unavailable")
  765. family = sock.family
  766. if family == socket.AF_INET:
  767. raw = sock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, 16)
  768. return parse_sockaddr(raw)
  769. if family == socket.AF_INET6:
  770. raw = sock.getsockopt(socket.IPPROTO_IPV6, IP6T_SO_ORIGINAL_DST, 128)
  771. return parse_sockaddr(raw)
  772. raise RuntimeError(f"unsupported socket family={family}")
  773. async def handle_udp_datagram(self, source: PeerAddress, target: TargetAddress, payload: bytes, listener: TransparentUdpListener) -> None:
  774. if not self.enable_udp:
  775. return
  776. if target.port == self.listen_port and target.host in ("127.0.0.1", "::1", self.listen_host):
  777. return
  778. key = (source, target)
  779. flow = self.udp_flows.get(key)
  780. if flow is None:
  781. flow_id = next(self.udp_flow_ids)
  782. paths: list[BasePath] = self._build_udp_direct_paths(target, flow_id)
  783. for connection in self.manager.available():
  784. stream_id = next(self.stream_ids)
  785. paths.append(RelayUdpPath(name=connection.node.name, on_frame=lambda path, event, data, fid=flow_id: self._handle_udp_path(fid, path, event, data), connection=connection, session_id=flow_id, stream_id=stream_id, target=target))
  786. flow = UdpFlow(
  787. flow_id=flow_id,
  788. source=source,
  789. target=target,
  790. send_response=listener.send_response,
  791. paths=paths,
  792. redundancy=self.config.udp_redundancy,
  793. always_broadcast=self.config.udp_always_broadcast,
  794. copy_interval_ms=self.config.udp_copy_interval_ms,
  795. )
  796. self.udp_flows[key] = flow
  797. listener.udp_flows_created += 1
  798. listener._log_udp_summary(force=True)
  799. print(f"[edge] udp flow={flow_id} source={source.host}:{source.port} target={target.host}:{target.port} redundancy={self.config.udp_redundancy} direct_redundancy={self.config.udp_direct_redundancy} always_broadcast={self.config.udp_always_broadcast} candidates={[path.name for path in paths]}")
  800. await flow.start()
  801. await flow.send(payload)
  802. async def _handle_udp_path(self, flow_id: int, path: BasePath, event: str, payload: bytes | None) -> None:
  803. for flow in list(self.udp_flows.values()):
  804. if flow.flow_id == flow_id:
  805. await flow.handle_path(path, event, payload)
  806. break
  807. async def _gc_udp_flows(self) -> None:
  808. loop = asyncio.get_running_loop()
  809. while True:
  810. await asyncio.sleep(30)
  811. now = loop.time()
  812. stale = [key for key, flow in self.udp_flows.items() if flow.last_activity and now - flow.last_activity > 120]
  813. for key in stale:
  814. flow = self.udp_flows.pop(key, None)
  815. if flow:
  816. await flow.close()