relay_client_tcp.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import json
  5. import random
  6. import socket
  7. from dataclasses import dataclass
  8. import time
  9. from typing import Awaitable, Callable, Dict
  10. from .config_tcp import TcpConfig, TcpRelayNode
  11. from .logging_utils import log_print as print
  12. from .protocol import AUTH, PING, PONG, STATUS_OK, TCP_CLOSE, Frame, encode_json, read_frame, write_frame
  13. from .scheduler_tcp import TcpScheduler
  14. FrameHandler = Callable[["TcpRelayConnection", Frame], Awaitable[None]]
  15. @dataclass
  16. class TcpRelayConnection:
  17. node: TcpRelayNode
  18. manager: "TcpRelayManager"
  19. reader: asyncio.StreamReader
  20. writer: asyncio.StreamWriter
  21. closed: bool = False
  22. handlers: Dict[tuple[int, int], FrameHandler] = None
  23. dispatch_tasks: Dict[tuple[int, int], asyncio.Task] = None
  24. pump_task: asyncio.Task | None = None
  25. keepalive_task: asyncio.Task | None = None
  26. last_pong_at: float = 0.0
  27. send_lock: asyncio.Lock | None = None
  28. closed_event: asyncio.Event | None = None
  29. dropped_frames: Dict[int, int] = None
  30. dropped_report_task: asyncio.Task | None = None
  31. def __post_init__(self) -> None:
  32. if self.handlers is None:
  33. self.handlers = {}
  34. if self.dispatch_tasks is None:
  35. self.dispatch_tasks = {}
  36. if self.send_lock is None:
  37. self.send_lock = asyncio.Lock()
  38. if self.closed_event is None:
  39. self.closed_event = asyncio.Event()
  40. if self.dropped_frames is None:
  41. self.dropped_frames = {}
  42. async def start(self) -> None:
  43. print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}")
  44. await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
  45. frame = await read_frame(self.reader)
  46. if frame.kind != AUTH or frame.packet_id != STATUS_OK:
  47. raise ConnectionError(f"relay auth failed: {self.node.name}")
  48. if frame.payload:
  49. with contextlib.suppress(Exception):
  50. json.loads(frame.payload.decode("utf-8"))
  51. self.last_pong_at = time.monotonic()
  52. self.keepalive_task = asyncio.create_task(self._keepalive())
  53. self.pump_task = asyncio.create_task(self._pump())
  54. print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port} mode=tcp")
  55. async def _keepalive(self) -> None:
  56. try:
  57. while not self.closed:
  58. await asyncio.sleep(self.manager.config.relay_ping_interval)
  59. if self.closed:
  60. break
  61. if self.last_pong_at and time.monotonic() - self.last_pong_at > (self.manager.config.relay_ping_interval + self.manager.config.relay_ping_timeout):
  62. print(f"[edge] relay health timeout name={self.node.name} addr={self.node.host}:{self.node.port} timeout={self.manager.config.relay_ping_timeout}")
  63. await self.close()
  64. break
  65. await self.send(Frame(PING, 0, 0, 0, 0, b""))
  66. except asyncio.CancelledError:
  67. pass
  68. except Exception:
  69. await self.close()
  70. async def _pump(self) -> None:
  71. try:
  72. while True:
  73. frame = await read_frame(self.reader)
  74. if frame.kind == PONG:
  75. self.last_pong_at = time.monotonic()
  76. continue
  77. handler = self.handlers.get((frame.session_id, frame.stream_id))
  78. if handler:
  79. self._dispatch_frame(frame, handler)
  80. else:
  81. self._record_dropped_frame(frame.kind)
  82. except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, OSError):
  83. pass
  84. except Exception:
  85. pass
  86. finally:
  87. await self.close()
  88. def _dispatch_frame(self, frame: Frame, handler: FrameHandler) -> None:
  89. key = (frame.session_id, frame.stream_id)
  90. previous = self.dispatch_tasks.get(key)
  91. task = asyncio.create_task(self._run_handler(key, frame, handler, previous))
  92. self.dispatch_tasks[key] = task
  93. async def _run_handler(self, key: tuple[int, int], frame: Frame, handler: FrameHandler, previous: asyncio.Task | None) -> None:
  94. try:
  95. if previous is not None:
  96. with contextlib.suppress(Exception):
  97. await previous
  98. if self.closed:
  99. return
  100. await handler(self, frame)
  101. except asyncio.CancelledError:
  102. pass
  103. except Exception:
  104. if not self.closed:
  105. await self.close()
  106. finally:
  107. if self.dispatch_tasks.get(key) is asyncio.current_task():
  108. self.dispatch_tasks.pop(key, None)
  109. def _record_dropped_frame(self, kind: int) -> None:
  110. self.dropped_frames[kind] = self.dropped_frames.get(kind, 0) + 1
  111. if self.dropped_report_task is None or self.dropped_report_task.done():
  112. self.dropped_report_task = asyncio.create_task(self._report_dropped_frames())
  113. async def _report_dropped_frames(self) -> None:
  114. try:
  115. await asyncio.sleep(5)
  116. dropped = self.dropped_frames
  117. self.dropped_frames = {}
  118. if dropped:
  119. detail = ", ".join(f"kind={kind} count={count}" for kind, count in sorted(dropped.items()))
  120. print(f"[edge] relay frame dropped summary name={self.node.name} {detail}")
  121. except asyncio.CancelledError:
  122. pass
  123. async def send(self, frame: Frame) -> None:
  124. if self.closed:
  125. raise ConnectionError(f"relay closed: {self.node.name}")
  126. assert self.send_lock is not None
  127. async with self.send_lock:
  128. if self.closed:
  129. raise ConnectionError(f"relay closed: {self.node.name}")
  130. await write_frame(self.writer, frame)
  131. def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None:
  132. self.handlers[(session_id, stream_id)] = handler
  133. def unbind(self, session_id: int, stream_id: int) -> None:
  134. self.handlers.pop((session_id, stream_id), None)
  135. task = self.dispatch_tasks.pop((session_id, stream_id), None)
  136. if task is not None:
  137. task.cancel()
  138. async def close(self) -> None:
  139. if self.closed:
  140. return
  141. self.closed = True
  142. assert self.closed_event is not None
  143. self.closed_event.set()
  144. handlers = list(self.handlers.items())
  145. self.handlers.clear()
  146. dispatch_tasks = list(self.dispatch_tasks.values())
  147. self.dispatch_tasks.clear()
  148. self.manager.on_closed(self)
  149. for (session_id, stream_id), handler in handlers:
  150. with contextlib.suppress(Exception):
  151. await handler(self, Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
  152. for task in dispatch_tasks:
  153. task.cancel()
  154. for task in dispatch_tasks:
  155. with contextlib.suppress(Exception):
  156. await task
  157. if self.dropped_report_task and self.dropped_report_task is not asyncio.current_task():
  158. self.dropped_report_task.cancel()
  159. with contextlib.suppress(Exception):
  160. await self.dropped_report_task
  161. if self.pump_task and self.pump_task is not asyncio.current_task():
  162. self.pump_task.cancel()
  163. with contextlib.suppress(Exception):
  164. await self.pump_task
  165. if self.keepalive_task and self.keepalive_task is not asyncio.current_task():
  166. self.keepalive_task.cancel()
  167. with contextlib.suppress(Exception):
  168. await self.keepalive_task
  169. self.writer.close()
  170. with contextlib.suppress(Exception):
  171. await self.writer.wait_closed()
  172. class TcpRelayManager:
  173. def __init__(self, config: TcpConfig) -> None:
  174. self.config = config
  175. self.scheduler = TcpScheduler(config)
  176. self.connections: Dict[str, TcpRelayConnection] = {}
  177. self.tasks: list[asyncio.Task] = []
  178. self._logged_attempts: set[str] = set()
  179. async def start(self) -> None:
  180. await self.scheduler.start()
  181. for node in self.config.relays:
  182. self.tasks.append(asyncio.create_task(self._maintain(node)))
  183. async def _maintain(self, node: TcpRelayNode) -> None:
  184. backoff = self.config.relay_reconnect_delay
  185. reconnect_attempt = 1
  186. failure_streak = 0
  187. healthy_since: float | None = None
  188. while True:
  189. current = self.connections.get(node.name)
  190. if current is not None and not current.closed:
  191. assert current.closed_event is not None
  192. await current.closed_event.wait()
  193. continue
  194. while True:
  195. try:
  196. marker = f"{node.name}:{reconnect_attempt}:{round(backoff, 1)}"
  197. if marker not in self._logged_attempts:
  198. self._logged_attempts.add(marker)
  199. print(f"[edge] relay reconnect attempt name={node.name} addr={node.host}:{node.port} attempt={reconnect_attempt} backoff={backoff:.1f}s")
  200. reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=self.config.relay_open_timeout)
  201. connection = TcpRelayConnection(node=node, manager=self, reader=reader, writer=writer)
  202. sock = writer.get_extra_info("socket")
  203. if sock is not None and self.config.relay_tcp_nodelay:
  204. with contextlib.suppress(OSError):
  205. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  206. await connection.start()
  207. self.connections[node.name] = connection
  208. healthy_since = time.monotonic()
  209. failure_streak = 0
  210. reconnect_attempt = 1
  211. assert connection.closed_event is not None
  212. await connection.closed_event.wait()
  213. if healthy_since is not None:
  214. healthy_runtime = time.monotonic() - healthy_since
  215. if healthy_runtime >= self.config.relay_ping_interval + self.config.relay_ping_timeout:
  216. backoff = self.config.relay_reconnect_delay
  217. else:
  218. backoff = min(self.config.relay_reconnect_max_delay, max(self.config.relay_reconnect_delay, backoff * 2))
  219. else:
  220. backoff = min(self.config.relay_reconnect_max_delay, max(self.config.relay_reconnect_delay, backoff * 2))
  221. break
  222. except asyncio.CancelledError:
  223. raise
  224. except Exception as exc:
  225. failure_streak += 1
  226. if failure_streak >= self.config.relay_reconnect_attempts:
  227. cooldown = min(
  228. self.config.relay_reconnect_max_delay,
  229. max(self.config.relay_reconnect_delay, backoff * 2),
  230. )
  231. print(f"[edge] relay cooldown name={node.name} addr={node.host}:{node.port} failures={failure_streak} cooldown={cooldown:.1f}s")
  232. await asyncio.sleep(cooldown)
  233. backoff = min(self.config.relay_reconnect_max_delay, cooldown * 2)
  234. failure_streak = 0
  235. reconnect_attempt = 1
  236. continue
  237. if reconnect_attempt == 1 or failure_streak >= self.config.relay_reconnect_attempts:
  238. print(f"[edge] relay connect failed name={node.name} addr={node.host}:{node.port} attempt={reconnect_attempt} error={exc!r}")
  239. jitter = random.uniform(0, min(1.0, backoff * 0.2))
  240. await asyncio.sleep(backoff + jitter)
  241. backoff = min(self.config.relay_reconnect_max_delay, max(self.config.relay_reconnect_delay, backoff * 2))
  242. reconnect_attempt += 1
  243. def on_closed(self, connection: TcpRelayConnection) -> None:
  244. current = self.connections.get(connection.node.name)
  245. if current is connection:
  246. self.connections.pop(connection.node.name, None)
  247. def available(self) -> list[TcpRelayConnection]:
  248. chosen = {node.name for node in self.scheduler.choose()}
  249. preferred = [self.connections[name] for name in chosen if name in self.connections and not self.connections[name].closed]
  250. if preferred:
  251. return preferred
  252. return [conn for conn in self.connections.values() if not conn.closed]
  253. def snapshot(self) -> list[dict[str, object]]:
  254. data = self.scheduler.snapshot()
  255. online = {name for name, conn in self.connections.items() if not conn.closed}
  256. for item in data:
  257. item["online"] = item["name"] in online
  258. return data