relay_server_udp.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import time
  5. from dataclasses import dataclass, field
  6. from .logging_utils import log_print as print
  7. from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame
  8. class UdpRelayProtocol(asyncio.DatagramProtocol):
  9. def __init__(self, channel: "UdpRelayChannel", session_id: int, stream_id: int) -> None:
  10. self.channel = channel
  11. self.session_id = session_id
  12. self.stream_id = stream_id
  13. def datagram_received(self, data: bytes, _addr) -> None:
  14. self.channel.log_udp_reply(self.session_id, self.stream_id, len(data))
  15. self.channel.enqueue_send(Frame(UDP_RECV, self.session_id, self.stream_id, 0, 0, data))
  16. @dataclass
  17. class UdpRelaySession:
  18. session_id: int
  19. stream_id: int
  20. transport: asyncio.DatagramTransport | None = None
  21. host: str = ""
  22. port: int = 0
  23. family: int = 0
  24. last_activity: float = 0.0
  25. @dataclass
  26. class UdpRelayChannel:
  27. reader: asyncio.StreamReader
  28. writer: asyncio.StreamWriter
  29. token: str
  30. udp_sessions: dict[tuple[int, int], UdpRelaySession] = field(default_factory=dict)
  31. closed: bool = False
  32. send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  33. send_queue: asyncio.Queue[tuple[int, int] | None] = field(default_factory=lambda: asyncio.Queue(maxsize=1024))
  34. pending_frames: dict[tuple[int, int], Frame] = field(default_factory=dict)
  35. queued_keys: set[tuple[int, int]] = field(default_factory=set)
  36. send_task: asyncio.Task | None = None
  37. cleanup_task: asyncio.Task | None = None
  38. _logged_sessions: set[tuple[int, int]] = field(default_factory=set)
  39. async def run(self) -> None:
  40. try:
  41. self.send_task = asyncio.create_task(self._send_loop())
  42. self.cleanup_task = asyncio.create_task(self._cleanup_loop())
  43. auth = await read_frame(self.reader)
  44. if auth.kind != AUTH:
  45. return
  46. payload = decode_json(auth.payload) if auth.payload else {}
  47. if payload.get("token") != self.token:
  48. return
  49. await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json({"status": "ok", "kind": payload.get("purpose", "normal"), "udp_only": True})))
  50. while True:
  51. frame = await read_frame(self.reader)
  52. await self.handle(frame)
  53. finally:
  54. await self.close()
  55. def enqueue_send(self, frame: Frame) -> None:
  56. if self.closed:
  57. return
  58. key = (frame.session_id, frame.stream_id)
  59. self.pending_frames[key] = frame
  60. if key in self.queued_keys:
  61. return
  62. if self.send_queue.full():
  63. with contextlib.suppress(asyncio.QueueEmpty):
  64. dropped_key = self.send_queue.get_nowait()
  65. if dropped_key is not None:
  66. self.queued_keys.discard(dropped_key)
  67. self.pending_frames.pop(dropped_key, None)
  68. with contextlib.suppress(asyncio.QueueFull):
  69. self.send_queue.put_nowait(key)
  70. self.queued_keys.add(key)
  71. async def _send_loop(self) -> None:
  72. try:
  73. while True:
  74. key = await self.send_queue.get()
  75. if key is None:
  76. break
  77. self.queued_keys.discard(key)
  78. frame = self.pending_frames.pop(key, None)
  79. if frame is None:
  80. continue
  81. ok = await self.safe_send(frame)
  82. if not ok:
  83. break
  84. except asyncio.CancelledError:
  85. pass
  86. async def safe_send(self, frame: Frame) -> bool:
  87. if self.closed:
  88. return False
  89. try:
  90. async with self.send_lock:
  91. if self.closed:
  92. return False
  93. await write_frame(self.writer, frame)
  94. return True
  95. except Exception:
  96. return False
  97. async def handle(self, frame: Frame) -> None:
  98. key = (frame.session_id, frame.stream_id)
  99. if frame.kind == PING:
  100. await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong"))
  101. return
  102. if frame.kind != UDP_SEND:
  103. return
  104. session = self.udp_sessions.get(key)
  105. meta = None
  106. payload = frame.payload
  107. if frame.packet_id > 0 and frame.packet_id <= len(frame.payload):
  108. try:
  109. meta = decode_json(frame.payload[: frame.packet_id])
  110. payload = frame.payload[frame.packet_id :]
  111. except Exception:
  112. if session is None:
  113. return
  114. if session is None:
  115. if meta is None:
  116. return
  117. try:
  118. family = int(meta.get("family", 0)) or 0
  119. transport, _protocol = await asyncio.get_running_loop().create_datagram_endpoint(
  120. lambda: UdpRelayProtocol(self, frame.session_id, frame.stream_id),
  121. remote_addr=(meta["host"], int(meta["port"])),
  122. family=family or 0,
  123. )
  124. session = UdpRelaySession(frame.session_id, frame.stream_id, transport, meta["host"], int(meta["port"]), family)
  125. self.udp_sessions[key] = session
  126. print(f"[relay] udp session opened session={frame.session_id} stream={frame.stream_id} target={meta['host']}:{int(meta['port'])}")
  127. except Exception as exc:
  128. await self.safe_send(Frame(UDP_RECV, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  129. return
  130. if session.transport is not None:
  131. session.last_activity = time.monotonic()
  132. with contextlib.suppress(Exception):
  133. session.transport.sendto(payload)
  134. def log_udp_reply(self, session_id: int, stream_id: int, size: int) -> None:
  135. key = (session_id, stream_id)
  136. if key in self._logged_sessions:
  137. return
  138. self._logged_sessions.add(key)
  139. print(f"[relay] udp reply session={session_id} stream={stream_id} bytes={size}")
  140. async def _cleanup_loop(self) -> None:
  141. try:
  142. while True:
  143. await asyncio.sleep(30)
  144. if self.closed:
  145. return
  146. now = time.monotonic()
  147. expired = [
  148. key
  149. for key, session in self.udp_sessions.items()
  150. if session.last_activity and now - session.last_activity >= 120
  151. ]
  152. for key in expired:
  153. session = self.udp_sessions.pop(key, None)
  154. if session and session.transport:
  155. session.transport.close()
  156. except asyncio.CancelledError:
  157. pass
  158. async def close(self) -> None:
  159. if self.closed:
  160. return
  161. self.closed = True
  162. with contextlib.suppress(asyncio.QueueFull):
  163. self.send_queue.put_nowait(None)
  164. if self.cleanup_task and self.cleanup_task is not asyncio.current_task():
  165. self.cleanup_task.cancel()
  166. with contextlib.suppress(Exception):
  167. await self.cleanup_task
  168. self.pending_frames.clear()
  169. self.queued_keys.clear()
  170. for session in self.udp_sessions.values():
  171. if session.transport:
  172. session.transport.close()
  173. self.udp_sessions.clear()
  174. if self.send_task and self.send_task is not asyncio.current_task():
  175. self.send_task.cancel()
  176. with contextlib.suppress(Exception):
  177. await self.send_task
  178. self.writer.close()
  179. with contextlib.suppress(Exception):
  180. await self.writer.wait_closed()
  181. class UdpRelayServer:
  182. def __init__(self, token: str) -> None:
  183. self.token = token
  184. async def start(self, host: str, port: int) -> None:
  185. async def accept(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  186. try:
  187. await UdpRelayChannel(reader, writer, self.token).run()
  188. except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, OSError):
  189. pass
  190. print(f"[relay] udp server listening on {host}:{port}")
  191. server = await asyncio.start_server(accept, host, port)
  192. async with server:
  193. await server.serve_forever()