scheduler.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from __future__ import annotations
  2. import asyncio
  3. import time
  4. from dataclasses import dataclass
  5. from .config import Config, RelayNode
  6. from .protocol import AUTH, PING, PONG, STATUS_OK, Frame, encode_json, read_frame, write_frame
  7. @dataclass
  8. class RelayScore:
  9. node: RelayNode
  10. latency_ms: float = 9999.0
  11. failures: int = 0
  12. last_ok: float = 0.0
  13. @property
  14. def score(self) -> float:
  15. penalty = self.failures * 250.0
  16. return self.latency_ms + penalty - (self.node.weight * 0.1)
  17. class Scheduler:
  18. def __init__(self, config: Config) -> None:
  19. self.config = config
  20. self.scores = {node.name: RelayScore(node=node) for node in config.relays}
  21. self._task: asyncio.Task | None = None
  22. async def start(self) -> None:
  23. if self._task is None:
  24. self._task = asyncio.create_task(self._probe_loop())
  25. async def _probe_loop(self) -> None:
  26. while True:
  27. await asyncio.gather(*(self._probe(node) for node in self.config.relays), return_exceptions=True)
  28. await asyncio.sleep(self.config.probe_interval)
  29. async def _probe(self, node: RelayNode) -> None:
  30. started = time.perf_counter()
  31. try:
  32. reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=3)
  33. await write_frame(writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": node.token})))
  34. auth = await asyncio.wait_for(read_frame(reader), timeout=3)
  35. if auth.kind != AUTH or auth.packet_id != STATUS_OK:
  36. raise ConnectionError(f"relay auth probe failed: {node.name}")
  37. await write_frame(writer, Frame(PING, 0, 0, 1, 0, b""))
  38. pong = await asyncio.wait_for(read_frame(reader), timeout=3)
  39. if pong.kind != PONG:
  40. raise ConnectionError(f"relay ping probe failed: {node.name}")
  41. writer.close()
  42. await writer.wait_closed()
  43. elapsed = (time.perf_counter() - started) * 1000
  44. score = self.scores[node.name]
  45. score.latency_ms = elapsed
  46. score.last_ok = time.time()
  47. score.failures = max(0, score.failures - 1)
  48. except Exception:
  49. self.scores[node.name].failures += 1
  50. def choose(self) -> list[RelayNode]:
  51. ordered = sorted(self.scores.values(), key=lambda item: item.score)
  52. if self.config.strategy == "broadcast":
  53. limit = min(len(ordered), max(1, self.config.redundancy))
  54. return [item.node for item in ordered[:limit]]
  55. if self.config.strategy == "backup":
  56. return [item.node for item in ordered[:1]]
  57. if self.config.strategy == "top4":
  58. limit = min(len(ordered), max(1, self.config.redundancy, 4))
  59. return [item.node for item in ordered[:limit]]
  60. if self.config.strategy == "top3":
  61. limit = min(len(ordered), max(1, self.config.redundancy, 3))
  62. return [item.node for item in ordered[:limit]]
  63. limit = min(len(ordered), max(1, self.config.redundancy, 2))
  64. return [item.node for item in ordered[:limit]]
  65. def snapshot(self) -> list[dict[str, object]]:
  66. ordered = sorted(self.scores.values(), key=lambda item: item.score)
  67. return [
  68. {
  69. "name": item.node.name,
  70. "host": item.node.host,
  71. "port": item.node.port,
  72. "latency_ms": round(item.latency_ms, 2),
  73. "failures": item.failures,
  74. "score": round(item.score, 2),
  75. }
  76. for item in ordered
  77. ]