from __future__ import annotations import asyncio import contextlib import time from dataclasses import dataclass, field from typing import Dict from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, read_frame, write_frame @dataclass class TcpSession: session_id: int stream_id: int writer: asyncio.StreamWriter task: asyncio.Task @dataclass class UdpSession: session_id: int stream_id: int transport: asyncio.DatagramTransport | None = None protocol: "RelayUdpProtocol | None" = None host: str = "" port: int = 0 family: int = 0 class RelayUdpProtocol(asyncio.DatagramProtocol): def __init__(self, channel: "RelayChannel", session_id: int, stream_id: int) -> None: self.channel = channel self.session_id = session_id self.stream_id = stream_id def datagram_received(self, data: bytes, _addr) -> None: if self.channel.closed: return asyncio.create_task(self.channel.safe_send(Frame(UDP_RECV, self.session_id, self.stream_id, 0, 0, data))) @dataclass class RelayChannel: reader: asyncio.StreamReader writer: asyncio.StreamWriter token: str tcp_sessions: Dict[tuple[int, int], TcpSession] = field(default_factory=dict) udp_sessions: Dict[tuple[int, int], UdpSession] = field(default_factory=dict) closed: bool = False send_lock: asyncio.Lock = field(default_factory=asyncio.Lock) authed_at: float = 0.0 frame_count: int = 0 authed_kind: str = "normal" async def run(self) -> None: peer = self.writer.get_extra_info("peername") authed = False try: auth = await read_frame(self.reader) if auth.kind != AUTH: raise PermissionError("invalid handshake kind") try: payload = decode_json(auth.payload) if auth.payload else {} except Exception as exc: raise PermissionError(f"invalid auth payload: {exc!r}") from exc if payload.get("token") != self.token: raise PermissionError("invalid token") authed = True self.authed_at = time.monotonic() self.authed_kind = payload.get("purpose", "normal") ack_payload = {"status": "ok", "kind": self.authed_kind} await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json(ack_payload))) while True: frame = await read_frame(self.reader) self.frame_count += 1 await self.handle(frame) except asyncio.IncompleteReadError: if authed and self.authed_kind != "probe": lived = time.monotonic() - self.authed_at if self.authed_at else 0.0 if lived >= 15 or self.frame_count > 20: print(f"[relay] session closed peer={peer} kind={self.authed_kind} lived={lived:.1f}s frames={self.frame_count}") except asyncio.CancelledError: pass except Exception as exc: if authed and self.authed_kind != "probe": lived = time.monotonic() - self.authed_at if self.authed_at else 0.0 print(f"[relay] session error peer={peer} kind={self.authed_kind} lived={lived:.1f}s frames={self.frame_count} error={exc!r}") finally: await self.close() async def safe_send(self, frame: Frame) -> bool: if self.closed: return False try: async with self.send_lock: if self.closed: return False await write_frame(self.writer, frame) return True except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError): return False async def handle(self, frame: Frame) -> None: key = (frame.session_id, frame.stream_id) if frame.kind == PING: await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong")) return if frame.kind == TCP_OPEN: meta = decode_json(frame.payload) family = int(meta.get("family", 0)) or 0 try: reader, writer = await asyncio.open_connection(meta["host"], int(meta["port"]), family=family or 0) task = asyncio.create_task(self._tcp_pump(frame.session_id, frame.stream_id, reader)) self.tcp_sessions[key] = TcpSession(frame.session_id, frame.stream_id, writer, task) await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_OK, b"ok")) except Exception as exc: await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode())) return if frame.kind == TCP_DATA: session = self.tcp_sessions.get(key) if session: try: session.writer.write(frame.payload) await session.writer.drain() except Exception: await self._close_tcp(key) return if frame.kind == TCP_CLOSE: await self._close_tcp(key) return if frame.kind == UDP_SEND: session = self.udp_sessions.get(key) meta = None payload = frame.payload if frame.packet_id > 0: meta = decode_json(frame.payload[: frame.packet_id]) payload = frame.payload[frame.packet_id :] if session is None: if meta is None: return family = int(meta.get("family", 0)) or 0 transport, protocol = await asyncio.get_running_loop().create_datagram_endpoint( lambda: RelayUdpProtocol(self, frame.session_id, frame.stream_id), remote_addr=(meta["host"], int(meta["port"])), family=family or 0, ) session = UdpSession(frame.session_id, frame.stream_id, transport, protocol, meta["host"], int(meta["port"]), family) self.udp_sessions[key] = session with contextlib.suppress(Exception): session.transport.sendto(payload) return async def _tcp_pump(self, session_id: int, stream_id: int, reader: asyncio.StreamReader) -> None: try: while True: chunk = await reader.read(65536) if not chunk: break sent = await self.safe_send(Frame(TCP_DATA, session_id, stream_id, 0, 0, chunk)) if not sent: break except asyncio.CancelledError: pass except Exception: pass finally: if not self.closed: await self.safe_send(Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b"")) await self._close_tcp((session_id, stream_id), from_task=True) async def _close_tcp(self, key: tuple[int, int], from_task: bool = False) -> None: session = self.tcp_sessions.pop(key, None) if session is None: return if not from_task and session.task is not asyncio.current_task(): session.task.cancel() with contextlib.suppress(Exception): await session.task session.writer.close() with contextlib.suppress(Exception): await session.writer.wait_closed() async def close(self) -> None: if self.closed: return self.closed = True for key in list(self.tcp_sessions): await self._close_tcp(key) for session in self.udp_sessions.values(): if session.transport: session.transport.close() self.udp_sessions.clear() self.writer.close() with contextlib.suppress(Exception): await self.writer.wait_closed() class RelayServer: def __init__(self, token: str) -> None: self.token = token async def start(self, host: str, port: int) -> None: server = await asyncio.start_server(self._accept, host, port) sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or []) print(f"[relay] listening on {sockets}") async with server: await server.serve_forever() async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: await RelayChannel(reader, writer, self.token).run()