socks_edge.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import itertools
  5. from collections import deque
  6. import json
  7. import socket
  8. import struct
  9. from dataclasses import dataclass, field
  10. from typing import Dict
  11. from .config import Config, RelayNode
  12. from .protocol import AUTH, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame
  13. from .scheduler import Scheduler
  14. SOCKS_VERSION = 5
  15. UDP_WARMUP_BROADCAST_PACKETS = 6
  16. UDP_SHADOW_PROBE_INTERVAL_SEC = 0.25
  17. UDP_FAST_FAILOVER_MISSES = 3
  18. async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes:
  19. return await reader.readexactly(size)
  20. @dataclass(eq=False)
  21. class RelayLink:
  22. node: RelayNode
  23. reader: asyncio.StreamReader
  24. writer: asyncio.StreamWriter
  25. pump: asyncio.Task | None = None
  26. closed_event: asyncio.Event = field(default_factory=asyncio.Event)
  27. maintain_task: asyncio.Task | None = None
  28. tcp_sessions: Dict[tuple[int, int], "TcpRaceSession"] = field(default_factory=dict)
  29. udp_server: "UdpAssociateServer | None" = None
  30. closed: bool = False
  31. supports_tcp: bool = True
  32. supports_udp: bool = True
  33. async def start(self) -> None:
  34. await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
  35. frame = await read_frame(self.reader)
  36. if frame.kind != AUTH or frame.packet_id != STATUS_OK:
  37. raise ConnectionError(f"relay auth failed: {self.node.name}")
  38. ack = {}
  39. if frame.payload:
  40. try:
  41. ack = json.loads(frame.payload.decode("utf-8"))
  42. except Exception:
  43. ack = {}
  44. self.supports_tcp = not bool(ack.get("udp_only", False))
  45. self.supports_udp = True
  46. self.closed = False
  47. self.closed_event.clear()
  48. self.pump = asyncio.create_task(self._pump())
  49. async def _pump(self) -> None:
  50. try:
  51. while True:
  52. frame = await read_frame(self.reader)
  53. key = (frame.session_id, frame.stream_id)
  54. if frame.kind in (TCP_STATUS, TCP_DATA, TCP_CLOSE):
  55. session = self.tcp_sessions.get(key)
  56. if session:
  57. await session.handle_frame(self, frame)
  58. elif frame.kind == UDP_RECV and self.udp_server:
  59. await self.udp_server.handle_from_relay(frame, self)
  60. except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, OSError):
  61. pass
  62. except Exception:
  63. pass
  64. finally:
  65. await self.close()
  66. async def send(self, frame: Frame) -> None:
  67. if self.closed:
  68. raise ConnectionError(f"relay closed: {self.node.name}")
  69. try:
  70. await write_frame(self.writer, frame)
  71. except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError) as exc:
  72. await self.close()
  73. raise ConnectionError(f"relay closed: {self.node.name}") from exc
  74. async def close(self) -> None:
  75. if self.closed:
  76. return
  77. self.closed = True
  78. self.closed_event.set()
  79. if self.pump and self.pump is not asyncio.current_task():
  80. self.pump.cancel()
  81. with contextlib.suppress(Exception):
  82. await self.pump
  83. self.writer.close()
  84. with contextlib.suppress(Exception):
  85. await self.writer.wait_closed()
  86. @dataclass
  87. class UdpFlowState:
  88. flow_id: int
  89. client_addr: tuple[str, int]
  90. target_host: str
  91. target_port: int
  92. created_at: float
  93. last_activity: float
  94. packets_sent: int = 0
  95. packets_received: int = 0
  96. duplicate_responses: int = 0
  97. winner_name: str | None = None
  98. candidate_names: tuple[str, ...] = ()
  99. link_streams: dict[str, int] = field(default_factory=dict)
  100. initialized_links: set[str] = field(default_factory=set)
  101. direct_sockets: dict[str, socket.socket] = field(default_factory=dict)
  102. direct_tasks: dict[str, asyncio.Task] = field(default_factory=dict)
  103. direct_failures: set[str] = field(default_factory=set)
  104. relay_failures: dict[str, int] = field(default_factory=dict)
  105. relay_error_seen: set[str] = field(default_factory=set)
  106. path_last_seen: dict[str, float] = field(default_factory=dict)
  107. packet_client_addrs: dict[int, tuple[str, int]] = field(default_factory=dict)
  108. direct_pending_clients: dict[str, deque[tuple[int, tuple[str, int]]]] = field(default_factory=dict)
  109. last_probe_at: float = 0.0
  110. winner_miss_streak: int = 0
  111. def touch(self, now: float) -> None:
  112. self.last_activity = now
  113. @dataclass
  114. class TcpRaceSession:
  115. session_id: int
  116. stream_id: int
  117. target_host: str
  118. target_port: int
  119. local_reader: asyncio.StreamReader
  120. local_writer: asyncio.StreamWriter
  121. links: list[RelayLink]
  122. warmup_bytes: int
  123. winning_link: RelayLink | None = None
  124. winner_name: str | None = None
  125. opened: int = 0
  126. open_errors: list[str] = field(default_factory=list)
  127. uplink_bytes: int = 0
  128. closed: bool = False
  129. open_event: asyncio.Event = field(default_factory=asyncio.Event)
  130. winner_event: asyncio.Event = field(default_factory=asyncio.Event)
  131. pump_task: asyncio.Task | None = None
  132. win_counts: Dict[str, int] = field(default_factory=dict)
  133. async def start(self) -> None:
  134. meta = encode_json({"host": self.target_host, "port": self.target_port})
  135. for link in self.links:
  136. link.tcp_sessions[(self.session_id, self.stream_id)] = self
  137. await link.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, meta))
  138. await asyncio.wait_for(self.open_event.wait(), timeout=10)
  139. if self.opened == 0:
  140. raise ConnectionError(self.open_errors[0] if self.open_errors else "all relays failed")
  141. self.pump_task = asyncio.create_task(self._pump_local())
  142. async def _pump_local(self) -> None:
  143. try:
  144. while True:
  145. chunk = await self.local_reader.read(65536)
  146. if not chunk:
  147. break
  148. self.uplink_bytes += len(chunk)
  149. if self.winning_link is None and self.uplink_bytes <= self.warmup_bytes:
  150. await asyncio.gather(*(link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk)) for link in self.links if not link.closed), return_exceptions=True)
  151. else:
  152. if self.winning_link is None:
  153. await self.winner_event.wait()
  154. if self.winning_link:
  155. await self.winning_link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk))
  156. except Exception:
  157. pass
  158. finally:
  159. await self.close()
  160. async def handle_frame(self, link: RelayLink, frame: Frame) -> None:
  161. if self.closed:
  162. return
  163. if frame.kind == TCP_STATUS:
  164. if frame.packet_id == STATUS_OK:
  165. self.opened += 1
  166. else:
  167. self.open_errors.append(frame.payload.decode("utf-8", errors="replace"))
  168. if self.opened > 0 or len(self.open_errors) == len(self.links):
  169. self.open_event.set()
  170. return
  171. if frame.kind == TCP_DATA:
  172. if self.winning_link is None:
  173. self.winning_link = link
  174. self.winner_name = link.node.name
  175. self.win_counts[link.node.name] = self.win_counts.get(link.node.name, 0) + 1
  176. node_total = self.win_counts[link.node.name]
  177. relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.win_counts.items())) or "none"
  178. print(f"[edge] tcp win session={self.session_id} target={self.target_host}:{self.target_port} winner={link.node.name} node_total={node_total} win_breakdown={relay_detail}")
  179. self.winner_event.set()
  180. await self._close_losers(except_link=link)
  181. if link is self.winning_link:
  182. self.local_writer.write(frame.payload)
  183. await self.local_writer.drain()
  184. return
  185. if frame.kind == TCP_CLOSE:
  186. if self.winning_link is None:
  187. self.winning_link = link
  188. self.winner_event.set()
  189. if link is self.winning_link:
  190. await self.close()
  191. async def _close_losers(self, except_link: RelayLink) -> None:
  192. await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if link is not except_link and not link.closed), return_exceptions=True)
  193. async def close(self) -> None:
  194. if self.closed:
  195. return
  196. self.closed = True
  197. if self.pump_task and self.pump_task is not asyncio.current_task():
  198. self.pump_task.cancel()
  199. with contextlib.suppress(Exception):
  200. await self.pump_task
  201. await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if not link.closed), return_exceptions=True)
  202. for link in self.links:
  203. link.tcp_sessions.pop((self.session_id, self.stream_id), None)
  204. self.local_writer.close()
  205. with contextlib.suppress(Exception):
  206. await self.local_writer.wait_closed()
  207. class UdpAssociateServer(asyncio.DatagramProtocol):
  208. def __init__(self, edge: "SocksEdge") -> None:
  209. self.edge = edge
  210. self.transport: asyncio.DatagramTransport | None = None
  211. self.client_addr = None
  212. self.associate_peer = None
  213. self.packet_counter = itertools.count(1)
  214. self.client_flows: dict[tuple[tuple[str, int], str, int], UdpFlowState] = {}
  215. self.flow_counter = itertools.count(1)
  216. self.last_summary_at = 0.0
  217. self.win_counts: Dict[str, int] = {}
  218. self.relay_error_counts: Dict[str, int] = {}
  219. def connection_made(self, transport) -> None:
  220. self.transport = transport
  221. def register_associate(self, peer) -> None:
  222. peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer)
  223. if self.associate_peer != peer_text:
  224. print(f"[edge] udp associate peer={peer_text}")
  225. self.associate_peer = peer_text
  226. def _client_flow_key(self, addr, host: str, port: int) -> tuple[tuple[str, int], str, int]:
  227. # Treat port rebinding from the same client IP as the same logical UDP client.
  228. return ((addr[0], 0), host, port)
  229. def datagram_received(self, data: bytes, addr) -> None:
  230. if len(data) < 10:
  231. return
  232. if self.client_addr is None:
  233. self.client_addr = addr
  234. print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}")
  235. elif addr != self.client_addr:
  236. if addr[0] == self.client_addr[0]:
  237. print(f"[edge] udp client port update host={addr[0]} old_port={self.client_addr[1]} new_port={addr[1]}")
  238. self.client_addr = addr
  239. else:
  240. print(f"[edge] udp client rebound old={self.client_addr[0]}:{self.client_addr[1]} new={addr[0]}:{addr[1]}")
  241. self._reset_client_state(addr)
  242. host, port, payload = self._parse_socks_udp(data)
  243. loop = asyncio.get_running_loop()
  244. now = loop.time()
  245. flow_key = self._client_flow_key(addr, host, port)
  246. flow = self.client_flows.get(flow_key)
  247. if flow is None:
  248. flow = UdpFlowState(
  249. flow_id=next(self.flow_counter),
  250. client_addr=(addr[0], addr[1]),
  251. target_host=host,
  252. target_port=port,
  253. created_at=now,
  254. last_activity=now,
  255. )
  256. self.client_flows[flow_key] = flow
  257. flow.touch(now)
  258. flow.client_addr = (addr[0], addr[1])
  259. flow.packets_sent += 1
  260. packet_id = next(self.packet_counter)
  261. flow.packet_client_addrs[packet_id] = (addr[0], addr[1])
  262. asyncio.create_task(self.edge.forward_udp(flow, payload, packet_id, (addr[0], addr[1]), self))
  263. self._log_udp_summary()
  264. def _reset_client_state(self, addr) -> None:
  265. old_addr = self.client_addr
  266. remapped_flows: dict[tuple[tuple[str, int], str, int], UdpFlowState] = {}
  267. for flow in list(self.client_flows.values()):
  268. flow.client_addr = (addr[0], addr[1])
  269. remapped_flows[self._client_flow_key(addr, flow.target_host, flow.target_port)] = flow
  270. self.client_flows = remapped_flows
  271. self.client_addr = addr
  272. print(f"[edge] udp client rebound migrated old={old_addr[0]}:{old_addr[1]} new={addr[0]}:{addr[1]} flows={len(self.client_flows)}")
  273. async def handle_from_relay(self, frame: Frame, link: RelayLink) -> None:
  274. if self.transport is None or self.client_addr is None:
  275. return
  276. flow = self.edge.udp_flow_sessions.get((frame.session_id, frame.stream_id))
  277. if flow is None:
  278. return
  279. await self._deliver_flow_packet(flow, frame.packet_id, frame.payload, link.node.name)
  280. async def handle_from_direct(self, flow: UdpFlowState, path_name: str, payload: bytes, packet_id: int = 0, client_addr: tuple[str, int] | None = None) -> None:
  281. if self.transport is None or self.client_addr is None:
  282. return
  283. await self._deliver_flow_packet(flow, packet_id, payload, path_name, client_addr)
  284. async def _deliver_flow_packet(self, flow: UdpFlowState, packet_id: int, payload: bytes, source_name: str, client_addr: tuple[str, int] | None = None) -> None:
  285. if self.transport is None or self.client_addr is None:
  286. return
  287. packet = self._build_socks_udp(flow.target_host, flow.target_port, payload)
  288. now = asyncio.get_running_loop().time()
  289. flow.touch(now)
  290. flow.path_last_seen[source_name] = now
  291. flow.packets_received += 1
  292. target_addr = client_addr or flow.packet_client_addrs.pop(packet_id, None) or flow.client_addr
  293. if flow.winner_name is None:
  294. flow.winner_name = source_name
  295. flow.winner_miss_streak = 0
  296. self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1
  297. self._log_udp_summary(force=True)
  298. elif flow.winner_name != source_name:
  299. flow.duplicate_responses += 1
  300. winner_last_seen = flow.path_last_seen.get(flow.winner_name, 0.0)
  301. if winner_last_seen and now - winner_last_seen >= (self.edge.config.udp_failover_idle_ms / 1000):
  302. flow.winner_name = source_name
  303. flow.winner_miss_streak = 0
  304. self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1
  305. self._log_udp_summary(force=True)
  306. else:
  307. flow.winner_miss_streak = 0
  308. if flow.winner_name == source_name:
  309. if target_addr is not None:
  310. self.transport.sendto(packet, target_addr)
  311. def set_flow_candidates(self, flow: UdpFlowState, candidate_names: tuple[str, ...]) -> None:
  312. if not flow.candidate_names:
  313. flow.candidate_names = candidate_names
  314. def note_unsent(self, flow: UdpFlowState, packet_id: int) -> None:
  315. flow.touch(asyncio.get_running_loop().time())
  316. flow.relay_failures["unsent"] = flow.relay_failures.get("unsent", 0) + 1
  317. self._log_udp_summary(force=True)
  318. def _log_udp_summary(self, force: bool = False) -> None:
  319. now = asyncio.get_running_loop().time()
  320. if not force and now - self.last_summary_at < 10:
  321. return
  322. self.last_summary_at = now
  323. active_flows = len(self.client_flows)
  324. winners = sum(1 for flow in self.client_flows.values() if flow.winner_name)
  325. packets_sent = sum(flow.packets_sent for flow in self.client_flows.values())
  326. packets_received = sum(flow.packets_received for flow in self.client_flows.values())
  327. duplicates = sum(flow.duplicate_responses for flow in self.client_flows.values())
  328. direct_paths = sum(len(flow.direct_sockets) for flow in self.client_flows.values())
  329. relay_candidates = sum(len(flow.link_streams) for flow in self.client_flows.values())
  330. candidate_names: list[str] = []
  331. seen_candidates: set[str] = set()
  332. for flow in sorted(self.client_flows.values(), key=lambda item: item.flow_id):
  333. for name in flow.candidate_names:
  334. if name in seen_candidates:
  335. continue
  336. seen_candidates.add(name)
  337. candidate_names.append(name)
  338. direct_wins = sum(1 for flow in self.client_flows.values() if flow.winner_name and flow.winner_name.startswith("direct"))
  339. relay_wins = winners - direct_wins
  340. sample_flows = [
  341. f"{flow.flow_id}:{flow.winner_name or 'pending'}"
  342. for flow in sorted(self.client_flows.values(), key=lambda item: item.flow_id)
  343. if flow.winner_name
  344. ][:5]
  345. winner_detail = ", ".join(sample_flows) or "none"
  346. relay_errors: list[str] = []
  347. for flow in self.client_flows.values():
  348. for name, count in flow.relay_failures.items():
  349. relay_errors.append(f"{name}={count}")
  350. relay_error_detail = ", ".join(sorted(relay_errors)) or "none"
  351. if self.client_addr:
  352. print(
  353. f"[edge] udp summary bind={self.client_addr[0]}:{self.client_addr[1]} flows={active_flows} winners={winners} "
  354. f"winner_breakdown=direct={direct_wins},relay={relay_wins} sample={winner_detail} "
  355. f"candidates={candidate_names or ['none']} "
  356. f"sent={packets_sent} recv={packets_received} dup={duplicates} "
  357. f"direct_paths={direct_paths} relay_paths={relay_candidates} relay_errors={relay_error_detail}"
  358. )
  359. else:
  360. print(
  361. f"[edge] udp summary bind=unbound flows={active_flows} winners={winners} "
  362. f"winner_breakdown=direct={direct_wins},relay={relay_wins} sample={winner_detail} "
  363. f"candidates={candidate_names or ['none']} "
  364. f"sent={packets_sent} recv={packets_received} dup={duplicates} "
  365. f"direct_paths={direct_paths} relay_paths={relay_candidates} relay_errors={relay_error_detail}"
  366. )
  367. def _parse_socks_udp(self, packet: bytes) -> tuple[str, int, bytes]:
  368. atyp = packet[3]
  369. offset = 4
  370. if atyp == 1:
  371. host = socket.inet_ntoa(packet[offset:offset + 4])
  372. offset += 4
  373. elif atyp == 3:
  374. size = packet[offset]
  375. offset += 1
  376. host = packet[offset:offset + size].decode()
  377. offset += size
  378. else:
  379. raise ValueError("unsupported udp atyp")
  380. port = struct.unpack("!H", packet[offset:offset + 2])[0]
  381. offset += 2
  382. return host, port, packet[offset:]
  383. def _build_socks_udp(self, host: str, port: int, payload: bytes) -> bytes:
  384. try:
  385. addr = socket.inet_aton(host)
  386. header = b"\x00\x00\x00\x01" + addr + struct.pack("!H", port)
  387. except OSError:
  388. raw = host.encode()
  389. header = b"\x00\x00\x00\x03" + bytes([len(raw)]) + raw + struct.pack("!H", port)
  390. return header + payload
  391. class SocksEdge:
  392. def __init__(self, listen_host: str, listen_port: int, config: Config) -> None:
  393. self.listen_host = listen_host
  394. self.listen_port = listen_port
  395. self.config = config
  396. self.scheduler = Scheduler(config)
  397. self.links: list[RelayLink] = []
  398. self.session_ids = itertools.count(1)
  399. self.udp_stream_ids = itertools.count(1)
  400. self.udp_flow_sessions: dict[tuple[int, int], UdpFlowState] = {}
  401. self.udp_server: UdpAssociateServer | None = None
  402. async def start(self) -> None:
  403. await self.scheduler.start()
  404. await self._connect_relays()
  405. server = await asyncio.start_server(self._accept, self.listen_host, self.listen_port)
  406. sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
  407. print(f"[edge] socks5 listening on {sockets}")
  408. async with server:
  409. await server.serve_forever()
  410. async def _connect_relays(self) -> None:
  411. loop = asyncio.get_running_loop()
  412. transport, protocol = await loop.create_datagram_endpoint(lambda: UdpAssociateServer(self), local_addr=(self.listen_host, 0))
  413. self.udp_server = protocol
  414. self.udp_transport = transport
  415. for node in self.config.relays:
  416. link = RelayLink(node=node, reader=None, writer=None) # type: ignore[arg-type]
  417. link.udp_server = protocol
  418. self.links.append(link)
  419. link.maintain_task = asyncio.create_task(self._maintain_link(link))
  420. async def _maintain_link(self, link: RelayLink) -> None:
  421. backoff = 1.0
  422. while True:
  423. try:
  424. reader, writer = await asyncio.open_connection(link.node.host, link.node.port)
  425. sock = writer.get_extra_info("socket")
  426. if sock is not None:
  427. with contextlib.suppress(OSError):
  428. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  429. link.reader = reader
  430. link.writer = writer
  431. await link.start()
  432. backoff = 1.0
  433. await link.closed_event.wait()
  434. except asyncio.CancelledError:
  435. raise
  436. except Exception:
  437. await asyncio.sleep(backoff)
  438. backoff = min(10.0, backoff * 2)
  439. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  440. try:
  441. peer = writer.get_extra_info("peername")
  442. _host, _port, udp_mode = await self._handshake(reader, writer, peer)
  443. if udp_mode:
  444. return
  445. except Exception:
  446. writer.close()
  447. with contextlib.suppress(Exception):
  448. await writer.wait_closed()
  449. def _selected_links(self) -> list[RelayLink]:
  450. chosen = {node.name for node in self.scheduler.choose()}
  451. links = [link for link in self.links if link.node.name in chosen and not link.closed and link.supports_tcp]
  452. return links or [link for link in self.links if not link.closed][:1]
  453. def _selected_udp_links(self) -> list[RelayLink]:
  454. online = [link for link in self.links if not link.closed and link.writer is not None and link.supports_udp]
  455. if not online:
  456. return []
  457. ordered = sorted(online, key=lambda link: self.scheduler.scores.get(link.node.name).score if link.node.name in self.scheduler.scores else 999999.0)
  458. return ordered
  459. def _udp_direct_redundancy_for_target(self, target_host: str) -> int:
  460. base = self.config.udp_direct_redundancy
  461. if ":" in target_host and self.config.udp_direct_redundancy_v6 is not None:
  462. base = self.config.udp_direct_redundancy_v6
  463. elif ":" not in target_host and self.config.udp_direct_redundancy_v4 is not None:
  464. base = self.config.udp_direct_redundancy_v4
  465. return max(1, base)
  466. async def _ensure_udp_direct_paths(self, flow: UdpFlowState, udp_server: UdpAssociateServer) -> None:
  467. target_count = self._udp_direct_redundancy_for_target(flow.target_host)
  468. for index in range(target_count):
  469. name = f"direct-{index + 1}" if target_count > 1 else "direct"
  470. if name in flow.direct_sockets or name in flow.direct_failures:
  471. continue
  472. try:
  473. family = socket.AF_INET6 if ":" in flow.target_host else socket.AF_INET
  474. sock = socket.socket(family, socket.SOCK_DGRAM)
  475. sock.setblocking(False)
  476. await asyncio.get_running_loop().sock_connect(sock, (flow.target_host, flow.target_port))
  477. flow.direct_sockets[name] = sock
  478. flow.direct_tasks[name] = asyncio.create_task(self._pump_udp_direct(flow, name, sock, udp_server))
  479. except Exception as exc:
  480. flow.direct_failures.add(name)
  481. print(f"[edge] udp direct open error flow={flow.flow_id} path={name} target={flow.target_host}:{flow.target_port} error={exc!r}")
  482. async def _pump_udp_direct(self, flow: UdpFlowState, path_name: str, sock: socket.socket, udp_server: UdpAssociateServer) -> None:
  483. loop = asyncio.get_running_loop()
  484. try:
  485. while True:
  486. data = await loop.sock_recv(sock, 65535)
  487. if not data:
  488. break
  489. pending = flow.direct_pending_clients.get(path_name)
  490. packet_id = 0
  491. client_addr = flow.client_addr
  492. if pending:
  493. packet_id, client_addr = pending.popleft()
  494. await udp_server.handle_from_direct(flow, path_name, data, packet_id, client_addr)
  495. except Exception:
  496. pass
  497. finally:
  498. flow.direct_tasks.pop(path_name, None)
  499. flow.direct_sockets.pop(path_name, None)
  500. with contextlib.suppress(Exception):
  501. sock.close()
  502. async def forward_udp(self, flow: UdpFlowState, payload: bytes, packet_id: int, client_addr: tuple[str, int], udp_server: UdpAssociateServer) -> None:
  503. await self._ensure_udp_direct_paths(flow, udp_server)
  504. meta = encode_json({"host": flow.target_host, "port": flow.target_port})
  505. links = self._selected_udp_links()
  506. direct_names = tuple(name for name in sorted(flow.direct_sockets))
  507. relay_names = tuple(link.node.name for link in links)
  508. candidate_names = direct_names + relay_names
  509. udp_server.set_flow_candidates(flow, candidate_names)
  510. if not candidate_names:
  511. udp_server.note_unsent(flow, packet_id)
  512. return
  513. active_direct_names = list(direct_names)
  514. active_links = links
  515. now = asyncio.get_running_loop().time()
  516. warmup_mode = flow.packets_sent <= UDP_WARMUP_BROADCAST_PACKETS
  517. shadow_probe = (
  518. flow.winner_name is not None
  519. and now - flow.last_probe_at >= UDP_SHADOW_PROBE_INTERVAL_SEC
  520. )
  521. if shadow_probe:
  522. flow.last_probe_at = now
  523. broadcast_mode = self.config.udp_always_broadcast or flow.winner_name is None or warmup_mode or shadow_probe
  524. if not broadcast_mode:
  525. winner_last_seen = flow.path_last_seen.get(flow.winner_name, 0.0) if flow.winner_name else 0.0
  526. winner_stale = bool(winner_last_seen and now - winner_last_seen >= (self.config.udp_failover_idle_ms / 1000))
  527. if not winner_stale:
  528. flow.winner_miss_streak += 1
  529. if winner_stale or flow.winner_miss_streak >= UDP_FAST_FAILOVER_MISSES:
  530. flow.winner_name = None
  531. flow.winner_miss_streak = 0
  532. broadcast_mode = True
  533. else:
  534. active_direct_names = [name for name in active_direct_names if name == flow.winner_name]
  535. active_links = [link for link in active_links if link.node.name == flow.winner_name]
  536. if not active_direct_names and not active_links:
  537. if direct_names:
  538. active_direct_names = [direct_names[0]]
  539. elif links:
  540. active_links = links[:1]
  541. copies = max(1, self.config.udp_redundancy + 1)
  542. sent_any = False
  543. for attempt in range(copies):
  544. for path_name in active_direct_names:
  545. sock = flow.direct_sockets.get(path_name)
  546. if sock is None:
  547. continue
  548. try:
  549. flow.direct_pending_clients.setdefault(path_name, deque()).append((packet_id, client_addr))
  550. await asyncio.get_running_loop().sock_sendall(sock, payload)
  551. sent_any = True
  552. except Exception as exc:
  553. pending = flow.direct_pending_clients.get(path_name)
  554. if pending:
  555. with contextlib.suppress(Exception):
  556. pending.pop()
  557. flow.direct_failures.add(path_name)
  558. flow.direct_sockets.pop(path_name, None)
  559. task = flow.direct_tasks.pop(path_name, None)
  560. if task is not None:
  561. task.cancel()
  562. with contextlib.suppress(Exception):
  563. sock.close()
  564. flow.relay_failures[path_name] = flow.relay_failures.get(path_name, 0) + 1
  565. if path_name not in flow.relay_error_seen:
  566. flow.relay_error_seen.add(path_name)
  567. print(
  568. f"[edge] udp relay error flow={flow.flow_id} relay={path_name} error={exc!r}"
  569. )
  570. for link in active_links:
  571. stream_id = flow.link_streams.get(link.node.name)
  572. if stream_id is None:
  573. stream_id = next(self.udp_stream_ids)
  574. flow.link_streams[link.node.name] = stream_id
  575. self.udp_flow_sessions[(flow.flow_id, stream_id)] = flow
  576. include_meta = link.node.name not in flow.initialized_links
  577. body = (meta + payload) if include_meta else payload
  578. meta_len = len(meta) if include_meta else 0
  579. try:
  580. await link.send(Frame(UDP_SEND, flow.flow_id, stream_id, 0, meta_len, body))
  581. flow.initialized_links.add(link.node.name)
  582. sent_any = True
  583. except Exception as exc:
  584. flow.link_streams.pop(link.node.name, None)
  585. self.udp_flow_sessions.pop((flow.flow_id, stream_id), None)
  586. flow.relay_failures[link.node.name] = flow.relay_failures.get(link.node.name, 0) + 1
  587. if link.node.name not in flow.relay_error_seen:
  588. flow.relay_error_seen.add(link.node.name)
  589. print(
  590. f"[edge] udp relay error flow={flow.flow_id} relay={link.node.name} error={exc!r}"
  591. )
  592. if attempt + 1 < copies and self.config.udp_copy_interval_ms > 0:
  593. await asyncio.sleep(self.config.udp_copy_interval_ms / 1000)
  594. if not sent_any:
  595. udp_server.note_unsent(flow, packet_id)
  596. async def _handshake(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer) -> tuple[str, int, bool]:
  597. version, methods_len = (await read_exact(reader, 2))
  598. if version != SOCKS_VERSION:
  599. raise ValueError("unsupported socks version")
  600. await read_exact(reader, methods_len)
  601. writer.write(b"\x05\x00")
  602. await writer.drain()
  603. version, command, _, atyp = await read_exact(reader, 4)
  604. if version != SOCKS_VERSION:
  605. raise ValueError("unsupported socks version")
  606. if atyp == 1:
  607. host = socket.inet_ntoa(await read_exact(reader, 4))
  608. elif atyp == 3:
  609. size = (await read_exact(reader, 1))[0]
  610. host = (await read_exact(reader, size)).decode()
  611. else:
  612. raise ValueError("unsupported atyp")
  613. port = struct.unpack("!H", await read_exact(reader, 2))[0]
  614. peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer)
  615. if command == 1:
  616. print(f"[edge] socks handshake peer={peer_text} command=connect target={host}:{port}")
  617. writer.write(b"\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00")
  618. await writer.drain()
  619. return host, port, False
  620. if command == 3 and self.udp_server and self.udp_server.transport:
  621. bind_host, bind_port = self.udp_server.transport.get_extra_info("sockname")[:2]
  622. self.udp_server.register_associate(peer)
  623. print(f"[edge] socks handshake peer={peer_text} command=udp_associate target={host}:{port} bind={bind_host}:{bind_port}")
  624. writer.write(b"\x05\x00\x00\x01" + socket.inet_aton(bind_host) + struct.pack("!H", bind_port))
  625. await writer.drain()
  626. return host, port, True
  627. raise ValueError("unsupported socks command")