Gogs 4 dni temu
rodzic
commit
03cf78497e
3 zmienionych plików z 53 dodań i 4 usunięć
  1. 1 1
      config.json
  2. 47 2
      relay_client.py
  3. 5 1
      relay_server.py

+ 1 - 1
config.json

@@ -1,7 +1,7 @@
 {
   "strategy": "top3",
   "redundancy": 3,
-  "direct_redundancy": 2,
+  "direct_redundancy": 3,
   "direct_max_redundancy": 3,
   "direct_redundancy_v6": 3,
   "tcp_warmup_bytes": 1048576,

+ 47 - 2
relay_client.py

@@ -22,13 +22,19 @@ class RelayConnection:
     writer: asyncio.StreamWriter
     closed: bool = False
     handlers: Dict[tuple[int, int], FrameHandler] = None
+    dispatch_tasks: Dict[tuple[int, int], asyncio.Task] = None
     pump_task: asyncio.Task | None = None
     keepalive_task: asyncio.Task | None = None
     last_pong_at: float = 0.0
+    send_lock: asyncio.Lock | None = None
 
     def __post_init__(self) -> None:
         if self.handlers is None:
             self.handlers = {}
+        if self.dispatch_tasks is None:
+            self.dispatch_tasks = {}
+        if self.send_lock is None:
+            self.send_lock = asyncio.Lock()
 
     async def start(self) -> None:
         print(f"[edge] connecting relay name={self.node.name} addr={self.node.host}:{self.node.port}")
@@ -66,7 +72,9 @@ class RelayConnection:
                     continue
                 handler = self.handlers.get((frame.session_id, frame.stream_id))
                 if handler:
-                    await handler(self, frame)
+                    self._dispatch_frame(frame, handler)
+                else:
+                    print(f"[edge] relay frame dropped name={self.node.name} session={frame.session_id} stream={frame.stream_id} kind={frame.kind}")
         except asyncio.IncompleteReadError:
             print(f"[edge] relay disconnected name={self.node.name} eof=true")
         except Exception as exc:
@@ -74,16 +82,46 @@ class RelayConnection:
         finally:
             await self.close()
 
+    def _dispatch_frame(self, frame: Frame, handler: FrameHandler) -> None:
+        key = (frame.session_id, frame.stream_id)
+        previous = self.dispatch_tasks.get(key)
+        task = asyncio.create_task(self._run_handler(key, frame, handler, previous))
+        self.dispatch_tasks[key] = task
+
+    async def _run_handler(self, key: tuple[int, int], frame: Frame, handler: FrameHandler, previous: asyncio.Task | None) -> None:
+        try:
+            if previous is not None:
+                with contextlib.suppress(Exception):
+                    await previous
+            if self.closed:
+                return
+            await handler(self, frame)
+        except asyncio.CancelledError:
+            pass
+        except Exception:
+            if not self.closed:
+                await self.close()
+        finally:
+            if self.dispatch_tasks.get(key) is asyncio.current_task():
+                self.dispatch_tasks.pop(key, None)
+
     async def send(self, frame: Frame) -> None:
         if self.closed:
             raise ConnectionError(f"relay closed: {self.node.name}")
-        await write_frame(self.writer, frame)
+        assert self.send_lock is not None
+        async with self.send_lock:
+            if self.closed:
+                raise ConnectionError(f"relay closed: {self.node.name}")
+            await write_frame(self.writer, frame)
 
     def bind(self, session_id: int, stream_id: int, handler: FrameHandler) -> None:
         self.handlers[(session_id, stream_id)] = handler
 
     def unbind(self, session_id: int, stream_id: int) -> None:
         self.handlers.pop((session_id, stream_id), None)
+        task = self.dispatch_tasks.pop((session_id, stream_id), None)
+        if task is not None:
+            task.cancel()
 
     async def close(self) -> None:
         if self.closed:
@@ -91,10 +129,17 @@ class RelayConnection:
         self.closed = True
         handlers = list(self.handlers.items())
         self.handlers.clear()
+        dispatch_tasks = list(self.dispatch_tasks.values())
+        self.dispatch_tasks.clear()
         self.manager.on_closed(self)
         for (session_id, stream_id), handler in handlers:
             with contextlib.suppress(Exception):
                 await handler(self, Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
+        for task in dispatch_tasks:
+            task.cancel()
+        for task in dispatch_tasks:
+            with contextlib.suppress(Exception):
+                await task
         if self.pump_task and self.pump_task is not asyncio.current_task():
             self.pump_task.cancel()
             with contextlib.suppress(Exception):

+ 5 - 1
relay_server.py

@@ -44,6 +44,7 @@ class RelayChannel:
     tcp_sessions: Dict[tuple[int, int], TcpSession] = field(default_factory=dict)
     udp_sessions: Dict[tuple[int, int], UdpSession] = field(default_factory=dict)
     closed: bool = False
+    send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
 
     async def run(self) -> None:
         peer = self.writer.get_extra_info("peername")
@@ -73,7 +74,10 @@ class RelayChannel:
         if self.closed:
             return False
         try:
-            await write_frame(self.writer, frame)
+            async with self.send_lock:
+                if self.closed:
+                    return False
+                await write_frame(self.writer, frame)
             return True
         except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError):
             return False