socks_edge.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import itertools
  5. import socket
  6. import struct
  7. from dataclasses import dataclass, field
  8. from typing import Dict
  9. from .config import Config, RelayNode
  10. from .protocol import AUTH, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame
  11. from .scheduler import Scheduler
  12. SOCKS_VERSION = 5
  13. async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes:
  14. return await reader.readexactly(size)
  15. @dataclass(eq=False)
  16. class RelayLink:
  17. node: RelayNode
  18. reader: asyncio.StreamReader
  19. writer: asyncio.StreamWriter
  20. pump: asyncio.Task | None = None
  21. closed_event: asyncio.Event = field(default_factory=asyncio.Event)
  22. maintain_task: asyncio.Task | None = None
  23. send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  24. tcp_sessions: Dict[tuple[int, int], "TcpRaceSession"] = field(default_factory=dict)
  25. udp_server: "UdpAssociateServer | None" = None
  26. closed: bool = False
  27. async def start(self) -> None:
  28. await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
  29. frame = await read_frame(self.reader)
  30. if frame.kind != AUTH or frame.packet_id != STATUS_OK:
  31. raise ConnectionError(f"relay auth failed: {self.node.name}")
  32. self.closed = False
  33. self.closed_event.clear()
  34. self.pump = asyncio.create_task(self._pump())
  35. async def _pump(self) -> None:
  36. try:
  37. while True:
  38. frame = await read_frame(self.reader)
  39. key = (frame.session_id, frame.stream_id)
  40. if frame.kind in (TCP_STATUS, TCP_DATA, TCP_CLOSE):
  41. session = self.tcp_sessions.get(key)
  42. if session:
  43. await session.handle_frame(self, frame)
  44. elif self.udp_server:
  45. if frame.kind == TCP_STATUS:
  46. await self.udp_server.handle_relay_status(frame, self)
  47. elif frame.kind == TCP_CLOSE:
  48. await self.udp_server.handle_relay_close(frame, self)
  49. elif frame.kind == UDP_RECV and self.udp_server:
  50. await self.udp_server.handle_from_relay(frame, self)
  51. except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, OSError):
  52. pass
  53. except Exception:
  54. pass
  55. finally:
  56. await self.close()
  57. async def send(self, frame: Frame) -> None:
  58. if self.closed:
  59. raise ConnectionError(f"relay closed: {self.node.name}")
  60. try:
  61. async with self.send_lock:
  62. if self.closed:
  63. raise ConnectionError(f"relay closed: {self.node.name}")
  64. await write_frame(self.writer, frame)
  65. except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError) as exc:
  66. await self.close()
  67. raise ConnectionError(f"relay closed: {self.node.name}") from exc
  68. async def close(self) -> None:
  69. if self.closed:
  70. return
  71. self.closed = True
  72. self.closed_event.set()
  73. if self.udp_server:
  74. self.udp_server.handle_relay_closed(self.node.name)
  75. if self.pump and self.pump is not asyncio.current_task():
  76. self.pump.cancel()
  77. with contextlib.suppress(Exception):
  78. await self.pump
  79. self.writer.close()
  80. with contextlib.suppress(Exception):
  81. await self.writer.wait_closed()
  82. @dataclass
  83. class UdpFlowState:
  84. flow_id: int
  85. client_addr: tuple[str, int]
  86. target_host: str
  87. target_port: int
  88. created_at: float
  89. last_activity: float
  90. packets_sent: int = 0
  91. packets_received: int = 0
  92. duplicate_responses: int = 0
  93. winner_name: str | None = None
  94. candidate_names: tuple[str, ...] = ()
  95. link_streams: dict[str, int] = field(default_factory=dict)
  96. initialized_links: set[str] = field(default_factory=set)
  97. direct_sockets: dict[str, socket.socket] = field(default_factory=dict)
  98. direct_tasks: dict[str, asyncio.Task] = field(default_factory=dict)
  99. direct_failures: set[str] = field(default_factory=set)
  100. relay_failures: dict[str, int] = field(default_factory=dict)
  101. relay_error_seen: set[str] = field(default_factory=set)
  102. target_family: int = 0
  103. def touch(self, now: float) -> None:
  104. self.last_activity = now
  105. @dataclass
  106. class TcpRaceSession:
  107. session_id: int
  108. stream_id: int
  109. target_host: str
  110. target_port: int
  111. local_reader: asyncio.StreamReader
  112. local_writer: asyncio.StreamWriter
  113. links: list[RelayLink]
  114. warmup_initial_bytes: int
  115. warmup_max_bytes: int
  116. stats: Dict[str, int]
  117. winning_link: RelayLink | None = None
  118. winner_name: str | None = None
  119. opened: int = 0
  120. open_errors: list[str] = field(default_factory=list)
  121. uplink_bytes: int = 0
  122. closed: bool = False
  123. open_event: asyncio.Event = field(default_factory=asyncio.Event)
  124. winner_event: asyncio.Event = field(default_factory=asyncio.Event)
  125. pump_task: asyncio.Task | None = None
  126. warmup_bytes: int = 0
  127. win_counts: Dict[str, int] = field(default_factory=dict)
  128. @staticmethod
  129. def _stability_score(stats: Dict[str, int]) -> float:
  130. total = sum(stats.values())
  131. if total < 2:
  132. return 0.0
  133. dominant = max(stats.values())
  134. confidence = min(1.0, total / 20.0)
  135. return max(0.0, ((dominant / total) - 0.5) * 2.0 * confidence)
  136. def _effective_warmup_bytes(self) -> int:
  137. low = max(0, self.warmup_initial_bytes)
  138. high = max(low, self.warmup_max_bytes)
  139. score = self._stability_score(self.stats)
  140. return low + int((high - low) * score)
  141. async def start(self) -> None:
  142. meta = encode_json({"host": self.target_host, "port": self.target_port})
  143. self.warmup_bytes = self._effective_warmup_bytes()
  144. for link in self.links:
  145. link.tcp_sessions[(self.session_id, self.stream_id)] = self
  146. await link.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, meta))
  147. await asyncio.wait_for(self.open_event.wait(), timeout=10)
  148. if self.opened == 0:
  149. raise ConnectionError(self.open_errors[0] if self.open_errors else "all relays failed")
  150. self.pump_task = asyncio.create_task(self._pump_local())
  151. async def _pump_local(self) -> None:
  152. try:
  153. while True:
  154. chunk = await self.local_reader.read(65536)
  155. if not chunk:
  156. break
  157. self.uplink_bytes += len(chunk)
  158. if self.winning_link is None and self.uplink_bytes <= self.warmup_bytes:
  159. await asyncio.gather(*(link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk)) for link in self.links if not link.closed), return_exceptions=True)
  160. else:
  161. if self.winning_link is None:
  162. await self.winner_event.wait()
  163. if self.winning_link:
  164. await self.winning_link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk))
  165. except Exception:
  166. pass
  167. finally:
  168. await self.close()
  169. async def handle_frame(self, link: RelayLink, frame: Frame) -> None:
  170. if self.closed:
  171. return
  172. if frame.kind == TCP_STATUS:
  173. if frame.packet_id == STATUS_OK:
  174. self.opened += 1
  175. else:
  176. self.open_errors.append(frame.payload.decode("utf-8", errors="replace"))
  177. if self.opened > 0 or len(self.open_errors) == len(self.links):
  178. self.open_event.set()
  179. return
  180. if frame.kind == TCP_DATA:
  181. if self.winning_link is None:
  182. self.winning_link = link
  183. self.winner_name = link.node.name
  184. self.win_counts[link.node.name] = self.win_counts.get(link.node.name, 0) + 1
  185. node_total = self.win_counts[link.node.name]
  186. relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.win_counts.items())) or "none"
  187. print(f"[edge] tcp win session={self.session_id} target={self.target_host}:{self.target_port} winner={link.node.name} node_total={node_total} win_breakdown={relay_detail}")
  188. self.winner_event.set()
  189. await self._close_losers(except_link=link)
  190. if link is self.winning_link:
  191. self.local_writer.write(frame.payload)
  192. await self.local_writer.drain()
  193. return
  194. if frame.kind == TCP_CLOSE:
  195. if self.winning_link is None:
  196. self.winning_link = link
  197. self.winner_event.set()
  198. if link is self.winning_link:
  199. await self.close()
  200. async def _close_losers(self, except_link: RelayLink) -> None:
  201. await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if link is not except_link and not link.closed), return_exceptions=True)
  202. async def close(self) -> None:
  203. if self.closed:
  204. return
  205. self.closed = True
  206. if self.pump_task and self.pump_task is not asyncio.current_task():
  207. self.pump_task.cancel()
  208. with contextlib.suppress(Exception):
  209. await self.pump_task
  210. await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if not link.closed), return_exceptions=True)
  211. for link in self.links:
  212. link.tcp_sessions.pop((self.session_id, self.stream_id), None)
  213. self.local_writer.close()
  214. with contextlib.suppress(Exception):
  215. await self.local_writer.wait_closed()
  216. class UdpAssociateServer(asyncio.DatagramProtocol):
  217. def __init__(self, edge: "SocksEdge") -> None:
  218. self.edge = edge
  219. self.transport: asyncio.DatagramTransport | None = None
  220. self.client_addr = None
  221. self.associate_peer = None
  222. self.packet_counter = itertools.count(1)
  223. self.client_flows: dict[tuple[tuple[str, int], str, int], UdpFlowState] = {}
  224. self.flow_counter = itertools.count(1)
  225. self.last_summary_at = 0.0
  226. self.win_counts: Dict[str, int] = {}
  227. self.relay_error_counts: Dict[str, int] = {}
  228. def connection_made(self, transport) -> None:
  229. self.transport = transport
  230. def register_associate(self, peer) -> None:
  231. peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer)
  232. if self.associate_peer != peer_text:
  233. print(f"[edge] udp associate peer={peer_text}")
  234. self.associate_peer = peer_text
  235. def datagram_received(self, data: bytes, addr) -> None:
  236. if len(data) < 10:
  237. return
  238. if self.client_addr is None:
  239. self.client_addr = addr
  240. print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}")
  241. elif addr != self.client_addr:
  242. print(f"[edge] udp client rebound old={self.client_addr[0]}:{self.client_addr[1]} new={addr[0]}:{addr[1]}")
  243. self._reset_client_state(addr)
  244. host, port, payload = self._parse_socks_udp(data)
  245. family = socket.AF_INET6 if ":" in host else socket.AF_INET
  246. loop = asyncio.get_running_loop()
  247. now = loop.time()
  248. flow_key = ((addr[0], addr[1]), host, port)
  249. flow = self.client_flows.get(flow_key)
  250. if flow is None:
  251. flow = UdpFlowState(
  252. flow_id=next(self.flow_counter),
  253. client_addr=(addr[0], addr[1]),
  254. target_host=host,
  255. target_port=port,
  256. target_family=family,
  257. created_at=now,
  258. last_activity=now,
  259. )
  260. self.client_flows[flow_key] = flow
  261. flow.touch(now)
  262. flow.packets_sent += 1
  263. packet_id = next(self.packet_counter)
  264. asyncio.create_task(self.edge.forward_udp(flow, payload, packet_id, self))
  265. self._log_udp_summary()
  266. def _reset_client_state(self, addr) -> None:
  267. for flow in list(self.client_flows.values()):
  268. for task in list(flow.direct_tasks.values()):
  269. task.cancel()
  270. for sock in list(flow.direct_sockets.values()):
  271. with contextlib.suppress(Exception):
  272. sock.close()
  273. for stream_id in list(flow.link_streams.values()):
  274. self.edge.udp_flow_sessions.pop((flow.flow_id, stream_id), None)
  275. self.client_flows.clear()
  276. self.client_addr = addr
  277. self.win_counts.clear()
  278. print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}")
  279. async def handle_from_relay(self, frame: Frame, link: RelayLink) -> None:
  280. if self.transport is None or self.client_addr is None:
  281. return
  282. flow = self.edge.udp_flow_sessions.get((frame.session_id, frame.stream_id))
  283. if flow is None:
  284. return
  285. await self._deliver_flow_packet(flow, frame.packet_id, frame.payload, link.node.name)
  286. async def handle_relay_status(self, frame: Frame, link: RelayLink) -> None:
  287. if frame.packet_id == STATUS_OK:
  288. return
  289. flow = self.edge.udp_flow_sessions.get((frame.session_id, frame.stream_id))
  290. if flow is None:
  291. return
  292. detail = frame.payload.decode("utf-8", errors="replace") if frame.payload else "unknown"
  293. current_stream = flow.link_streams.get(link.node.name)
  294. if current_stream == frame.stream_id:
  295. flow.link_streams.pop(link.node.name, None)
  296. self.edge.udp_flow_sessions.pop((frame.session_id, frame.stream_id), None)
  297. flow.initialized_links.discard(link.node.name)
  298. flow.relay_failures[link.node.name] = flow.relay_failures.get(link.node.name, 0) + 1
  299. print(f"[edge] udp relay status error flow={flow.flow_id} relay={link.node.name} detail={detail}")
  300. async def handle_relay_close(self, frame: Frame, link: RelayLink) -> None:
  301. flow = self.edge.udp_flow_sessions.get((frame.session_id, frame.stream_id))
  302. if flow is None:
  303. return
  304. current_stream = flow.link_streams.get(link.node.name)
  305. if current_stream == frame.stream_id:
  306. flow.link_streams.pop(link.node.name, None)
  307. self.edge.udp_flow_sessions.pop((frame.session_id, frame.stream_id), None)
  308. flow.initialized_links.discard(link.node.name)
  309. flow.relay_failures[link.node.name] = flow.relay_failures.get(link.node.name, 0) + 1
  310. print(f"[edge] udp relay close flow={flow.flow_id} relay={link.node.name}")
  311. async def handle_from_direct(self, flow: UdpFlowState, path_name: str, payload: bytes) -> None:
  312. if self.transport is None or self.client_addr is None:
  313. return
  314. await self._deliver_flow_packet(flow, 0, payload, path_name)
  315. async def _deliver_flow_packet(self, flow: UdpFlowState, packet_id: int, payload: bytes, source_name: str) -> None:
  316. if self.transport is None or self.client_addr is None:
  317. return
  318. packet = self._build_socks_udp(flow.target_host, flow.target_port, payload)
  319. now = asyncio.get_running_loop().time()
  320. flow.touch(now)
  321. flow.packets_received += 1
  322. if flow.winner_name is None:
  323. flow.winner_name = source_name
  324. self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1
  325. self._log_udp_summary(force=True)
  326. elif flow.winner_name != source_name:
  327. flow.duplicate_responses += 1
  328. if flow.winner_name == source_name:
  329. self.transport.sendto(packet, self.client_addr)
  330. def set_flow_candidates(self, flow: UdpFlowState, candidate_names: tuple[str, ...]) -> None:
  331. flow.candidate_names = candidate_names
  332. def handle_relay_closed(self, relay_name: str) -> None:
  333. for flow in self.client_flows.values():
  334. stream_id = flow.link_streams.pop(relay_name, None)
  335. if stream_id is not None:
  336. self.edge.udp_flow_sessions.pop((flow.flow_id, stream_id), None)
  337. flow.initialized_links.discard(relay_name)
  338. def note_unsent(self, flow: UdpFlowState, packet_id: int) -> None:
  339. flow.touch(asyncio.get_running_loop().time())
  340. flow.relay_failures["unsent"] = flow.relay_failures.get("unsent", 0) + 1
  341. self._log_udp_summary(force=True)
  342. def _log_udp_summary(self, force: bool = False) -> None:
  343. now = asyncio.get_running_loop().time()
  344. if not force and now - self.last_summary_at < 10:
  345. return
  346. self.last_summary_at = now
  347. active_flows = len(self.client_flows)
  348. winners = sum(1 for flow in self.client_flows.values() if flow.winner_name)
  349. packets_sent = sum(flow.packets_sent for flow in self.client_flows.values())
  350. packets_received = sum(flow.packets_received for flow in self.client_flows.values())
  351. duplicates = sum(flow.duplicate_responses for flow in self.client_flows.values())
  352. direct_paths = sum(len(flow.direct_sockets) for flow in self.client_flows.values())
  353. relay_candidates = sum(len(flow.link_streams) for flow in self.client_flows.values())
  354. winner_detail = ", ".join(f"{flow.flow_id}:{flow.winner_name}" for flow in self.client_flows.values() if flow.winner_name) or "none"
  355. candidate_detail_parts: list[str] = []
  356. relay_errors: list[str] = []
  357. for flow in self.client_flows.values():
  358. states: list[str] = []
  359. for name in flow.candidate_names:
  360. if name in flow.direct_sockets:
  361. state = "direct"
  362. elif name in flow.link_streams:
  363. state = "relay"
  364. elif name in flow.direct_failures:
  365. state = "direct_down"
  366. elif name in flow.relay_failures:
  367. state = "relay_down"
  368. else:
  369. state = "pending"
  370. if flow.winner_name == name:
  371. state += ":win"
  372. states.append(f"{name}={state}")
  373. candidate_detail_parts.append(f"{flow.flow_id}[{', '.join(states) or 'none'}]")
  374. for name, count in flow.relay_failures.items():
  375. relay_errors.append(f"{name}={count}")
  376. candidate_detail = "; ".join(candidate_detail_parts) or "none"
  377. relay_error_detail = ", ".join(sorted(relay_errors)) or "none"
  378. if self.client_addr:
  379. print(
  380. f"[edge] udp summary bind={self.client_addr[0]}:{self.client_addr[1]} active_flows={active_flows} "
  381. f"winner_flows={winners} winner_detail={winner_detail} packets_sent={packets_sent} packets_received={packets_received} dup={duplicates} "
  382. f"direct_paths={direct_paths} relay_paths={relay_candidates} candidates={candidate_detail} relay_errors={relay_error_detail}"
  383. )
  384. else:
  385. print(
  386. f"[edge] udp summary bind=unbound active_flows={active_flows} winner_flows={winners} winner_detail={winner_detail} packets_sent={packets_sent} packets_received={packets_received} dup={duplicates} "
  387. f"direct_paths={direct_paths} relay_paths={relay_candidates} candidates={candidate_detail} relay_errors={relay_error_detail}"
  388. )
  389. def _parse_socks_udp(self, packet: bytes) -> tuple[str, int, bytes]:
  390. atyp = packet[3]
  391. offset = 4
  392. if atyp == 1:
  393. host = socket.inet_ntoa(packet[offset:offset + 4])
  394. offset += 4
  395. elif atyp == 3:
  396. size = packet[offset]
  397. offset += 1
  398. host = packet[offset:offset + size].decode()
  399. offset += size
  400. else:
  401. raise ValueError("unsupported udp atyp")
  402. port = struct.unpack("!H", packet[offset:offset + 2])[0]
  403. offset += 2
  404. return host, port, packet[offset:]
  405. def _build_socks_udp(self, host: str, port: int, payload: bytes) -> bytes:
  406. try:
  407. addr = socket.inet_aton(host)
  408. header = b"\x00\x00\x00\x01" + addr + struct.pack("!H", port)
  409. except OSError:
  410. raw = host.encode()
  411. header = b"\x00\x00\x00\x03" + bytes([len(raw)]) + raw + struct.pack("!H", port)
  412. return header + payload
  413. class SocksEdge:
  414. def __init__(self, listen_host: str, listen_port: int, config: Config) -> None:
  415. self.listen_host = listen_host
  416. self.listen_port = listen_port
  417. self.config = config
  418. self.scheduler = Scheduler(config)
  419. self.links: list[RelayLink] = []
  420. self.session_ids = itertools.count(1)
  421. self.udp_stream_ids = itertools.count(1)
  422. self.udp_flow_sessions: dict[tuple[int, int], UdpFlowState] = {}
  423. self.udp_server: UdpAssociateServer | None = None
  424. self.tcp_win_counts: Dict[str, int] = {}
  425. async def start(self) -> None:
  426. await self.scheduler.start()
  427. await self._connect_relays()
  428. server = await asyncio.start_server(self._accept, self.listen_host, self.listen_port)
  429. sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
  430. print(f"[edge] socks5 listening on {sockets}")
  431. async with server:
  432. await server.serve_forever()
  433. async def _connect_relays(self) -> None:
  434. loop = asyncio.get_running_loop()
  435. transport, protocol = await loop.create_datagram_endpoint(lambda: UdpAssociateServer(self), local_addr=(self.listen_host, 0))
  436. self.udp_server = protocol
  437. self.udp_transport = transport
  438. for node in self.config.relays:
  439. link = RelayLink(node=node, reader=None, writer=None) # type: ignore[arg-type]
  440. link.udp_server = protocol
  441. self.links.append(link)
  442. link.maintain_task = asyncio.create_task(self._maintain_link(link))
  443. async def _maintain_link(self, link: RelayLink) -> None:
  444. backoff = 1.0
  445. while True:
  446. try:
  447. reader, writer = await asyncio.open_connection(link.node.host, link.node.port)
  448. sock = writer.get_extra_info("socket")
  449. if sock is not None:
  450. with contextlib.suppress(OSError):
  451. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  452. link.reader = reader
  453. link.writer = writer
  454. await link.start()
  455. backoff = 1.0
  456. await link.closed_event.wait()
  457. except asyncio.CancelledError:
  458. raise
  459. except Exception:
  460. await asyncio.sleep(backoff)
  461. backoff = min(10.0, backoff * 2)
  462. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  463. try:
  464. peer = writer.get_extra_info("peername")
  465. _host, _port, udp_mode = await self._handshake(reader, writer, peer)
  466. if udp_mode:
  467. return
  468. except Exception:
  469. writer.close()
  470. with contextlib.suppress(Exception):
  471. await writer.wait_closed()
  472. def _selected_links(self) -> list[RelayLink]:
  473. chosen = {node.name for node in self.scheduler.choose()}
  474. links = [link for link in self.links if link.node.name in chosen and not link.closed]
  475. return links or [link for link in self.links if not link.closed][:1]
  476. def _selected_udp_links(self) -> list[RelayLink]:
  477. online = [link for link in self.links if not link.closed and link.writer is not None]
  478. if not online:
  479. return []
  480. ordered = sorted(online, key=lambda link: self.scheduler.scores.get(link.node.name).score if link.node.name in self.scheduler.scores else 999999.0)
  481. return ordered
  482. def _udp_direct_redundancy_for_target(self, target_host: str) -> int:
  483. base = self.config.udp_direct_redundancy
  484. if ":" in target_host and self.config.udp_direct_redundancy_v6 is not None:
  485. base = self.config.udp_direct_redundancy_v6
  486. elif ":" not in target_host and self.config.udp_direct_redundancy_v4 is not None:
  487. base = self.config.udp_direct_redundancy_v4
  488. return max(1, base)
  489. async def _ensure_udp_direct_paths(self, flow: UdpFlowState, udp_server: UdpAssociateServer) -> None:
  490. target_count = self._udp_direct_redundancy_for_target(flow.target_host)
  491. for index in range(target_count):
  492. name = f"direct-{index + 1}" if target_count > 1 else "direct"
  493. if name in flow.direct_sockets or name in flow.direct_failures:
  494. continue
  495. try:
  496. family = socket.AF_INET6 if ":" in flow.target_host else socket.AF_INET
  497. sock = socket.socket(family, socket.SOCK_DGRAM)
  498. sock.setblocking(False)
  499. await asyncio.get_running_loop().sock_connect(sock, (flow.target_host, flow.target_port))
  500. flow.direct_sockets[name] = sock
  501. flow.direct_tasks[name] = asyncio.create_task(self._pump_udp_direct(flow, name, sock, udp_server))
  502. except Exception as exc:
  503. flow.direct_failures.add(name)
  504. print(f"[edge] udp direct open error flow={flow.flow_id} path={name} target={flow.target_host}:{flow.target_port} error={exc!r}")
  505. async def _pump_udp_direct(self, flow: UdpFlowState, path_name: str, sock: socket.socket, udp_server: UdpAssociateServer) -> None:
  506. loop = asyncio.get_running_loop()
  507. try:
  508. while True:
  509. data = await loop.sock_recv(sock, 65535)
  510. if not data:
  511. break
  512. await udp_server.handle_from_direct(flow, path_name, data)
  513. except Exception:
  514. pass
  515. finally:
  516. flow.direct_tasks.pop(path_name, None)
  517. flow.direct_sockets.pop(path_name, None)
  518. with contextlib.suppress(Exception):
  519. sock.close()
  520. async def forward_udp(self, flow: UdpFlowState, payload: bytes, packet_id: int, udp_server: UdpAssociateServer) -> None:
  521. await self._ensure_udp_direct_paths(flow, udp_server)
  522. meta = encode_json({"host": flow.target_host, "port": flow.target_port, "family": flow.target_family})
  523. links = self._selected_udp_links()
  524. direct_names = tuple(name for name in sorted(flow.direct_sockets))
  525. relay_names = tuple(link.node.name for link in links)
  526. candidate_names = direct_names + relay_names
  527. udp_server.set_flow_candidates(flow, candidate_names)
  528. if not candidate_names:
  529. udp_server.note_unsent(flow, packet_id)
  530. return
  531. active_direct_names = list(direct_names)
  532. active_links = links
  533. if not (self.config.udp_always_broadcast or flow.winner_name is None):
  534. active_direct_names = [name for name in active_direct_names if name == flow.winner_name]
  535. active_links = [link for link in active_links if link.node.name == flow.winner_name]
  536. if not active_direct_names and not active_links:
  537. if direct_names:
  538. active_direct_names = [direct_names[0]]
  539. elif links:
  540. active_links = links[:1]
  541. copies = max(1, self.config.udp_redundancy + 1)
  542. sent_any = False
  543. for attempt in range(copies):
  544. for path_name in active_direct_names:
  545. sock = flow.direct_sockets.get(path_name)
  546. if sock is None:
  547. continue
  548. try:
  549. await asyncio.get_running_loop().sock_sendall(sock, payload)
  550. sent_any = True
  551. except Exception as exc:
  552. flow.direct_failures.add(path_name)
  553. flow.direct_sockets.pop(path_name, None)
  554. task = flow.direct_tasks.pop(path_name, None)
  555. if task is not None:
  556. task.cancel()
  557. with contextlib.suppress(Exception):
  558. sock.close()
  559. flow.relay_failures[path_name] = flow.relay_failures.get(path_name, 0) + 1
  560. if path_name not in flow.relay_error_seen:
  561. flow.relay_error_seen.add(path_name)
  562. print(
  563. f"[edge] udp relay error flow={flow.flow_id} relay={path_name} error={exc!r}"
  564. )
  565. for link in active_links:
  566. stream_id = flow.link_streams.get(link.node.name)
  567. if stream_id is None:
  568. stream_id = next(self.udp_stream_ids)
  569. flow.link_streams[link.node.name] = stream_id
  570. self.udp_flow_sessions[(flow.flow_id, stream_id)] = flow
  571. body = meta + payload
  572. meta_len = len(meta)
  573. try:
  574. await link.send(Frame(UDP_SEND, flow.flow_id, stream_id, 0, meta_len, body))
  575. sent_any = True
  576. except Exception as exc:
  577. flow.link_streams.pop(link.node.name, None)
  578. self.udp_flow_sessions.pop((flow.flow_id, stream_id), None)
  579. flow.initialized_links.discard(link.node.name)
  580. flow.relay_failures[link.node.name] = flow.relay_failures.get(link.node.name, 0) + 1
  581. if link.node.name not in flow.relay_error_seen:
  582. flow.relay_error_seen.add(link.node.name)
  583. print(
  584. f"[edge] udp relay error flow={flow.flow_id} relay={link.node.name} error={exc!r}"
  585. )
  586. if attempt + 1 < copies and self.config.udp_copy_interval_ms > 0:
  587. await asyncio.sleep(self.config.udp_copy_interval_ms / 1000)
  588. if not sent_any:
  589. udp_server.note_unsent(flow, packet_id)
  590. async def _handshake(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer) -> tuple[str, int, bool]:
  591. version, methods_len = (await read_exact(reader, 2))
  592. if version != SOCKS_VERSION:
  593. raise ValueError("unsupported socks version")
  594. await read_exact(reader, methods_len)
  595. writer.write(b"\x05\x00")
  596. await writer.drain()
  597. version, command, _, atyp = await read_exact(reader, 4)
  598. if version != SOCKS_VERSION:
  599. raise ValueError("unsupported socks version")
  600. if atyp == 1:
  601. host = socket.inet_ntoa(await read_exact(reader, 4))
  602. elif atyp == 3:
  603. size = (await read_exact(reader, 1))[0]
  604. host = (await read_exact(reader, size)).decode()
  605. else:
  606. raise ValueError("unsupported atyp")
  607. port = struct.unpack("!H", await read_exact(reader, 2))[0]
  608. peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer)
  609. if command == 1:
  610. writer.write(b"\x05\x07\x00\x01\x00\x00\x00\x00\x00\x00")
  611. await writer.drain()
  612. raise ValueError("tcp connect disabled in socks udp-only mode")
  613. if command == 3 and self.udp_server and self.udp_server.transport:
  614. bind_host, bind_port = self.udp_server.transport.get_extra_info("sockname")[:2]
  615. self.udp_server.register_associate(peer)
  616. print(f"[edge] socks handshake peer={peer_text} command=udp_associate target={host}:{port} bind={bind_host}:{bind_port}")
  617. writer.write(b"\x05\x00\x00\x01" + socket.inet_aton(bind_host) + struct.pack("!H", bind_port))
  618. await writer.drain()
  619. return host, port, True
  620. raise ValueError("unsupported socks command")