| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- from __future__ import annotations
- import asyncio
- import contextlib
- from dataclasses import dataclass
- from typing import Awaitable, Callable, Dict
- from .config import Config, RelayNode
- from .protocol import AUTH, STATUS_OK, Frame, encode_json, read_frame, write_frame
- from .scheduler import Scheduler
- FrameHandler = Callable[["RelayConnection", Frame], Awaitable[None]]
- @dataclass
- class RelayConnection:
- node: RelayNode
- manager: "RelayManager"
- reader: asyncio.StreamReader
- writer: asyncio.StreamWriter
- closed: bool = False
- handlers: Dict[tuple[int, int], FrameHandler] = None
- pump_task: asyncio.Task | None = None
- def __post_init__(self) -> None:
- if self.handlers is None:
- self.handlers = {}
- async def start(self) -> None:
- print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}")
- await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
- frame = await read_frame(self.reader)
- if frame.kind != AUTH or frame.packet_id != STATUS_OK:
- raise ConnectionError(f"relay auth failed: {self.node.name}")
- print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port}")
- self.pump_task = asyncio.create_task(self._pump())
- async def _pump(self) -> None:
- try:
- while True:
- frame = await read_frame(self.reader)
- handler = self.handlers.get((frame.session_id, frame.stream_id))
- if handler:
- await handler(self, frame)
- except asyncio.IncompleteReadError:
- print(f"[edge] relay disconnected name={self.node.name} eof=true")
- except Exception as exc:
- print(f"[edge] relay pump error name={self.node.name} error={exc!r}")
- finally:
- await self.close()
- async def send(self, frame: Frame) -> None:
- if self.closed:
- raise ConnectionError(f"relay closed: {self.node.name}")
- await write_frame(self.writer, frame)
- def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None:
- self.handlers[(session_id, stream_id)] = handler
- def unbind(self, session_id: int, stream_id: int) -> None:
- self.handlers.pop((session_id, stream_id), None)
- async def close(self) -> None:
- if self.closed:
- return
- self.closed = True
- self.manager.on_closed(self)
- self.writer.close()
- with contextlib.suppress(Exception):
- await self.writer.wait_closed()
- class RelayManager:
- def __init__(self, config: Config) -> None:
- self.config = config
- self.scheduler = Scheduler(config)
- self.connections: Dict[str, RelayConnection] = {}
- self.tasks: list[asyncio.Task] = []
- async def start(self) -> None:
- await self.scheduler.start()
- for node in self.config.relays:
- self.tasks.append(asyncio.create_task(self._maintain(node)))
- async def _maintain(self, node: RelayNode) -> None:
- while True:
- if node.name in self.connections and not self.connections[node.name].closed:
- await asyncio.sleep(2)
- continue
- try:
- reader, writer = await asyncio.open_connection(node.host, node.port)
- connection = RelayConnection(node=node, manager=self, reader=reader, writer=writer)
- await connection.start()
- self.connections[node.name] = connection
- await connection.pump_task
- except Exception as exc:
- print(f"[edge] relay connect failed name={node.name} addr={node.host}:{node.port} error={exc!r}")
- await asyncio.sleep(3)
- def on_closed(self, connection: RelayConnection) -> None:
- current = self.connections.get(connection.node.name)
- if current is connection:
- self.connections.pop(connection.node.name, None)
- def available(self) -> list[RelayConnection]:
- chosen = {node.name for node in self.scheduler.choose()}
- preferred = [self.connections[name] for name in chosen if name in self.connections and not self.connections[name].closed]
- if preferred:
- return preferred
- return [conn for conn in self.connections.values() if not conn.closed]
- def snapshot(self) -> list[dict[str, object]]:
- data = self.scheduler.snapshot()
- online = {name for name, conn in self.connections.items() if not conn.closed}
- for item in data:
- item["online"] = item["name"] in online
- return data
|