from __future__ import annotations import asyncio import contextlib import itertools import socket import struct from dataclasses import dataclass, field from typing import Dict from .config import Config, RelayNode from .protocol import AUTH, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame from .scheduler import Scheduler SOCKS_VERSION = 5 async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes: return await reader.readexactly(size) @dataclass(eq=False) class RelayLink: node: RelayNode reader: asyncio.StreamReader writer: asyncio.StreamWriter pump: asyncio.Task | None = None tcp_sessions: Dict[tuple[int, int], "TcpRaceSession"] = field(default_factory=dict) udp_server: "UdpAssociateServer | None" = None closed: bool = False async def start(self) -> None: await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token}))) frame = await read_frame(self.reader) if frame.kind != AUTH or frame.packet_id != STATUS_OK: raise ConnectionError(f"relay auth failed: {self.node.name}") self.pump = asyncio.create_task(self._pump()) async def _pump(self) -> None: try: while True: frame = await read_frame(self.reader) key = (frame.session_id, frame.stream_id) if frame.kind in (TCP_STATUS, TCP_DATA, TCP_CLOSE): session = self.tcp_sessions.get(key) if session: await session.handle_frame(self, frame) elif frame.kind == UDP_RECV and self.udp_server: await self.udp_server.handle_from_relay(frame, self) except asyncio.IncompleteReadError: pass finally: await self.close() async def send(self, frame: Frame) -> None: if self.closed: raise ConnectionError(f"relay closed: {self.node.name}") await write_frame(self.writer, frame) async def close(self) -> None: if self.closed: return self.closed = True self.writer.close() with contextlib.suppress(Exception): await self.writer.wait_closed() @dataclass class UdpFlowState: flow_id: int client_addr: tuple[str, int] target_host: str target_port: int created_at: float last_activity: float packets_sent: int = 0 packets_received: int = 0 duplicate_responses: int = 0 winner_name: str | None = None candidate_names: tuple[str, ...] = () def touch(self, now: float) -> None: self.last_activity = now @dataclass class TcpRaceSession: session_id: int stream_id: int target_host: str target_port: int local_reader: asyncio.StreamReader local_writer: asyncio.StreamWriter links: list[RelayLink] warmup_bytes: int winning_link: RelayLink | None = None winner_name: str | None = None opened: int = 0 open_errors: list[str] = field(default_factory=list) uplink_bytes: int = 0 closed: bool = False open_event: asyncio.Event = field(default_factory=asyncio.Event) winner_event: asyncio.Event = field(default_factory=asyncio.Event) pump_task: asyncio.Task | None = None win_counts: Dict[str, int] = field(default_factory=dict) async def start(self) -> None: meta = encode_json({"host": self.target_host, "port": self.target_port}) for link in self.links: link.tcp_sessions[(self.session_id, self.stream_id)] = self await link.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, meta)) await asyncio.wait_for(self.open_event.wait(), timeout=10) if self.opened == 0: raise ConnectionError(self.open_errors[0] if self.open_errors else "all relays failed") self.pump_task = asyncio.create_task(self._pump_local()) async def _pump_local(self) -> None: try: while True: chunk = await self.local_reader.read(65536) if not chunk: break self.uplink_bytes += len(chunk) if self.winning_link is None and self.uplink_bytes <= self.warmup_bytes: await asyncio.gather(*(link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk)) for link in self.links if not link.closed), return_exceptions=True) else: if self.winning_link is None: await self.winner_event.wait() if self.winning_link: await self.winning_link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk)) except Exception: pass finally: await self.close() async def handle_frame(self, link: RelayLink, frame: Frame) -> None: if self.closed: return if frame.kind == TCP_STATUS: if frame.packet_id == STATUS_OK: self.opened += 1 else: self.open_errors.append(frame.payload.decode("utf-8", errors="replace")) if self.opened > 0 or len(self.open_errors) == len(self.links): self.open_event.set() return if frame.kind == TCP_DATA: if self.winning_link is None: self.winning_link = link self.winner_name = link.node.name self.win_counts[link.node.name] = self.win_counts.get(link.node.name, 0) + 1 node_total = self.win_counts[link.node.name] relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.win_counts.items())) or "none" print(f"[edge] tcp win session={self.session_id} target={self.target_host}:{self.target_port} winner={link.node.name} node_total={node_total} win_breakdown={relay_detail}") self.winner_event.set() await self._close_losers(except_link=link) if link is self.winning_link: self.local_writer.write(frame.payload) await self.local_writer.drain() return if frame.kind == TCP_CLOSE: if self.winning_link is None: self.winning_link = link self.winner_event.set() if link is self.winning_link: await self.close() async def _close_losers(self, except_link: RelayLink) -> None: await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if link is not except_link and not link.closed), return_exceptions=True) async def close(self) -> None: if self.closed: return self.closed = True if self.pump_task and self.pump_task is not asyncio.current_task(): self.pump_task.cancel() with contextlib.suppress(Exception): await self.pump_task await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if not link.closed), return_exceptions=True) for link in self.links: link.tcp_sessions.pop((self.session_id, self.stream_id), None) self.local_writer.close() with contextlib.suppress(Exception): await self.local_writer.wait_closed() class UdpAssociateServer(asyncio.DatagramProtocol): def __init__(self, edge: "SocksEdge") -> None: self.edge = edge self.transport: asyncio.DatagramTransport | None = None self.client_addr = None self.packet_counter = itertools.count(1) self.pending: set[int] = set() self.packet_flows: dict[int, int] = {} self.client_flows: dict[tuple[tuple[str, int], str, int], UdpFlowState] = {} self.flow_counter = itertools.count(1) self.last_summary_at = 0.0 self.win_counts: Dict[str, int] = {} def connection_made(self, transport) -> None: self.transport = transport def datagram_received(self, data: bytes, addr) -> None: if len(data) < 10: return if self.client_addr is None: self.client_addr = addr print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}") if addr != self.client_addr: return host, port, payload = self._parse_socks_udp(data) loop = asyncio.get_running_loop() now = loop.time() flow_key = ((addr[0], addr[1]), host, port) flow = self.client_flows.get(flow_key) if flow is None: flow = UdpFlowState( flow_id=next(self.flow_counter), client_addr=(addr[0], addr[1]), target_host=host, target_port=port, created_at=now, last_activity=now, ) self.client_flows[flow_key] = flow flow.touch(now) flow.packets_sent += 1 packet_id = next(self.packet_counter) self.pending.add(packet_id) self.packet_flows[packet_id] = flow.flow_id print(f"[edge] udp recv flow={flow.flow_id} packet_id={packet_id} target={host}:{port} size={len(payload)}") asyncio.create_task(self.edge.forward_udp(host, port, payload, packet_id, self)) self._log_udp_summary() async def handle_from_relay(self, frame: Frame, link: RelayLink) -> None: if frame.packet_id not in self.pending or self.transport is None or self.client_addr is None: return self.pending.discard(frame.packet_id) flow_id = self.packet_flows.pop(frame.packet_id, 0) host = self.edge.udp_targets.get(frame.packet_id, ("0.0.0.0", 0))[0] port = self.edge.udp_targets.get(frame.packet_id, ("0.0.0.0", 0))[1] packet = self._build_socks_udp(host, port, frame.payload) winner_log = "" flow = self._find_flow(flow_id) if flow is not None: now = asyncio.get_running_loop().time() flow.touch(now) flow.packets_received += 1 if flow.winner_name is None: flow.winner_name = link.node.name self.win_counts[link.node.name] = self.win_counts.get(link.node.name, 0) + 1 relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.win_counts.items())) or "none" print( f"[edge] udp flow={flow.flow_id} winner={link.node.name} " f"target={flow.target_host}:{flow.target_port} mode=single candidates={len(flow.candidate_names) or len(self.edge.links)}" ) print(f"[edge] udp win relay_breakdown={relay_detail}") elif flow.winner_name != link.node.name: flow.duplicate_responses += 1 winner_log = f" duplicate=1 winner={flow.winner_name} from={link.node.name}" print( f"[edge] udp send flow={flow_id or 'unknown'} packet_id={frame.packet_id} " f"target={host}:{port} size={len(frame.payload)} relay={link.node.name}{winner_log}" ) self.transport.sendto(packet, self.client_addr) self._log_udp_summary() def set_flow_candidates(self, packet_id: int, candidate_names: tuple[str, ...]) -> None: flow_id = self.packet_flows.get(packet_id) flow = self._find_flow(flow_id) if flow is not None and not flow.candidate_names: flow.candidate_names = candidate_names def note_unsent(self, packet_id: int) -> None: flow_id = self.packet_flows.pop(packet_id, 0) self.pending.discard(packet_id) flow = self._find_flow(flow_id) if flow is not None: flow.touch(asyncio.get_running_loop().time()) print(f"[edge] udp drop flow={flow_id or 'unknown'} packet_id={packet_id} reason=no_available_links") self._log_udp_summary(force=True) def _find_flow(self, flow_id: int | None) -> UdpFlowState | None: if not flow_id: return None for flow in self.client_flows.values(): if flow.flow_id == flow_id: return flow return None def _log_udp_summary(self, force: bool = False) -> None: now = asyncio.get_running_loop().time() if not force and now - self.last_summary_at < 10: return self.last_summary_at = now active_flows = len(self.client_flows) winners = sum(1 for flow in self.client_flows.values() if flow.winner_name) packets_sent = sum(flow.packets_sent for flow in self.client_flows.values()) packets_received = sum(flow.packets_received for flow in self.client_flows.values()) duplicates = sum(flow.duplicate_responses for flow in self.client_flows.values()) print( f"[edge] udp summary bind={self.client_addr[0]}:{self.client_addr[1]} active_flows={active_flows} " f"winner_flows={winners} packets_sent={packets_sent} packets_received={packets_received} dup={duplicates}" if self.client_addr else f"[edge] udp summary bind=unbound active_flows={active_flows} winner_flows={winners} packets_sent={packets_sent} packets_received={packets_received} dup={duplicates}" ) def _parse_socks_udp(self, packet: bytes) -> tuple[str, int, bytes]: atyp = packet[3] offset = 4 if atyp == 1: host = socket.inet_ntoa(packet[offset:offset + 4]) offset += 4 elif atyp == 3: size = packet[offset] offset += 1 host = packet[offset:offset + size].decode() offset += size else: raise ValueError("unsupported udp atyp") port = struct.unpack("!H", packet[offset:offset + 2])[0] offset += 2 return host, port, packet[offset:] def _build_socks_udp(self, host: str, port: int, payload: bytes) -> bytes: try: addr = socket.inet_aton(host) header = b"\x00\x00\x00\x01" + addr + struct.pack("!H", port) except OSError: raw = host.encode() header = b"\x00\x00\x00\x03" + bytes([len(raw)]) + raw + struct.pack("!H", port) return header + payload class SocksEdge: def __init__(self, listen_host: str, listen_port: int, config: Config) -> None: self.listen_host = listen_host self.listen_port = listen_port self.config = config self.scheduler = Scheduler(config) self.links: list[RelayLink] = [] self.session_ids = itertools.count(1) self.udp_targets: dict[int, tuple[str, int]] = {} self.udp_server: UdpAssociateServer | None = None async def start(self) -> None: await self.scheduler.start() await self._connect_relays() server = await asyncio.start_server(self._accept, self.listen_host, self.listen_port) sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or []) print(f"[edge] socks5 listening on {sockets}") async with server: await server.serve_forever() async def _connect_relays(self) -> None: for node in self.config.relays: reader, writer = await asyncio.open_connection(node.host, node.port) link = RelayLink(node, reader, writer) await link.start() self.links.append(link) loop = asyncio.get_running_loop() transport, protocol = await loop.create_datagram_endpoint(lambda: UdpAssociateServer(self), local_addr=(self.listen_host, 0)) self.udp_server = protocol for link in self.links: link.udp_server = protocol self.udp_transport = transport async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: try: peer = writer.get_extra_info("peername") host, port, udp_mode = await self._handshake(reader, writer, peer) if udp_mode: return links = self._selected_links() session = TcpRaceSession( session_id=next(self.session_ids), stream_id=0, target_host=host, target_port=port, local_reader=reader, local_writer=writer, links=links, warmup_bytes=self.config.tcp_warmup_bytes, ) await session.start() except Exception: writer.close() with contextlib.suppress(Exception): await writer.wait_closed() def _selected_links(self) -> list[RelayLink]: chosen = {node.name for node in self.scheduler.choose()} links = [link for link in self.links if link.node.name in chosen and not link.closed] return links or [link for link in self.links if not link.closed][:1] async def forward_udp(self, host: str, port: int, payload: bytes, packet_id: int, udp_server: UdpAssociateServer) -> None: self.udp_targets[packet_id] = (host, port) meta = encode_json({"host": host, "port": port}) links = self._selected_links() link_names = ",".join(link.node.name for link in links) or "none" udp_server.set_flow_candidates(packet_id, tuple(link.node.name for link in links)) print(f"[edge] udp forward packet_id={packet_id} target={host}:{port} size={len(payload)} links={link_names}") if not links: udp_server.note_unsent(packet_id) return for index, link in enumerate(links): body = meta + payload if index == 0 else payload await link.send(Frame(UDP_SEND, 1, index, 0, packet_id if index == 0 else 0, body)) async def _handshake(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer) -> tuple[str, int, bool]: version, methods_len = (await read_exact(reader, 2)) if version != SOCKS_VERSION: raise ValueError("unsupported socks version") await read_exact(reader, methods_len) writer.write(b"\x05\x00") await writer.drain() version, command, _, atyp = await read_exact(reader, 4) if version != SOCKS_VERSION: raise ValueError("unsupported socks version") if atyp == 1: host = socket.inet_ntoa(await read_exact(reader, 4)) elif atyp == 3: size = (await read_exact(reader, 1))[0] host = (await read_exact(reader, size)).decode() else: raise ValueError("unsupported atyp") port = struct.unpack("!H", await read_exact(reader, 2))[0] peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer) if command == 1: print(f"[edge] socks handshake peer={peer_text} command=connect target={host}:{port}") writer.write(b"\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00") await writer.drain() return host, port, False if command == 3 and self.udp_server and self.udp_server.transport: bind_host, bind_port = self.udp_server.transport.get_extra_info("sockname")[:2] print(f"[edge] socks handshake peer={peer_text} command=udp_associate target={host}:{port} bind={bind_host}:{bind_port}") writer.write(b"\x05\x00\x00\x01" + socket.inet_aton(bind_host) + struct.pack("!H", bind_port)) await writer.drain() return host, port, True raise ValueError("unsupported socks command")