relay_client.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. from dataclasses import dataclass
  5. from typing import Awaitable, Callable, Dict
  6. from .config import Config, RelayNode
  7. from .protocol import AUTH, STATUS_OK, Frame, encode_json, read_frame, write_frame
  8. from .scheduler import Scheduler
  9. FrameHandler = Callable[["RelayConnection", Frame], Awaitable[None]]
  10. @dataclass
  11. class RelayConnection:
  12. node: RelayNode
  13. manager: "RelayManager"
  14. reader: asyncio.StreamReader
  15. writer: asyncio.StreamWriter
  16. closed: bool = False
  17. handlers: Dict[tuple[int, int], FrameHandler] = None
  18. pump_task: asyncio.Task | None = None
  19. def __post_init__(self) -> None:
  20. if self.handlers is None:
  21. self.handlers = {}
  22. async def start(self) -> None:
  23. print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}")
  24. await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
  25. frame = await read_frame(self.reader)
  26. if frame.kind != AUTH or frame.packet_id != STATUS_OK:
  27. raise ConnectionError(f"relay auth failed: {self.node.name}")
  28. print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port}")
  29. self.pump_task = asyncio.create_task(self._pump())
  30. async def _pump(self) -> None:
  31. try:
  32. while True:
  33. frame = await read_frame(self.reader)
  34. handler = self.handlers.get((frame.session_id, frame.stream_id))
  35. if handler:
  36. await handler(self, frame)
  37. except asyncio.IncompleteReadError:
  38. print(f"[edge] relay disconnected name={self.node.name} eof=true")
  39. except Exception as exc:
  40. print(f"[edge] relay pump error name={self.node.name} error={exc!r}")
  41. finally:
  42. await self.close()
  43. async def send(self, frame: Frame) -> None:
  44. if self.closed:
  45. raise ConnectionError(f"relay closed: {self.node.name}")
  46. await write_frame(self.writer, frame)
  47. def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None:
  48. self.handlers[(session_id, stream_id)] = handler
  49. def unbind(self, session_id: int, stream_id: int) -> None:
  50. self.handlers.pop((session_id, stream_id), None)
  51. async def close(self) -> None:
  52. if self.closed:
  53. return
  54. self.closed = True
  55. self.manager.on_closed(self)
  56. self.writer.close()
  57. with contextlib.suppress(Exception):
  58. await self.writer.wait_closed()
  59. class RelayManager:
  60. def __init__(self, config: Config) -> None:
  61. self.config = config
  62. self.scheduler = Scheduler(config)
  63. self.connections: Dict[str, RelayConnection] = {}
  64. self.tasks: list[asyncio.Task] = []
  65. async def start(self) -> None:
  66. await self.scheduler.start()
  67. for node in self.config.relays:
  68. self.tasks.append(asyncio.create_task(self._maintain(node)))
  69. async def _maintain(self, node: RelayNode) -> None:
  70. while True:
  71. if node.name in self.connections and not self.connections[node.name].closed:
  72. await asyncio.sleep(2)
  73. continue
  74. try:
  75. reader, writer = await asyncio.open_connection(node.host, node.port)
  76. connection = RelayConnection(node=node, manager=self, reader=reader, writer=writer)
  77. await connection.start()
  78. self.connections[node.name] = connection
  79. await connection.pump_task
  80. except Exception as exc:
  81. print(f"[edge] relay connect failed name={node.name} addr={node.host}:{node.port} error={exc!r}")
  82. await asyncio.sleep(3)
  83. def on_closed(self, connection: RelayConnection) -> None:
  84. current = self.connections.get(connection.node.name)
  85. if current is connection:
  86. self.connections.pop(connection.node.name, None)
  87. def available(self) -> list[RelayConnection]:
  88. chosen = {node.name for node in self.scheduler.choose()}
  89. preferred = [self.connections[name] for name in chosen if name in self.connections and not self.connections[name].closed]
  90. if preferred:
  91. return preferred
  92. return [conn for conn in self.connections.values() if not conn.closed]
  93. def snapshot(self) -> list[dict[str, object]]:
  94. data = self.scheduler.snapshot()
  95. online = {name for name, conn in self.connections.items() if not conn.closed}
  96. for item in data:
  97. item["online"] = item["name"] in online
  98. return data