| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- from __future__ import annotations
- import asyncio
- import time
- from dataclasses import dataclass
- from .config import Config, RelayNode
- from .protocol import AUTH, PING, PONG, STATUS_OK, Frame, encode_json, read_frame, write_frame
- @dataclass
- class RelayScore:
- node: RelayNode
- latency_ms: float = 9999.0
- failures: int = 0
- last_ok: float = 0.0
- @property
- def score(self) -> float:
- penalty = self.failures * 250.0
- return self.latency_ms + penalty - (self.node.weight * 0.1)
- class Scheduler:
- def __init__(self, config: Config) -> None:
- self.config = config
- self.scores = {node.name: RelayScore(node=node) for node in config.relays}
- self._task: asyncio.Task | None = None
- async def start(self) -> None:
- if self._task is None:
- self._task = asyncio.create_task(self._probe_loop())
- async def _probe_loop(self) -> None:
- while True:
- await asyncio.gather(*(self._probe(node) for node in self.config.relays), return_exceptions=True)
- await asyncio.sleep(self.config.probe_interval)
- async def _probe(self, node: RelayNode) -> None:
- started = time.perf_counter()
- try:
- reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=3)
- await write_frame(writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": node.token, "purpose": "probe"})))
- auth = await asyncio.wait_for(read_frame(reader), timeout=3)
- if auth.kind != AUTH or auth.packet_id != STATUS_OK:
- raise ConnectionError(f"relay auth probe failed: {node.name}")
- await write_frame(writer, Frame(PING, 0, 0, 1, 0, b""))
- pong = await asyncio.wait_for(read_frame(reader), timeout=3)
- if pong.kind != PONG:
- raise ConnectionError(f"relay ping probe failed: {node.name}")
- writer.close()
- await writer.wait_closed()
- elapsed = (time.perf_counter() - started) * 1000
- score = self.scores[node.name]
- score.latency_ms = elapsed
- score.last_ok = time.time()
- score.failures = max(0, score.failures - 1)
- except Exception:
- self.scores[node.name].failures += 1
- def choose(self) -> list[RelayNode]:
- ordered = sorted(self.scores.values(), key=lambda item: item.score)
- if self.config.strategy == "broadcast":
- limit = min(len(ordered), max(1, self.config.redundancy))
- return [item.node for item in ordered[:limit]]
- if self.config.strategy == "backup":
- return [item.node for item in ordered[:1]]
- if self.config.strategy == "top4":
- limit = min(len(ordered), max(1, self.config.redundancy, 4))
- return [item.node for item in ordered[:limit]]
- if self.config.strategy == "top3":
- limit = min(len(ordered), max(1, self.config.redundancy, 3))
- return [item.node for item in ordered[:limit]]
- limit = min(len(ordered), max(1, self.config.redundancy, 2))
- return [item.node for item in ordered[:limit]]
- def snapshot(self) -> list[dict[str, object]]:
- ordered = sorted(self.scores.values(), key=lambda item: item.score)
- return [
- {
- "name": item.node.name,
- "host": item.node.host,
- "port": item.node.port,
- "latency_ms": round(item.latency_ms, 2),
- "failures": item.failures,
- "score": round(item.score, 2),
- }
- for item in ordered
- ]
|