scheduler_udp.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from __future__ import annotations
  2. import asyncio
  3. import time
  4. from dataclasses import dataclass
  5. from .config_udp import UdpConfig, UdpRelayNode
  6. from .protocol import AUTH, PING, PONG, STATUS_OK, Frame, encode_json, read_frame, write_frame
  7. @dataclass
  8. class UdpRelayScore:
  9. node: UdpRelayNode
  10. latency_ms: float = 9999.0
  11. failures: int = 0
  12. last_ok: float = 0.0
  13. @property
  14. def score(self) -> float:
  15. return self.latency_ms + (self.failures * 250.0) - (self.node.weight * 0.1)
  16. class UdpScheduler:
  17. def __init__(self, config: UdpConfig) -> None:
  18. self.config = config
  19. self.scores = {node.name: UdpRelayScore(node=node) for node in config.relays}
  20. self._task: asyncio.Task | None = None
  21. async def start(self) -> None:
  22. if self._task is None and self.config.relays:
  23. self._task = asyncio.create_task(self._probe_loop())
  24. async def _probe_loop(self) -> None:
  25. while True:
  26. await asyncio.gather(*(self._probe(node) for node in self.config.relays), return_exceptions=True)
  27. await asyncio.sleep(self.config.probe_interval)
  28. async def _probe(self, node: UdpRelayNode) -> None:
  29. started = time.perf_counter()
  30. try:
  31. reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=3)
  32. await write_frame(writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": node.token, "purpose": "probe"})))
  33. auth = await asyncio.wait_for(read_frame(reader), timeout=3)
  34. if auth.kind != AUTH or auth.packet_id != STATUS_OK:
  35. raise ConnectionError(node.name)
  36. await write_frame(writer, Frame(PING, 0, 0, 1, 0, b""))
  37. pong = await asyncio.wait_for(read_frame(reader), timeout=3)
  38. if pong.kind != PONG:
  39. raise ConnectionError(node.name)
  40. writer.close()
  41. await writer.wait_closed()
  42. elapsed = (time.perf_counter() - started) * 1000
  43. score = self.scores[node.name]
  44. score.latency_ms = elapsed
  45. score.last_ok = time.time()
  46. score.failures = max(0, score.failures - 1)
  47. except Exception:
  48. self.scores[node.name].failures += 1
  49. def choose(self) -> list[UdpRelayNode]:
  50. if not self.scores:
  51. return []
  52. ordered = sorted(self.scores.values(), key=lambda item: item.score)
  53. if self.config.strategy == "broadcast":
  54. limit = min(len(ordered), max(1, self.config.redundancy))
  55. elif self.config.strategy == "backup":
  56. limit = 1
  57. elif self.config.strategy == "top4":
  58. limit = min(len(ordered), max(1, self.config.redundancy, 4))
  59. elif self.config.strategy == "top3":
  60. limit = min(len(ordered), max(1, self.config.redundancy, 3))
  61. else:
  62. limit = min(len(ordered), max(1, self.config.redundancy, 2))
  63. return [item.node for item in ordered[:limit]]
  64. def snapshot(self) -> list[dict[str, object]]:
  65. ordered = sorted(self.scores.values(), key=lambda item: item.score)
  66. return [{"name": item.node.name, "host": item.node.host, "port": item.node.port, "latency_ms": round(item.latency_ms, 2), "failures": item.failures, "score": round(item.score, 2)} for item in ordered]