from __future__ import annotations import asyncio import contextlib import json import random import socket from dataclasses import dataclass import time from typing import Awaitable, Callable, Dict from .config_tcp import TcpConfig, TcpRelayNode from .logging_utils import log_print as print from .protocol import AUTH, PING, PONG, STATUS_OK, TCP_CLOSE, Frame, encode_json, read_frame, write_frame from .scheduler_tcp import TcpScheduler FrameHandler = Callable[["TcpRelayConnection", Frame], Awaitable[None]] @dataclass class TcpRelayConnection: node: TcpRelayNode manager: "TcpRelayManager" reader: asyncio.StreamReader writer: asyncio.StreamWriter closed: bool = False handlers: Dict[tuple[int, int], FrameHandler] = None dispatch_tasks: Dict[tuple[int, int], asyncio.Task] = None pump_task: asyncio.Task | None = None keepalive_task: asyncio.Task | None = None last_pong_at: float = 0.0 send_lock: asyncio.Lock | None = None closed_event: asyncio.Event | None = None dropped_frames: Dict[int, int] = None dropped_report_task: asyncio.Task | None = None def __post_init__(self) -> None: if self.handlers is None: self.handlers = {} if self.dispatch_tasks is None: self.dispatch_tasks = {} if self.send_lock is None: self.send_lock = asyncio.Lock() if self.closed_event is None: self.closed_event = asyncio.Event() if self.dropped_frames is None: self.dropped_frames = {} async def start(self) -> None: print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}") await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token}))) frame = await read_frame(self.reader) if frame.kind != AUTH or frame.packet_id != STATUS_OK: raise ConnectionError(f"relay auth failed: {self.node.name}") if frame.payload: with contextlib.suppress(Exception): json.loads(frame.payload.decode("utf-8")) self.last_pong_at = time.monotonic() self.keepalive_task = asyncio.create_task(self._keepalive()) self.pump_task = asyncio.create_task(self._pump()) print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port} mode=tcp") async def _keepalive(self) -> None: try: while not self.closed: await asyncio.sleep(self.manager.config.relay_ping_interval) if self.closed: break if self.last_pong_at and time.monotonic() - self.last_pong_at > (self.manager.config.relay_ping_interval + self.manager.config.relay_ping_timeout): print(f"[edge] relay health timeout name={self.node.name} addr={self.node.host}:{self.node.port} timeout={self.manager.config.relay_ping_timeout}") await self.close() break await self.send(Frame(PING, 0, 0, 0, 0, b"")) except asyncio.CancelledError: pass except Exception: await self.close() async def _pump(self) -> None: try: while True: frame = await read_frame(self.reader) if frame.kind == PONG: self.last_pong_at = time.monotonic() continue handler = self.handlers.get((frame.session_id, frame.stream_id)) if handler: self._dispatch_frame(frame, handler) else: self._record_dropped_frame(frame.kind) except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, OSError): pass except Exception: pass finally: await self.close() def _dispatch_frame(self, frame: Frame, handler: FrameHandler) -> None: key = (frame.session_id, frame.stream_id) previous = self.dispatch_tasks.get(key) task = asyncio.create_task(self._run_handler(key, frame, handler, previous)) self.dispatch_tasks[key] = task async def _run_handler(self, key: tuple[int, int], frame: Frame, handler: FrameHandler, previous: asyncio.Task | None) -> None: try: if previous is not None: with contextlib.suppress(Exception): await previous if self.closed: return await handler(self, frame) except asyncio.CancelledError: pass except Exception: if not self.closed: await self.close() finally: if self.dispatch_tasks.get(key) is asyncio.current_task(): self.dispatch_tasks.pop(key, None) def _record_dropped_frame(self, kind: int) -> None: self.dropped_frames[kind] = self.dropped_frames.get(kind, 0) + 1 if self.dropped_report_task is None or self.dropped_report_task.done(): self.dropped_report_task = asyncio.create_task(self._report_dropped_frames()) async def _report_dropped_frames(self) -> None: try: await asyncio.sleep(5) dropped = self.dropped_frames self.dropped_frames = {} if dropped: detail = ", ".join(f"kind={kind} count={count}" for kind, count in sorted(dropped.items())) print(f"[edge] relay frame dropped summary name={self.node.name} {detail}") except asyncio.CancelledError: pass async def send(self, frame: Frame) -> None: if self.closed: raise ConnectionError(f"relay closed: {self.node.name}") assert self.send_lock is not None async with self.send_lock: if self.closed: raise ConnectionError(f"relay closed: {self.node.name}") await write_frame(self.writer, frame) def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None: self.handlers[(session_id, stream_id)] = handler def unbind(self, session_id: int, stream_id: int) -> None: self.handlers.pop((session_id, stream_id), None) task = self.dispatch_tasks.pop((session_id, stream_id), None) if task is not None: task.cancel() async def close(self) -> None: if self.closed: return self.closed = True assert self.closed_event is not None self.closed_event.set() handlers = list(self.handlers.items()) self.handlers.clear() dispatch_tasks = list(self.dispatch_tasks.values()) self.dispatch_tasks.clear() self.manager.on_closed(self) for (session_id, stream_id), handler in handlers: with contextlib.suppress(Exception): await handler(self, Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b"")) for task in dispatch_tasks: task.cancel() for task in dispatch_tasks: with contextlib.suppress(Exception): await task if self.dropped_report_task and self.dropped_report_task is not asyncio.current_task(): self.dropped_report_task.cancel() with contextlib.suppress(Exception): await self.dropped_report_task if self.pump_task and self.pump_task is not asyncio.current_task(): self.pump_task.cancel() with contextlib.suppress(Exception): await self.pump_task if self.keepalive_task and self.keepalive_task is not asyncio.current_task(): self.keepalive_task.cancel() with contextlib.suppress(Exception): await self.keepalive_task self.writer.close() with contextlib.suppress(Exception): await self.writer.wait_closed() class TcpRelayManager: def __init__(self, config: TcpConfig) -> None: self.config = config self.scheduler = TcpScheduler(config) self.connections: Dict[str, TcpRelayConnection] = {} self.tasks: list[asyncio.Task] = [] self._logged_attempts: set[str] = set() async def start(self) -> None: await self.scheduler.start() for node in self.config.relays: self.tasks.append(asyncio.create_task(self._maintain(node))) async def _maintain(self, node: TcpRelayNode) -> None: backoff = self.config.relay_reconnect_delay reconnect_attempt = 1 failure_streak = 0 healthy_since: float | None = None while True: current = self.connections.get(node.name) if current is not None and not current.closed: assert current.closed_event is not None await current.closed_event.wait() continue while True: try: marker = f"{node.name}:{reconnect_attempt}:{round(backoff, 1)}" if marker not in self._logged_attempts: self._logged_attempts.add(marker) print(f"[edge] relay reconnect attempt name={node.name} addr={node.host}:{node.port} attempt={reconnect_attempt} backoff={backoff:.1f}s") reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=self.config.relay_open_timeout) connection = TcpRelayConnection(node=node, manager=self, reader=reader, writer=writer) sock = writer.get_extra_info("socket") if sock is not None and self.config.relay_tcp_nodelay: with contextlib.suppress(OSError): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) await connection.start() self.connections[node.name] = connection healthy_since = time.monotonic() failure_streak = 0 reconnect_attempt = 1 assert connection.closed_event is not None await connection.closed_event.wait() if healthy_since is not None: healthy_runtime = time.monotonic() - healthy_since if healthy_runtime >= self.config.relay_ping_interval + self.config.relay_ping_timeout: backoff = self.config.relay_reconnect_delay else: backoff = min(self.config.relay_reconnect_max_delay, max(self.config.relay_reconnect_delay, backoff * 2)) else: backoff = min(self.config.relay_reconnect_max_delay, max(self.config.relay_reconnect_delay, backoff * 2)) break except asyncio.CancelledError: raise except Exception as exc: failure_streak += 1 if failure_streak >= self.config.relay_reconnect_attempts: cooldown = min( self.config.relay_reconnect_max_delay, max(self.config.relay_reconnect_delay, backoff * 2), ) print(f"[edge] relay cooldown name={node.name} addr={node.host}:{node.port} failures={failure_streak} cooldown={cooldown:.1f}s") await asyncio.sleep(cooldown) backoff = min(self.config.relay_reconnect_max_delay, cooldown * 2) failure_streak = 0 reconnect_attempt = 1 continue if reconnect_attempt == 1 or failure_streak >= self.config.relay_reconnect_attempts: print(f"[edge] relay connect failed name={node.name} addr={node.host}:{node.port} attempt={reconnect_attempt} error={exc!r}") jitter = random.uniform(0, min(1.0, backoff * 0.2)) await asyncio.sleep(backoff + jitter) backoff = min(self.config.relay_reconnect_max_delay, max(self.config.relay_reconnect_delay, backoff * 2)) reconnect_attempt += 1 def on_closed(self, connection: TcpRelayConnection) -> None: current = self.connections.get(connection.node.name) if current is connection: self.connections.pop(connection.node.name, None) def available(self) -> list[TcpRelayConnection]: chosen = {node.name for node in self.scheduler.choose()} preferred = [self.connections[name] for name in chosen if name in self.connections and not self.connections[name].closed] if preferred: return preferred return [conn for conn in self.connections.values() if not conn.closed] def snapshot(self) -> list[dict[str, object]]: data = self.scheduler.snapshot() online = {name for name, conn in self.connections.items() if not conn.closed} for item in data: item["online"] = item["name"] in online return data