edge_udp.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import itertools
  5. import json
  6. import socket
  7. import struct
  8. from collections import deque
  9. from dataclasses import dataclass, field
  10. from typing import Dict
  11. from .config_udp import UdpConfig, UdpRelayNode
  12. from .logging_utils import log_print as print
  13. from .scheduler_udp import UdpScheduler
  14. from .protocol import AUTH, STATUS_ERR, STATUS_OK, UDP_RECV, UDP_SEND, Frame, read_frame, write_frame, encode_json
  15. SOCKS_VERSION = 5
  16. UDP_WARMUP_BROADCAST_PACKETS = 6
  17. UDP_SHADOW_PROBE_INTERVAL_SEC = 0.25
  18. UDP_FAST_FAILOVER_MISSES = 3
  19. UDP_FLOW_IDLE_CLEANUP_SEC = 30.0
  20. UDP_PACKET_CLIENT_MAP_LIMIT = 4096
  21. UDP_DIRECT_PENDING_LIMIT = 128
  22. UDP_SOCKET_BUFFER_BYTES = 1 << 20
  23. async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes:
  24. return await reader.readexactly(size)
  25. @dataclass(eq=False)
  26. class RelayLink:
  27. node: UdpRelayNode
  28. reader: asyncio.StreamReader
  29. writer: asyncio.StreamWriter
  30. pump: asyncio.Task | None = None
  31. closed_event: asyncio.Event = field(default_factory=asyncio.Event)
  32. maintain_task: asyncio.Task | None = None
  33. udp_server: "UdpAssociateServer | None" = None
  34. closed: bool = False
  35. supports_udp: bool = True
  36. async def start(self) -> None:
  37. await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
  38. frame = await read_frame(self.reader)
  39. if frame.kind != AUTH or frame.packet_id != STATUS_OK:
  40. raise ConnectionError(f"relay auth failed: {self.node.name}")
  41. ack = {}
  42. if frame.payload:
  43. try:
  44. ack = json.loads(frame.payload.decode("utf-8"))
  45. except Exception:
  46. ack = {}
  47. self.supports_udp = True
  48. self.closed = False
  49. self.closed_event.clear()
  50. self.pump = asyncio.create_task(self._pump())
  51. if ack:
  52. print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port} udp_only={ack.get('udp_only', True)}")
  53. async def _pump(self) -> None:
  54. try:
  55. while True:
  56. try:
  57. frame = await read_frame(self.reader)
  58. except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, OSError):
  59. break
  60. if frame.kind == UDP_RECV and self.udp_server:
  61. await self.udp_server.handle_from_relay(frame, self)
  62. finally:
  63. await self.close()
  64. async def send(self, frame: Frame) -> None:
  65. if self.closed:
  66. raise ConnectionError(f"relay closed: {self.node.name}")
  67. try:
  68. await write_frame(self.writer, frame)
  69. except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError) as exc:
  70. await self.close()
  71. raise ConnectionError(f"relay closed: {self.node.name}") from exc
  72. async def close(self) -> None:
  73. if self.closed:
  74. return
  75. self.closed = True
  76. self.closed_event.set()
  77. if self.pump and self.pump is not asyncio.current_task():
  78. self.pump.cancel()
  79. with contextlib.suppress(Exception):
  80. await self.pump
  81. self.writer.close()
  82. with contextlib.suppress(Exception):
  83. await self.writer.wait_closed()
  84. @dataclass
  85. class UdpFlowState:
  86. flow_id: int
  87. client_addr: tuple[str, int]
  88. target_host: str
  89. target_port: int
  90. created_at: float
  91. last_activity: float
  92. packets_sent: int = 0
  93. packets_received: int = 0
  94. duplicate_responses: int = 0
  95. winner_name: str | None = None
  96. candidate_names: tuple[str, ...] = ()
  97. link_streams: dict[str, int] = field(default_factory=dict)
  98. initialized_links: set[str] = field(default_factory=set)
  99. direct_sockets: dict[str, socket.socket] = field(default_factory=dict)
  100. direct_tasks: dict[str, asyncio.Task] = field(default_factory=dict)
  101. direct_failures: set[str] = field(default_factory=set)
  102. relay_failures: dict[str, int] = field(default_factory=dict)
  103. relay_error_seen: set[str] = field(default_factory=set)
  104. path_last_seen: dict[str, float] = field(default_factory=dict)
  105. packet_client_addrs: dict[int, tuple[str, int]] = field(default_factory=dict)
  106. direct_pending_clients: dict[str, deque[tuple[int, tuple[str, int]]]] = field(default_factory=dict)
  107. last_probe_at: float = 0.0
  108. winner_miss_streak: int = 0
  109. target_family: int = 0
  110. last_cleanup_at: float = 0.0
  111. def touch(self, now: float) -> None:
  112. self.last_activity = now
  113. class UdpAssociateServer(asyncio.DatagramProtocol):
  114. def __init__(self, edge: "UdpEdge") -> None:
  115. self.edge = edge
  116. self.transport: asyncio.DatagramTransport | None = None
  117. self.client_addr = None
  118. self.associate_peer = None
  119. self.packet_counter = itertools.count(1)
  120. self.last_packet_id = 0
  121. self.client_flows: dict[tuple[tuple[str, int], str, int], UdpFlowState] = {}
  122. self.flow_counter = itertools.count(1)
  123. self.last_summary_at = 0.0
  124. self.win_counts: Dict[str, int] = {}
  125. self._last_client_port_log_at = 0.0
  126. self._last_flow_cleanup_at = 0.0
  127. def connection_made(self, transport) -> None:
  128. self.transport = transport
  129. def register_associate(self, peer) -> None:
  130. peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer)
  131. if self.associate_peer != peer_text:
  132. print(f"[edge] udp associate peer={peer_text}")
  133. self.associate_peer = peer_text
  134. def _client_flow_key(self, addr, host: str, port: int) -> tuple[tuple[str, int], str, int]:
  135. return ((addr[0], 0), host, port)
  136. def datagram_received(self, data: bytes, addr) -> None:
  137. if len(data) < 10:
  138. return
  139. if self.client_addr is None:
  140. self.client_addr = addr
  141. print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}")
  142. elif addr != self.client_addr:
  143. if addr[0] == self.client_addr[0]:
  144. now = asyncio.get_running_loop().time()
  145. if now - self._last_client_port_log_at >= 30:
  146. self._last_client_port_log_at = now
  147. print(f"[edge] udp client port update host={addr[0]} old_port={self.client_addr[1]} new_port={addr[1]}")
  148. self.client_addr = addr
  149. else:
  150. print(f"[edge] udp client rebound old={self.client_addr[0]}:{self.client_addr[1]} new={addr[0]}:{addr[1]}")
  151. self._reset_client_state(addr)
  152. host, port, payload = self._parse_socks_udp(data)
  153. now = asyncio.get_running_loop().time()
  154. flow_key = self._client_flow_key(addr, host, port)
  155. flow = self.client_flows.get(flow_key)
  156. if flow is None:
  157. family = socket.AF_INET6 if ":" in host else socket.AF_INET
  158. flow = UdpFlowState(next(self.flow_counter), (addr[0], addr[1]), host, port, now, now, target_family=family)
  159. self.client_flows[flow_key] = flow
  160. flow.touch(now)
  161. flow.client_addr = (addr[0], addr[1])
  162. flow.packets_sent += 1
  163. packet_id = next(self.packet_counter)
  164. self.last_packet_id = packet_id
  165. flow.packet_client_addrs[packet_id] = (addr[0], addr[1])
  166. self._cleanup_packet_state(flow, now)
  167. asyncio.create_task(self.edge.forward_udp(flow, payload, packet_id, (addr[0], addr[1]), self))
  168. self._cleanup_inactive_flows(now)
  169. self._log_udp_summary()
  170. def _reset_client_state(self, addr) -> None:
  171. remapped: dict[tuple[tuple[str, int], str, int], UdpFlowState] = {}
  172. for flow in list(self.client_flows.values()):
  173. flow.client_addr = (addr[0], addr[1])
  174. remapped[self._client_flow_key(addr, flow.target_host, flow.target_port)] = flow
  175. self.client_flows = remapped
  176. self.client_addr = addr
  177. async def handle_from_relay(self, frame: Frame, link: RelayLink) -> None:
  178. if self.transport is None or self.client_addr is None:
  179. return
  180. flow = self.edge.udp_flow_sessions.get((frame.session_id, frame.stream_id))
  181. if flow is None:
  182. return
  183. if frame.packet_id == STATUS_ERR:
  184. flow.relay_failures[link.node.name] = flow.relay_failures.get(link.node.name, 0) + 1
  185. if link.node.name not in flow.relay_error_seen:
  186. flow.relay_error_seen.add(link.node.name)
  187. detail = frame.payload.decode("utf-8", errors="replace")
  188. print(f"[edge] udp relay error flow={flow.flow_id} relay={link.node.name} error={detail}")
  189. return
  190. await self._deliver_flow_packet(flow, frame.packet_id, frame.payload, link.node.name)
  191. async def handle_from_direct(self, flow: UdpFlowState, path_name: str, payload: bytes, packet_id: int = 0, client_addr: tuple[str, int] | None = None) -> None:
  192. if self.transport is None or self.client_addr is None:
  193. return
  194. await self._deliver_flow_packet(flow, packet_id, payload, path_name, client_addr)
  195. async def _deliver_flow_packet(self, flow: UdpFlowState, packet_id: int, payload: bytes, source_name: str, client_addr: tuple[str, int] | None = None) -> None:
  196. if self.transport is None or self.client_addr is None:
  197. return
  198. packet = self._build_socks_udp(flow.target_host, flow.target_port, payload)
  199. now = asyncio.get_running_loop().time()
  200. flow.touch(now)
  201. flow.path_last_seen[source_name] = now
  202. flow.packets_received += 1
  203. target_addr = client_addr or flow.packet_client_addrs.pop(packet_id, None) or flow.client_addr
  204. if flow.winner_name is None:
  205. flow.winner_name = source_name
  206. flow.winner_miss_streak = 0
  207. self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1
  208. self._log_udp_summary(force=True)
  209. elif flow.winner_name != source_name:
  210. flow.duplicate_responses += 1
  211. winner_last_seen = flow.path_last_seen.get(flow.winner_name, 0.0)
  212. if winner_last_seen and now - winner_last_seen >= (self.edge.config.udp_failover_idle_ms / 1000):
  213. flow.winner_name = source_name
  214. flow.winner_miss_streak = 0
  215. self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1
  216. self._log_udp_summary(force=True)
  217. else:
  218. flow.winner_miss_streak = 0
  219. if flow.winner_name == source_name and target_addr is not None:
  220. if flow.packets_received == 1:
  221. print(
  222. f"[edge] udp relay reply flow={flow.flow_id} relay={source_name} "
  223. f"target={flow.target_host}:{flow.target_port} bytes={len(payload)}"
  224. )
  225. self.transport.sendto(packet, target_addr)
  226. def _cleanup_packet_state(self, flow: UdpFlowState, now: float) -> None:
  227. if flow.last_cleanup_at and now - flow.last_cleanup_at < 1.0:
  228. return
  229. flow.last_cleanup_at = now
  230. expired_packet_ids = [
  231. packet_id
  232. for packet_id in flow.packet_client_addrs
  233. if packet_id <= (self.last_packet_id - UDP_PACKET_CLIENT_MAP_LIMIT)
  234. ]
  235. for packet_id in expired_packet_ids:
  236. flow.packet_client_addrs.pop(packet_id, None)
  237. for path_name, pending in list(flow.direct_pending_clients.items()):
  238. while len(pending) > UDP_DIRECT_PENDING_LIMIT:
  239. pending.popleft()
  240. if not pending:
  241. flow.direct_pending_clients.pop(path_name, None)
  242. def _cleanup_inactive_flows(self, now: float) -> None:
  243. if self._last_flow_cleanup_at and now - self._last_flow_cleanup_at < 5.0:
  244. return
  245. self._last_flow_cleanup_at = now
  246. expired_keys = [
  247. key
  248. for key, flow in self.client_flows.items()
  249. if now - flow.last_activity >= UDP_FLOW_IDLE_CLEANUP_SEC
  250. ]
  251. for key in expired_keys:
  252. flow = self.client_flows.pop(key, None)
  253. if flow is None:
  254. continue
  255. self.edge.release_udp_flow(flow)
  256. def set_flow_candidates(self, flow: UdpFlowState, candidate_names: tuple[str, ...]) -> None:
  257. if not flow.candidate_names:
  258. flow.candidate_names = candidate_names
  259. def note_unsent(self, flow: UdpFlowState, _packet_id: int) -> None:
  260. flow.touch(asyncio.get_running_loop().time())
  261. flow.relay_failures["unsent"] = flow.relay_failures.get("unsent", 0) + 1
  262. self._log_udp_summary(force=True)
  263. def _log_udp_summary(self, force: bool = False) -> None:
  264. now = asyncio.get_running_loop().time()
  265. if not force and now - self.last_summary_at < 10:
  266. return
  267. self.last_summary_at = now
  268. active_flows = len(self.client_flows)
  269. winners = sum(1 for flow in self.client_flows.values() if flow.winner_name)
  270. packets_sent = sum(flow.packets_sent for flow in self.client_flows.values())
  271. packets_received = sum(flow.packets_received for flow in self.client_flows.values())
  272. duplicates = sum(flow.duplicate_responses for flow in self.client_flows.values())
  273. direct_paths = sum(len(flow.direct_sockets) for flow in self.client_flows.values())
  274. relay_candidates = sum(len(flow.link_streams) for flow in self.client_flows.values())
  275. candidate_names: list[str] = []
  276. seen_candidates: set[str] = set()
  277. for flow in sorted(self.client_flows.values(), key=lambda item: item.flow_id):
  278. for name in flow.candidate_names:
  279. if name in seen_candidates:
  280. continue
  281. seen_candidates.add(name)
  282. candidate_names.append(name)
  283. direct_wins = sum(1 for flow in self.client_flows.values() if flow.winner_name and flow.winner_name.startswith("direct"))
  284. relay_wins = winners - direct_wins
  285. sample_flows = [f"{flow.flow_id}:{flow.winner_name or 'pending'}" for flow in sorted(self.client_flows.values(), key=lambda item: item.flow_id) if flow.winner_name][:5]
  286. relay_errors: list[str] = []
  287. for flow in self.client_flows.values():
  288. for name, count in flow.relay_failures.items():
  289. relay_errors.append(f"{name}={count}")
  290. bind = f"{self.client_addr[0]}:{self.client_addr[1]}" if self.client_addr else "unbound"
  291. print(
  292. f"[edge] udp summary bind={bind} flows={active_flows} winners={winners} "
  293. f"winner_breakdown=direct={direct_wins},relay={relay_wins} sample={', '.join(sample_flows) or 'none'} "
  294. f"candidates={candidate_names or ['none']} sent={packets_sent} recv={packets_received} dup={duplicates} "
  295. f"direct_paths={direct_paths} relay_paths={relay_candidates} relay_errors={', '.join(sorted(relay_errors)) or 'none'}"
  296. )
  297. def _parse_socks_udp(self, packet: bytes) -> tuple[str, int, bytes]:
  298. atyp = packet[3]
  299. offset = 4
  300. if atyp == 1:
  301. host = socket.inet_ntoa(packet[offset:offset + 4])
  302. offset += 4
  303. elif atyp == 3:
  304. size = packet[offset]
  305. offset += 1
  306. host = packet[offset:offset + size].decode()
  307. offset += size
  308. else:
  309. raise ValueError("unsupported udp atyp")
  310. port = struct.unpack("!H", packet[offset:offset + 2])[0]
  311. offset += 2
  312. return host, port, packet[offset:]
  313. def _build_socks_udp(self, host: str, port: int, payload: bytes) -> bytes:
  314. try:
  315. addr = socket.inet_aton(host)
  316. header = b"\x00\x00\x00\x01" + addr + struct.pack("!H", port)
  317. except OSError:
  318. raw = host.encode()
  319. header = b"\x00\x00\x00\x03" + bytes([len(raw)]) + raw + struct.pack("!H", port)
  320. return header + payload
  321. class UdpEdge:
  322. def __init__(self, listen_host: str, listen_port: int, config: UdpConfig) -> None:
  323. self.listen_host = listen_host
  324. self.listen_port = listen_port
  325. self.config = config
  326. self.scheduler = UdpScheduler(config)
  327. self.links: list[RelayLink] = []
  328. self.udp_stream_ids = itertools.count(1)
  329. self.udp_flow_sessions: dict[tuple[int, int], UdpFlowState] = {}
  330. self.udp_server: UdpAssociateServer | None = None
  331. def _udp_direct_copies(self) -> int:
  332. if self.config.udp_direct_copies is not None:
  333. return max(1, self.config.udp_direct_copies)
  334. return max(1, self.config.udp_redundancy + 1)
  335. def _udp_relay_copies(self) -> int:
  336. if self.config.udp_relay_copies is not None:
  337. return max(1, self.config.udp_relay_copies)
  338. return max(1, self.config.udp_redundancy + 1)
  339. def release_udp_flow(self, flow: UdpFlowState) -> None:
  340. for stream_id in list(flow.link_streams.values()):
  341. self.udp_flow_sessions.pop((flow.flow_id, stream_id), None)
  342. flow.link_streams.clear()
  343. flow.initialized_links.clear()
  344. flow.packet_client_addrs.clear()
  345. for task in list(flow.direct_tasks.values()):
  346. task.cancel()
  347. flow.direct_tasks.clear()
  348. for sock in list(flow.direct_sockets.values()):
  349. with contextlib.suppress(Exception):
  350. sock.close()
  351. flow.direct_sockets.clear()
  352. flow.direct_pending_clients.clear()
  353. async def start(self) -> None:
  354. await self.scheduler.start()
  355. await self._connect_relays()
  356. server = await asyncio.start_server(self._accept, self.listen_host, self.listen_port)
  357. sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
  358. relay_mode = "direct-only" if not self.config.relays else "direct+relay"
  359. print(f"[edge] socks5 listening on {sockets} relay_mode={relay_mode}")
  360. async with server:
  361. await server.serve_forever()
  362. async def _connect_relays(self) -> None:
  363. loop = asyncio.get_running_loop()
  364. transport, protocol = await loop.create_datagram_endpoint(lambda: UdpAssociateServer(self), local_addr=(self.listen_host, 0))
  365. self.udp_server = protocol
  366. self.udp_transport = transport
  367. for node in self.config.relays:
  368. link = RelayLink(node=node, reader=None, writer=None) # type: ignore[arg-type]
  369. link.udp_server = protocol
  370. self.links.append(link)
  371. link.maintain_task = asyncio.create_task(self._maintain_link(link))
  372. async def _maintain_link(self, link: RelayLink) -> None:
  373. backoff = 1.0
  374. while True:
  375. try:
  376. reader, writer = await asyncio.open_connection(link.node.host, link.node.port)
  377. sock = writer.get_extra_info("socket")
  378. if sock is not None:
  379. with contextlib.suppress(OSError):
  380. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  381. link.reader = reader
  382. link.writer = writer
  383. await link.start()
  384. backoff = 1.0
  385. await link.closed_event.wait()
  386. except asyncio.CancelledError:
  387. raise
  388. except Exception:
  389. await asyncio.sleep(backoff)
  390. backoff = min(10.0, backoff * 2)
  391. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  392. try:
  393. peer = writer.get_extra_info("peername")
  394. _host, _port, udp_mode = await self._handshake(reader, writer, peer)
  395. if udp_mode:
  396. return
  397. except Exception:
  398. writer.close()
  399. with contextlib.suppress(Exception):
  400. await writer.wait_closed()
  401. def _selected_udp_links(self) -> list[RelayLink]:
  402. online = [link for link in self.links if not link.closed and link.writer is not None and link.supports_udp]
  403. if not online:
  404. return []
  405. return sorted(online, key=lambda link: self.scheduler.scores.get(link.node.name).score if link.node.name in self.scheduler.scores else 999999.0)
  406. def _udp_direct_redundancy_for_target(self, target_host: str) -> int:
  407. base = self.config.udp_direct_redundancy
  408. if ":" in target_host and self.config.udp_direct_redundancy_v6 is not None:
  409. base = self.config.udp_direct_redundancy_v6
  410. elif ":" not in target_host and self.config.udp_direct_redundancy_v4 is not None:
  411. base = self.config.udp_direct_redundancy_v4
  412. return max(1, base)
  413. async def _ensure_udp_direct_paths(self, flow: UdpFlowState, udp_server: UdpAssociateServer) -> None:
  414. target_count = self._udp_direct_redundancy_for_target(flow.target_host)
  415. for index in range(target_count):
  416. name = f"direct-{index + 1}" if target_count > 1 else "direct"
  417. if name in flow.direct_sockets or name in flow.direct_failures:
  418. continue
  419. try:
  420. family = socket.AF_INET6 if ":" in flow.target_host else socket.AF_INET
  421. sock = socket.socket(family, socket.SOCK_DGRAM)
  422. sock.setblocking(False)
  423. with contextlib.suppress(OSError):
  424. sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, UDP_SOCKET_BUFFER_BYTES)
  425. with contextlib.suppress(OSError):
  426. sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, UDP_SOCKET_BUFFER_BYTES)
  427. await asyncio.get_running_loop().sock_connect(sock, (flow.target_host, flow.target_port))
  428. flow.direct_sockets[name] = sock
  429. flow.direct_tasks[name] = asyncio.create_task(self._pump_udp_direct(flow, name, sock, udp_server))
  430. except Exception as exc:
  431. flow.direct_failures.add(name)
  432. print(f"[edge] udp direct open error flow={flow.flow_id} path={name} target={flow.target_host}:{flow.target_port} error={exc!r}")
  433. async def _pump_udp_direct(self, flow: UdpFlowState, path_name: str, sock: socket.socket, udp_server: UdpAssociateServer) -> None:
  434. loop = asyncio.get_running_loop()
  435. try:
  436. while True:
  437. data = await loop.sock_recv(sock, 65535)
  438. if not data:
  439. break
  440. pending = flow.direct_pending_clients.get(path_name)
  441. packet_id = 0
  442. client_addr = flow.client_addr
  443. if pending:
  444. packet_id, client_addr = pending.popleft()
  445. await udp_server.handle_from_direct(flow, path_name, data, packet_id, client_addr)
  446. finally:
  447. flow.direct_tasks.pop(path_name, None)
  448. flow.direct_sockets.pop(path_name, None)
  449. with contextlib.suppress(Exception):
  450. sock.close()
  451. async def forward_udp(self, flow: UdpFlowState, payload: bytes, packet_id: int, client_addr: tuple[str, int], udp_server: UdpAssociateServer) -> None:
  452. await self._ensure_udp_direct_paths(flow, udp_server)
  453. meta = encode_json({"host": flow.target_host, "port": flow.target_port, "family": flow.target_family})
  454. links = self._selected_udp_links()
  455. direct_names = tuple(name for name in sorted(flow.direct_sockets))
  456. relay_names = tuple(link.node.name for link in links)
  457. candidate_names = direct_names + relay_names
  458. udp_server.set_flow_candidates(flow, candidate_names)
  459. if not candidate_names:
  460. udp_server.note_unsent(flow, packet_id)
  461. return
  462. active_direct_names = list(direct_names)
  463. active_links = links
  464. now = asyncio.get_running_loop().time()
  465. warmup_mode = flow.packets_sent <= UDP_WARMUP_BROADCAST_PACKETS
  466. shadow_probe = flow.winner_name is not None and now - flow.last_probe_at >= UDP_SHADOW_PROBE_INTERVAL_SEC
  467. if shadow_probe:
  468. flow.last_probe_at = now
  469. broadcast_mode = self.config.udp_always_broadcast or flow.winner_name is None or warmup_mode or shadow_probe
  470. if not broadcast_mode:
  471. winner_last_seen = flow.path_last_seen.get(flow.winner_name, 0.0) if flow.winner_name else 0.0
  472. winner_stale = bool(winner_last_seen and now - winner_last_seen >= (self.config.udp_failover_idle_ms / 1000))
  473. if not winner_stale:
  474. flow.winner_miss_streak += 1
  475. if winner_stale or flow.winner_miss_streak >= UDP_FAST_FAILOVER_MISSES:
  476. flow.winner_name = None
  477. flow.winner_miss_streak = 0
  478. broadcast_mode = True
  479. else:
  480. active_direct_names = [name for name in active_direct_names if name == flow.winner_name]
  481. active_links = [link for link in active_links if link.node.name == flow.winner_name]
  482. if not active_direct_names and not active_links:
  483. if direct_names:
  484. active_direct_names = [direct_names[0]]
  485. elif links:
  486. active_links = links[:1]
  487. direct_copies = self._udp_direct_copies()
  488. relay_copies = self._udp_relay_copies()
  489. sent_any = False
  490. for attempt in range(max(direct_copies, relay_copies)):
  491. for path_name in active_direct_names if attempt < direct_copies else ():
  492. sock = flow.direct_sockets.get(path_name)
  493. if sock is None:
  494. continue
  495. try:
  496. flow.direct_pending_clients.setdefault(path_name, deque()).append((packet_id, client_addr))
  497. await asyncio.get_running_loop().sock_sendall(sock, payload)
  498. sent_any = True
  499. except Exception as exc:
  500. pending = flow.direct_pending_clients.get(path_name)
  501. if pending:
  502. with contextlib.suppress(Exception):
  503. pending.pop()
  504. flow.direct_failures.add(path_name)
  505. flow.direct_sockets.pop(path_name, None)
  506. task = flow.direct_tasks.pop(path_name, None)
  507. if task is not None:
  508. task.cancel()
  509. with contextlib.suppress(Exception):
  510. sock.close()
  511. flow.relay_failures[path_name] = flow.relay_failures.get(path_name, 0) + 1
  512. if path_name not in flow.relay_error_seen:
  513. flow.relay_error_seen.add(path_name)
  514. print(f"[edge] udp relay error flow={flow.flow_id} relay={path_name} error={exc!r}")
  515. for link in active_links if attempt < relay_copies else ():
  516. stream_id = flow.link_streams.get(link.node.name)
  517. if stream_id is None:
  518. stream_id = next(self.udp_stream_ids)
  519. flow.link_streams[link.node.name] = stream_id
  520. self.udp_flow_sessions[(flow.flow_id, stream_id)] = flow
  521. include_meta = link.node.name not in flow.initialized_links
  522. body = (meta + payload) if include_meta else payload
  523. meta_len = len(meta) if include_meta else 0
  524. try:
  525. await link.send(Frame(UDP_SEND, flow.flow_id, stream_id, 0, meta_len, body))
  526. flow.initialized_links.add(link.node.name)
  527. sent_any = True
  528. except Exception as exc:
  529. flow.link_streams.pop(link.node.name, None)
  530. self.udp_flow_sessions.pop((flow.flow_id, stream_id), None)
  531. flow.relay_failures[link.node.name] = flow.relay_failures.get(link.node.name, 0) + 1
  532. if link.node.name not in flow.relay_error_seen:
  533. flow.relay_error_seen.add(link.node.name)
  534. print(f"[edge] udp relay error flow={flow.flow_id} relay={link.node.name} error={exc!r}")
  535. if attempt + 1 < max(direct_copies, relay_copies) and self.config.udp_copy_interval_ms > 0:
  536. await asyncio.sleep(self.config.udp_copy_interval_ms / 1000)
  537. if not sent_any:
  538. udp_server.note_unsent(flow, packet_id)
  539. async def _handshake(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer) -> tuple[str, int, bool]:
  540. version, methods_len = (await read_exact(reader, 2))
  541. if version != SOCKS_VERSION:
  542. raise ValueError("unsupported socks version")
  543. await read_exact(reader, methods_len)
  544. writer.write(b"\x05\x00")
  545. await writer.drain()
  546. version, command, _, atyp = await read_exact(reader, 4)
  547. if version != SOCKS_VERSION:
  548. raise ValueError("unsupported socks version")
  549. if atyp == 1:
  550. host = socket.inet_ntoa(await read_exact(reader, 4))
  551. elif atyp == 3:
  552. size = (await read_exact(reader, 1))[0]
  553. host = (await read_exact(reader, size)).decode()
  554. else:
  555. raise ValueError("unsupported atyp")
  556. port = struct.unpack("!H", await read_exact(reader, 2))[0]
  557. peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer)
  558. if command == 3 and self.udp_server and self.udp_server.transport:
  559. bind_host, bind_port = self.udp_server.transport.get_extra_info("sockname")[:2]
  560. self.udp_server.register_associate(peer)
  561. print(f"[edge] socks handshake peer={peer_text} command=udp_associate target={host}:{port} bind={bind_host}:{bind_port}")
  562. writer.write(b"\x05\x00\x00\x01" + socket.inet_aton(bind_host) + struct.pack("!H", bind_port))
  563. await writer.drain()
  564. return host, port, True
  565. raise ValueError("unsupported socks command")