from __future__ import annotations import asyncio import contextlib import itertools import socket import struct from dataclasses import dataclass, field from typing import Awaitable, Callable from .config import Config from .protocol import STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, encode_json from .relay_client import RelayConnection, RelayManager SO_ORIGINAL_DST = 80 IP6T_SO_ORIGINAL_DST = 80 IP_RECVORIGDSTADDR = 20 IPV6_RECVORIGDSTADDR = 74 @dataclass(frozen=True) class TargetAddress: host: str port: int family: int @dataclass(frozen=True) class PeerAddress: host: str port: int family: int def parse_sockaddr(raw: bytes) -> TargetAddress: if len(raw) < 8: raise ValueError("invalid transparent destination payload") family = struct.unpack_from("=H", raw, 0)[0] port = struct.unpack_from("!H", raw, 2)[0] if family == socket.AF_INET: host = socket.inet_ntoa(raw[4:8]) return TargetAddress(host=host, port=port, family=family) if family == socket.AF_INET6: if len(raw) < 28: raise ValueError("invalid IPv6 transparent destination payload") host = socket.inet_ntop(socket.AF_INET6, raw[8:24]) return TargetAddress(host=host, port=port, family=family) raise ValueError(f"unsupported family={family}") class BasePath: def __init__(self, name: str, on_frame: Callable[["BasePath", str, bytes | None], Awaitable[None]]) -> None: self.name = name self.on_frame = on_frame self.opened = False self.closed = False async def open(self, target: TargetAddress) -> None: raise NotImplementedError async def send(self, data: bytes) -> None: raise NotImplementedError async def close(self) -> None: raise NotImplementedError class DirectTcpPath(BasePath): def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]]) -> None: super().__init__(name, on_frame) self.reader: asyncio.StreamReader | None = None self.writer: asyncio.StreamWriter | None = None self.pump_task: asyncio.Task | None = None async def open(self, target: TargetAddress) -> None: try: family = socket.AF_INET6 if target.family == socket.AF_INET6 else socket.AF_INET self.reader, self.writer = await asyncio.open_connection(host=target.host, port=target.port, family=family) self.opened = True self.pump_task = asyncio.create_task(self._pump()) await self.on_frame(self, "status", b"ok") except Exception as exc: await self.on_frame(self, "status", str(exc).encode()) async def _pump(self) -> None: assert self.reader is not None try: while True: chunk = await self.reader.read(65536) if not chunk: break await self.on_frame(self, "data", chunk) except Exception: pass finally: await self.on_frame(self, "close", None) async def send(self, data: bytes) -> None: if self.closed or self.writer is None: return self.writer.write(data) await self.writer.drain() async def close(self) -> None: if self.closed: return self.closed = True if self.pump_task and self.pump_task is not asyncio.current_task(): self.pump_task.cancel() with contextlib.suppress(Exception): await self.pump_task if self.writer: self.writer.close() with contextlib.suppress(Exception): await self.writer.wait_closed() class RelayTcpPath(BasePath): def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], connection: RelayConnection, session_id: int, stream_id: int) -> None: super().__init__(name, on_frame) self.connection = connection self.session_id = session_id self.stream_id = stream_id async def open(self, target: TargetAddress) -> None: if self.connection.closed: await self.on_frame(self, "status", b"relay unavailable") return self.connection.bind(self.session_id, self.stream_id, self._handle_frame) try: await self.connection.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, encode_json({"host": target.host, "port": target.port, "family": target.family}))) except Exception as exc: await self.on_frame(self, "status", str(exc).encode()) async def _handle_frame(self, _conn: RelayConnection, frame: Frame) -> None: if frame.kind == TCP_STATUS: if frame.packet_id == STATUS_OK: self.opened = True await self.on_frame(self, "status", b"ok") else: await self.on_frame(self, "status", frame.payload) return if frame.kind == TCP_DATA: await self.on_frame(self, "data", frame.payload) return if frame.kind == TCP_CLOSE: await self.on_frame(self, "close", None) async def send(self, data: bytes) -> None: if self.closed or self.connection.closed: return await self.connection.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, data)) async def close(self) -> None: if self.closed: return self.closed = True self.connection.unbind(self.session_id, self.stream_id) if not self.connection.closed: with contextlib.suppress(Exception): await self.connection.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) @dataclass class TransparentSession: session_id: int target: TargetAddress reader: asyncio.StreamReader writer: asyncio.StreamWriter paths: list[BasePath] warmup_bytes: int opened_count: int = 0 status_count: int = 0 errors: list[str] = field(default_factory=list) winner: BasePath | None = None uplink_bytes: int = 0 open_event: asyncio.Event = field(default_factory=asyncio.Event) winner_event: asyncio.Event = field(default_factory=asyncio.Event) closed: bool = False pump_task: asyncio.Task | None = None async def start(self) -> None: await asyncio.gather(*(path.open(self.target) for path in self.paths), return_exceptions=True) await asyncio.wait_for(self.open_event.wait(), timeout=10) if self.opened_count == 0: raise ConnectionError(self.errors[0] if self.errors else "all paths failed") self.pump_task = asyncio.create_task(self._pump_local()) async def _pump_local(self) -> None: try: while True: chunk = await self.reader.read(65536) if not chunk: break self.uplink_bytes += len(chunk) active = [path for path in self.paths if path.opened and not path.closed] if not active: break if self.winner is None and self.uplink_bytes <= self.warmup_bytes: await asyncio.gather(*(path.send(chunk) for path in active), return_exceptions=True) else: if self.winner is None: await self.winner_event.wait() if self.winner: await self.winner.send(chunk) except Exception: pass finally: await self.close() async def handle_path(self, path: BasePath, event: str, payload: bytes | None) -> None: if self.closed: return if event == "status": self.status_count += 1 if payload == b"ok": self.opened_count += 1 elif payload is not None: self.errors.append(payload.decode("utf-8", errors="replace")) if self.opened_count > 0 or self.status_count == len(self.paths): self.open_event.set() return if event == "data": if self.winner is None: self.winner = path print(f"[edge] session={self.session_id} winner={path.name} target={self.target.host}:{self.target.port}") self.winner_event.set() await self._close_losers(path) if path is self.winner and payload is not None: self.writer.write(payload) await self.writer.drain() return if event == "close": path.closed = True if self.winner is None: remaining = [candidate for candidate in self.paths if candidate.opened and not candidate.closed] if not remaining: await self.close() elif path is self.winner: await self.close() async def _close_losers(self, winner: BasePath) -> None: await asyncio.gather(*(path.close() for path in self.paths if path is not winner), return_exceptions=True) async def close(self) -> None: if self.closed: return self.closed = True print(f"[edge] session={self.session_id} closed target={self.target.host}:{self.target.port}") if self.pump_task and self.pump_task is not asyncio.current_task(): self.pump_task.cancel() with contextlib.suppress(Exception): await self.pump_task await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True) self.writer.close() with contextlib.suppress(Exception): await self.writer.wait_closed() class DirectUdpPath(BasePath): def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], target: TargetAddress) -> None: super().__init__(name, on_frame) self.target = target self.socket: socket.socket | None = None self.read_task: asyncio.Task | None = None async def open(self, _target: TargetAddress) -> None: try: family = socket.AF_INET6 if self.target.family == socket.AF_INET6 else socket.AF_INET self.socket = socket.socket(family, socket.SOCK_DGRAM) self.socket.setblocking(False) await asyncio.get_running_loop().sock_connect(self.socket, (self.target.host, self.target.port)) self.opened = True self.read_task = asyncio.create_task(self._pump()) await self.on_frame(self, "status", b"ok") except Exception as exc: await self.on_frame(self, "status", str(exc).encode()) async def _pump(self) -> None: assert self.socket is not None loop = asyncio.get_running_loop() try: while True: data = await loop.sock_recv(self.socket, 65535) if not data: break await self.on_frame(self, "data", data) except Exception: pass finally: await self.on_frame(self, "close", None) async def send(self, data: bytes) -> None: if self.closed or self.socket is None: return await asyncio.get_running_loop().sock_sendall(self.socket, data) async def close(self) -> None: if self.closed: return self.closed = True if self.read_task and self.read_task is not asyncio.current_task(): self.read_task.cancel() with contextlib.suppress(Exception): await self.read_task if self.socket: self.socket.close() class RelayUdpPath(BasePath): def __init__(self, name: str, on_frame: Callable[[BasePath, str, bytes | None], Awaitable[None]], connection: RelayConnection, session_id: int, stream_id: int, target: TargetAddress) -> None: super().__init__(name, on_frame) self.connection = connection self.session_id = session_id self.stream_id = stream_id self.target = target async def open(self, _target: TargetAddress) -> None: if self.connection.closed: await self.on_frame(self, "status", b"relay unavailable") return self.connection.bind(self.session_id, self.stream_id, self._handle_frame) self.opened = True await self.on_frame(self, "status", b"ok") async def _handle_frame(self, _conn: RelayConnection, frame: Frame) -> None: if frame.kind == UDP_RECV: await self.on_frame(self, "data", frame.payload) async def send(self, data: bytes) -> None: if self.closed or self.connection.closed: return meta = encode_json({"host": self.target.host, "port": self.target.port, "family": self.target.family}) payload = meta + data await self.connection.send(Frame(UDP_SEND, self.session_id, self.stream_id, 0, len(meta), payload)) async def close(self) -> None: if self.closed: return self.closed = True self.connection.unbind(self.session_id, self.stream_id) @dataclass class UdpFlow: flow_id: int source: PeerAddress target: TargetAddress send_response: Callable[[PeerAddress, bytes], Awaitable[None]] paths: list[BasePath] winner: BasePath | None = None closed: bool = False last_activity: float = 0.0 async def start(self) -> None: await asyncio.gather(*(path.open(self.target) for path in self.paths), return_exceptions=True) async def send(self, payload: bytes) -> None: self.last_activity = asyncio.get_running_loop().time() active = [path for path in self.paths if path.opened and not path.closed] if self.winner is None: await asyncio.gather(*(path.send(payload) for path in active), return_exceptions=True) elif not self.winner.closed: await self.winner.send(payload) async def handle_path(self, path: BasePath, event: str, payload: bytes | None) -> None: self.last_activity = asyncio.get_running_loop().time() if event == "data" and payload is not None: if self.winner is None: self.winner = path print(f"[edge] udp flow={self.flow_id} winner={path.name} target={self.target.host}:{self.target.port}") if path is self.winner: await self.send_response(self.source, payload) if event == "close": path.closed = True async def close(self) -> None: if self.closed: return self.closed = True await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True) class TransparentUdpListener: def __init__(self, edge: "TransparentEdge", family: int, bind_host: str, port: int) -> None: self.edge = edge self.family = family self.bind_host = bind_host self.port = port self.socket: socket.socket | None = None def start(self) -> None: sock = socket.socket(self.family, socket.SOCK_DGRAM) sock.setblocking(False) if self.family == socket.AF_INET: sock.setsockopt(socket.SOL_IP, IP_RECVORIGDSTADDR, 1) sock.bind((self.bind_host, self.port)) else: sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) sock.setsockopt(socket.IPPROTO_IPV6, IPV6_RECVORIGDSTADDR, 1) sock.bind((self.bind_host, self.port, 0, 0)) self.socket = sock asyncio.get_running_loop().add_reader(sock.fileno(), self._on_readable) print(f"[edge] transparent udp listening on {sock.getsockname()}") def _on_readable(self) -> None: assert self.socket is not None try: data, ancdata, _flags, src = self.socket.recvmsg(65535, 512) except BlockingIOError: return except Exception as exc: print(f"[edge] udp recv failed family={self.family} error={exc!r}") return original = None for level, ctype, cdata in ancdata: if self.family == socket.AF_INET and level == socket.SOL_IP and ctype == IP_RECVORIGDSTADDR: original = parse_sockaddr(cdata) break if self.family == socket.AF_INET6 and level == socket.IPPROTO_IPV6 and ctype == IPV6_RECVORIGDSTADDR: original = parse_sockaddr(cdata) break if original is None: print(f"[edge] udp missing original dst family={self.family} src={src}") return if self.family == socket.AF_INET: source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET) else: source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET6) asyncio.create_task(self.edge.handle_udp_datagram(source, original, data, self)) async def send_response(self, source: PeerAddress, payload: bytes) -> None: assert self.socket is not None if source.family == socket.AF_INET: self.socket.sendto(payload, (source.host, source.port)) else: self.socket.sendto(payload, (source.host, source.port, 0, 0)) async def close(self) -> None: if self.socket is None: return asyncio.get_running_loop().remove_reader(self.socket.fileno()) self.socket.close() self.socket = None class TransparentEdge: def __init__(self, listen_host: str, listen_port: int, config: Config) -> None: self.listen_host = listen_host self.listen_port = listen_port self.config = config self.manager = RelayManager(config) self.session_ids = itertools.count(1) self.stream_ids = itertools.count(1) self.udp_listeners: list[TransparentUdpListener] = [] self.udp_flows: dict[tuple[PeerAddress, TargetAddress], UdpFlow] = {} self.udp_flow_ids = itertools.count(1) self.udp_gc_task: asyncio.Task | None = None async def start(self) -> None: await self.manager.start() print(f"[edge] relay snapshot: {self.manager.snapshot()}") server4 = await asyncio.start_server(self._accept, self.listen_host, self.listen_port, family=socket.AF_INET) sockets = [str(sock.getsockname()) for sock in server4.sockets or []] server6 = None if self.listen_host in ("::", "::1", "0.0.0.0", "127.0.0.1"): host6 = "::1" if self.listen_host == "127.0.0.1" else "::" try: server6 = await asyncio.start_server(self._accept, host6, self.listen_port, family=socket.AF_INET6) sockets.extend(str(sock.getsockname()) for sock in server6.sockets or []) except Exception as exc: print(f"[edge] ipv6 tcp listener skipped: {exc!r}") self._start_udp_listeners() self.udp_gc_task = asyncio.create_task(self._gc_udp_flows()) print(f"[edge] transparent tcp listening on {', '.join(sockets)}") if server6 is None: async with server4: await server4.serve_forever() else: async with server4, server6: await asyncio.gather(server4.serve_forever(), server6.serve_forever()) def _start_udp_listeners(self) -> None: binds = [] if self.listen_host == "127.0.0.1": binds = [(socket.AF_INET, "127.0.0.1"), (socket.AF_INET6, "::1")] elif self.listen_host == "0.0.0.0": binds = [(socket.AF_INET, "0.0.0.0"), (socket.AF_INET6, "::")] else: family = socket.AF_INET6 if ":" in self.listen_host else socket.AF_INET binds = [(family, self.listen_host)] for family, host in binds: try: listener = TransparentUdpListener(self, family, host, self.listen_port) listener.start() self.udp_listeners.append(listener) except Exception as exc: print(f"[edge] udp listener skipped family={family} host={host} error={exc!r}") async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: peer = writer.get_extra_info("peername") try: target = self._get_original_dst(writer) session_id = next(self.session_ids) session = TransparentSession(session_id=session_id, target=target, reader=reader, writer=writer, paths=[], warmup_bytes=self.config.tcp_warmup_bytes) paths: list[BasePath] = [DirectTcpPath(name="direct", on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload))] for connection in self.manager.available(): stream_id = next(self.stream_ids) paths.append(RelayTcpPath(name=connection.node.name, on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload), connection=connection, session_id=session_id, stream_id=stream_id)) session.paths = paths print(f"[edge] accept peer={peer} session={session_id} target={target.host}:{target.port} candidates={[path.name for path in paths]}") await session.start() except Exception as exc: print(f"[edge] accept failed peer={peer} error={exc!r}") writer.close() with contextlib.suppress(Exception): await writer.wait_closed() async def _handle_tcp_session(self, session: TransparentSession, path: BasePath, event: str, payload: bytes | None) -> None: await session.handle_path(path, event, payload) def _get_original_dst(self, writer: asyncio.StreamWriter) -> TargetAddress: sock = writer.get_extra_info("socket") if sock is None: raise RuntimeError("socket unavailable") family = sock.family if family == socket.AF_INET: raw = sock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, 16) return parse_sockaddr(raw) if family == socket.AF_INET6: raw = sock.getsockopt(socket.IPPROTO_IPV6, IP6T_SO_ORIGINAL_DST, 128) return parse_sockaddr(raw) raise RuntimeError(f"unsupported socket family={family}") async def handle_udp_datagram(self, source: PeerAddress, target: TargetAddress, payload: bytes, listener: TransparentUdpListener) -> None: key = (source, target) flow = self.udp_flows.get(key) if flow is None: flow_id = next(self.udp_flow_ids) paths: list[BasePath] = [DirectUdpPath(name="direct", on_frame=lambda path, event, data, fid=flow_id: self._handle_udp_path(fid, path, event, data), target=target)] for connection in self.manager.available(): stream_id = next(self.stream_ids) paths.append(RelayUdpPath(name=connection.node.name, on_frame=lambda path, event, data, fid=flow_id: self._handle_udp_path(fid, path, event, data), connection=connection, session_id=flow_id, stream_id=stream_id, target=target)) flow = UdpFlow(flow_id=flow_id, source=source, target=target, send_response=listener.send_response, paths=paths) self.udp_flows[key] = flow print(f"[edge] udp flow={flow_id} target={target.host}:{target.port} candidates={[path.name for path in paths]}") await flow.start() await flow.send(payload) async def _handle_udp_path(self, flow_id: int, path: BasePath, event: str, payload: bytes | None) -> None: for flow in list(self.udp_flows.values()): if flow.flow_id == flow_id: await flow.handle_path(path, event, payload) break async def _gc_udp_flows(self) -> None: loop = asyncio.get_running_loop() while True: await asyncio.sleep(30) now = loop.time() stale = [key for key, flow in self.udp_flows.items() if flow.last_activity and now - flow.last_activity > 120] for key in stale: flow = self.udp_flows.pop(key, None) if flow: await flow.close()