from __future__ import annotations from pathlib import Path import asyncio import contextlib import itertools import os import socket import struct from dataclasses import dataclass, field from typing import Awaitable, Callable from .config_tcp import TcpConfig from .protocol import STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, Frame, encode_json from .relay_client_tcp import TcpRelayConnection, TcpRelayManager SO_ORIGINAL_DST = 80 IP6T_SO_ORIGINAL_DST = 80 SUPPRESSED_CLOSE_EXCEPTIONS = (Exception, asyncio.CancelledError) @dataclass(frozen=True) class TargetAddress: host: str port: int family: int def parse_sockaddr(raw: bytes) -> TargetAddress: if len(raw) < 8: raise ValueError("invalid transparent destination payload") family = struct.unpack_from("=H", raw, 0)[0] port = struct.unpack_from("!H", raw, 2)[0] if family == socket.AF_INET: return TargetAddress(host=socket.inet_ntoa(raw[4:8]), port=port, family=family) if family == socket.AF_INET6: if len(raw) < 28: raise ValueError("invalid IPv6 transparent destination payload") return TargetAddress(host=socket.inet_ntop(socket.AF_INET6, raw[8:24]), port=port, family=family) raise ValueError(f"unsupported family={family}") def winner_group(name: str) -> str: return "direct" if name.startswith("direct") else name def grouped_total(stats: dict[str, int], group: str) -> int: return sum(count for name, count in stats.items() if winner_group(name) == group) class BasePath: def __init__(self, name: str, on_frame: Callable[["BasePath", str, bytes | None], Awaitable[None]]) -> None: self.name = name self.on_frame = on_frame self.opened = False self.closed = False async def open(self, target: TargetAddress) -> None: raise NotImplementedError async def send(self, data: bytes) -> None: raise NotImplementedError async def close(self) -> None: raise NotImplementedError class DirectTcpPath(BasePath): def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], open_timeout: float, happy_eyeballs_delay: float | None, tcp_nodelay: bool = True) -> None: super().__init__(name, on_frame) self.reader: asyncio.StreamReader | None = None self.writer: asyncio.StreamWriter | None = None self.pump_task: asyncio.Task | None = None self.open_timeout = open_timeout self.happy_eyeballs_delay = happy_eyeballs_delay self.tcp_nodelay = tcp_nodelay async def open(self, target: TargetAddress) -> None: try: family = socket.AF_INET6 if target.family == socket.AF_INET6 else socket.AF_INET kwargs = {"host": target.host, "port": target.port, "family": family} if self.happy_eyeballs_delay is not None: kwargs["happy_eyeballs_delay"] = self.happy_eyeballs_delay self.reader, self.writer = await asyncio.wait_for(asyncio.open_connection(**kwargs), timeout=self.open_timeout) sock = self.writer.get_extra_info("socket") if sock is not None and self.tcp_nodelay: with contextlib.suppress(OSError): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.opened = True self.pump_task = asyncio.create_task(self._pump()) await self.on_frame(self, "status", b"ok") except Exception as exc: await self.on_frame(self, "status", str(exc).encode()) async def _pump(self) -> None: assert self.reader is not None try: while True: try: chunk = await self.reader.read(65536) except (ConnectionResetError, BrokenPipeError, OSError): break if not chunk: break await self.on_frame(self, "data", chunk) finally: await self.on_frame(self, "close", None) async def send(self, data: bytes) -> None: if self.closed or self.writer is None: return try: self.writer.write(data) await self.writer.drain() except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError) as exc: await self.close() raise ConnectionError("relay closed") from exc async def close(self) -> None: if self.closed: return self.closed = True if self.writer: self.writer.close() with contextlib.suppress(*SUPPRESSED_CLOSE_EXCEPTIONS): await self.writer.wait_closed() class RelayTcpPath(BasePath): def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], connection: TcpRelayConnection, session_id: int, stream_id: int) -> None: super().__init__(name, on_frame) self.connection = connection self.session_id = session_id self.stream_id = stream_id self.unbind_task: asyncio.Task | None = None async def open(self, target: TargetAddress) -> None: if self.connection.closed: await self.on_frame(self, "status", b"relay unavailable") return self.connection.bind(self.session_id, self.stream_id, self._handle_frame) try: await self.connection.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, encode_json({"host": target.host, "port": target.port, "family": target.family}))) except Exception as exc: self.connection.unbind(self.session_id, self.stream_id) await self.on_frame(self, "status", str(exc).encode()) async def _handle_frame(self, _conn: TcpRelayConnection, frame: Frame) -> None: if frame.kind == TCP_STATUS: if frame.packet_id == STATUS_OK: self.opened = True await self.on_frame(self, "status", b"ok") else: await self.on_frame(self, "status", frame.payload) return if frame.kind == TCP_DATA: await self.on_frame(self, "data", frame.payload) return if frame.kind == TCP_CLOSE: await self.on_frame(self, "close", None) async def send(self, data: bytes) -> None: if self.closed or self.connection.closed: return await self.connection.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, data)) async def close(self) -> None: if self.closed: return self.closed = True if self.unbind_task is None or self.unbind_task.done(): self.unbind_task = asyncio.create_task(self._delayed_unbind()) if not self.connection.closed: with contextlib.suppress(Exception): await self.connection.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) async def _delayed_unbind(self) -> None: await asyncio.sleep(0.5) self.connection.unbind(self.session_id, self.stream_id) @dataclass class TcpSession: session_id: int target: TargetAddress reader: asyncio.StreamReader writer: asyncio.StreamWriter paths: list[BasePath] warmup_bytes: int loser_grace_ms: int stats: dict[str, int] target_stats: dict[tuple[str, int], dict[str, int]] family_stats: dict[str, dict[str, int]] opened_count: int = 0 status_count: int = 0 errors: list[str] = field(default_factory=list) winner: BasePath | None = None uplink_bytes: int = 0 open_event: asyncio.Event = field(default_factory=asyncio.Event) winner_event: asyncio.Event = field(default_factory=asyncio.Event) close_event: asyncio.Event = field(default_factory=asyncio.Event) closed: bool = False closing: bool = False close_task: asyncio.Task | None = None pump_task: asyncio.Task | None = None loser_close_task: asyncio.Task | None = None open_tasks: list[asyncio.Task] = field(default_factory=list) def _choose_winner(self, winner: BasePath) -> None: if self.winner is not None: return self.winner = winner self._record_win(winner) self.winner_event.set() def _record_win(self, winner: BasePath) -> None: self.stats[winner.name] = self.stats.get(winner.name, 0) + 1 key = (self.target.host, self.target.port) target_stats = self.target_stats.setdefault(key, {}) target_stats[winner.name] = target_stats.get(winner.name, 0) + 1 family_key = "ipv6" if self.target.family == socket.AF_INET6 else "ipv4" family_stats = self.family_stats.setdefault(family_key, {}) family_stats[winner.name] = family_stats.get(winner.name, 0) + 1 direct_wins = grouped_total(self.stats, "direct") relay_wins = sum(count for name, count in self.stats.items() if winner_group(name) != "direct") target_direct = grouped_total(target_stats, "direct") target_relay = sum(count for name, count in target_stats.items() if winner_group(name) != "direct") family_direct = grouped_total(family_stats, "direct") family_relay = sum(count for name, count in family_stats.items() if winner_group(name) != "direct") relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.stats.items()) if winner_group(name) != "direct") or "none" target_detail = ", ".join(f"{name}={count}" for name, count in sorted(target_stats.items()) if winner_group(name) != "direct") or "none" target_pref = "relay" if target_relay > target_direct else "direct" family_pref = "relay" if family_relay > family_direct else "direct" print(f"[edge] tcp win session={self.session_id} target={self.target.host}:{self.target.port} winner={winner.name} direct={direct_wins} relay={relay_wins} relay_breakdown={relay_detail} target_pref={target_pref} target_direct={target_direct} target_relay={target_relay} target_breakdown={target_detail} family_pref={family_pref} family={family_key} family_direct={family_direct} family_relay={family_relay}") async def start(self) -> None: self.open_tasks = [asyncio.create_task(path.open(self.target)) for path in self.paths] await asyncio.wait_for(self.open_event.wait(), timeout=8) if self.opened_count == 0: raise ConnectionError(self.errors[0] if self.errors else "all paths failed") self.pump_task = asyncio.create_task(self._pump_local()) async def _pump_local(self) -> None: try: while True: chunk = await self.reader.read(65536) if not chunk: break self.uplink_bytes += len(chunk) active = [path for path in self.paths if path.opened and not path.closed] if not active: break if self.winner is None and self.uplink_bytes <= self.warmup_bytes: await asyncio.gather(*(path.send(chunk) for path in active), return_exceptions=True) else: if self.winner is None: await self.winner_event.wait() if self.winner and self.winner.opened and not self.winner.closed: await self.winner.send(chunk) else: break finally: self._request_close() async def handle_path(self, path: BasePath, event: str, payload: bytes | None) -> None: if self.closed: return if event == "status": self.status_count += 1 if payload == b"ok": self.opened_count += 1 elif payload is not None: self.errors.append(payload.decode("utf-8", errors="replace")) if self.opened_count > 0 or self.status_count == len(self.paths): self.open_event.set() return if event == "data": if self.winner is None: self._choose_winner(path) if self.loser_grace_ms > 0: self.loser_close_task = asyncio.create_task(self._close_losers_after_grace(path)) else: self.loser_close_task = asyncio.create_task(self._close_losers(path)) if path is self.winner and payload is not None: self.writer.write(payload) await self.writer.drain() return if event == "close": path.closed = True if self.winner is None: remaining = [candidate for candidate in self.paths if candidate.opened and not candidate.closed] if not remaining: self._request_close() elif path is self.winner: self._request_close() async def _close_losers(self, winner: BasePath) -> None: await asyncio.gather(*(path.close() for path in self.paths if path is not winner), return_exceptions=True) async def _close_losers_after_grace(self, winner: BasePath) -> None: await asyncio.sleep(self.loser_grace_ms / 1000) if not self.closed: await self._close_losers(winner) def _request_close(self) -> None: if self.closing: return self.closing = True self.close_task = asyncio.create_task(self._finalize()) async def _finalize(self) -> None: if self.closed: self.close_event.set() return self.closed = True if self.pump_task and not self.pump_task.done(): self.pump_task.cancel() if self.loser_close_task and not self.loser_close_task.done(): self.loser_close_task.cancel() for task in self.open_tasks: if not task.done(): task.cancel() if self.pump_task: with contextlib.suppress(*SUPPRESSED_CLOSE_EXCEPTIONS): await self.pump_task for task in self.open_tasks: with contextlib.suppress(*SUPPRESSED_CLOSE_EXCEPTIONS): await task await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True) self.writer.close() with contextlib.suppress(*SUPPRESSED_CLOSE_EXCEPTIONS): await self.writer.wait_closed() self.close_event.set() async def close(self) -> None: self._request_close() if asyncio.current_task() is self.pump_task: return await self.close_event.wait() class TcpEdge: def __init__(self, listen_host: str, listen_port: int, config: TcpConfig, kernel_mode: str = "auto") -> None: self.listen_host = listen_host self.listen_port = listen_port self.config = config self.kernel_mode = self._resolve_kernel_mode(kernel_mode, config.kernel_mode) self.manager = TcpRelayManager(config) self.session_ids = itertools.count(1) self.stream_ids = itertools.count(1) self.tcp_win_counts: dict[str, int] = {} self.tcp_target_wins: dict[tuple[str, int], dict[str, int]] = {} self.tcp_family_wins: dict[str, dict[str, int]] = {"ipv4": {}, "ipv6": {}} self._accept_log_every = 25 self._interactive_ports = {22, 29765} def _resolve_kernel_mode(self, cli_kernel_mode: str, config_kernel_mode: str) -> str: mode = cli_kernel_mode if cli_kernel_mode != "auto" else config_kernel_mode if mode != "auto": return mode try: if Path("/etc/os-release").exists() and 'VERSION_ID="24' in Path("/etc/os-release").read_text(errors="ignore"): return "24" except Exception: pass try: if os.uname().release.startswith("6."): return "24" except Exception: pass return "20" async def start(self) -> None: if self.kernel_mode == "24": if self.config.direct_open_timeout == 10.0: self.config.direct_open_timeout = 6.0 if self.config.relay_open_timeout == 10.0: self.config.relay_open_timeout = 6.0 if self.config.tcp_connect_happy_eyeballs_delay is None: self.config.tcp_connect_happy_eyeballs_delay = 0.25 await self.manager.start() relay_mode = "direct-only" if not self.config.relays else "direct+relay" print(f"[edge] kernel_mode={self.kernel_mode} relay_mode={relay_mode} relay snapshot: {self.manager.snapshot()}") server4 = await asyncio.start_server(self._accept, self.listen_host, self.listen_port, family=socket.AF_INET) sockets = [str(sock.getsockname()) for sock in server4.sockets or []] server6 = None if self.listen_host in ("::", "::1", "0.0.0.0", "127.0.0.1"): host6 = "::1" if self.listen_host == "127.0.0.1" else "::" try: server6 = await asyncio.start_server(self._accept, host6, self.listen_port, family=socket.AF_INET6) sockets.extend(str(sock.getsockname()) for sock in server6.sockets or []) except Exception as exc: print(f"[edge] ipv6 tcp listener skipped: {exc!r}") print(f"[edge] transparent tcp listening on {', '.join(sockets)}") if server6 is None: async with server4: await server4.serve_forever() else: async with server4, server6: await asyncio.gather(server4.serve_forever(), server6.serve_forever()) def _direct_redundancy_for_target(self, target: TargetAddress) -> int: if target.family == socket.AF_INET6 and not self.config.direct_ipv6_enabled: return 0 base = self.config.direct_redundancy if target.family == socket.AF_INET6 and self.config.direct_redundancy_v6 is not None: base = self.config.direct_redundancy_v6 elif target.family == socket.AF_INET and self.config.direct_redundancy_v4 is not None: base = self.config.direct_redundancy_v4 base = max(1, min(base, self.config.direct_max_redundancy)) target_stats = self.tcp_target_wins.get((target.host, target.port), {}) family_key = "ipv6" if target.family == socket.AF_INET6 else "ipv4" family_stats = self.tcp_family_wins.get(family_key, {}) target_total = sum(target_stats.values()) family_total = sum(family_stats.values()) target_relay = sum(count for name, count in target_stats.items() if winner_group(name) != "direct") family_relay = sum(count for name, count in family_stats.items() if winner_group(name) != "direct") if target_total >= 4 and target_relay > grouped_total(target_stats, "direct"): return max(1, base - 1) if family_total >= 8 and family_relay > grouped_total(family_stats, "direct"): return max(1, base - 1) if target_total >= 4 and grouped_total(target_stats, "direct") > target_relay and base > 2: return base - 1 if family_total >= 8 and grouped_total(family_stats, "direct") > family_relay and base > 2: return base - 1 return base def _build_direct_paths(self, session: TcpSession) -> list[BasePath]: count = self._direct_redundancy_for_target(session.target) if count <= 0: return [] return [ DirectTcpPath( name=f"direct-{index + 1}" if count > 1 else "direct", on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload), open_timeout=self.config.direct_open_timeout, happy_eyeballs_delay=self.config.tcp_connect_happy_eyeballs_delay, tcp_nodelay=self.config.relay_tcp_nodelay, ) for index in range(count) ] def _tcp_relay_connections(self) -> list[TcpRelayConnection]: return self.manager.available() def _session_race_profile(self, target: TargetAddress) -> tuple[int, int]: if target.port in self._interactive_ports: return self.config.ssh_warmup_bytes, self.config.ssh_loser_grace_ms return self.config.tcp_warmup_bytes, self.config.tcp_loser_grace_ms async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: peer = writer.get_extra_info("peername") try: target = self._get_original_dst(writer) session_id = next(self.session_ids) warmup_bytes, loser_grace_ms = self._session_race_profile(target) session = TcpSession( session_id=session_id, target=target, reader=reader, writer=writer, paths=[], warmup_bytes=warmup_bytes, loser_grace_ms=loser_grace_ms, stats=self.tcp_win_counts, target_stats=self.tcp_target_wins, family_stats=self.tcp_family_wins, ) paths: list[BasePath] = self._build_direct_paths(session) for connection in self._tcp_relay_connections(): stream_id = next(self.stream_ids) paths.append(RelayTcpPath(connection.node.name, lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload), connection, session_id, stream_id)) if not paths: raise RuntimeError("no tcp candidates available") session.paths = paths if session_id == 1 or session_id % self._accept_log_every == 0: print(f"[edge] accept peer={peer} session={session_id} target={target.host}:{target.port} candidates={[path.name for path in paths]}") await session.start() except Exception as exc: print(f"[edge] accept failed peer={peer} error={exc!r}") writer.close() with contextlib.suppress(Exception): await writer.wait_closed() async def _handle_tcp_session(self, session: TcpSession, path: BasePath, event: str, payload: bytes | None) -> None: await session.handle_path(path, event, payload) def _get_original_dst(self, writer: asyncio.StreamWriter) -> TargetAddress: sock = writer.get_extra_info("socket") if sock is None: raise RuntimeError("socket unavailable") if sock.family == socket.AF_INET: return parse_sockaddr(sock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, 16)) if sock.family == socket.AF_INET6: return parse_sockaddr(sock.getsockopt(socket.IPPROTO_IPV6, IP6T_SO_ORIGINAL_DST, 128)) raise RuntimeError(f"unsupported socket family={sock.family}")