from __future__ import annotations import asyncio import time from dataclasses import dataclass from .config import Config, RelayNode from .protocol import AUTH, PING, PONG, STATUS_OK, Frame, encode_json, read_frame, write_frame @dataclass class RelayScore: node: RelayNode latency_ms: float = 9999.0 failures: int = 0 last_ok: float = 0.0 @property def score(self) -> float: penalty = self.failures * 250.0 return self.latency_ms + penalty - (self.node.weight * 0.1) class Scheduler: def __init__(self, config: Config) -> None: self.config = config self.scores = {node.name: RelayScore(node=node) for node in config.relays} self._task: asyncio.Task | None = None async def start(self) -> None: if self._task is None: self._task = asyncio.create_task(self._probe_loop()) async def _probe_loop(self) -> None: while True: await asyncio.gather(*(self._probe(node) for node in self.config.relays), return_exceptions=True) await asyncio.sleep(self.config.probe_interval) async def _probe(self, node: RelayNode) -> None: started = time.perf_counter() try: reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=3) await write_frame(writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": node.token}))) auth = await asyncio.wait_for(read_frame(reader), timeout=3) if auth.kind != AUTH or auth.packet_id != STATUS_OK: raise ConnectionError(f"relay auth probe failed: {node.name}") await write_frame(writer, Frame(PING, 0, 0, 1, 0, b"")) pong = await asyncio.wait_for(read_frame(reader), timeout=3) if pong.kind != PONG: raise ConnectionError(f"relay ping probe failed: {node.name}") writer.close() await writer.wait_closed() elapsed = (time.perf_counter() - started) * 1000 score = self.scores[node.name] score.latency_ms = elapsed score.last_ok = time.time() score.failures = max(0, score.failures - 1) except Exception: self.scores[node.name].failures += 1 def choose(self) -> list[RelayNode]: ordered = sorted(self.scores.values(), key=lambda item: item.score) if self.config.strategy == "broadcast": limit = min(len(ordered), max(1, self.config.redundancy)) return [item.node for item in ordered[:limit]] if self.config.strategy == "backup": return [item.node for item in ordered[:1]] if self.config.strategy == "top4": limit = min(len(ordered), max(1, self.config.redundancy, 4)) return [item.node for item in ordered[:limit]] if self.config.strategy == "top3": limit = min(len(ordered), max(1, self.config.redundancy, 3)) return [item.node for item in ordered[:limit]] limit = min(len(ordered), max(1, self.config.redundancy, 2)) return [item.node for item in ordered[:limit]] def snapshot(self) -> list[dict[str, object]]: ordered = sorted(self.scores.values(), key=lambda item: item.score) return [ { "name": item.node.name, "host": item.node.host, "port": item.node.port, "latency_ms": round(item.latency_ms, 2), "failures": item.failures, "score": round(item.score, 2), } for item in ordered ]