from __future__ import annotations import asyncio import contextlib import json import random import socket from dataclasses import dataclass import time from typing import Awaitable, Callable, Dict from .config_udp import UdpConfig, UdpRelayNode from .logging_utils import log_print as print from .protocol import AUTH, PING, PONG, STATUS_OK, TCP_CLOSE, Frame, encode_json, read_frame, write_frame from .scheduler_udp import UdpScheduler FrameHandler = Callable[["UdpRelayConnection", Frame], Awaitable[None]] @dataclass class UdpRelayConnection: node: UdpRelayNode manager: "UdpRelayManager" reader: asyncio.StreamReader writer: asyncio.StreamWriter closed: bool = False handlers: Dict[tuple[int, int], FrameHandler] = None dispatch_tasks: Dict[tuple[int, int], asyncio.Task] = None pump_task: asyncio.Task | None = None keepalive_task: asyncio.Task | None = None last_pong_at: float = 0.0 send_lock: asyncio.Lock | None = None closed_event: asyncio.Event | None = None dropped_frames: Dict[int, int] = None dropped_report_task: asyncio.Task | None = None def __post_init__(self) -> None: if self.handlers is None: self.handlers = {} if self.dispatch_tasks is None: self.dispatch_tasks = {} if self.send_lock is None: self.send_lock = asyncio.Lock() if self.closed_event is None: self.closed_event = asyncio.Event() if self.dropped_frames is None: self.dropped_frames = {} async def start(self) -> None: print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}") await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token}))) frame = await read_frame(self.reader) if frame.kind != AUTH or frame.packet_id != STATUS_OK: raise ConnectionError(f"relay auth failed: {self.node.name}") if frame.payload: with contextlib.suppress(Exception): json.loads(frame.payload.decode("utf-8")) self.last_pong_at = time.monotonic() self.keepalive_task = asyncio.create_task(self._keepalive()) self.pump_task = asyncio.create_task(self._pump()) print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port} mode=udp") async def _keepalive(self) -> None: try: while not self.closed: await asyncio.sleep(self.manager.config.relay_ping_interval) if self.closed: break if self.last_pong_at and time.monotonic() - self.last_pong_at > (self.manager.config.relay_ping_interval + self.manager.config.relay_ping_timeout): print(f"[edge] relay health timeout name={self.node.name} addr={self.node.host}:{self.node.port} timeout={self.manager.config.relay_ping_timeout}") await self.close() break await self.send(Frame(PING, 0, 0, 0, 0, b"")) except asyncio.CancelledError: pass except Exception: await self.close() async def _pump(self) -> None: try: while True: frame = await read_frame(self.reader) if frame.kind == PONG: self.last_pong_at = time.monotonic() continue handler = self.handlers.get((frame.session_id, frame.stream_id)) if handler: self._dispatch_frame(frame, handler) else: self._record_dropped_frame(frame.kind) except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, OSError): pass except Exception: pass finally: await self.close() def _dispatch_frame(self, frame: Frame, handler: FrameHandler) -> None: key = (frame.session_id, frame.stream_id) previous = self.dispatch_tasks.get(key) task = asyncio.create_task(self._run_handler(key, frame, handler, previous)) self.dispatch_tasks[key] = task async def _run_handler(self, key: tuple[int, int], frame: Frame, handler: FrameHandler, previous: asyncio.Task | None) -> None: try: if previous is not None: with contextlib.suppress(Exception): await previous if self.closed: return await handler(self, frame) except asyncio.CancelledError: pass except Exception: if not self.closed: await self.close() finally: if self.dispatch_tasks.get(key) is asyncio.current_task(): self.dispatch_tasks.pop(key, None) def _record_dropped_frame(self, kind: int) -> None: self.dropped_frames[kind] = self.dropped_frames.get(kind, 0) + 1 if self.dropped_report_task is None or self.dropped_report_task.done(): self.dropped_report_task = asyncio.create_task(self._report_dropped_frames()) async def _report_dropped_frames(self) -> None: try: await asyncio.sleep(5) dropped = self.dropped_frames self.dropped_frames = {} if dropped: detail = ", ".join(f"kind={kind} count={count}" for kind, count in sorted(dropped.items())) print(f"[edge] relay frame dropped summary name={self.node.name} {detail}") except asyncio.CancelledError: pass async def send(self, frame: Frame) -> None: if self.closed: raise ConnectionError(f"relay closed: {self.node.name}") assert self.send_lock is not None async with self.send_lock: if self.closed: raise ConnectionError(f"relay closed: {self.node.name}") await write_frame(self.writer, frame) def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None: self.handlers[(session_id, stream_id)] = handler def unbind(self, session_id: int, stream_id: int) -> None: self.handlers.pop((session_id, stream_id), None) task = self.dispatch_tasks.pop((session_id, stream_id), None) if task is not None: task.cancel() async def close(self) -> None: if self.closed: return self.closed = True assert self.closed_event is not None self.closed_event.set() handlers = list(self.handlers.items()) self.handlers.clear() dispatch_tasks = list(self.dispatch_tasks.values()) self.dispatch_tasks.clear() self.manager.on_closed(self) for (session_id, stream_id), handler in handlers: with contextlib.suppress(Exception): await handler(self, Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b"")) for task in dispatch_tasks: task.cancel() for task in dispatch_tasks: with contextlib.suppress(Exception): await task if self.dropped_report_task and self.dropped_report_task is not asyncio.current_task(): self.dropped_report_task.cancel() with contextlib.suppress(Exception): await self.dropped_report_task if self.pump_task and self.pump_task is not asyncio.current_task(): self.pump_task.cancel() with contextlib.suppress(Exception): await self.pump_task if self.keepalive_task and self.keepalive_task is not asyncio.current_task(): self.keepalive_task.cancel() with contextlib.suppress(Exception): await self.keepalive_task self.writer.close() with contextlib.suppress(Exception): await self.writer.wait_closed() class UdpRelayManager: def __init__(self, config: UdpConfig) -> None: self.config = config self.scheduler = UdpScheduler(config) self.connections: Dict[str, UdpRelayConnection] = {} self.tasks: list[asyncio.Task] = [] async def start(self) -> None: await self.scheduler.start() for node in self.config.relays: self.tasks.append(asyncio.create_task(self._maintain(node))) async def _maintain(self, node: UdpRelayNode) -> None: backoff = self.config.relay_reconnect_delay while True: current = self.connections.get(node.name) if current is not None and not current.closed: assert current.closed_event is not None await current.closed_event.wait() continue attempt = 1 while True: try: print(f"[edge] relay reconnect attempt name={node.name} addr={node.host}:{node.port} attempt={attempt} backoff={backoff:.1f}s") reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=self.config.relay_open_timeout) connection = UdpRelayConnection(node=node, manager=self, reader=reader, writer=writer) sock = writer.get_extra_info("socket") if sock is not None and self.config.relay_tcp_nodelay: with contextlib.suppress(OSError): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) await connection.start() self.connections[node.name] = connection backoff = self.config.relay_reconnect_delay assert connection.closed_event is not None await connection.closed_event.wait() print(f"[edge] relay supervisor noticed close name={node.name} addr={node.host}:{node.port}") break except asyncio.CancelledError: raise except Exception as exc: print(f"[edge] relay connect failed name={node.name} addr={node.host}:{node.port} attempt={attempt} error={exc!r}") jitter = random.uniform(0, min(1.0, backoff * 0.2)) await asyncio.sleep(backoff + jitter) backoff = min(self.config.relay_reconnect_max_delay, max(self.config.relay_reconnect_delay, backoff * 2)) attempt += 1 def on_closed(self, connection: UdpRelayConnection) -> None: current = self.connections.get(connection.node.name) if current is connection: self.connections.pop(connection.node.name, None) def available_udp(self) -> list[UdpRelayConnection]: chosen = {node.name for node in self.scheduler.choose()} preferred = [self.connections[name] for name in chosen if name in self.connections and not self.connections[name].closed] if preferred: return preferred return [conn for conn in self.connections.values() if not conn.closed] def snapshot(self) -> list[dict[str, object]]: data = self.scheduler.snapshot() online = {name for name, conn in self.connections.items() if not conn.closed} for item in data: item["online"] = item["name"] in online return data