socks_edge.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import itertools
  5. import socket
  6. import struct
  7. from dataclasses import dataclass, field
  8. from typing import Dict
  9. from .config import Config, RelayNode
  10. from .protocol import AUTH, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame
  11. from .scheduler import Scheduler
  12. SOCKS_VERSION = 5
  13. async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes:
  14. return await reader.readexactly(size)
  15. @dataclass(eq=False)
  16. class RelayLink:
  17. node: RelayNode
  18. reader: asyncio.StreamReader
  19. writer: asyncio.StreamWriter
  20. pump: asyncio.Task | None = None
  21. tcp_sessions: Dict[tuple[int, int], "TcpRaceSession"] = field(default_factory=dict)
  22. udp_server: "UdpAssociateServer | None" = None
  23. closed: bool = False
  24. async def start(self) -> None:
  25. await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
  26. frame = await read_frame(self.reader)
  27. if frame.kind != AUTH or frame.packet_id != STATUS_OK:
  28. raise ConnectionError(f"relay auth failed: {self.node.name}")
  29. self.pump = asyncio.create_task(self._pump())
  30. async def _pump(self) -> None:
  31. try:
  32. while True:
  33. frame = await read_frame(self.reader)
  34. key = (frame.session_id, frame.stream_id)
  35. if frame.kind in (TCP_STATUS, TCP_DATA, TCP_CLOSE):
  36. session = self.tcp_sessions.get(key)
  37. if session:
  38. await session.handle_frame(self, frame)
  39. elif frame.kind == UDP_RECV and self.udp_server:
  40. await self.udp_server.handle_from_relay(frame, self)
  41. except asyncio.IncompleteReadError:
  42. pass
  43. finally:
  44. await self.close()
  45. async def send(self, frame: Frame) -> None:
  46. if self.closed:
  47. raise ConnectionError(f"relay closed: {self.node.name}")
  48. await write_frame(self.writer, frame)
  49. async def close(self) -> None:
  50. if self.closed:
  51. return
  52. self.closed = True
  53. self.writer.close()
  54. with contextlib.suppress(Exception):
  55. await self.writer.wait_closed()
  56. @dataclass
  57. class TcpRaceSession:
  58. session_id: int
  59. stream_id: int
  60. target_host: str
  61. target_port: int
  62. local_reader: asyncio.StreamReader
  63. local_writer: asyncio.StreamWriter
  64. links: list[RelayLink]
  65. warmup_bytes: int
  66. winning_link: RelayLink | None = None
  67. winner_name: str | None = None
  68. opened: int = 0
  69. open_errors: list[str] = field(default_factory=list)
  70. uplink_bytes: int = 0
  71. closed: bool = False
  72. open_event: asyncio.Event = field(default_factory=asyncio.Event)
  73. winner_event: asyncio.Event = field(default_factory=asyncio.Event)
  74. pump_task: asyncio.Task | None = None
  75. win_counts: Dict[str, int] = field(default_factory=dict)
  76. async def start(self) -> None:
  77. meta = encode_json({"host": self.target_host, "port": self.target_port})
  78. for link in self.links:
  79. link.tcp_sessions[(self.session_id, self.stream_id)] = self
  80. await link.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, meta))
  81. await asyncio.wait_for(self.open_event.wait(), timeout=10)
  82. if self.opened == 0:
  83. raise ConnectionError(self.open_errors[0] if self.open_errors else "all relays failed")
  84. self.pump_task = asyncio.create_task(self._pump_local())
  85. async def _pump_local(self) -> None:
  86. try:
  87. while True:
  88. chunk = await self.local_reader.read(65536)
  89. if not chunk:
  90. break
  91. self.uplink_bytes += len(chunk)
  92. if self.winning_link is None and self.uplink_bytes <= self.warmup_bytes:
  93. await asyncio.gather(*(link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk)) for link in self.links if not link.closed), return_exceptions=True)
  94. else:
  95. if self.winning_link is None:
  96. await self.winner_event.wait()
  97. if self.winning_link:
  98. await self.winning_link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk))
  99. except Exception:
  100. pass
  101. finally:
  102. await self.close()
  103. async def handle_frame(self, link: RelayLink, frame: Frame) -> None:
  104. if self.closed:
  105. return
  106. if frame.kind == TCP_STATUS:
  107. if frame.packet_id == STATUS_OK:
  108. self.opened += 1
  109. else:
  110. self.open_errors.append(frame.payload.decode("utf-8", errors="replace"))
  111. if self.opened > 0 or len(self.open_errors) == len(self.links):
  112. self.open_event.set()
  113. return
  114. if frame.kind == TCP_DATA:
  115. if self.winning_link is None:
  116. self.winning_link = link
  117. self.winner_name = link.node.name
  118. self.win_counts[link.node.name] = self.win_counts.get(link.node.name, 0) + 1
  119. node_total = self.win_counts[link.node.name]
  120. relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.win_counts.items())) or "none"
  121. print(f"[edge] tcp win session={self.session_id} target={self.target_host}:{self.target_port} winner={link.node.name} node_total={node_total} win_breakdown={relay_detail}")
  122. self.winner_event.set()
  123. await self._close_losers(except_link=link)
  124. if link is self.winning_link:
  125. self.local_writer.write(frame.payload)
  126. await self.local_writer.drain()
  127. return
  128. if frame.kind == TCP_CLOSE:
  129. if self.winning_link is None:
  130. self.winning_link = link
  131. self.winner_event.set()
  132. if link is self.winning_link:
  133. await self.close()
  134. async def _close_losers(self, except_link: RelayLink) -> None:
  135. await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if link is not except_link and not link.closed), return_exceptions=True)
  136. async def close(self) -> None:
  137. if self.closed:
  138. return
  139. self.closed = True
  140. if self.pump_task and self.pump_task is not asyncio.current_task():
  141. self.pump_task.cancel()
  142. with contextlib.suppress(Exception):
  143. await self.pump_task
  144. await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if not link.closed), return_exceptions=True)
  145. for link in self.links:
  146. link.tcp_sessions.pop((self.session_id, self.stream_id), None)
  147. self.local_writer.close()
  148. with contextlib.suppress(Exception):
  149. await self.local_writer.wait_closed()
  150. class UdpAssociateServer(asyncio.DatagramProtocol):
  151. def __init__(self, edge: "SocksEdge") -> None:
  152. self.edge = edge
  153. self.transport: asyncio.DatagramTransport | None = None
  154. self.client_addr = None
  155. self.packet_counter = itertools.count(1)
  156. self.pending: set[int] = set()
  157. def connection_made(self, transport) -> None:
  158. self.transport = transport
  159. def datagram_received(self, data: bytes, addr) -> None:
  160. if len(data) < 10:
  161. return
  162. if self.client_addr is None:
  163. self.client_addr = addr
  164. if addr != self.client_addr:
  165. return
  166. host, port, payload = self._parse_socks_udp(data)
  167. packet_id = next(self.packet_counter)
  168. self.pending.add(packet_id)
  169. asyncio.create_task(self.edge.forward_udp(host, port, payload, packet_id, self))
  170. async def handle_from_relay(self, frame: Frame, _link: RelayLink) -> None:
  171. if frame.packet_id not in self.pending or self.transport is None or self.client_addr is None:
  172. return
  173. self.pending.discard(frame.packet_id)
  174. host = self.edge.udp_targets.get(frame.packet_id, ("0.0.0.0", 0))[0]
  175. port = self.edge.udp_targets.get(frame.packet_id, ("0.0.0.0", 0))[1]
  176. packet = self._build_socks_udp(host, port, frame.payload)
  177. self.transport.sendto(packet, self.client_addr)
  178. def _parse_socks_udp(self, packet: bytes) -> tuple[str, int, bytes]:
  179. atyp = packet[3]
  180. offset = 4
  181. if atyp == 1:
  182. host = socket.inet_ntoa(packet[offset:offset + 4])
  183. offset += 4
  184. elif atyp == 3:
  185. size = packet[offset]
  186. offset += 1
  187. host = packet[offset:offset + size].decode()
  188. offset += size
  189. else:
  190. raise ValueError("unsupported udp atyp")
  191. port = struct.unpack("!H", packet[offset:offset + 2])[0]
  192. offset += 2
  193. return host, port, packet[offset:]
  194. def _build_socks_udp(self, host: str, port: int, payload: bytes) -> bytes:
  195. try:
  196. addr = socket.inet_aton(host)
  197. header = b"\x00\x00\x00\x01" + addr + struct.pack("!H", port)
  198. except OSError:
  199. raw = host.encode()
  200. header = b"\x00\x00\x00\x03" + bytes([len(raw)]) + raw + struct.pack("!H", port)
  201. return header + payload
  202. class SocksEdge:
  203. def __init__(self, listen_host: str, listen_port: int, config: Config) -> None:
  204. self.listen_host = listen_host
  205. self.listen_port = listen_port
  206. self.config = config
  207. self.scheduler = Scheduler(config)
  208. self.links: list[RelayLink] = []
  209. self.session_ids = itertools.count(1)
  210. self.udp_targets: dict[int, tuple[str, int]] = {}
  211. self.udp_server: UdpAssociateServer | None = None
  212. async def start(self) -> None:
  213. await self.scheduler.start()
  214. await self._connect_relays()
  215. server = await asyncio.start_server(self._accept, self.listen_host, self.listen_port)
  216. sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
  217. print(f"[edge] socks5 listening on {sockets}")
  218. async with server:
  219. await server.serve_forever()
  220. async def _connect_relays(self) -> None:
  221. for node in self.config.relays:
  222. reader, writer = await asyncio.open_connection(node.host, node.port)
  223. link = RelayLink(node, reader, writer)
  224. await link.start()
  225. self.links.append(link)
  226. loop = asyncio.get_running_loop()
  227. transport, protocol = await loop.create_datagram_endpoint(lambda: UdpAssociateServer(self), local_addr=(self.listen_host, 0))
  228. self.udp_server = protocol
  229. for link in self.links:
  230. link.udp_server = protocol
  231. self.udp_transport = transport
  232. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  233. try:
  234. host, port, udp_mode = await self._handshake(reader, writer)
  235. if udp_mode:
  236. return
  237. links = self._selected_links()
  238. session = TcpRaceSession(
  239. session_id=next(self.session_ids),
  240. stream_id=0,
  241. target_host=host,
  242. target_port=port,
  243. local_reader=reader,
  244. local_writer=writer,
  245. links=links,
  246. warmup_bytes=self.config.tcp_warmup_bytes,
  247. )
  248. await session.start()
  249. except Exception:
  250. writer.close()
  251. with contextlib.suppress(Exception):
  252. await writer.wait_closed()
  253. def _selected_links(self) -> list[RelayLink]:
  254. chosen = {node.name for node in self.scheduler.choose()}
  255. links = [link for link in self.links if link.node.name in chosen and not link.closed]
  256. return links or [link for link in self.links if not link.closed][:1]
  257. async def forward_udp(self, host: str, port: int, payload: bytes, packet_id: int, udp_server: UdpAssociateServer) -> None:
  258. self.udp_targets[packet_id] = (host, port)
  259. meta = encode_json({"host": host, "port": port})
  260. links = self._selected_links()
  261. for index, link in enumerate(links):
  262. body = meta + payload if index == 0 else payload
  263. await link.send(Frame(UDP_SEND, 1, index, 0, packet_id if index == 0 else 0, body))
  264. async def _handshake(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> tuple[str, int, bool]:
  265. version, methods_len = (await read_exact(reader, 2))
  266. if version != SOCKS_VERSION:
  267. raise ValueError("unsupported socks version")
  268. await read_exact(reader, methods_len)
  269. writer.write(b"\x05\x00")
  270. await writer.drain()
  271. version, command, _, atyp = await read_exact(reader, 4)
  272. if version != SOCKS_VERSION:
  273. raise ValueError("unsupported socks version")
  274. if atyp == 1:
  275. host = socket.inet_ntoa(await read_exact(reader, 4))
  276. elif atyp == 3:
  277. size = (await read_exact(reader, 1))[0]
  278. host = (await read_exact(reader, size)).decode()
  279. else:
  280. raise ValueError("unsupported atyp")
  281. port = struct.unpack("!H", await read_exact(reader, 2))[0]
  282. if command == 1:
  283. writer.write(b"\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00")
  284. await writer.drain()
  285. return host, port, False
  286. if command == 3 and self.udp_server and self.udp_server.transport:
  287. bind_host, bind_port = self.udp_server.transport.get_extra_info("sockname")[:2]
  288. writer.write(b"\x05\x00\x00\x01" + socket.inet_aton(bind_host) + struct.pack("!H", bind_port))
  289. await writer.drain()
  290. return host, port, True
  291. raise ValueError("unsupported socks command")