relay_server_udp.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. from dataclasses import dataclass, field
  5. from .logging_utils import log_print as print
  6. from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame
  7. class UdpRelayProtocol(asyncio.DatagramProtocol):
  8. def __init__(self, channel: "UdpRelayChannel", session_id: int, stream_id: int) -> None:
  9. self.channel = channel
  10. self.session_id = session_id
  11. self.stream_id = stream_id
  12. def datagram_received(self, data: bytes, _addr) -> None:
  13. self.channel.log_udp_reply(self.session_id, self.stream_id, len(data))
  14. self.channel.enqueue_send(Frame(UDP_RECV, self.session_id, self.stream_id, 0, 0, data))
  15. @dataclass
  16. class UdpRelaySession:
  17. session_id: int
  18. stream_id: int
  19. transport: asyncio.DatagramTransport | None = None
  20. host: str = ""
  21. port: int = 0
  22. family: int = 0
  23. @dataclass
  24. class UdpRelayChannel:
  25. reader: asyncio.StreamReader
  26. writer: asyncio.StreamWriter
  27. token: str
  28. udp_sessions: dict[tuple[int, int], UdpRelaySession] = field(default_factory=dict)
  29. closed: bool = False
  30. send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  31. send_queue: asyncio.Queue[Frame | None] = field(default_factory=asyncio.Queue)
  32. send_task: asyncio.Task | None = None
  33. _logged_sessions: set[tuple[int, int]] = field(default_factory=set)
  34. async def run(self) -> None:
  35. try:
  36. self.send_task = asyncio.create_task(self._send_loop())
  37. auth = await read_frame(self.reader)
  38. if auth.kind != AUTH:
  39. return
  40. payload = decode_json(auth.payload) if auth.payload else {}
  41. if payload.get("token") != self.token:
  42. return
  43. await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json({"status": "ok", "kind": payload.get("purpose", "normal"), "udp_only": True})))
  44. while True:
  45. frame = await read_frame(self.reader)
  46. await self.handle(frame)
  47. finally:
  48. await self.close()
  49. def enqueue_send(self, frame: Frame) -> None:
  50. if self.closed:
  51. return
  52. with contextlib.suppress(asyncio.QueueFull):
  53. self.send_queue.put_nowait(frame)
  54. async def _send_loop(self) -> None:
  55. try:
  56. while True:
  57. frame = await self.send_queue.get()
  58. if frame is None:
  59. break
  60. ok = await self.safe_send(frame)
  61. if not ok:
  62. break
  63. except asyncio.CancelledError:
  64. pass
  65. async def safe_send(self, frame: Frame) -> bool:
  66. if self.closed:
  67. return False
  68. try:
  69. async with self.send_lock:
  70. if self.closed:
  71. return False
  72. await write_frame(self.writer, frame)
  73. return True
  74. except Exception:
  75. return False
  76. async def handle(self, frame: Frame) -> None:
  77. key = (frame.session_id, frame.stream_id)
  78. if frame.kind == PING:
  79. await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong"))
  80. return
  81. if frame.kind != UDP_SEND:
  82. return
  83. session = self.udp_sessions.get(key)
  84. meta = None
  85. payload = frame.payload
  86. if frame.packet_id > 0 and frame.packet_id <= len(frame.payload):
  87. try:
  88. meta = decode_json(frame.payload[: frame.packet_id])
  89. payload = frame.payload[frame.packet_id :]
  90. except Exception:
  91. if session is None:
  92. return
  93. if session is None:
  94. if meta is None:
  95. return
  96. try:
  97. family = int(meta.get("family", 0)) or 0
  98. transport, _protocol = await asyncio.get_running_loop().create_datagram_endpoint(
  99. lambda: UdpRelayProtocol(self, frame.session_id, frame.stream_id),
  100. remote_addr=(meta["host"], int(meta["port"])),
  101. family=family or 0,
  102. )
  103. session = UdpRelaySession(frame.session_id, frame.stream_id, transport, meta["host"], int(meta["port"]), family)
  104. self.udp_sessions[key] = session
  105. print(f"[relay] udp session opened session={frame.session_id} stream={frame.stream_id} target={meta['host']}:{int(meta['port'])}")
  106. except Exception as exc:
  107. await self.safe_send(Frame(UDP_RECV, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  108. return
  109. if session.transport is not None:
  110. with contextlib.suppress(Exception):
  111. session.transport.sendto(payload)
  112. def log_udp_reply(self, session_id: int, stream_id: int, size: int) -> None:
  113. key = (session_id, stream_id)
  114. if key in self._logged_sessions:
  115. return
  116. self._logged_sessions.add(key)
  117. print(f"[relay] udp reply session={session_id} stream={stream_id} bytes={size}")
  118. async def close(self) -> None:
  119. if self.closed:
  120. return
  121. self.closed = True
  122. self.send_queue.put_nowait(None)
  123. for session in self.udp_sessions.values():
  124. if session.transport:
  125. session.transport.close()
  126. self.udp_sessions.clear()
  127. if self.send_task and self.send_task is not asyncio.current_task():
  128. self.send_task.cancel()
  129. with contextlib.suppress(Exception):
  130. await self.send_task
  131. self.writer.close()
  132. with contextlib.suppress(Exception):
  133. await self.writer.wait_closed()
  134. class UdpRelayServer:
  135. def __init__(self, token: str) -> None:
  136. self.token = token
  137. async def start(self, host: str, port: int) -> None:
  138. async def accept(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  139. try:
  140. await UdpRelayChannel(reader, writer, self.token).run()
  141. except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, OSError):
  142. pass
  143. print(f"[relay] udp server listening on {host}:{port}")
  144. server = await asyncio.start_server(accept, host, port)
  145. async with server:
  146. await server.serve_forever()