socks_edge.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import itertools
  5. import socket
  6. import struct
  7. from dataclasses import dataclass, field
  8. from typing import Dict
  9. from .config import Config, RelayNode
  10. from .protocol import AUTH, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame
  11. from .scheduler import Scheduler
  12. SOCKS_VERSION = 5
  13. async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes:
  14. return await reader.readexactly(size)
  15. @dataclass(eq=False)
  16. class RelayLink:
  17. node: RelayNode
  18. reader: asyncio.StreamReader
  19. writer: asyncio.StreamWriter
  20. pump: asyncio.Task | None = None
  21. tcp_sessions: Dict[tuple[int, int], "TcpRaceSession"] = field(default_factory=dict)
  22. udp_server: "UdpAssociateServer | None" = None
  23. closed: bool = False
  24. async def start(self) -> None:
  25. await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
  26. frame = await read_frame(self.reader)
  27. if frame.kind != AUTH or frame.packet_id != STATUS_OK:
  28. raise ConnectionError(f"relay auth failed: {self.node.name}")
  29. self.pump = asyncio.create_task(self._pump())
  30. async def _pump(self) -> None:
  31. try:
  32. while True:
  33. frame = await read_frame(self.reader)
  34. key = (frame.session_id, frame.stream_id)
  35. if frame.kind in (TCP_STATUS, TCP_DATA, TCP_CLOSE):
  36. session = self.tcp_sessions.get(key)
  37. if session:
  38. await session.handle_frame(self, frame)
  39. elif frame.kind == UDP_RECV and self.udp_server:
  40. await self.udp_server.handle_from_relay(frame, self)
  41. except asyncio.IncompleteReadError:
  42. pass
  43. finally:
  44. await self.close()
  45. async def send(self, frame: Frame) -> None:
  46. if self.closed:
  47. raise ConnectionError(f"relay closed: {self.node.name}")
  48. await write_frame(self.writer, frame)
  49. async def close(self) -> None:
  50. if self.closed:
  51. return
  52. self.closed = True
  53. self.writer.close()
  54. with contextlib.suppress(Exception):
  55. await self.writer.wait_closed()
  56. @dataclass
  57. class UdpFlowState:
  58. flow_id: int
  59. client_addr: tuple[str, int]
  60. target_host: str
  61. target_port: int
  62. created_at: float
  63. last_activity: float
  64. packets_sent: int = 0
  65. packets_received: int = 0
  66. duplicate_responses: int = 0
  67. winner_name: str | None = None
  68. candidate_names: tuple[str, ...] = ()
  69. link_streams: dict[str, int] = field(default_factory=dict)
  70. initialized_links: set[str] = field(default_factory=set)
  71. def touch(self, now: float) -> None:
  72. self.last_activity = now
  73. @dataclass
  74. class TcpRaceSession:
  75. session_id: int
  76. stream_id: int
  77. target_host: str
  78. target_port: int
  79. local_reader: asyncio.StreamReader
  80. local_writer: asyncio.StreamWriter
  81. links: list[RelayLink]
  82. warmup_bytes: int
  83. winning_link: RelayLink | None = None
  84. winner_name: str | None = None
  85. opened: int = 0
  86. open_errors: list[str] = field(default_factory=list)
  87. uplink_bytes: int = 0
  88. closed: bool = False
  89. open_event: asyncio.Event = field(default_factory=asyncio.Event)
  90. winner_event: asyncio.Event = field(default_factory=asyncio.Event)
  91. pump_task: asyncio.Task | None = None
  92. win_counts: Dict[str, int] = field(default_factory=dict)
  93. async def start(self) -> None:
  94. meta = encode_json({"host": self.target_host, "port": self.target_port})
  95. for link in self.links:
  96. link.tcp_sessions[(self.session_id, self.stream_id)] = self
  97. await link.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, meta))
  98. await asyncio.wait_for(self.open_event.wait(), timeout=10)
  99. if self.opened == 0:
  100. raise ConnectionError(self.open_errors[0] if self.open_errors else "all relays failed")
  101. self.pump_task = asyncio.create_task(self._pump_local())
  102. async def _pump_local(self) -> None:
  103. try:
  104. while True:
  105. chunk = await self.local_reader.read(65536)
  106. if not chunk:
  107. break
  108. self.uplink_bytes += len(chunk)
  109. if self.winning_link is None and self.uplink_bytes <= self.warmup_bytes:
  110. await asyncio.gather(*(link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk)) for link in self.links if not link.closed), return_exceptions=True)
  111. else:
  112. if self.winning_link is None:
  113. await self.winner_event.wait()
  114. if self.winning_link:
  115. await self.winning_link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk))
  116. except Exception:
  117. pass
  118. finally:
  119. await self.close()
  120. async def handle_frame(self, link: RelayLink, frame: Frame) -> None:
  121. if self.closed:
  122. return
  123. if frame.kind == TCP_STATUS:
  124. if frame.packet_id == STATUS_OK:
  125. self.opened += 1
  126. else:
  127. self.open_errors.append(frame.payload.decode("utf-8", errors="replace"))
  128. if self.opened > 0 or len(self.open_errors) == len(self.links):
  129. self.open_event.set()
  130. return
  131. if frame.kind == TCP_DATA:
  132. if self.winning_link is None:
  133. self.winning_link = link
  134. self.winner_name = link.node.name
  135. self.win_counts[link.node.name] = self.win_counts.get(link.node.name, 0) + 1
  136. node_total = self.win_counts[link.node.name]
  137. relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.win_counts.items())) or "none"
  138. print(f"[edge] tcp win session={self.session_id} target={self.target_host}:{self.target_port} winner={link.node.name} node_total={node_total} win_breakdown={relay_detail}")
  139. self.winner_event.set()
  140. await self._close_losers(except_link=link)
  141. if link is self.winning_link:
  142. self.local_writer.write(frame.payload)
  143. await self.local_writer.drain()
  144. return
  145. if frame.kind == TCP_CLOSE:
  146. if self.winning_link is None:
  147. self.winning_link = link
  148. self.winner_event.set()
  149. if link is self.winning_link:
  150. await self.close()
  151. async def _close_losers(self, except_link: RelayLink) -> None:
  152. await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if link is not except_link and not link.closed), return_exceptions=True)
  153. async def close(self) -> None:
  154. if self.closed:
  155. return
  156. self.closed = True
  157. if self.pump_task and self.pump_task is not asyncio.current_task():
  158. self.pump_task.cancel()
  159. with contextlib.suppress(Exception):
  160. await self.pump_task
  161. await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if not link.closed), return_exceptions=True)
  162. for link in self.links:
  163. link.tcp_sessions.pop((self.session_id, self.stream_id), None)
  164. self.local_writer.close()
  165. with contextlib.suppress(Exception):
  166. await self.local_writer.wait_closed()
  167. class UdpAssociateServer(asyncio.DatagramProtocol):
  168. def __init__(self, edge: "SocksEdge") -> None:
  169. self.edge = edge
  170. self.transport: asyncio.DatagramTransport | None = None
  171. self.client_addr = None
  172. self.packet_counter = itertools.count(1)
  173. self.client_flows: dict[tuple[tuple[str, int], str, int], UdpFlowState] = {}
  174. self.flow_counter = itertools.count(1)
  175. self.last_summary_at = 0.0
  176. self.win_counts: Dict[str, int] = {}
  177. def connection_made(self, transport) -> None:
  178. self.transport = transport
  179. def datagram_received(self, data: bytes, addr) -> None:
  180. if len(data) < 10:
  181. return
  182. if self.client_addr is None:
  183. self.client_addr = addr
  184. print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}")
  185. if addr != self.client_addr:
  186. return
  187. host, port, payload = self._parse_socks_udp(data)
  188. loop = asyncio.get_running_loop()
  189. now = loop.time()
  190. flow_key = ((addr[0], addr[1]), host, port)
  191. flow = self.client_flows.get(flow_key)
  192. if flow is None:
  193. flow = UdpFlowState(
  194. flow_id=next(self.flow_counter),
  195. client_addr=(addr[0], addr[1]),
  196. target_host=host,
  197. target_port=port,
  198. created_at=now,
  199. last_activity=now,
  200. )
  201. self.client_flows[flow_key] = flow
  202. flow.touch(now)
  203. flow.packets_sent += 1
  204. packet_id = next(self.packet_counter)
  205. print(f"[edge] udp recv flow={flow.flow_id} packet_id={packet_id} target={host}:{port} size={len(payload)}")
  206. asyncio.create_task(self.edge.forward_udp(flow, payload, packet_id, self))
  207. self._log_udp_summary()
  208. async def handle_from_relay(self, frame: Frame, link: RelayLink) -> None:
  209. if self.transport is None or self.client_addr is None:
  210. return
  211. flow = self.edge.udp_flow_sessions.get((frame.session_id, frame.stream_id))
  212. if flow is None:
  213. return
  214. flow_id = flow.flow_id
  215. host = flow.target_host
  216. port = flow.target_port
  217. packet = self._build_socks_udp(host, port, frame.payload)
  218. winner_log = ""
  219. now = asyncio.get_running_loop().time()
  220. flow.touch(now)
  221. flow.packets_received += 1
  222. if flow.winner_name is None:
  223. flow.winner_name = link.node.name
  224. self.win_counts[link.node.name] = self.win_counts.get(link.node.name, 0) + 1
  225. relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.win_counts.items())) or "none"
  226. mode = "redundant" if self.edge.config.udp_redundancy > 0 else "single"
  227. print(
  228. f"[edge] udp flow={flow.flow_id} winner={link.node.name} "
  229. f"target={flow.target_host}:{flow.target_port} mode={mode} candidates={len(flow.candidate_names) or len(self.edge.links)}"
  230. )
  231. print(f"[edge] udp win relay_breakdown={relay_detail}")
  232. elif flow.winner_name != link.node.name:
  233. flow.duplicate_responses += 1
  234. winner_log = f" duplicate=1 winner={flow.winner_name} from={link.node.name}"
  235. print(
  236. f"[edge] udp send flow={flow_id or 'unknown'} packet_id={frame.packet_id} "
  237. f"target={host}:{port} size={len(frame.payload)} relay={link.node.name}{winner_log}"
  238. )
  239. self.transport.sendto(packet, self.client_addr)
  240. self._log_udp_summary()
  241. def set_flow_candidates(self, flow: UdpFlowState, candidate_names: tuple[str, ...]) -> None:
  242. if not flow.candidate_names:
  243. flow.candidate_names = candidate_names
  244. def note_unsent(self, flow: UdpFlowState, packet_id: int) -> None:
  245. flow.touch(asyncio.get_running_loop().time())
  246. print(f"[edge] udp drop flow={flow.flow_id} packet_id={packet_id} reason=no_available_links")
  247. self._log_udp_summary(force=True)
  248. def _log_udp_summary(self, force: bool = False) -> None:
  249. now = asyncio.get_running_loop().time()
  250. if not force and now - self.last_summary_at < 10:
  251. return
  252. self.last_summary_at = now
  253. active_flows = len(self.client_flows)
  254. winners = sum(1 for flow in self.client_flows.values() if flow.winner_name)
  255. packets_sent = sum(flow.packets_sent for flow in self.client_flows.values())
  256. packets_received = sum(flow.packets_received for flow in self.client_flows.values())
  257. duplicates = sum(flow.duplicate_responses for flow in self.client_flows.values())
  258. print(
  259. f"[edge] udp summary bind={self.client_addr[0]}:{self.client_addr[1]} active_flows={active_flows} "
  260. f"winner_flows={winners} packets_sent={packets_sent} packets_received={packets_received} dup={duplicates}"
  261. if self.client_addr
  262. else f"[edge] udp summary bind=unbound active_flows={active_flows} winner_flows={winners} packets_sent={packets_sent} packets_received={packets_received} dup={duplicates}"
  263. )
  264. def _parse_socks_udp(self, packet: bytes) -> tuple[str, int, bytes]:
  265. atyp = packet[3]
  266. offset = 4
  267. if atyp == 1:
  268. host = socket.inet_ntoa(packet[offset:offset + 4])
  269. offset += 4
  270. elif atyp == 3:
  271. size = packet[offset]
  272. offset += 1
  273. host = packet[offset:offset + size].decode()
  274. offset += size
  275. else:
  276. raise ValueError("unsupported udp atyp")
  277. port = struct.unpack("!H", packet[offset:offset + 2])[0]
  278. offset += 2
  279. return host, port, packet[offset:]
  280. def _build_socks_udp(self, host: str, port: int, payload: bytes) -> bytes:
  281. try:
  282. addr = socket.inet_aton(host)
  283. header = b"\x00\x00\x00\x01" + addr + struct.pack("!H", port)
  284. except OSError:
  285. raw = host.encode()
  286. header = b"\x00\x00\x00\x03" + bytes([len(raw)]) + raw + struct.pack("!H", port)
  287. return header + payload
  288. class SocksEdge:
  289. def __init__(self, listen_host: str, listen_port: int, config: Config) -> None:
  290. self.listen_host = listen_host
  291. self.listen_port = listen_port
  292. self.config = config
  293. self.scheduler = Scheduler(config)
  294. self.links: list[RelayLink] = []
  295. self.session_ids = itertools.count(1)
  296. self.udp_stream_ids = itertools.count(1)
  297. self.udp_flow_sessions: dict[tuple[int, int], UdpFlowState] = {}
  298. self.udp_server: UdpAssociateServer | None = None
  299. async def start(self) -> None:
  300. await self.scheduler.start()
  301. await self._connect_relays()
  302. server = await asyncio.start_server(self._accept, self.listen_host, self.listen_port)
  303. sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
  304. print(f"[edge] socks5 listening on {sockets}")
  305. async with server:
  306. await server.serve_forever()
  307. async def _connect_relays(self) -> None:
  308. for node in self.config.relays:
  309. reader, writer = await asyncio.open_connection(node.host, node.port)
  310. link = RelayLink(node, reader, writer)
  311. await link.start()
  312. self.links.append(link)
  313. loop = asyncio.get_running_loop()
  314. transport, protocol = await loop.create_datagram_endpoint(lambda: UdpAssociateServer(self), local_addr=(self.listen_host, 0))
  315. self.udp_server = protocol
  316. for link in self.links:
  317. link.udp_server = protocol
  318. self.udp_transport = transport
  319. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  320. try:
  321. peer = writer.get_extra_info("peername")
  322. host, port, udp_mode = await self._handshake(reader, writer, peer)
  323. if udp_mode:
  324. return
  325. links = self._selected_links()
  326. session = TcpRaceSession(
  327. session_id=next(self.session_ids),
  328. stream_id=0,
  329. target_host=host,
  330. target_port=port,
  331. local_reader=reader,
  332. local_writer=writer,
  333. links=links,
  334. warmup_bytes=self.config.tcp_warmup_bytes,
  335. )
  336. await session.start()
  337. except Exception:
  338. writer.close()
  339. with contextlib.suppress(Exception):
  340. await writer.wait_closed()
  341. def _selected_links(self) -> list[RelayLink]:
  342. chosen = {node.name for node in self.scheduler.choose()}
  343. links = [link for link in self.links if link.node.name in chosen and not link.closed]
  344. return links or [link for link in self.links if not link.closed][:1]
  345. def _selected_udp_links(self) -> list[RelayLink]:
  346. online = [link for link in self.links if not link.closed]
  347. if not online:
  348. return []
  349. ordered = sorted(online, key=lambda link: self.scheduler.scores.get(link.node.name).score if link.node.name in self.scheduler.scores else 999999.0)
  350. return ordered
  351. async def forward_udp(self, flow: UdpFlowState, payload: bytes, packet_id: int, udp_server: UdpAssociateServer) -> None:
  352. meta = encode_json({"host": flow.target_host, "port": flow.target_port})
  353. links = self._selected_udp_links()
  354. link_names = ",".join(link.node.name for link in links) or "none"
  355. udp_server.set_flow_candidates(flow, tuple(link.node.name for link in links))
  356. print(f"[edge] udp forward packet_id={packet_id} target={flow.target_host}:{flow.target_port} size={len(payload)} links={link_names}")
  357. if not links:
  358. udp_server.note_unsent(flow, packet_id)
  359. return
  360. active_links = links if self.config.udp_always_broadcast or flow.winner_name is None else [link for link in links if link.node.name == flow.winner_name]
  361. active_links = active_links or links[:1]
  362. copies = max(1, self.config.udp_redundancy + 1)
  363. for attempt in range(copies):
  364. for link in active_links:
  365. stream_id = flow.link_streams.get(link.node.name)
  366. if stream_id is None:
  367. stream_id = next(self.udp_stream_ids)
  368. flow.link_streams[link.node.name] = stream_id
  369. self.udp_flow_sessions[(flow.flow_id, stream_id)] = flow
  370. include_meta = link.node.name not in flow.initialized_links
  371. body = (meta + payload) if include_meta else payload
  372. meta_len = len(meta) if include_meta else 0
  373. await link.send(Frame(UDP_SEND, flow.flow_id, stream_id, 0, meta_len, body))
  374. flow.initialized_links.add(link.node.name)
  375. if attempt + 1 < copies and self.config.udp_copy_interval_ms > 0:
  376. await asyncio.sleep(self.config.udp_copy_interval_ms / 1000)
  377. async def _handshake(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer) -> tuple[str, int, bool]:
  378. version, methods_len = (await read_exact(reader, 2))
  379. if version != SOCKS_VERSION:
  380. raise ValueError("unsupported socks version")
  381. await read_exact(reader, methods_len)
  382. writer.write(b"\x05\x00")
  383. await writer.drain()
  384. version, command, _, atyp = await read_exact(reader, 4)
  385. if version != SOCKS_VERSION:
  386. raise ValueError("unsupported socks version")
  387. if atyp == 1:
  388. host = socket.inet_ntoa(await read_exact(reader, 4))
  389. elif atyp == 3:
  390. size = (await read_exact(reader, 1))[0]
  391. host = (await read_exact(reader, size)).decode()
  392. else:
  393. raise ValueError("unsupported atyp")
  394. port = struct.unpack("!H", await read_exact(reader, 2))[0]
  395. peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer)
  396. if command == 1:
  397. print(f"[edge] socks handshake peer={peer_text} command=connect target={host}:{port}")
  398. writer.write(b"\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00")
  399. await writer.drain()
  400. return host, port, False
  401. if command == 3 and self.udp_server and self.udp_server.transport:
  402. bind_host, bind_port = self.udp_server.transport.get_extra_info("sockname")[:2]
  403. print(f"[edge] socks handshake peer={peer_text} command=udp_associate target={host}:{port} bind={bind_host}:{bind_port}")
  404. writer.write(b"\x05\x00\x00\x01" + socket.inet_aton(bind_host) + struct.pack("!H", bind_port))
  405. await writer.drain()
  406. return host, port, True
  407. raise ValueError("unsupported socks command")