| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- from __future__ import annotations
- import asyncio
- import contextlib
- import random
- import socket
- from dataclasses import dataclass
- import time
- from typing import Awaitable, Callable, Dict
- from .config import Config, RelayNode
- from .protocol import AUTH, PING, PONG, STATUS_OK, TCP_CLOSE, Frame, encode_json, read_frame, write_frame
- from .scheduler import Scheduler
- FrameHandler = Callable[["RelayConnection", Frame], Awaitable[None]]
- @dataclass
- class RelayConnection:
- node: RelayNode
- manager: "RelayManager"
- reader: asyncio.StreamReader
- writer: asyncio.StreamWriter
- closed: bool = False
- handlers: Dict[tuple[int, int], FrameHandler] = None
- dispatch_tasks: Dict[tuple[int, int], asyncio.Task] = None
- pump_task: asyncio.Task | None = None
- keepalive_task: asyncio.Task | None = None
- last_pong_at: float = 0.0
- send_lock: asyncio.Lock | None = None
- closed_event: asyncio.Event | None = None
- dropped_frames: Dict[int, int] = None
- dropped_report_task: asyncio.Task | None = None
- def __post_init__(self) -> None:
- if self.handlers is None:
- self.handlers = {}
- if self.dispatch_tasks is None:
- self.dispatch_tasks = {}
- if self.send_lock is None:
- self.send_lock = asyncio.Lock()
- if self.closed_event is None:
- self.closed_event = asyncio.Event()
- if self.dropped_frames is None:
- self.dropped_frames = {}
- async def start(self) -> None:
- print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}")
- await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
- frame = await read_frame(self.reader)
- if frame.kind != AUTH or frame.packet_id != STATUS_OK:
- raise ConnectionError(f"relay auth failed: {self.node.name}")
- print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port}")
- self.last_pong_at = time.monotonic()
- self.keepalive_task = asyncio.create_task(self._keepalive())
- self.pump_task = asyncio.create_task(self._pump())
- async def _keepalive(self) -> None:
- try:
- while not self.closed:
- await asyncio.sleep(self.manager.config.relay_ping_interval)
- if self.closed:
- break
- if self.last_pong_at and time.monotonic() - self.last_pong_at > (self.manager.config.relay_ping_interval + self.manager.config.relay_ping_timeout):
- print(f"[edge] relay health timeout name={self.node.name} addr={self.node.host}:{self.node.port} timeout={self.manager.config.relay_ping_timeout}")
- await self.close()
- break
- await self.send(Frame(PING, 0, 0, 0, 0, b""))
- except asyncio.CancelledError:
- pass
- except Exception:
- await self.close()
- async def _pump(self) -> None:
- try:
- while True:
- frame = await read_frame(self.reader)
- if frame.kind == PONG:
- self.last_pong_at = time.monotonic()
- continue
- handler = self.handlers.get((frame.session_id, frame.stream_id))
- if handler:
- self._dispatch_frame(frame, handler)
- else:
- self._record_dropped_frame(frame.kind)
- except asyncio.IncompleteReadError:
- print(f"[edge] relay disconnected name={self.node.name} eof=true")
- except Exception as exc:
- print(f"[edge] relay pump error name={self.node.name} error={exc!r}")
- finally:
- await self.close()
- def _dispatch_frame(self, frame: Frame, handler: FrameHandler) -> None:
- key = (frame.session_id, frame.stream_id)
- previous = self.dispatch_tasks.get(key)
- task = asyncio.create_task(self._run_handler(key, frame, handler, previous))
- self.dispatch_tasks[key] = task
- async def _run_handler(self, key: tuple[int, int], frame: Frame, handler: FrameHandler, previous: asyncio.Task | None) -> None:
- try:
- if previous is not None:
- with contextlib.suppress(Exception):
- await previous
- if self.closed:
- return
- await handler(self, frame)
- except asyncio.CancelledError:
- pass
- except Exception:
- if not self.closed:
- await self.close()
- finally:
- if self.dispatch_tasks.get(key) is asyncio.current_task():
- self.dispatch_tasks.pop(key, None)
- def _record_dropped_frame(self, kind: int) -> None:
- self.dropped_frames[kind] = self.dropped_frames.get(kind, 0) + 1
- if self.dropped_report_task is None or self.dropped_report_task.done():
- self.dropped_report_task = asyncio.create_task(self._report_dropped_frames())
- async def _report_dropped_frames(self) -> None:
- try:
- await asyncio.sleep(5)
- dropped = self.dropped_frames
- self.dropped_frames = {}
- if dropped:
- detail = ", ".join(f"kind={kind} count={count}" for kind, count in sorted(dropped.items()))
- print(f"[edge] relay frame dropped summary name={self.node.name} {detail}")
- except asyncio.CancelledError:
- pass
- async def send(self, frame: Frame) -> None:
- if self.closed:
- raise ConnectionError(f"relay closed: {self.node.name}")
- assert self.send_lock is not None
- async with self.send_lock:
- if self.closed:
- raise ConnectionError(f"relay closed: {self.node.name}")
- await write_frame(self.writer, frame)
- def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None:
- self.handlers[(session_id, stream_id)] = handler
- def unbind(self, session_id: int, stream_id: int) -> None:
- self.handlers.pop((session_id, stream_id), None)
- task = self.dispatch_tasks.pop((session_id, stream_id), None)
- if task is not None:
- task.cancel()
- async def close(self) -> None:
- if self.closed:
- return
- self.closed = True
- assert self.closed_event is not None
- self.closed_event.set()
- handlers = list(self.handlers.items())
- self.handlers.clear()
- dispatch_tasks = list(self.dispatch_tasks.values())
- self.dispatch_tasks.clear()
- self.manager.on_closed(self)
- for (session_id, stream_id), handler in handlers:
- with contextlib.suppress(Exception):
- await handler(self, Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
- for task in dispatch_tasks:
- task.cancel()
- for task in dispatch_tasks:
- with contextlib.suppress(Exception):
- await task
- if self.dropped_report_task and self.dropped_report_task is not asyncio.current_task():
- self.dropped_report_task.cancel()
- with contextlib.suppress(Exception):
- await self.dropped_report_task
- if self.pump_task and self.pump_task is not asyncio.current_task():
- self.pump_task.cancel()
- with contextlib.suppress(Exception):
- await self.pump_task
- if self.keepalive_task and self.keepalive_task is not asyncio.current_task():
- self.keepalive_task.cancel()
- with contextlib.suppress(Exception):
- await self.keepalive_task
- self.writer.close()
- with contextlib.suppress(Exception):
- await self.writer.wait_closed()
- class RelayManager:
- def __init__(self, config: Config) -> None:
- self.config = config
- self.scheduler = Scheduler(config)
- self.connections: Dict[str, RelayConnection] = {}
- self.tasks: list[asyncio.Task] = []
- async def start(self) -> None:
- await self.scheduler.start()
- for node in self.config.relays:
- self.tasks.append(asyncio.create_task(self._maintain(node)))
- async def _maintain(self, node: RelayNode) -> None:
- backoff = self.config.relay_reconnect_delay
- while True:
- current = self.connections.get(node.name)
- if current is not None and not current.closed:
- assert current.closed_event is not None
- await current.closed_event.wait()
- continue
- attempt = 1
- while True:
- try:
- print(f"[edge] relay reconnect attempt name={node.name} addr={node.host}:{node.port} attempt={attempt} backoff={backoff:.1f}s")
- reader, writer = await asyncio.wait_for(asyncio.open_connection(node.host, node.port), timeout=self.config.relay_open_timeout)
- connection = RelayConnection(node=node, manager=self, reader=reader, writer=writer)
- sock = writer.get_extra_info("socket")
- if sock is not None and self.config.relay_tcp_nodelay:
- with contextlib.suppress(OSError):
- sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
- await connection.start()
- self.connections[node.name] = connection
- backoff = self.config.relay_reconnect_delay
- assert connection.closed_event is not None
- await connection.closed_event.wait()
- print(f"[edge] relay supervisor noticed close name={node.name} addr={node.host}:{node.port}")
- break
- except asyncio.CancelledError:
- raise
- except Exception as exc:
- print(f"[edge] relay connect failed name={node.name} addr={node.host}:{node.port} attempt={attempt} error={exc!r}")
- jitter = random.uniform(0, min(1.0, backoff * 0.2))
- await asyncio.sleep(backoff + jitter)
- backoff = min(self.config.relay_reconnect_max_delay, max(self.config.relay_reconnect_delay, backoff * 2))
- attempt += 1
- def on_closed(self, connection: RelayConnection) -> None:
- current = self.connections.get(connection.node.name)
- if current is connection:
- self.connections.pop(connection.node.name, None)
- def available(self) -> list[RelayConnection]:
- chosen = {node.name for node in self.scheduler.choose()}
- preferred = [self.connections[name] for name in chosen if name in self.connections and not self.connections[name].closed]
- if preferred:
- return preferred
- return [conn for conn in self.connections.values() if not conn.closed]
- def snapshot(self) -> list[dict[str, object]]:
- data = self.scheduler.snapshot()
- online = {name for name, conn in self.connections.items() if not conn.closed}
- for item in data:
- item["online"] = item["name"] in online
- return data
|