from __future__ import annotations import asyncio import contextlib import socket from dataclasses import dataclass import time from typing import Awaitable, Callable, Dict from .config import Config, RelayNode from .protocol import AUTH, PING, PONG, STATUS_OK, Frame, encode_json, read_frame, write_frame from .scheduler import Scheduler FrameHandler = Callable[["RelayConnection", Frame], Awaitable[None]] @dataclass class RelayConnection: node: RelayNode manager: "RelayManager" reader: asyncio.StreamReader writer: asyncio.StreamWriter closed: bool = False handlers: Dict[tuple[int, int], FrameHandler] = None pump_task: asyncio.Task | None = None keepalive_task: asyncio.Task | None = None last_pong_at: float = 0.0 def __post_init__(self) -> None: if self.handlers is None: self.handlers = {} async def start(self) -> None: print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}") await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token}))) frame = await read_frame(self.reader) if frame.kind != AUTH or frame.packet_id != STATUS_OK: raise ConnectionError(f"relay auth failed: {self.node.name}") print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port}") self.last_pong_at = time.monotonic() self.keepalive_task = asyncio.create_task(self._keepalive()) self.pump_task = asyncio.create_task(self._pump()) async def _keepalive(self) -> None: try: while not self.closed: await asyncio.sleep(self.manager.config.relay_ping_interval) if self.closed: break if time.monotonic() - self.last_pong_at > self.manager.config.relay_ping_timeout: print(f"[edge] relay health timeout name={self.node.name} addr={self.node.host}:{self.node.port} timeout={self.manager.config.relay_ping_timeout}") await self.close() break await self.send(Frame(PING, 0, 0, 0, 0, b"")) except asyncio.CancelledError: pass except Exception: await self.close() async def _pump(self) -> None: try: while True: frame = await read_frame(self.reader) if frame.kind == PONG: self.last_pong_at = time.monotonic() continue handler = self.handlers.get((frame.session_id, frame.stream_id)) if handler: await handler(self, frame) except asyncio.IncompleteReadError: print(f"[edge] relay disconnected name={self.node.name} eof=true") except Exception as exc: print(f"[edge] relay pump error name={self.node.name} error={exc!r}") finally: await self.close() async def send(self, frame: Frame) -> None: if self.closed: raise ConnectionError(f"relay closed: {self.node.name}") await write_frame(self.writer, frame) def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None: self.handlers[(session_id, stream_id)] = handler def unbind(self, session_id: int, stream_id: int) -> None: self.handlers.pop((session_id, stream_id), None) async def close(self) -> None: if self.closed: return self.closed = True self.manager.on_closed(self) if self.keepalive_task and self.keepalive_task is not asyncio.current_task(): self.keepalive_task.cancel() with contextlib.suppress(Exception): await self.keepalive_task self.writer.close() with contextlib.suppress(Exception): await self.writer.wait_closed() class RelayManager: def __init__(self, config: Config) -> None: self.config = config self.scheduler = Scheduler(config) self.connections: Dict[str, RelayConnection] = {} self.tasks: list[asyncio.Task] = [] async def start(self) -> None: await self.scheduler.start() for node in self.config.relays: self.tasks.append(asyncio.create_task(self._maintain(node))) async def _maintain(self, node: RelayNode) -> None: while True: if node.name in self.connections and not self.connections[node.name].closed: await asyncio.sleep(2) continue connected = False for attempt in range(1, self.config.relay_reconnect_attempts + 1): try: print(f"[edge] relay reconnect attempt name={node.name} addr={node.host}:{node.port} attempt={attempt}/{self.config.relay_reconnect_attempts}") reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=self.config.relay_open_timeout) connection = RelayConnection(node=node, manager=self, reader=reader, writer=writer) sock = writer.get_extra_info("socket") if sock is not None and self.config.relay_tcp_nodelay: with contextlib.suppress(OSError): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) await connection.start() self.connections[node.name] = connection connected = True await connection.pump_task break except Exception as exc: print(f"[edge] relay connect failed name={node.name} addr={node.host}:{node.port} attempt={attempt}/{self.config.relay_reconnect_attempts} error={exc!r}") if attempt < self.config.relay_reconnect_attempts: await asyncio.sleep(self.config.relay_reconnect_delay) if not connected: print(f"[edge] relay reconnect exhausted name={node.name} addr={node.host}:{node.port} attempts={self.config.relay_reconnect_attempts}") await asyncio.sleep(self.config.relay_reconnect_delay) def on_closed(self, connection: RelayConnection) -> None: current = self.connections.get(connection.node.name) if current is connection: self.connections.pop(connection.node.name, None) def available(self) -> list[RelayConnection]: chosen = {node.name for node in self.scheduler.choose()} preferred = [self.connections[name] for name in chosen if name in self.connections and not self.connections[name].closed] if preferred: return preferred return [conn for conn in self.connections.values() if not conn.closed] def snapshot(self) -> list[dict[str, object]]: data = self.scheduler.snapshot() online = {name for name, conn in self.connections.items() if not conn.closed} for item in data: item["online"] = item["name"] in online return data