edge_tcp.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. from __future__ import annotations
  2. from pathlib import Path
  3. import asyncio
  4. import contextlib
  5. import itertools
  6. import os
  7. import socket
  8. import struct
  9. from dataclasses import dataclass, field
  10. from typing import Awaitable, Callable
  11. from .config_tcp import TcpConfig
  12. from .protocol import STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, Frame, encode_json
  13. from .relay_client_tcp import TcpRelayConnection, TcpRelayManager
  14. SO_ORIGINAL_DST = 80
  15. IP6T_SO_ORIGINAL_DST = 80
  16. SUPPRESSED_CLOSE_EXCEPTIONS = (Exception, asyncio.CancelledError)
  17. @dataclass(frozen=True)
  18. class TargetAddress:
  19. host: str
  20. port: int
  21. family: int
  22. def parse_sockaddr(raw: bytes) -> TargetAddress:
  23. if len(raw) < 8:
  24. raise ValueError("invalid transparent destination payload")
  25. family = struct.unpack_from("=H", raw, 0)[0]
  26. port = struct.unpack_from("!H", raw, 2)[0]
  27. if family == socket.AF_INET:
  28. return TargetAddress(host=socket.inet_ntoa(raw[4:8]), port=port, family=family)
  29. if family == socket.AF_INET6:
  30. if len(raw) < 28:
  31. raise ValueError("invalid IPv6 transparent destination payload")
  32. return TargetAddress(host=socket.inet_ntop(socket.AF_INET6, raw[8:24]), port=port, family=family)
  33. raise ValueError(f"unsupported family={family}")
  34. def winner_group(name: str) -> str:
  35. return "direct" if name.startswith("direct") else name
  36. def grouped_total(stats: dict[str, int], group: str) -> int:
  37. return sum(count for name, count in stats.items() if winner_group(name) == group)
  38. class BasePath:
  39. def __init__(self, name: str, on_frame: Callable[["BasePath", str, bytes | None], Awaitable[None]]) -> None:
  40. self.name = name
  41. self.on_frame = on_frame
  42. self.opened = False
  43. self.closed = False
  44. async def open(self, target: TargetAddress) -> None:
  45. raise NotImplementedError
  46. async def send(self, data: bytes) -> None:
  47. raise NotImplementedError
  48. async def close(self) -> None:
  49. raise NotImplementedError
  50. class DirectTcpPath(BasePath):
  51. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], open_timeout: float, happy_eyeballs_delay: float | None, tcp_nodelay: bool = True) -> None:
  52. super().__init__(name, on_frame)
  53. self.reader: asyncio.StreamReader | None = None
  54. self.writer: asyncio.StreamWriter | None = None
  55. self.pump_task: asyncio.Task | None = None
  56. self.open_timeout = open_timeout
  57. self.happy_eyeballs_delay = happy_eyeballs_delay
  58. self.tcp_nodelay = tcp_nodelay
  59. async def open(self, target: TargetAddress) -> None:
  60. try:
  61. family = socket.AF_INET6 if target.family == socket.AF_INET6 else socket.AF_INET
  62. kwargs = {"host": target.host, "port": target.port, "family": family}
  63. if self.happy_eyeballs_delay is not None:
  64. kwargs["happy_eyeballs_delay"] = self.happy_eyeballs_delay
  65. self.reader, self.writer = await asyncio.wait_for(asyncio.open_connection(**kwargs), timeout=self.open_timeout)
  66. sock = self.writer.get_extra_info("socket")
  67. if sock is not None and self.tcp_nodelay:
  68. with contextlib.suppress(OSError):
  69. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  70. self.opened = True
  71. self.pump_task = asyncio.create_task(self._pump())
  72. await self.on_frame(self, "status", b"ok")
  73. except Exception as exc:
  74. await self.on_frame(self, "status", str(exc).encode())
  75. async def _pump(self) -> None:
  76. assert self.reader is not None
  77. try:
  78. while True:
  79. try:
  80. chunk = await self.reader.read(65536)
  81. except (ConnectionResetError, BrokenPipeError, OSError):
  82. break
  83. if not chunk:
  84. break
  85. await self.on_frame(self, "data", chunk)
  86. finally:
  87. await self.on_frame(self, "close", None)
  88. async def send(self, data: bytes) -> None:
  89. if self.closed or self.writer is None:
  90. return
  91. try:
  92. self.writer.write(data)
  93. await self.writer.drain()
  94. except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError) as exc:
  95. await self.close()
  96. raise ConnectionError("relay closed") from exc
  97. async def close(self) -> None:
  98. if self.closed:
  99. return
  100. self.closed = True
  101. if self.writer:
  102. self.writer.close()
  103. with contextlib.suppress(*SUPPRESSED_CLOSE_EXCEPTIONS):
  104. await self.writer.wait_closed()
  105. class RelayTcpPath(BasePath):
  106. def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], connection: TcpRelayConnection, session_id: int, stream_id: int) -> None:
  107. super().__init__(name, on_frame)
  108. self.connection = connection
  109. self.session_id = session_id
  110. self.stream_id = stream_id
  111. self.unbind_task: asyncio.Task | None = None
  112. async def open(self, target: TargetAddress) -> None:
  113. if self.connection.closed:
  114. await self.on_frame(self, "status", b"relay unavailable")
  115. return
  116. self.connection.bind(self.session_id, self.stream_id, self._handle_frame)
  117. try:
  118. await self.connection.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, encode_json({"host": target.host, "port": target.port, "family": target.family})))
  119. except Exception as exc:
  120. self.connection.unbind(self.session_id, self.stream_id)
  121. await self.on_frame(self, "status", str(exc).encode())
  122. async def _handle_frame(self, _conn: TcpRelayConnection, frame: Frame) -> None:
  123. if frame.kind == TCP_STATUS:
  124. if frame.packet_id == STATUS_OK:
  125. self.opened = True
  126. await self.on_frame(self, "status", b"ok")
  127. else:
  128. await self.on_frame(self, "status", frame.payload)
  129. return
  130. if frame.kind == TCP_DATA:
  131. await self.on_frame(self, "data", frame.payload)
  132. return
  133. if frame.kind == TCP_CLOSE:
  134. await self.on_frame(self, "close", None)
  135. async def send(self, data: bytes) -> None:
  136. if self.closed or self.connection.closed:
  137. return
  138. await self.connection.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, data))
  139. async def close(self) -> None:
  140. if self.closed:
  141. return
  142. self.closed = True
  143. if self.unbind_task is None or self.unbind_task.done():
  144. self.unbind_task = asyncio.create_task(self._delayed_unbind())
  145. if not self.connection.closed:
  146. with contextlib.suppress(Exception):
  147. await self.connection.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b""))
  148. async def _delayed_unbind(self) -> None:
  149. await asyncio.sleep(0.5)
  150. self.connection.unbind(self.session_id, self.stream_id)
  151. @dataclass
  152. class TcpSession:
  153. session_id: int
  154. target: TargetAddress
  155. reader: asyncio.StreamReader
  156. writer: asyncio.StreamWriter
  157. paths: list[BasePath]
  158. warmup_bytes: int
  159. loser_grace_ms: int
  160. stats: dict[str, int]
  161. target_stats: dict[tuple[str, int], dict[str, int]]
  162. family_stats: dict[str, dict[str, int]]
  163. opened_count: int = 0
  164. status_count: int = 0
  165. errors: list[str] = field(default_factory=list)
  166. winner: BasePath | None = None
  167. uplink_bytes: int = 0
  168. open_event: asyncio.Event = field(default_factory=asyncio.Event)
  169. winner_event: asyncio.Event = field(default_factory=asyncio.Event)
  170. close_event: asyncio.Event = field(default_factory=asyncio.Event)
  171. closed: bool = False
  172. closing: bool = False
  173. close_task: asyncio.Task | None = None
  174. pump_task: asyncio.Task | None = None
  175. loser_close_task: asyncio.Task | None = None
  176. open_tasks: list[asyncio.Task] = field(default_factory=list)
  177. def _choose_winner(self, winner: BasePath) -> None:
  178. if self.winner is not None:
  179. return
  180. self.winner = winner
  181. self._record_win(winner)
  182. self.winner_event.set()
  183. def _record_win(self, winner: BasePath) -> None:
  184. self.stats[winner.name] = self.stats.get(winner.name, 0) + 1
  185. key = (self.target.host, self.target.port)
  186. target_stats = self.target_stats.setdefault(key, {})
  187. target_stats[winner.name] = target_stats.get(winner.name, 0) + 1
  188. family_key = "ipv6" if self.target.family == socket.AF_INET6 else "ipv4"
  189. family_stats = self.family_stats.setdefault(family_key, {})
  190. family_stats[winner.name] = family_stats.get(winner.name, 0) + 1
  191. direct_wins = grouped_total(self.stats, "direct")
  192. relay_wins = sum(count for name, count in self.stats.items() if winner_group(name) != "direct")
  193. target_direct = grouped_total(target_stats, "direct")
  194. target_relay = sum(count for name, count in target_stats.items() if winner_group(name) != "direct")
  195. family_direct = grouped_total(family_stats, "direct")
  196. family_relay = sum(count for name, count in family_stats.items() if winner_group(name) != "direct")
  197. relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.stats.items()) if winner_group(name) != "direct") or "none"
  198. target_detail = ", ".join(f"{name}={count}" for name, count in sorted(target_stats.items()) if winner_group(name) != "direct") or "none"
  199. target_pref = "relay" if target_relay > target_direct else "direct"
  200. family_pref = "relay" if family_relay > family_direct else "direct"
  201. print(f"[edge] tcp win session={self.session_id} target={self.target.host}:{self.target.port} winner={winner.name} direct={direct_wins} relay={relay_wins} relay_breakdown={relay_detail} target_pref={target_pref} target_direct={target_direct} target_relay={target_relay} target_breakdown={target_detail} family_pref={family_pref} family={family_key} family_direct={family_direct} family_relay={family_relay}")
  202. async def start(self) -> None:
  203. self.open_tasks = [asyncio.create_task(path.open(self.target)) for path in self.paths]
  204. await asyncio.wait_for(self.open_event.wait(), timeout=8)
  205. if self.opened_count == 0:
  206. raise ConnectionError(self.errors[0] if self.errors else "all paths failed")
  207. self.pump_task = asyncio.create_task(self._pump_local())
  208. async def _pump_local(self) -> None:
  209. try:
  210. while True:
  211. chunk = await self.reader.read(65536)
  212. if not chunk:
  213. break
  214. self.uplink_bytes += len(chunk)
  215. active = [path for path in self.paths if path.opened and not path.closed]
  216. if not active:
  217. break
  218. if self.winner is None and self.uplink_bytes <= self.warmup_bytes:
  219. await asyncio.gather(*(path.send(chunk) for path in active), return_exceptions=True)
  220. else:
  221. if self.winner is None:
  222. await self.winner_event.wait()
  223. if self.winner and self.winner.opened and not self.winner.closed:
  224. await self.winner.send(chunk)
  225. else:
  226. break
  227. finally:
  228. self._request_close()
  229. async def handle_path(self, path: BasePath, event: str, payload: bytes | None) -> None:
  230. if self.closed:
  231. return
  232. if event == "status":
  233. self.status_count += 1
  234. if payload == b"ok":
  235. self.opened_count += 1
  236. elif payload is not None:
  237. self.errors.append(payload.decode("utf-8", errors="replace"))
  238. if self.opened_count > 0 or self.status_count == len(self.paths):
  239. self.open_event.set()
  240. return
  241. if event == "data":
  242. if self.winner is None:
  243. self._choose_winner(path)
  244. if self.loser_grace_ms > 0:
  245. self.loser_close_task = asyncio.create_task(self._close_losers_after_grace(path))
  246. else:
  247. self.loser_close_task = asyncio.create_task(self._close_losers(path))
  248. if path is self.winner and payload is not None:
  249. self.writer.write(payload)
  250. await self.writer.drain()
  251. return
  252. if event == "close":
  253. path.closed = True
  254. if self.winner is None:
  255. remaining = [candidate for candidate in self.paths if candidate.opened and not candidate.closed]
  256. if not remaining:
  257. self._request_close()
  258. elif path is self.winner:
  259. self._request_close()
  260. async def _close_losers(self, winner: BasePath) -> None:
  261. await asyncio.gather(*(path.close() for path in self.paths if path is not winner), return_exceptions=True)
  262. async def _close_losers_after_grace(self, winner: BasePath) -> None:
  263. await asyncio.sleep(self.loser_grace_ms / 1000)
  264. if not self.closed:
  265. await self._close_losers(winner)
  266. def _request_close(self) -> None:
  267. if self.closing:
  268. return
  269. self.closing = True
  270. self.close_task = asyncio.create_task(self._finalize())
  271. async def _finalize(self) -> None:
  272. if self.closed:
  273. self.close_event.set()
  274. return
  275. self.closed = True
  276. if self.pump_task and not self.pump_task.done():
  277. self.pump_task.cancel()
  278. if self.loser_close_task and not self.loser_close_task.done():
  279. self.loser_close_task.cancel()
  280. for task in self.open_tasks:
  281. if not task.done():
  282. task.cancel()
  283. if self.pump_task:
  284. with contextlib.suppress(*SUPPRESSED_CLOSE_EXCEPTIONS):
  285. await self.pump_task
  286. for task in self.open_tasks:
  287. with contextlib.suppress(*SUPPRESSED_CLOSE_EXCEPTIONS):
  288. await task
  289. await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True)
  290. self.writer.close()
  291. with contextlib.suppress(*SUPPRESSED_CLOSE_EXCEPTIONS):
  292. await self.writer.wait_closed()
  293. self.close_event.set()
  294. async def close(self) -> None:
  295. self._request_close()
  296. if asyncio.current_task() is self.pump_task:
  297. return
  298. await self.close_event.wait()
  299. class TcpEdge:
  300. def __init__(self, listen_host: str, listen_port: int, config: TcpConfig, kernel_mode: str = "auto") -> None:
  301. self.listen_host = listen_host
  302. self.listen_port = listen_port
  303. self.config = config
  304. self.kernel_mode = self._resolve_kernel_mode(kernel_mode, config.kernel_mode)
  305. self.manager = TcpRelayManager(config)
  306. self.session_ids = itertools.count(1)
  307. self.stream_ids = itertools.count(1)
  308. self.tcp_win_counts: dict[str, int] = {}
  309. self.tcp_target_wins: dict[tuple[str, int], dict[str, int]] = {}
  310. self.tcp_family_wins: dict[str, dict[str, int]] = {"ipv4": {}, "ipv6": {}}
  311. self._accept_log_every = 25
  312. self._interactive_ports = {22, 29765}
  313. def _resolve_kernel_mode(self, cli_kernel_mode: str, config_kernel_mode: str) -> str:
  314. mode = cli_kernel_mode if cli_kernel_mode != "auto" else config_kernel_mode
  315. if mode != "auto":
  316. return mode
  317. try:
  318. if Path("/etc/os-release").exists() and 'VERSION_ID="24' in Path("/etc/os-release").read_text(errors="ignore"):
  319. return "24"
  320. except Exception:
  321. pass
  322. try:
  323. if os.uname().release.startswith("6."):
  324. return "24"
  325. except Exception:
  326. pass
  327. return "20"
  328. async def start(self) -> None:
  329. if self.kernel_mode == "24":
  330. if self.config.direct_open_timeout == 10.0:
  331. self.config.direct_open_timeout = 6.0
  332. if self.config.relay_open_timeout == 10.0:
  333. self.config.relay_open_timeout = 6.0
  334. if self.config.tcp_connect_happy_eyeballs_delay is None:
  335. self.config.tcp_connect_happy_eyeballs_delay = 0.25
  336. await self.manager.start()
  337. relay_mode = "direct-only" if not self.config.relays else "direct+relay"
  338. print(f"[edge] kernel_mode={self.kernel_mode} relay_mode={relay_mode} relay snapshot: {self.manager.snapshot()}")
  339. server4 = await asyncio.start_server(self._accept, self.listen_host, self.listen_port, family=socket.AF_INET)
  340. sockets = [str(sock.getsockname()) for sock in server4.sockets or []]
  341. server6 = None
  342. if self.listen_host in ("::", "::1", "0.0.0.0", "127.0.0.1"):
  343. host6 = "::1" if self.listen_host == "127.0.0.1" else "::"
  344. try:
  345. server6 = await asyncio.start_server(self._accept, host6, self.listen_port, family=socket.AF_INET6)
  346. sockets.extend(str(sock.getsockname()) for sock in server6.sockets or [])
  347. except Exception as exc:
  348. print(f"[edge] ipv6 tcp listener skipped: {exc!r}")
  349. print(f"[edge] transparent tcp listening on {', '.join(sockets)}")
  350. if server6 is None:
  351. async with server4:
  352. await server4.serve_forever()
  353. else:
  354. async with server4, server6:
  355. await asyncio.gather(server4.serve_forever(), server6.serve_forever())
  356. def _direct_redundancy_for_target(self, target: TargetAddress) -> int:
  357. if target.family == socket.AF_INET6 and not self.config.direct_ipv6_enabled:
  358. return 0
  359. base = self.config.direct_redundancy
  360. if target.family == socket.AF_INET6 and self.config.direct_redundancy_v6 is not None:
  361. base = self.config.direct_redundancy_v6
  362. elif target.family == socket.AF_INET and self.config.direct_redundancy_v4 is not None:
  363. base = self.config.direct_redundancy_v4
  364. base = max(1, min(base, self.config.direct_max_redundancy))
  365. target_stats = self.tcp_target_wins.get((target.host, target.port), {})
  366. family_key = "ipv6" if target.family == socket.AF_INET6 else "ipv4"
  367. family_stats = self.tcp_family_wins.get(family_key, {})
  368. target_total = sum(target_stats.values())
  369. family_total = sum(family_stats.values())
  370. target_relay = sum(count for name, count in target_stats.items() if winner_group(name) != "direct")
  371. family_relay = sum(count for name, count in family_stats.items() if winner_group(name) != "direct")
  372. if target_total >= 4 and target_relay > grouped_total(target_stats, "direct"):
  373. return max(1, base - 1)
  374. if family_total >= 8 and family_relay > grouped_total(family_stats, "direct"):
  375. return max(1, base - 1)
  376. if target_total >= 4 and grouped_total(target_stats, "direct") > target_relay and base > 2:
  377. return base - 1
  378. if family_total >= 8 and grouped_total(family_stats, "direct") > family_relay and base > 2:
  379. return base - 1
  380. return base
  381. def _build_direct_paths(self, session: TcpSession) -> list[BasePath]:
  382. count = self._direct_redundancy_for_target(session.target)
  383. if count <= 0:
  384. return []
  385. return [
  386. DirectTcpPath(
  387. name=f"direct-{index + 1}" if count > 1 else "direct",
  388. on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload),
  389. open_timeout=self.config.direct_open_timeout,
  390. happy_eyeballs_delay=self.config.tcp_connect_happy_eyeballs_delay,
  391. tcp_nodelay=self.config.relay_tcp_nodelay,
  392. )
  393. for index in range(count)
  394. ]
  395. def _tcp_relay_connections(self) -> list[TcpRelayConnection]:
  396. return self.manager.available()
  397. def _session_race_profile(self, target: TargetAddress) -> tuple[int, int]:
  398. if target.port in self._interactive_ports:
  399. return self.config.ssh_warmup_bytes, self.config.ssh_loser_grace_ms
  400. return self.config.tcp_warmup_bytes, self.config.tcp_loser_grace_ms
  401. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  402. peer = writer.get_extra_info("peername")
  403. try:
  404. target = self._get_original_dst(writer)
  405. session_id = next(self.session_ids)
  406. warmup_bytes, loser_grace_ms = self._session_race_profile(target)
  407. session = TcpSession(
  408. session_id=session_id,
  409. target=target,
  410. reader=reader,
  411. writer=writer,
  412. paths=[],
  413. warmup_bytes=warmup_bytes,
  414. loser_grace_ms=loser_grace_ms,
  415. stats=self.tcp_win_counts,
  416. target_stats=self.tcp_target_wins,
  417. family_stats=self.tcp_family_wins,
  418. )
  419. paths: list[BasePath] = self._build_direct_paths(session)
  420. for connection in self._tcp_relay_connections():
  421. stream_id = next(self.stream_ids)
  422. paths.append(RelayTcpPath(connection.node.name, lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload), connection, session_id, stream_id))
  423. if not paths:
  424. raise RuntimeError("no tcp candidates available")
  425. session.paths = paths
  426. if session_id == 1 or session_id % self._accept_log_every == 0:
  427. print(f"[edge] accept peer={peer} session={session_id} target={target.host}:{target.port} candidates={[path.name for path in paths]}")
  428. await session.start()
  429. except Exception as exc:
  430. print(f"[edge] accept failed peer={peer} error={exc!r}")
  431. writer.close()
  432. with contextlib.suppress(Exception):
  433. await writer.wait_closed()
  434. async def _handle_tcp_session(self, session: TcpSession, path: BasePath, event: str, payload: bytes | None) -> None:
  435. await session.handle_path(path, event, payload)
  436. def _get_original_dst(self, writer: asyncio.StreamWriter) -> TargetAddress:
  437. sock = writer.get_extra_info("socket")
  438. if sock is None:
  439. raise RuntimeError("socket unavailable")
  440. if sock.family == socket.AF_INET:
  441. return parse_sockaddr(sock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, 16))
  442. if sock.family == socket.AF_INET6:
  443. return parse_sockaddr(sock.getsockopt(socket.IPPROTO_IPV6, IP6T_SO_ORIGINAL_DST, 128))
  444. raise RuntimeError(f"unsupported socket family={sock.family}")