|
|
@@ -22,13 +22,19 @@ class RelayConnection:
|
|
|
writer: asyncio.StreamWriter
|
|
|
closed: bool = False
|
|
|
handlers: Dict[tuple[int, int], FrameHandler] = None
|
|
|
+ dispatch_tasks: Dict[tuple[int, int], asyncio.Task] = None
|
|
|
pump_task: asyncio.Task | None = None
|
|
|
keepalive_task: asyncio.Task | None = None
|
|
|
last_pong_at: float = 0.0
|
|
|
+ send_lock: asyncio.Lock | None = None
|
|
|
|
|
|
def __post_init__(self) -> None:
|
|
|
if self.handlers is None:
|
|
|
self.handlers = {}
|
|
|
+ if self.dispatch_tasks is None:
|
|
|
+ self.dispatch_tasks = {}
|
|
|
+ if self.send_lock is None:
|
|
|
+ self.send_lock = asyncio.Lock()
|
|
|
|
|
|
async def start(self) -> None:
|
|
|
print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}")
|
|
|
@@ -66,7 +72,9 @@ class RelayConnection:
|
|
|
continue
|
|
|
handler = self.handlers.get((frame.session_id, frame.stream_id))
|
|
|
if handler:
|
|
|
- await handler(self, frame)
|
|
|
+ self._dispatch_frame(frame, handler)
|
|
|
+ else:
|
|
|
+ print(f"[edge] relay frame dropped name={self.node.name} session={frame.session_id} stream={frame.stream_id} kind={frame.kind}")
|
|
|
except asyncio.IncompleteReadError:
|
|
|
print(f"[edge] relay disconnected name={self.node.name} eof=true")
|
|
|
except Exception as exc:
|
|
|
@@ -74,16 +82,46 @@ class RelayConnection:
|
|
|
finally:
|
|
|
await self.close()
|
|
|
|
|
|
+ def _dispatch_frame(self, frame: Frame, handler: FrameHandler) -> None:
|
|
|
+ key = (frame.session_id, frame.stream_id)
|
|
|
+ previous = self.dispatch_tasks.get(key)
|
|
|
+ task = asyncio.create_task(self._run_handler(key, frame, handler, previous))
|
|
|
+ self.dispatch_tasks[key] = task
|
|
|
+
|
|
|
+ async def _run_handler(self, key: tuple[int, int], frame: Frame, handler: FrameHandler, previous: asyncio.Task | None) -> None:
|
|
|
+ try:
|
|
|
+ if previous is not None:
|
|
|
+ with contextlib.suppress(Exception):
|
|
|
+ await previous
|
|
|
+ if self.closed:
|
|
|
+ return
|
|
|
+ await handler(self, frame)
|
|
|
+ except asyncio.CancelledError:
|
|
|
+ pass
|
|
|
+ except Exception:
|
|
|
+ if not self.closed:
|
|
|
+ await self.close()
|
|
|
+ finally:
|
|
|
+ if self.dispatch_tasks.get(key) is asyncio.current_task():
|
|
|
+ self.dispatch_tasks.pop(key, None)
|
|
|
+
|
|
|
async def send(self, frame: Frame) -> None:
|
|
|
if self.closed:
|
|
|
raise ConnectionError(f"relay closed: {self.node.name}")
|
|
|
- await write_frame(self.writer, frame)
|
|
|
+ assert self.send_lock is not None
|
|
|
+ async with self.send_lock:
|
|
|
+ if self.closed:
|
|
|
+ raise ConnectionError(f"relay closed: {self.node.name}")
|
|
|
+ await write_frame(self.writer, frame)
|
|
|
|
|
|
def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None:
|
|
|
self.handlers[(session_id, stream_id)] = handler
|
|
|
|
|
|
def unbind(self, session_id: int, stream_id: int) -> None:
|
|
|
self.handlers.pop((session_id, stream_id), None)
|
|
|
+ task = self.dispatch_tasks.pop((session_id, stream_id), None)
|
|
|
+ if task is not None:
|
|
|
+ task.cancel()
|
|
|
|
|
|
async def close(self) -> None:
|
|
|
if self.closed:
|
|
|
@@ -91,10 +129,17 @@ class RelayConnection:
|
|
|
self.closed = True
|
|
|
handlers = list(self.handlers.items())
|
|
|
self.handlers.clear()
|
|
|
+ dispatch_tasks = list(self.dispatch_tasks.values())
|
|
|
+ self.dispatch_tasks.clear()
|
|
|
self.manager.on_closed(self)
|
|
|
for (session_id, stream_id), handler in handlers:
|
|
|
with contextlib.suppress(Exception):
|
|
|
await handler(self, Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
|
|
|
+ for task in dispatch_tasks:
|
|
|
+ task.cancel()
|
|
|
+ for task in dispatch_tasks:
|
|
|
+ with contextlib.suppress(Exception):
|
|
|
+ await task
|
|
|
if self.pump_task and self.pump_task is not asyncio.current_task():
|
|
|
self.pump_task.cancel()
|
|
|
with contextlib.suppress(Exception):
|