from __future__ import annotations import asyncio import contextlib from dataclasses import dataclass from typing import Awaitable, Callable, Dict from .config import Config, RelayNode from .protocol import AUTH, STATUS_OK, Frame, encode_json, read_frame, write_frame from .scheduler import Scheduler FrameHandler = Callable[["RelayConnection", Frame], Awaitable[None]] @dataclass class RelayConnection: node: RelayNode manager: "RelayManager" reader: asyncio.StreamReader writer: asyncio.StreamWriter closed: bool = False handlers: Dict[tuple[int, int], FrameHandler] = None pump_task: asyncio.Task | None = None def __post_init__(self) -> None: if self.handlers is None: self.handlers = {} async def start(self) -> None: print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}") await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token}))) frame = await read_frame(self.reader) if frame.kind != AUTH or frame.packet_id != STATUS_OK: raise ConnectionError(f"relay auth failed: {self.node.name}") print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port}") self.pump_task = asyncio.create_task(self._pump()) async def _pump(self) -> None: try: while True: frame = await read_frame(self.reader) handler = self.handlers.get((frame.session_id, frame.stream_id)) if handler: await handler(self, frame) except asyncio.IncompleteReadError: print(f"[edge] relay disconnected name={self.node.name} eof=true") except Exception as exc: print(f"[edge] relay pump error name={self.node.name} error={exc!r}") finally: await self.close() async def send(self, frame: Frame) -> None: if self.closed: raise ConnectionError(f"relay closed: {self.node.name}") await write_frame(self.writer, frame) def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None: self.handlers[(session_id, stream_id)] = handler def unbind(self, session_id: int, stream_id: int) -> None: self.handlers.pop((session_id, stream_id), None) async def close(self) -> None: if self.closed: return self.closed = True self.manager.on_closed(self) self.writer.close() with contextlib.suppress(Exception): await self.writer.wait_closed() class RelayManager: def __init__(self, config: Config) -> None: self.config = config self.scheduler = Scheduler(config) self.connections: Dict[str, RelayConnection] = {} self.tasks: list[asyncio.Task] = [] async def start(self) -> None: await self.scheduler.start() for node in self.config.relays: self.tasks.append(asyncio.create_task(self._maintain(node))) async def _maintain(self, node: RelayNode) -> None: while True: if node.name in self.connections and not self.connections[node.name].closed: await asyncio.sleep(2) continue try: reader, writer = await asyncio.open_connection(node.host, node.port) connection = RelayConnection(node=node, manager=self, reader=reader, writer=writer) await connection.start() self.connections[node.name] = connection await connection.pump_task except Exception as exc: print(f"[edge] relay connect failed name={node.name} addr={node.host}:{node.port} error={exc!r}") await asyncio.sleep(3) def on_closed(self, connection: RelayConnection) -> None: current = self.connections.get(connection.node.name) if current is connection: self.connections.pop(connection.node.name, None) def available(self) -> list[RelayConnection]: chosen = {node.name for node in self.scheduler.choose()} preferred = [self.connections[name] for name in chosen if name in self.connections and not self.connections[name].closed] if preferred: return preferred return [conn for conn in self.connections.values() if not conn.closed] def snapshot(self) -> list[dict[str, object]]: data = self.scheduler.snapshot() online = {name for name, conn in self.connections.items() if not conn.closed} for item in data: item["online"] = item["name"] in online return data