from __future__ import annotations import asyncio import contextlib import random import socket from dataclasses import dataclass import time from typing import Awaitable, Callable, Dict from .config import Config, RelayNode from .protocol import AUTH, PING, PONG, STATUS_OK, TCP_CLOSE, Frame, encode_json, read_frame, write_frame from .scheduler import Scheduler FrameHandler = Callable[["RelayConnection", Frame], Awaitable[None]] @dataclass class RelayConnection: node: RelayNode manager: "RelayManager" reader: asyncio.StreamReader writer: asyncio.StreamWriter closed: bool = False handlers: Dict[tuple[int, int], FrameHandler] = None dispatch_tasks: Dict[tuple[int, int], asyncio.Task] = None pump_task: asyncio.Task | None = None keepalive_task: asyncio.Task | None = None last_pong_at: float = 0.0 send_lock: asyncio.Lock | None = None closed_event: asyncio.Event | None = None dropped_frames: Dict[int, int] = None dropped_report_task: asyncio.Task | None = None def __post_init__(self) -> None: if self.handlers is None: self.handlers = {} if self.dispatch_tasks is None: self.dispatch_tasks = {} if self.send_lock is None: self.send_lock = asyncio.Lock() if self.closed_event is None: self.closed_event = asyncio.Event() if self.dropped_frames is None: self.dropped_frames = {} async def start(self) -> None: print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}") await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token}))) frame = await read_frame(self.reader) if frame.kind != AUTH or frame.packet_id != STATUS_OK: raise ConnectionError(f"relay auth failed: {self.node.name}") self.last_pong_at = time.monotonic() self.keepalive_task = asyncio.create_task(self._keepalive()) self.pump_task = asyncio.create_task(self._pump()) print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port}") async def _keepalive(self) -> None: try: while not self.closed: await asyncio.sleep(self.manager.config.relay_ping_interval) if self.closed: break if self.last_pong_at and time.monotonic() - self.last_pong_at > (self.manager.config.relay_ping_interval + self.manager.config.relay_ping_timeout): print(f"[edge] relay health timeout name={self.node.name} addr={self.node.host}:{self.node.port} timeout={self.manager.config.relay_ping_timeout}") await self.close() break await self.send(Frame(PING, 0, 0, 0, 0, b"")) except asyncio.CancelledError: pass except Exception: await self.close() async def _pump(self) -> None: try: while True: frame = await read_frame(self.reader) if frame.kind == PONG: self.last_pong_at = time.monotonic() continue handler = self.handlers.get((frame.session_id, frame.stream_id)) if handler: self._dispatch_frame(frame, handler) else: self._record_dropped_frame(frame.kind) except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, OSError): pass except Exception: pass finally: await self.close() def _dispatch_frame(self, frame: Frame, handler: FrameHandler) -> None: key = (frame.session_id, frame.stream_id) previous = self.dispatch_tasks.get(key) task = asyncio.create_task(self._run_handler(key, frame, handler, previous)) self.dispatch_tasks[key] = task async def _run_handler(self, key: tuple[int, int], frame: Frame, handler: FrameHandler, previous: asyncio.Task | None) -> None: try: if previous is not None: with contextlib.suppress(Exception): await previous if self.closed: return await handler(self, frame) except asyncio.CancelledError: pass except Exception: if not self.closed: await self.close() finally: if self.dispatch_tasks.get(key) is asyncio.current_task(): self.dispatch_tasks.pop(key, None) def _record_dropped_frame(self, kind: int) -> None: self.dropped_frames[kind] = self.dropped_frames.get(kind, 0) + 1 if self.dropped_report_task is None or self.dropped_report_task.done(): self.dropped_report_task = asyncio.create_task(self._report_dropped_frames()) async def _report_dropped_frames(self) -> None: try: await asyncio.sleep(5) dropped = self.dropped_frames self.dropped_frames = {} if dropped: detail = ", ".join(f"kind={kind} count={count}" for kind, count in sorted(dropped.items())) print(f"[edge] relay frame dropped summary name={self.node.name} {detail}") except asyncio.CancelledError: pass async def send(self, frame: Frame) -> None: if self.closed: raise ConnectionError(f"relay closed: {self.node.name}") assert self.send_lock is not None async with self.send_lock: if self.closed: raise ConnectionError(f"relay closed: {self.node.name}") await write_frame(self.writer, frame) def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None: self.handlers[(session_id, stream_id)] = handler def unbind(self, session_id: int, stream_id: int) -> None: self.handlers.pop((session_id, stream_id), None) task = self.dispatch_tasks.pop((session_id, stream_id), None) if task is not None: task.cancel() async def close(self) -> None: if self.closed: return self.closed = True assert self.closed_event is not None self.closed_event.set() handlers = list(self.handlers.items()) self.handlers.clear() dispatch_tasks = list(self.dispatch_tasks.values()) self.dispatch_tasks.clear() self.manager.on_closed(self) for (session_id, stream_id), handler in handlers: with contextlib.suppress(Exception): await handler(self, Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b"")) for task in dispatch_tasks: task.cancel() for task in dispatch_tasks: with contextlib.suppress(Exception): await task if self.dropped_report_task and self.dropped_report_task is not asyncio.current_task(): self.dropped_report_task.cancel() with contextlib.suppress(Exception): await self.dropped_report_task if self.pump_task and self.pump_task is not asyncio.current_task(): self.pump_task.cancel() with contextlib.suppress(Exception): await self.pump_task if self.keepalive_task and self.keepalive_task is not asyncio.current_task(): self.keepalive_task.cancel() with contextlib.suppress(Exception): await self.keepalive_task self.writer.close() with contextlib.suppress(Exception): await self.writer.wait_closed() class RelayManager: def __init__(self, config: Config) -> None: self.config = config self.scheduler = Scheduler(config) self.connections: Dict[str, RelayConnection] = {} self.tasks: list[asyncio.Task] = [] async def start(self) -> None: await self.scheduler.start() for node in self.config.relays: self.tasks.append(asyncio.create_task(self._maintain(node))) async def _maintain(self, node: RelayNode) -> None: backoff = self.config.relay_reconnect_delay while True: current = self.connections.get(node.name) if current is not None and not current.closed: assert current.closed_event is not None await current.closed_event.wait() continue attempt = 1 while True: try: print(f"[edge] relay reconnect attempt name={node.name} addr={node.host}:{node.port} attempt={attempt} backoff={backoff:.1f}s") reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=self.config.relay_open_timeout) connection = RelayConnection(node=node, manager=self, reader=reader, writer=writer) sock = writer.get_extra_info("socket") if sock is not None and self.config.relay_tcp_nodelay: with contextlib.suppress(OSError): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) await connection.start() self.connections[node.name] = connection backoff = self.config.relay_reconnect_delay assert connection.closed_event is not None await connection.closed_event.wait() print(f"[edge] relay supervisor noticed close name={node.name} addr={node.host}:{node.port}") break except asyncio.CancelledError: raise except Exception as exc: print(f"[edge] relay connect failed name={node.name} addr={node.host}:{node.port} attempt={attempt} error={exc!r}") jitter = random.uniform(0, min(1.0, backoff * 0.2)) await asyncio.sleep(backoff + jitter) backoff = min(self.config.relay_reconnect_max_delay, max(self.config.relay_reconnect_delay, backoff * 2)) attempt += 1 def on_closed(self, connection: RelayConnection) -> None: current = self.connections.get(connection.node.name) if current is connection: self.connections.pop(connection.node.name, None) def available(self) -> list[RelayConnection]: chosen = {node.name for node in self.scheduler.choose()} preferred = [self.connections[name] for name in chosen if name in self.connections and not self.connections[name].closed] if preferred: return preferred return [conn for conn in self.connections.values() if not conn.closed] def snapshot(self) -> list[dict[str, object]]: data = self.scheduler.snapshot() online = {name for name, conn in self.connections.items() if not conn.closed} for item in data: item["online"] = item["name"] in online return data