from __future__ import annotations import asyncio import contextlib from dataclasses import dataclass, field from .logging_utils import log_print as print from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame class UdpRelayProtocol(asyncio.DatagramProtocol): def __init__(self, channel: "UdpRelayChannel", session_id: int, stream_id: int) -> None: self.channel = channel self.session_id = session_id self.stream_id = stream_id def datagram_received(self, data: bytes, _addr) -> None: self.channel.log_udp_reply(self.session_id, self.stream_id, len(data)) self.channel.enqueue_send(Frame(UDP_RECV, self.session_id, self.stream_id, 0, 0, data)) @dataclass class UdpRelaySession: session_id: int stream_id: int transport: asyncio.DatagramTransport | None = None host: str = "" port: int = 0 family: int = 0 @dataclass class UdpRelayChannel: reader: asyncio.StreamReader writer: asyncio.StreamWriter token: str udp_sessions: dict[tuple[int, int], UdpRelaySession] = field(default_factory=dict) closed: bool = False send_lock: asyncio.Lock = field(default_factory=asyncio.Lock) send_queue: asyncio.Queue[Frame | None] = field(default_factory=asyncio.Queue) send_task: asyncio.Task | None = None _logged_sessions: set[tuple[int, int]] = field(default_factory=set) async def run(self) -> None: try: self.send_task = asyncio.create_task(self._send_loop()) auth = await read_frame(self.reader) if auth.kind != AUTH: return payload = decode_json(auth.payload) if auth.payload else {} if payload.get("token") != self.token: return await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json({"status": "ok", "kind": payload.get("purpose", "normal"), "udp_only": True}))) while True: frame = await read_frame(self.reader) await self.handle(frame) finally: await self.close() def enqueue_send(self, frame: Frame) -> None: if self.closed: return with contextlib.suppress(asyncio.QueueFull): self.send_queue.put_nowait(frame) async def _send_loop(self) -> None: try: while True: frame = await self.send_queue.get() if frame is None: break ok = await self.safe_send(frame) if not ok: break except asyncio.CancelledError: pass async def safe_send(self, frame: Frame) -> bool: if self.closed: return False try: async with self.send_lock: if self.closed: return False await write_frame(self.writer, frame) return True except Exception: return False async def handle(self, frame: Frame) -> None: key = (frame.session_id, frame.stream_id) if frame.kind == PING: await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong")) return if frame.kind != UDP_SEND: return session = self.udp_sessions.get(key) meta = None payload = frame.payload if frame.packet_id > 0 and frame.packet_id <= len(frame.payload): try: meta = decode_json(frame.payload[: frame.packet_id]) payload = frame.payload[frame.packet_id :] except Exception: if session is None: return if session is None: if meta is None: return try: family = int(meta.get("family", 0)) or 0 transport, _protocol = await asyncio.get_running_loop().create_datagram_endpoint( lambda: UdpRelayProtocol(self, frame.session_id, frame.stream_id), remote_addr=(meta["host"], int(meta["port"])), family=family or 0, ) session = UdpRelaySession(frame.session_id, frame.stream_id, transport, meta["host"], int(meta["port"]), family) self.udp_sessions[key] = session print(f"[relay] udp session opened session={frame.session_id} stream={frame.stream_id} target={meta['host']}:{int(meta['port'])}") except Exception as exc: await self.safe_send(Frame(UDP_RECV, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode())) return if session.transport is not None: with contextlib.suppress(Exception): session.transport.sendto(payload) def log_udp_reply(self, session_id: int, stream_id: int, size: int) -> None: key = (session_id, stream_id) if key in self._logged_sessions: return self._logged_sessions.add(key) print(f"[relay] udp reply session={session_id} stream={stream_id} bytes={size}") async def close(self) -> None: if self.closed: return self.closed = True self.send_queue.put_nowait(None) for session in self.udp_sessions.values(): if session.transport: session.transport.close() self.udp_sessions.clear() if self.send_task and self.send_task is not asyncio.current_task(): self.send_task.cancel() with contextlib.suppress(Exception): await self.send_task self.writer.close() with contextlib.suppress(Exception): await self.writer.wait_closed() class UdpRelayServer: def __init__(self, token: str) -> None: self.token = token async def start(self, host: str, port: int) -> None: async def accept(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: try: await UdpRelayChannel(reader, writer, self.token).run() except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, OSError): pass print(f"[relay] udp server listening on {host}:{port}") server = await asyncio.start_server(accept, host, port) async with server: await server.serve_forever()