relay_client.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import socket
  5. from dataclasses import dataclass
  6. import time
  7. from typing import Awaitable, Callable, Dict
  8. from .config import Config, RelayNode
  9. from .protocol import AUTH, PING, PONG, STATUS_OK, TCP_CLOSE, Frame, encode_json, read_frame, write_frame
  10. from .scheduler import Scheduler
  11. FrameHandler = Callable[["RelayConnection", Frame], Awaitable[None]]
  12. @dataclass
  13. class RelayConnection:
  14. node: RelayNode
  15. manager: "RelayManager"
  16. reader: asyncio.StreamReader
  17. writer: asyncio.StreamWriter
  18. closed: bool = False
  19. handlers: Dict[tuple[int, int], FrameHandler] = None
  20. dispatch_tasks: Dict[tuple[int, int], asyncio.Task] = None
  21. pump_task: asyncio.Task | None = None
  22. keepalive_task: asyncio.Task | None = None
  23. last_pong_at: float = 0.0
  24. send_lock: asyncio.Lock | None = None
  25. def __post_init__(self) -> None:
  26. if self.handlers is None:
  27. self.handlers = {}
  28. if self.dispatch_tasks is None:
  29. self.dispatch_tasks = {}
  30. if self.send_lock is None:
  31. self.send_lock = asyncio.Lock()
  32. async def start(self) -> None:
  33. print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}")
  34. await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
  35. frame = await read_frame(self.reader)
  36. if frame.kind != AUTH or frame.packet_id != STATUS_OK:
  37. raise ConnectionError(f"relay auth failed: {self.node.name}")
  38. print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port}")
  39. self.last_pong_at = time.monotonic()
  40. self.keepalive_task = asyncio.create_task(self._keepalive())
  41. self.pump_task = asyncio.create_task(self._pump())
  42. async def _keepalive(self) -> None:
  43. try:
  44. while not self.closed:
  45. await asyncio.sleep(self.manager.config.relay_ping_interval)
  46. if self.closed:
  47. break
  48. if self.last_pong_at and time.monotonic() - self.last_pong_at > (self.manager.config.relay_ping_interval + self.manager.config.relay_ping_timeout):
  49. print(f"[edge] relay health timeout name={self.node.name} addr={self.node.host}:{self.node.port} timeout={self.manager.config.relay_ping_timeout}")
  50. await self.close()
  51. break
  52. await self.send(Frame(PING, 0, 0, 0, 0, b""))
  53. except asyncio.CancelledError:
  54. pass
  55. except Exception:
  56. await self.close()
  57. async def _pump(self) -> None:
  58. try:
  59. while True:
  60. frame = await read_frame(self.reader)
  61. if frame.kind == PONG:
  62. self.last_pong_at = time.monotonic()
  63. continue
  64. handler = self.handlers.get((frame.session_id, frame.stream_id))
  65. if handler:
  66. self._dispatch_frame(frame, handler)
  67. else:
  68. print(f"[edge] relay frame dropped name={self.node.name} session={frame.session_id} stream={frame.stream_id} kind={frame.kind}")
  69. except asyncio.IncompleteReadError:
  70. print(f"[edge] relay disconnected name={self.node.name} eof=true")
  71. except Exception as exc:
  72. print(f"[edge] relay pump error name={self.node.name} error={exc!r}")
  73. finally:
  74. await self.close()
  75. def _dispatch_frame(self, frame: Frame, handler: FrameHandler) -> None:
  76. key = (frame.session_id, frame.stream_id)
  77. previous = self.dispatch_tasks.get(key)
  78. task = asyncio.create_task(self._run_handler(key, frame, handler, previous))
  79. self.dispatch_tasks[key] = task
  80. async def _run_handler(self, key: tuple[int, int], frame: Frame, handler: FrameHandler, previous: asyncio.Task | None) -> None:
  81. try:
  82. if previous is not None:
  83. with contextlib.suppress(Exception):
  84. await previous
  85. if self.closed:
  86. return
  87. await handler(self, frame)
  88. except asyncio.CancelledError:
  89. pass
  90. except Exception:
  91. if not self.closed:
  92. await self.close()
  93. finally:
  94. if self.dispatch_tasks.get(key) is asyncio.current_task():
  95. self.dispatch_tasks.pop(key, None)
  96. async def send(self, frame: Frame) -> None:
  97. if self.closed:
  98. raise ConnectionError(f"relay closed: {self.node.name}")
  99. assert self.send_lock is not None
  100. async with self.send_lock:
  101. if self.closed:
  102. raise ConnectionError(f"relay closed: {self.node.name}")
  103. await write_frame(self.writer, frame)
  104. def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None:
  105. self.handlers[(session_id, stream_id)] = handler
  106. def unbind(self, session_id: int, stream_id: int) -> None:
  107. self.handlers.pop((session_id, stream_id), None)
  108. task = self.dispatch_tasks.pop((session_id, stream_id), None)
  109. if task is not None:
  110. task.cancel()
  111. async def close(self) -> None:
  112. if self.closed:
  113. return
  114. self.closed = True
  115. handlers = list(self.handlers.items())
  116. self.handlers.clear()
  117. dispatch_tasks = list(self.dispatch_tasks.values())
  118. self.dispatch_tasks.clear()
  119. self.manager.on_closed(self)
  120. for (session_id, stream_id), handler in handlers:
  121. with contextlib.suppress(Exception):
  122. await handler(self, Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
  123. for task in dispatch_tasks:
  124. task.cancel()
  125. for task in dispatch_tasks:
  126. with contextlib.suppress(Exception):
  127. await task
  128. if self.pump_task and self.pump_task is not asyncio.current_task():
  129. self.pump_task.cancel()
  130. with contextlib.suppress(Exception):
  131. await self.pump_task
  132. if self.keepalive_task and self.keepalive_task is not asyncio.current_task():
  133. self.keepalive_task.cancel()
  134. with contextlib.suppress(Exception):
  135. await self.keepalive_task
  136. self.writer.close()
  137. with contextlib.suppress(Exception):
  138. await self.writer.wait_closed()
  139. class RelayManager:
  140. def __init__(self, config: Config) -> None:
  141. self.config = config
  142. self.scheduler = Scheduler(config)
  143. self.connections: Dict[str, RelayConnection] = {}
  144. self.tasks: list[asyncio.Task] = []
  145. async def start(self) -> None:
  146. await self.scheduler.start()
  147. for node in self.config.relays:
  148. self.tasks.append(asyncio.create_task(self._maintain(node)))
  149. async def _maintain(self, node: RelayNode) -> None:
  150. while True:
  151. if node.name in self.connections and not self.connections[node.name].closed:
  152. await asyncio.sleep(2)
  153. continue
  154. connected = False
  155. for attempt in range(1, self.config.relay_reconnect_attempts + 1):
  156. try:
  157. print(f"[edge] relay reconnect attempt name={node.name} addr={node.host}:{node.port} attempt={attempt}/{self.config.relay_reconnect_attempts}")
  158. reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=self.config.relay_open_timeout)
  159. connection = RelayConnection(node=node, manager=self, reader=reader, writer=writer)
  160. sock = writer.get_extra_info("socket")
  161. if sock is not None and self.config.relay_tcp_nodelay:
  162. with contextlib.suppress(OSError):
  163. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  164. await connection.start()
  165. self.connections[node.name] = connection
  166. connected = True
  167. await connection.pump_task
  168. break
  169. except Exception as exc:
  170. print(f"[edge] relay connect failed name={node.name} addr={node.host}:{node.port} attempt={attempt}/{self.config.relay_reconnect_attempts} error={exc!r}")
  171. if attempt < self.config.relay_reconnect_attempts:
  172. await asyncio.sleep(self.config.relay_reconnect_delay)
  173. if not connected:
  174. print(f"[edge] relay reconnect exhausted name={node.name} addr={node.host}:{node.port} attempts={self.config.relay_reconnect_attempts}")
  175. await asyncio.sleep(self.config.relay_reconnect_delay)
  176. def on_closed(self, connection: RelayConnection) -> None:
  177. current = self.connections.get(connection.node.name)
  178. if current is connection:
  179. self.connections.pop(connection.node.name, None)
  180. def available(self) -> list[RelayConnection]:
  181. chosen = {node.name for node in self.scheduler.choose()}
  182. preferred = [self.connections[name] for name in chosen if name in self.connections and not self.connections[name].closed]
  183. if preferred:
  184. return preferred
  185. return [conn for conn in self.connections.values() if not conn.closed]
  186. def snapshot(self) -> list[dict[str, object]]:
  187. data = self.scheduler.snapshot()
  188. online = {name for name, conn in self.connections.items() if not conn.closed}
  189. for item in data:
  190. item["online"] = item["name"] in online
  191. return data