scheduler.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from __future__ import annotations
  2. import asyncio
  3. import time
  4. from dataclasses import dataclass
  5. from .config import Config, RelayNode
  6. @dataclass
  7. class RelayScore:
  8. node: RelayNode
  9. latency_ms: float = 9999.0
  10. failures: int = 0
  11. last_ok: float = 0.0
  12. @property
  13. def score(self) -> float:
  14. penalty = self.failures * 250.0
  15. return self.latency_ms + penalty - (self.node.weight * 0.1)
  16. class Scheduler:
  17. def __init__(self, config: Config) -> None:
  18. self.config = config
  19. self.scores = {node.name: RelayScore(node=node) for node in config.relays}
  20. self._task: asyncio.Task | None = None
  21. async def start(self) -> None:
  22. if self._task is None:
  23. self._task = asyncio.create_task(self._probe_loop())
  24. async def _probe_loop(self) -> None:
  25. while True:
  26. await asyncio.gather(*(self._probe(node) for node in self.config.relays), return_exceptions=True)
  27. await asyncio.sleep(self.config.probe_interval)
  28. async def _probe(self, node: RelayNode) -> None:
  29. started = time.perf_counter()
  30. try:
  31. reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=3)
  32. writer.close()
  33. await writer.wait_closed()
  34. elapsed = (time.perf_counter() - started) * 1000
  35. score = self.scores[node.name]
  36. score.latency_ms = elapsed
  37. score.last_ok = time.time()
  38. score.failures = max(0, score.failures - 1)
  39. except Exception:
  40. self.scores[node.name].failures += 1
  41. def choose(self) -> list[RelayNode]:
  42. ordered = sorted(self.scores.values(), key=lambda item: item.score)
  43. if self.config.strategy == "broadcast":
  44. limit = min(len(ordered), max(1, self.config.redundancy))
  45. return [item.node for item in ordered[:limit]]
  46. if self.config.strategy == "backup":
  47. return [item.node for item in ordered[:1]]
  48. limit = min(len(ordered), max(1, self.config.redundancy, 2))
  49. return [item.node for item in ordered[:limit]]
  50. def snapshot(self) -> list[dict[str, object]]:
  51. ordered = sorted(self.scores.values(), key=lambda item: item.score)
  52. return [
  53. {
  54. "name": item.node.name,
  55. "host": item.node.host,
  56. "port": item.node.port,
  57. "latency_ms": round(item.latency_ms, 2),
  58. "failures": item.failures,
  59. "score": round(item.score, 2),
  60. }
  61. for item in ordered
  62. ]