| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317 |
- from __future__ import annotations
- import asyncio
- import contextlib
- import itertools
- import socket
- import struct
- from dataclasses import dataclass, field
- from typing import Dict
- from .config import Config, RelayNode
- from .protocol import AUTH, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame
- from .scheduler import Scheduler
- SOCKS_VERSION = 5
- async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes:
- return await reader.readexactly(size)
- @dataclass(eq=False)
- class RelayLink:
- node: RelayNode
- reader: asyncio.StreamReader
- writer: asyncio.StreamWriter
- pump: asyncio.Task | None = None
- tcp_sessions: Dict[tuple[int, int], "TcpRaceSession"] = field(default_factory=dict)
- udp_server: "UdpAssociateServer | None" = None
- closed: bool = False
- async def start(self) -> None:
- await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
- frame = await read_frame(self.reader)
- if frame.kind != AUTH or frame.packet_id != STATUS_OK:
- raise ConnectionError(f"relay auth failed: {self.node.name}")
- self.pump = asyncio.create_task(self._pump())
- async def _pump(self) -> None:
- try:
- while True:
- frame = await read_frame(self.reader)
- key = (frame.session_id, frame.stream_id)
- if frame.kind in (TCP_STATUS, TCP_DATA, TCP_CLOSE):
- session = self.tcp_sessions.get(key)
- if session:
- await session.handle_frame(self, frame)
- elif frame.kind == UDP_RECV and self.udp_server:
- await self.udp_server.handle_from_relay(frame, self)
- except asyncio.IncompleteReadError:
- pass
- finally:
- await self.close()
- async def send(self, frame: Frame) -> None:
- if self.closed:
- raise ConnectionError(f"relay closed: {self.node.name}")
- await write_frame(self.writer, frame)
- async def close(self) -> None:
- if self.closed:
- return
- self.closed = True
- self.writer.close()
- with contextlib.suppress(Exception):
- await self.writer.wait_closed()
- @dataclass
- class TcpRaceSession:
- session_id: int
- stream_id: int
- target_host: str
- target_port: int
- local_reader: asyncio.StreamReader
- local_writer: asyncio.StreamWriter
- links: list[RelayLink]
- warmup_bytes: int
- winning_link: RelayLink | None = None
- opened: int = 0
- open_errors: list[str] = field(default_factory=list)
- uplink_bytes: int = 0
- closed: bool = False
- open_event: asyncio.Event = field(default_factory=asyncio.Event)
- winner_event: asyncio.Event = field(default_factory=asyncio.Event)
- pump_task: asyncio.Task | None = None
- async def start(self) -> None:
- meta = encode_json({"host": self.target_host, "port": self.target_port})
- for link in self.links:
- link.tcp_sessions[(self.session_id, self.stream_id)] = self
- await link.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, meta))
- await asyncio.wait_for(self.open_event.wait(), timeout=10)
- if self.opened == 0:
- raise ConnectionError(self.open_errors[0] if self.open_errors else "all relays failed")
- self.pump_task = asyncio.create_task(self._pump_local())
- async def _pump_local(self) -> None:
- try:
- while True:
- chunk = await self.local_reader.read(65536)
- if not chunk:
- break
- self.uplink_bytes += len(chunk)
- if self.winning_link is None and self.uplink_bytes <= self.warmup_bytes:
- await asyncio.gather(*(link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk)) for link in self.links if not link.closed), return_exceptions=True)
- else:
- if self.winning_link is None:
- await self.winner_event.wait()
- if self.winning_link:
- await self.winning_link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk))
- except Exception:
- pass
- finally:
- await self.close()
- async def handle_frame(self, link: RelayLink, frame: Frame) -> None:
- if self.closed:
- return
- if frame.kind == TCP_STATUS:
- if frame.packet_id == STATUS_OK:
- self.opened += 1
- else:
- self.open_errors.append(frame.payload.decode("utf-8", errors="replace"))
- if self.opened > 0 or len(self.open_errors) == len(self.links):
- self.open_event.set()
- return
- if frame.kind == TCP_DATA:
- if self.winning_link is None:
- self.winning_link = link
- self.winner_event.set()
- await self._close_losers(except_link=link)
- if link is self.winning_link:
- self.local_writer.write(frame.payload)
- await self.local_writer.drain()
- return
- if frame.kind == TCP_CLOSE:
- if self.winning_link is None:
- self.winning_link = link
- self.winner_event.set()
- if link is self.winning_link:
- await self.close()
- async def _close_losers(self, except_link: RelayLink) -> None:
- await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if link is not except_link and not link.closed), return_exceptions=True)
- async def close(self) -> None:
- if self.closed:
- return
- self.closed = True
- if self.pump_task and self.pump_task is not asyncio.current_task():
- self.pump_task.cancel()
- with contextlib.suppress(Exception):
- await self.pump_task
- await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if not link.closed), return_exceptions=True)
- for link in self.links:
- link.tcp_sessions.pop((self.session_id, self.stream_id), None)
- self.local_writer.close()
- with contextlib.suppress(Exception):
- await self.local_writer.wait_closed()
- class UdpAssociateServer(asyncio.DatagramProtocol):
- def __init__(self, edge: "SocksEdge") -> None:
- self.edge = edge
- self.transport: asyncio.DatagramTransport | None = None
- self.client_addr = None
- self.packet_counter = itertools.count(1)
- self.pending: set[int] = set()
- def connection_made(self, transport) -> None:
- self.transport = transport
- def datagram_received(self, data: bytes, addr) -> None:
- if len(data) < 10:
- return
- if self.client_addr is None:
- self.client_addr = addr
- if addr != self.client_addr:
- return
- host, port, payload = self._parse_socks_udp(data)
- packet_id = next(self.packet_counter)
- self.pending.add(packet_id)
- asyncio.create_task(self.edge.forward_udp(host, port, payload, packet_id, self))
- async def handle_from_relay(self, frame: Frame, _link: RelayLink) -> None:
- if frame.packet_id not in self.pending or self.transport is None or self.client_addr is None:
- return
- self.pending.discard(frame.packet_id)
- host = self.edge.udp_targets.get(frame.packet_id, ("0.0.0.0", 0))[0]
- port = self.edge.udp_targets.get(frame.packet_id, ("0.0.0.0", 0))[1]
- packet = self._build_socks_udp(host, port, frame.payload)
- self.transport.sendto(packet, self.client_addr)
- def _parse_socks_udp(self, packet: bytes) -> tuple[str, int, bytes]:
- atyp = packet[3]
- offset = 4
- if atyp == 1:
- host = socket.inet_ntoa(packet[offset:offset + 4])
- offset += 4
- elif atyp == 3:
- size = packet[offset]
- offset += 1
- host = packet[offset:offset + size].decode()
- offset += size
- else:
- raise ValueError("unsupported udp atyp")
- port = struct.unpack("!H", packet[offset:offset + 2])[0]
- offset += 2
- return host, port, packet[offset:]
- def _build_socks_udp(self, host: str, port: int, payload: bytes) -> bytes:
- try:
- addr = socket.inet_aton(host)
- header = b"\x00\x00\x00\x01" + addr + struct.pack("!H", port)
- except OSError:
- raw = host.encode()
- header = b"\x00\x00\x00\x03" + bytes([len(raw)]) + raw + struct.pack("!H", port)
- return header + payload
- class SocksEdge:
- def __init__(self, listen_host: str, listen_port: int, config: Config) -> None:
- self.listen_host = listen_host
- self.listen_port = listen_port
- self.config = config
- self.scheduler = Scheduler(config)
- self.links: list[RelayLink] = []
- self.session_ids = itertools.count(1)
- self.udp_targets: dict[int, tuple[str, int]] = {}
- self.udp_server: UdpAssociateServer | None = None
- async def start(self) -> None:
- await self.scheduler.start()
- await self._connect_relays()
- server = await asyncio.start_server(self._accept, self.listen_host, self.listen_port)
- sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
- print(f"[edge] socks5 listening on {sockets}")
- async with server:
- await server.serve_forever()
- async def _connect_relays(self) -> None:
- for node in self.config.relays:
- reader, writer = await asyncio.open_connection(node.host, node.port)
- link = RelayLink(node, reader, writer)
- await link.start()
- self.links.append(link)
- loop = asyncio.get_running_loop()
- transport, protocol = await loop.create_datagram_endpoint(lambda: UdpAssociateServer(self), local_addr=(self.listen_host, 0))
- self.udp_server = protocol
- for link in self.links:
- link.udp_server = protocol
- self.udp_transport = transport
- async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
- try:
- host, port, udp_mode = await self._handshake(reader, writer)
- if udp_mode:
- return
- links = self._selected_links()
- session = TcpRaceSession(
- session_id=next(self.session_ids),
- stream_id=0,
- target_host=host,
- target_port=port,
- local_reader=reader,
- local_writer=writer,
- links=links,
- warmup_bytes=self.config.tcp_warmup_bytes,
- )
- await session.start()
- except Exception:
- writer.close()
- with contextlib.suppress(Exception):
- await writer.wait_closed()
- def _selected_links(self) -> list[RelayLink]:
- chosen = {node.name for node in self.scheduler.choose()}
- links = [link for link in self.links if link.node.name in chosen and not link.closed]
- return links or [link for link in self.links if not link.closed][:1]
- async def forward_udp(self, host: str, port: int, payload: bytes, packet_id: int, udp_server: UdpAssociateServer) -> None:
- self.udp_targets[packet_id] = (host, port)
- meta = encode_json({"host": host, "port": port})
- links = self._selected_links()
- for index, link in enumerate(links):
- body = meta + payload if index == 0 else payload
- await link.send(Frame(UDP_SEND, 1, index, 0, packet_id if index == 0 else 0, body))
- async def _handshake(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> tuple[str, int, bool]:
- version, methods_len = (await read_exact(reader, 2))
- if version != SOCKS_VERSION:
- raise ValueError("unsupported socks version")
- await read_exact(reader, methods_len)
- writer.write(b"\x05\x00")
- await writer.drain()
- version, command, _, atyp = await read_exact(reader, 4)
- if version != SOCKS_VERSION:
- raise ValueError("unsupported socks version")
- if atyp == 1:
- host = socket.inet_ntoa(await read_exact(reader, 4))
- elif atyp == 3:
- size = (await read_exact(reader, 1))[0]
- host = (await read_exact(reader, size)).decode()
- else:
- raise ValueError("unsupported atyp")
- port = struct.unpack("!H", await read_exact(reader, 2))[0]
- if command == 1:
- writer.write(b"\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00")
- await writer.drain()
- return host, port, False
- if command == 3 and self.udp_server and self.udp_server.transport:
- bind_host, bind_port = self.udp_server.transport.get_extra_info("sockname")[:2]
- writer.write(b"\x05\x00\x00\x01" + socket.inet_aton(bind_host) + struct.pack("!H", bind_port))
- await writer.drain()
- return host, port, True
- raise ValueError("unsupported socks command")
|