relay_client.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import socket
  5. from dataclasses import dataclass
  6. from typing import Awaitable, Callable, Dict
  7. from .config import Config, RelayNode
  8. from .protocol import AUTH, STATUS_OK, Frame, encode_json, read_frame, write_frame
  9. from .scheduler import Scheduler
  10. FrameHandler = Callable[["RelayConnection", Frame], Awaitable[None]]
  11. @dataclass
  12. class RelayConnection:
  13. node: RelayNode
  14. manager: "RelayManager"
  15. reader: asyncio.StreamReader
  16. writer: asyncio.StreamWriter
  17. closed: bool = False
  18. handlers: Dict[tuple[int, int], FrameHandler] = None
  19. pump_task: asyncio.Task | None = None
  20. def __post_init__(self) -> None:
  21. if self.handlers is None:
  22. self.handlers = {}
  23. async def start(self) -> None:
  24. print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}")
  25. await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
  26. frame = await read_frame(self.reader)
  27. if frame.kind != AUTH or frame.packet_id != STATUS_OK:
  28. raise ConnectionError(f"relay auth failed: {self.node.name}")
  29. print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port}")
  30. self.pump_task = asyncio.create_task(self._pump())
  31. async def _pump(self) -> None:
  32. try:
  33. while True:
  34. frame = await read_frame(self.reader)
  35. handler = self.handlers.get((frame.session_id, frame.stream_id))
  36. if handler:
  37. await handler(self, frame)
  38. except asyncio.IncompleteReadError:
  39. print(f"[edge] relay disconnected name={self.node.name} eof=true")
  40. except Exception as exc:
  41. print(f"[edge] relay pump error name={self.node.name} error={exc!r}")
  42. finally:
  43. await self.close()
  44. async def send(self, frame: Frame) -> None:
  45. if self.closed:
  46. raise ConnectionError(f"relay closed: {self.node.name}")
  47. await write_frame(self.writer, frame)
  48. def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None:
  49. self.handlers[(session_id, stream_id)] = handler
  50. def unbind(self, session_id: int, stream_id: int) -> None:
  51. self.handlers.pop((session_id, stream_id), None)
  52. async def close(self) -> None:
  53. if self.closed:
  54. return
  55. self.closed = True
  56. self.manager.on_closed(self)
  57. self.writer.close()
  58. with contextlib.suppress(Exception):
  59. await self.writer.wait_closed()
  60. class RelayManager:
  61. def __init__(self, config: Config) -> None:
  62. self.config = config
  63. self.scheduler = Scheduler(config)
  64. self.connections: Dict[str, RelayConnection] = {}
  65. self.tasks: list[asyncio.Task] = []
  66. async def start(self) -> None:
  67. await self.scheduler.start()
  68. for node in self.config.relays:
  69. self.tasks.append(asyncio.create_task(self._maintain(node)))
  70. async def _maintain(self, node: RelayNode) -> None:
  71. while True:
  72. if node.name in self.connections and not self.connections[node.name].closed:
  73. await asyncio.sleep(2)
  74. continue
  75. try:
  76. reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=self.config.relay_open_timeout)
  77. connection = RelayConnection(node=node, manager=self, reader=reader, writer=writer)
  78. sock = writer.get_extra_info("socket")
  79. if sock is not None and self.config.relay_tcp_nodelay:
  80. with contextlib.suppress(OSError):
  81. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  82. await connection.start()
  83. self.connections[node.name] = connection
  84. await connection.pump_task
  85. except Exception as exc:
  86. print(f"[edge] relay connect failed name={node.name} addr={node.host}:{node.port} error={exc!r}")
  87. await asyncio.sleep(self.config.relay_reconnect_delay)
  88. def on_closed(self, connection: RelayConnection) -> None:
  89. current = self.connections.get(connection.node.name)
  90. if current is connection:
  91. self.connections.pop(connection.node.name, None)
  92. def available(self) -> list[RelayConnection]:
  93. chosen = {node.name for node in self.scheduler.choose()}
  94. preferred = [self.connections[name] for name in chosen if name in self.connections and not self.connections[name].closed]
  95. if preferred:
  96. return preferred
  97. return [conn for conn in self.connections.values() if not conn.closed]
  98. def snapshot(self) -> list[dict[str, object]]:
  99. data = self.scheduler.snapshot()
  100. online = {name for name, conn in self.connections.items() if not conn.closed}
  101. for item in data:
  102. item["online"] = item["name"] in online
  103. return data