from __future__ import annotations import asyncio import contextlib import itertools import socket import struct from dataclasses import dataclass, field from typing import Dict from .config import Config, RelayNode from .protocol import AUTH, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame from .scheduler import Scheduler SOCKS_VERSION = 5 async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes: return await reader.readexactly(size) @dataclass(eq=False) class RelayLink: node: RelayNode reader: asyncio.StreamReader writer: asyncio.StreamWriter pump: asyncio.Task | None = None closed_event: asyncio.Event = field(default_factory=asyncio.Event) maintain_task: asyncio.Task | None = None tcp_sessions: Dict[tuple[int, int], "TcpRaceSession"] = field(default_factory=dict) udp_server: "UdpAssociateServer | None" = None closed: bool = False async def start(self) -> None: await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token}))) frame = await read_frame(self.reader) if frame.kind != AUTH or frame.packet_id != STATUS_OK: raise ConnectionError(f"relay auth failed: {self.node.name}") self.closed = False self.closed_event.clear() self.pump = asyncio.create_task(self._pump()) async def _pump(self) -> None: try: while True: frame = await read_frame(self.reader) key = (frame.session_id, frame.stream_id) if frame.kind in (TCP_STATUS, TCP_DATA, TCP_CLOSE): session = self.tcp_sessions.get(key) if session: await session.handle_frame(self, frame) elif frame.kind == UDP_RECV and self.udp_server: await self.udp_server.handle_from_relay(frame, self) except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, OSError): pass except Exception: pass finally: await self.close() async def send(self, frame: Frame) -> None: if self.closed: raise ConnectionError(f"relay closed: {self.node.name}") try: await write_frame(self.writer, frame) except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError) as exc: await self.close() raise ConnectionError(f"relay closed: {self.node.name}") from exc async def close(self) -> None: if self.closed: return self.closed = True self.closed_event.set() if self.pump and self.pump is not asyncio.current_task(): self.pump.cancel() with contextlib.suppress(Exception): await self.pump self.writer.close() with contextlib.suppress(Exception): await self.writer.wait_closed() @dataclass class UdpFlowState: flow_id: int client_addr: tuple[str, int] target_host: str target_port: int created_at: float last_activity: float packets_sent: int = 0 packets_received: int = 0 duplicate_responses: int = 0 winner_name: str | None = None candidate_names: tuple[str, ...] = () link_streams: dict[str, int] = field(default_factory=dict) initialized_links: set[str] = field(default_factory=set) direct_sockets: dict[str, socket.socket] = field(default_factory=dict) direct_tasks: dict[str, asyncio.Task] = field(default_factory=dict) direct_failures: set[str] = field(default_factory=set) relay_failures: dict[str, int] = field(default_factory=dict) relay_error_seen: set[str] = field(default_factory=set) path_last_seen: dict[str, float] = field(default_factory=dict) def touch(self, now: float) -> None: self.last_activity = now @dataclass class TcpRaceSession: session_id: int stream_id: int target_host: str target_port: int local_reader: asyncio.StreamReader local_writer: asyncio.StreamWriter links: list[RelayLink] warmup_bytes: int winning_link: RelayLink | None = None winner_name: str | None = None opened: int = 0 open_errors: list[str] = field(default_factory=list) uplink_bytes: int = 0 closed: bool = False open_event: asyncio.Event = field(default_factory=asyncio.Event) winner_event: asyncio.Event = field(default_factory=asyncio.Event) pump_task: asyncio.Task | None = None win_counts: Dict[str, int] = field(default_factory=dict) async def start(self) -> None: meta = encode_json({"host": self.target_host, "port": self.target_port}) for link in self.links: link.tcp_sessions[(self.session_id, self.stream_id)] = self await link.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, meta)) await asyncio.wait_for(self.open_event.wait(), timeout=10) if self.opened == 0: raise ConnectionError(self.open_errors[0] if self.open_errors else "all relays failed") self.pump_task = asyncio.create_task(self._pump_local()) async def _pump_local(self) -> None: try: while True: chunk = await self.local_reader.read(65536) if not chunk: break self.uplink_bytes += len(chunk) if self.winning_link is None and self.uplink_bytes <= self.warmup_bytes: await asyncio.gather(*(link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk)) for link in self.links if not link.closed), return_exceptions=True) else: if self.winning_link is None: await self.winner_event.wait() if self.winning_link: await self.winning_link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk)) except Exception: pass finally: await self.close() async def handle_frame(self, link: RelayLink, frame: Frame) -> None: if self.closed: return if frame.kind == TCP_STATUS: if frame.packet_id == STATUS_OK: self.opened += 1 else: self.open_errors.append(frame.payload.decode("utf-8", errors="replace")) if self.opened > 0 or len(self.open_errors) == len(self.links): self.open_event.set() return if frame.kind == TCP_DATA: if self.winning_link is None: self.winning_link = link self.winner_name = link.node.name self.win_counts[link.node.name] = self.win_counts.get(link.node.name, 0) + 1 node_total = self.win_counts[link.node.name] relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.win_counts.items())) or "none" print(f"[edge] tcp win session={self.session_id} target={self.target_host}:{self.target_port} winner={link.node.name} node_total={node_total} win_breakdown={relay_detail}") self.winner_event.set() await self._close_losers(except_link=link) if link is self.winning_link: self.local_writer.write(frame.payload) await self.local_writer.drain() return if frame.kind == TCP_CLOSE: if self.winning_link is None: self.winning_link = link self.winner_event.set() if link is self.winning_link: await self.close() async def _close_losers(self, except_link: RelayLink) -> None: await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if link is not except_link and not link.closed), return_exceptions=True) async def close(self) -> None: if self.closed: return self.closed = True if self.pump_task and self.pump_task is not asyncio.current_task(): self.pump_task.cancel() with contextlib.suppress(Exception): await self.pump_task await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if not link.closed), return_exceptions=True) for link in self.links: link.tcp_sessions.pop((self.session_id, self.stream_id), None) self.local_writer.close() with contextlib.suppress(Exception): await self.local_writer.wait_closed() class UdpAssociateServer(asyncio.DatagramProtocol): def __init__(self, edge: "SocksEdge") -> None: self.edge = edge self.transport: asyncio.DatagramTransport | None = None self.client_addr = None self.associate_peer = None self.packet_counter = itertools.count(1) self.client_flows: dict[tuple[tuple[str, int], str, int], UdpFlowState] = {} self.flow_counter = itertools.count(1) self.last_summary_at = 0.0 self.win_counts: Dict[str, int] = {} self.relay_error_counts: Dict[str, int] = {} def connection_made(self, transport) -> None: self.transport = transport def register_associate(self, peer) -> None: peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer) if self.associate_peer != peer_text: print(f"[edge] udp associate peer={peer_text}") self.associate_peer = peer_text def datagram_received(self, data: bytes, addr) -> None: if len(data) < 10: return if self.client_addr is None: self.client_addr = addr print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}") elif addr != self.client_addr: print(f"[edge] udp client rebound old={self.client_addr[0]}:{self.client_addr[1]} new={addr[0]}:{addr[1]}") self._reset_client_state(addr) host, port, payload = self._parse_socks_udp(data) loop = asyncio.get_running_loop() now = loop.time() flow_key = ((addr[0], addr[1]), host, port) flow = self.client_flows.get(flow_key) if flow is None: flow = UdpFlowState( flow_id=next(self.flow_counter), client_addr=(addr[0], addr[1]), target_host=host, target_port=port, created_at=now, last_activity=now, ) self.client_flows[flow_key] = flow flow.touch(now) flow.packets_sent += 1 packet_id = next(self.packet_counter) asyncio.create_task(self.edge.forward_udp(flow, payload, packet_id, self)) self._log_udp_summary() def _reset_client_state(self, addr) -> None: old_addr = self.client_addr remapped_flows: dict[tuple[tuple[str, int], str, int], UdpFlowState] = {} for flow in list(self.client_flows.values()): flow.client_addr = (addr[0], addr[1]) remapped_flows[((addr[0], addr[1]), flow.target_host, flow.target_port)] = flow self.client_flows = remapped_flows self.client_addr = addr print(f"[edge] udp client rebound migrated old={old_addr[0]}:{old_addr[1]} new={addr[0]}:{addr[1]} flows={len(self.client_flows)}") async def handle_from_relay(self, frame: Frame, link: RelayLink) -> None: if self.transport is None or self.client_addr is None: return flow = self.edge.udp_flow_sessions.get((frame.session_id, frame.stream_id)) if flow is None: return await self._deliver_flow_packet(flow, frame.packet_id, frame.payload, link.node.name) async def handle_from_direct(self, flow: UdpFlowState, path_name: str, payload: bytes) -> None: if self.transport is None or self.client_addr is None: return await self._deliver_flow_packet(flow, 0, payload, path_name) async def _deliver_flow_packet(self, flow: UdpFlowState, packet_id: int, payload: bytes, source_name: str) -> None: if self.transport is None or self.client_addr is None: return packet = self._build_socks_udp(flow.target_host, flow.target_port, payload) now = asyncio.get_running_loop().time() flow.touch(now) flow.path_last_seen[source_name] = now flow.packets_received += 1 if flow.winner_name is None: flow.winner_name = source_name self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1 self._log_udp_summary(force=True) elif flow.winner_name != source_name: flow.duplicate_responses += 1 winner_last_seen = flow.path_last_seen.get(flow.winner_name, 0.0) if winner_last_seen and now - winner_last_seen >= (self.edge.config.udp_failover_idle_ms / 1000): flow.winner_name = source_name self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1 self._log_udp_summary(force=True) if flow.winner_name == source_name: self.transport.sendto(packet, self.client_addr) def set_flow_candidates(self, flow: UdpFlowState, candidate_names: tuple[str, ...]) -> None: if not flow.candidate_names: flow.candidate_names = candidate_names def note_unsent(self, flow: UdpFlowState, packet_id: int) -> None: flow.touch(asyncio.get_running_loop().time()) flow.relay_failures["unsent"] = flow.relay_failures.get("unsent", 0) + 1 self._log_udp_summary(force=True) def _log_udp_summary(self, force: bool = False) -> None: now = asyncio.get_running_loop().time() if not force and now - self.last_summary_at < 10: return self.last_summary_at = now active_flows = len(self.client_flows) winners = sum(1 for flow in self.client_flows.values() if flow.winner_name) packets_sent = sum(flow.packets_sent for flow in self.client_flows.values()) packets_received = sum(flow.packets_received for flow in self.client_flows.values()) duplicates = sum(flow.duplicate_responses for flow in self.client_flows.values()) direct_paths = sum(len(flow.direct_sockets) for flow in self.client_flows.values()) relay_candidates = sum(len(flow.link_streams) for flow in self.client_flows.values()) candidate_names: list[str] = [] seen_candidates: set[str] = set() for flow in sorted(self.client_flows.values(), key=lambda item: item.flow_id): for name in flow.candidate_names: if name in seen_candidates: continue seen_candidates.add(name) candidate_names.append(name) direct_wins = sum(1 for flow in self.client_flows.values() if flow.winner_name and flow.winner_name.startswith("direct")) relay_wins = winners - direct_wins sample_flows = [ f"{flow.flow_id}:{flow.winner_name or 'pending'}" for flow in sorted(self.client_flows.values(), key=lambda item: item.flow_id) if flow.winner_name ][:5] winner_detail = ", ".join(sample_flows) or "none" relay_errors: list[str] = [] for flow in self.client_flows.values(): for name, count in flow.relay_failures.items(): relay_errors.append(f"{name}={count}") relay_error_detail = ", ".join(sorted(relay_errors)) or "none" if self.client_addr: print( f"[edge] udp summary bind={self.client_addr[0]}:{self.client_addr[1]} flows={active_flows} winners={winners} " f"winner_breakdown=direct={direct_wins},relay={relay_wins} sample={winner_detail} " f"candidates={candidate_names or ['none']} " f"sent={packets_sent} recv={packets_received} dup={duplicates} " f"direct_paths={direct_paths} relay_paths={relay_candidates} relay_errors={relay_error_detail}" ) else: print( f"[edge] udp summary bind=unbound flows={active_flows} winners={winners} " f"winner_breakdown=direct={direct_wins},relay={relay_wins} sample={winner_detail} " f"candidates={candidate_names or ['none']} " f"sent={packets_sent} recv={packets_received} dup={duplicates} " f"direct_paths={direct_paths} relay_paths={relay_candidates} relay_errors={relay_error_detail}" ) def _parse_socks_udp(self, packet: bytes) -> tuple[str, int, bytes]: atyp = packet[3] offset = 4 if atyp == 1: host = socket.inet_ntoa(packet[offset:offset + 4]) offset += 4 elif atyp == 3: size = packet[offset] offset += 1 host = packet[offset:offset + size].decode() offset += size else: raise ValueError("unsupported udp atyp") port = struct.unpack("!H", packet[offset:offset + 2])[0] offset += 2 return host, port, packet[offset:] def _build_socks_udp(self, host: str, port: int, payload: bytes) -> bytes: try: addr = socket.inet_aton(host) header = b"\x00\x00\x00\x01" + addr + struct.pack("!H", port) except OSError: raw = host.encode() header = b"\x00\x00\x00\x03" + bytes([len(raw)]) + raw + struct.pack("!H", port) return header + payload class SocksEdge: def __init__(self, listen_host: str, listen_port: int, config: Config) -> None: self.listen_host = listen_host self.listen_port = listen_port self.config = config self.scheduler = Scheduler(config) self.links: list[RelayLink] = [] self.session_ids = itertools.count(1) self.udp_stream_ids = itertools.count(1) self.udp_flow_sessions: dict[tuple[int, int], UdpFlowState] = {} self.udp_server: UdpAssociateServer | None = None async def start(self) -> None: await self.scheduler.start() await self._connect_relays() server = await asyncio.start_server(self._accept, self.listen_host, self.listen_port) sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or []) print(f"[edge] socks5 listening on {sockets}") async with server: await server.serve_forever() async def _connect_relays(self) -> None: loop = asyncio.get_running_loop() transport, protocol = await loop.create_datagram_endpoint(lambda: UdpAssociateServer(self), local_addr=(self.listen_host, 0)) self.udp_server = protocol self.udp_transport = transport for node in self.config.relays: link = RelayLink(node=node, reader=None, writer=None) # type: ignore[arg-type] link.udp_server = protocol self.links.append(link) link.maintain_task = asyncio.create_task(self._maintain_link(link)) async def _maintain_link(self, link: RelayLink) -> None: backoff = 1.0 while True: try: reader, writer = await asyncio.open_connection(link.node.host, link.node.port) sock = writer.get_extra_info("socket") if sock is not None: with contextlib.suppress(OSError): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) link.reader = reader link.writer = writer await link.start() backoff = 1.0 await link.closed_event.wait() except asyncio.CancelledError: raise except Exception: await asyncio.sleep(backoff) backoff = min(10.0, backoff * 2) async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: try: peer = writer.get_extra_info("peername") _host, _port, udp_mode = await self._handshake(reader, writer, peer) if udp_mode: return except Exception: writer.close() with contextlib.suppress(Exception): await writer.wait_closed() def _selected_links(self) -> list[RelayLink]: chosen = {node.name for node in self.scheduler.choose()} links = [link for link in self.links if link.node.name in chosen and not link.closed] return links or [link for link in self.links if not link.closed][:1] def _selected_udp_links(self) -> list[RelayLink]: online = [link for link in self.links if not link.closed and link.writer is not None] if not online: return [] ordered = sorted(online, key=lambda link: self.scheduler.scores.get(link.node.name).score if link.node.name in self.scheduler.scores else 999999.0) return ordered def _udp_direct_redundancy_for_target(self, target_host: str) -> int: base = self.config.udp_direct_redundancy if ":" in target_host and self.config.udp_direct_redundancy_v6 is not None: base = self.config.udp_direct_redundancy_v6 elif ":" not in target_host and self.config.udp_direct_redundancy_v4 is not None: base = self.config.udp_direct_redundancy_v4 return max(1, base) async def _ensure_udp_direct_paths(self, flow: UdpFlowState, udp_server: UdpAssociateServer) -> None: target_count = self._udp_direct_redundancy_for_target(flow.target_host) for index in range(target_count): name = f"direct-{index + 1}" if target_count > 1 else "direct" if name in flow.direct_sockets or name in flow.direct_failures: continue try: family = socket.AF_INET6 if ":" in flow.target_host else socket.AF_INET sock = socket.socket(family, socket.SOCK_DGRAM) sock.setblocking(False) await asyncio.get_running_loop().sock_connect(sock, (flow.target_host, flow.target_port)) flow.direct_sockets[name] = sock flow.direct_tasks[name] = asyncio.create_task(self._pump_udp_direct(flow, name, sock, udp_server)) except Exception as exc: flow.direct_failures.add(name) print(f"[edge] udp direct open error flow={flow.flow_id} path={name} target={flow.target_host}:{flow.target_port} error={exc!r}") async def _pump_udp_direct(self, flow: UdpFlowState, path_name: str, sock: socket.socket, udp_server: UdpAssociateServer) -> None: loop = asyncio.get_running_loop() try: while True: data = await loop.sock_recv(sock, 65535) if not data: break await udp_server.handle_from_direct(flow, path_name, data) except Exception: pass finally: flow.direct_tasks.pop(path_name, None) flow.direct_sockets.pop(path_name, None) with contextlib.suppress(Exception): sock.close() async def forward_udp(self, flow: UdpFlowState, payload: bytes, packet_id: int, udp_server: UdpAssociateServer) -> None: await self._ensure_udp_direct_paths(flow, udp_server) meta = encode_json({"host": flow.target_host, "port": flow.target_port}) links = self._selected_udp_links() direct_names = tuple(name for name in sorted(flow.direct_sockets)) relay_names = tuple(link.node.name for link in links) candidate_names = direct_names + relay_names udp_server.set_flow_candidates(flow, candidate_names) if not candidate_names: udp_server.note_unsent(flow, packet_id) return active_direct_names = list(direct_names) active_links = links if not (self.config.udp_always_broadcast or flow.winner_name is None): winner_last_seen = flow.path_last_seen.get(flow.winner_name, 0.0) if flow.winner_name else 0.0 if winner_last_seen and asyncio.get_running_loop().time() - winner_last_seen >= (self.config.udp_failover_idle_ms / 1000): flow.winner_name = None active_direct_names = [name for name in active_direct_names if name == flow.winner_name] active_links = [link for link in active_links if link.node.name == flow.winner_name] if not active_direct_names and not active_links: if direct_names: active_direct_names = [direct_names[0]] elif links: active_links = links[:1] copies = max(1, self.config.udp_redundancy + 1) sent_any = False for attempt in range(copies): for path_name in active_direct_names: sock = flow.direct_sockets.get(path_name) if sock is None: continue try: await asyncio.get_running_loop().sock_sendall(sock, payload) sent_any = True except Exception as exc: flow.direct_failures.add(path_name) flow.direct_sockets.pop(path_name, None) task = flow.direct_tasks.pop(path_name, None) if task is not None: task.cancel() with contextlib.suppress(Exception): sock.close() flow.relay_failures[path_name] = flow.relay_failures.get(path_name, 0) + 1 if path_name not in flow.relay_error_seen: flow.relay_error_seen.add(path_name) print( f"[edge] udp relay error flow={flow.flow_id} relay={path_name} error={exc!r}" ) for link in active_links: stream_id = flow.link_streams.get(link.node.name) if stream_id is None: stream_id = next(self.udp_stream_ids) flow.link_streams[link.node.name] = stream_id self.udp_flow_sessions[(flow.flow_id, stream_id)] = flow include_meta = link.node.name not in flow.initialized_links body = (meta + payload) if include_meta else payload meta_len = len(meta) if include_meta else 0 try: await link.send(Frame(UDP_SEND, flow.flow_id, stream_id, 0, meta_len, body)) flow.initialized_links.add(link.node.name) sent_any = True except Exception as exc: flow.link_streams.pop(link.node.name, None) self.udp_flow_sessions.pop((flow.flow_id, stream_id), None) flow.relay_failures[link.node.name] = flow.relay_failures.get(link.node.name, 0) + 1 if link.node.name not in flow.relay_error_seen: flow.relay_error_seen.add(link.node.name) print( f"[edge] udp relay error flow={flow.flow_id} relay={link.node.name} error={exc!r}" ) if attempt + 1 < copies and self.config.udp_copy_interval_ms > 0: await asyncio.sleep(self.config.udp_copy_interval_ms / 1000) if not sent_any: udp_server.note_unsent(flow, packet_id) async def _handshake(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer) -> tuple[str, int, bool]: version, methods_len = (await read_exact(reader, 2)) if version != SOCKS_VERSION: raise ValueError("unsupported socks version") await read_exact(reader, methods_len) writer.write(b"\x05\x00") await writer.drain() version, command, _, atyp = await read_exact(reader, 4) if version != SOCKS_VERSION: raise ValueError("unsupported socks version") if atyp == 1: host = socket.inet_ntoa(await read_exact(reader, 4)) elif atyp == 3: size = (await read_exact(reader, 1))[0] host = (await read_exact(reader, size)).decode() else: raise ValueError("unsupported atyp") port = struct.unpack("!H", await read_exact(reader, 2))[0] peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer) if command == 1: print(f"[edge] socks handshake peer={peer_text} command=connect target={host}:{port}") writer.write(b"\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00") await writer.drain() return host, port, False if command == 3 and self.udp_server and self.udp_server.transport: bind_host, bind_port = self.udp_server.transport.get_extra_info("sockname")[:2] self.udp_server.register_associate(peer) print(f"[edge] socks handshake peer={peer_text} command=udp_associate target={host}:{port} bind={bind_host}:{bind_port}") writer.write(b"\x05\x00\x00\x01" + socket.inet_aton(bind_host) + struct.pack("!H", bind_port)) await writer.drain() return host, port, True raise ValueError("unsupported socks command")