| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- from __future__ import annotations
- import asyncio
- import time
- from dataclasses import dataclass
- from .config_udp import UdpConfig, UdpRelayNode
- from .protocol import AUTH, PING, PONG, STATUS_OK, Frame, encode_json, read_frame, write_frame
- @dataclass
- class UdpRelayScore:
- node: UdpRelayNode
- latency_ms: float = 9999.0
- failures: int = 0
- last_ok: float = 0.0
- @property
- def score(self) -> float:
- return self.latency_ms + (self.failures * 250.0) - (self.node.weight * 0.1)
- class UdpScheduler:
- def __init__(self, config: UdpConfig) -> None:
- self.config = config
- self.scores = {node.name: UdpRelayScore(node=node) for node in config.relays}
- self._task: asyncio.Task | None = None
- async def start(self) -> None:
- if self._task is None and self.config.relays:
- self._task = asyncio.create_task(self._probe_loop())
- async def _probe_loop(self) -> None:
- while True:
- await asyncio.gather(*(self._probe(node) for node in self.config.relays), return_exceptions=True)
- await asyncio.sleep(self.config.probe_interval)
- async def _probe(self, node: UdpRelayNode) -> None:
- started = time.perf_counter()
- try:
- reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=3)
- await write_frame(writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": node.token, "purpose": "probe"})))
- auth = await asyncio.wait_for(read_frame(reader), timeout=3)
- if auth.kind != AUTH or auth.packet_id != STATUS_OK:
- raise ConnectionError(node.name)
- await write_frame(writer, Frame(PING, 0, 0, 1, 0, b""))
- pong = await asyncio.wait_for(read_frame(reader), timeout=3)
- if pong.kind != PONG:
- raise ConnectionError(node.name)
- writer.close()
- await writer.wait_closed()
- elapsed = (time.perf_counter() - started) * 1000
- score = self.scores[node.name]
- score.latency_ms = elapsed
- score.last_ok = time.time()
- score.failures = max(0, score.failures - 1)
- except Exception:
- self.scores[node.name].failures += 1
- def choose(self) -> list[UdpRelayNode]:
- if not self.scores:
- return []
- ordered = sorted(self.scores.values(), key=lambda item: item.score)
- if self.config.strategy == "broadcast":
- limit = min(len(ordered), max(1, self.config.redundancy))
- elif self.config.strategy == "backup":
- limit = 1
- elif self.config.strategy == "top4":
- limit = min(len(ordered), max(1, self.config.redundancy, 4))
- elif self.config.strategy == "top3":
- limit = min(len(ordered), max(1, self.config.redundancy, 3))
- else:
- limit = min(len(ordered), max(1, self.config.redundancy, 2))
- return [item.node for item in ordered[:limit]]
- def snapshot(self) -> list[dict[str, object]]:
- ordered = sorted(self.scores.values(), key=lambda item: item.score)
- return [{"name": item.node.name, "host": item.node.host, "port": item.node.port, "latency_ms": round(item.latency_ms, 2), "failures": item.failures, "score": round(item.score, 2)} for item in ordered]
|