Gogs 3 днів тому
батько
коміт
53bfb59f37
2 змінених файлів з 74 додано та 10 видалено
  1. 21 5
      edge_udp.py
  2. 53 5
      relay_server_udp.py

+ 21 - 5
edge_udp.py

@@ -23,6 +23,10 @@ UDP_FLOW_IDLE_CLEANUP_SEC = 30.0
 UDP_PACKET_CLIENT_MAP_LIMIT = 4096
 UDP_DIRECT_PENDING_LIMIT = 128
 UDP_SOCKET_BUFFER_BYTES = 1 << 20
+UDP_WARMUP_WINDOW_PACKETS = 8
+UDP_STABLE_WINNER_SWITCH_MISSES = 4
+UDP_WINNER_SWITCH_GRACE_SEC = 1.0
+UDP_WINNER_STALE_SEC = 1.5
 
 
 async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes:
@@ -119,6 +123,7 @@ class UdpFlowState:
     direct_pending_clients: dict[str, deque[tuple[int, tuple[str, int]]]] = field(default_factory=dict)
     last_probe_at: float = 0.0
     winner_miss_streak: int = 0
+    winner_stable_since: float = 0.0
     target_family: int = 0
     last_cleanup_at: float = 0.0
 
@@ -228,18 +233,24 @@ class UdpAssociateServer(asyncio.DatagramProtocol):
         if flow.winner_name is None:
             flow.winner_name = source_name
             flow.winner_miss_streak = 0
+            flow.winner_stable_since = now
             self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1
             self._log_udp_summary(force=True)
         elif flow.winner_name != source_name:
             flow.duplicate_responses += 1
             winner_last_seen = flow.path_last_seen.get(flow.winner_name, 0.0)
-            if winner_last_seen and now - winner_last_seen >= (self.edge.config.udp_failover_idle_ms / 1000):
+            if winner_last_seen and now - winner_last_seen >= max(
+                self.edge.config.udp_failover_idle_ms / 1000,
+                UDP_WINNER_SWITCH_GRACE_SEC,
+            ):
                 flow.winner_name = source_name
                 flow.winner_miss_streak = 0
+                flow.winner_stable_since = now
                 self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1
                 self._log_udp_summary(force=True)
         else:
             flow.winner_miss_streak = 0
+            flow.winner_stable_since = flow.winner_stable_since or now
         if flow.winner_name == source_name and target_addr is not None:
             if flow.packets_received == 1:
                 print(
@@ -507,17 +518,22 @@ class UdpEdge:
         active_direct_names = list(direct_names)
         active_links = links
         now = asyncio.get_running_loop().time()
-        warmup_mode = flow.packets_sent <= UDP_WARMUP_BROADCAST_PACKETS
+        warmup_mode = flow.packets_sent <= UDP_WARMUP_WINDOW_PACKETS
         shadow_probe = flow.winner_name is not None and now - flow.last_probe_at >= UDP_SHADOW_PROBE_INTERVAL_SEC
         if shadow_probe:
             flow.last_probe_at = now
         broadcast_mode = self.config.udp_always_broadcast or flow.winner_name is None or warmup_mode or shadow_probe
         if not broadcast_mode:
             winner_last_seen = flow.path_last_seen.get(flow.winner_name, 0.0) if flow.winner_name else 0.0
-            winner_stale = bool(winner_last_seen and now - winner_last_seen >= (self.config.udp_failover_idle_ms / 1000))
-            if not winner_stale:
+            winner_stale = bool(
+                winner_last_seen
+                and now - winner_last_seen >= max(self.config.udp_failover_idle_ms / 1000, UDP_WINNER_STALE_SEC)
+            )
+            if winner_stale:
                 flow.winner_miss_streak += 1
-            if winner_stale or flow.winner_miss_streak >= UDP_FAST_FAILOVER_MISSES:
+            else:
+                flow.winner_miss_streak = 0
+            if winner_stale and flow.winner_miss_streak >= UDP_STABLE_WINNER_SWITCH_MISSES:
                 flow.winner_name = None
                 flow.winner_miss_streak = 0
                 broadcast_mode = True

+ 53 - 5
relay_server_udp.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import asyncio
 import contextlib
+import time
 from dataclasses import dataclass, field
 
 from .logging_utils import log_print as print
@@ -27,6 +28,7 @@ class UdpRelaySession:
     host: str = ""
     port: int = 0
     family: int = 0
+    last_activity: float = 0.0
 
 
 @dataclass
@@ -37,13 +39,17 @@ class UdpRelayChannel:
     udp_sessions: dict[tuple[int, int], UdpRelaySession] = field(default_factory=dict)
     closed: bool = False
     send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
-    send_queue: asyncio.Queue[Frame | None] = field(default_factory=asyncio.Queue)
+    send_queue: asyncio.Queue[tuple[int, int] | None] = field(default_factory=lambda: asyncio.Queue(maxsize=1024))
+    pending_frames: dict[tuple[int, int], Frame] = field(default_factory=dict)
+    queued_keys: set[tuple[int, int]] = field(default_factory=set)
     send_task: asyncio.Task | None = None
+    cleanup_task: asyncio.Task | None = None
     _logged_sessions: set[tuple[int, int]] = field(default_factory=set)
 
     async def run(self) -> None:
         try:
             self.send_task = asyncio.create_task(self._send_loop())
+            self.cleanup_task = asyncio.create_task(self._cleanup_loop())
             auth = await read_frame(self.reader)
             if auth.kind != AUTH:
                 return
@@ -60,15 +66,30 @@ class UdpRelayChannel:
     def enqueue_send(self, frame: Frame) -> None:
         if self.closed:
             return
+        key = (frame.session_id, frame.stream_id)
+        self.pending_frames[key] = frame
+        if key in self.queued_keys:
+            return
+        if self.send_queue.full():
+            with contextlib.suppress(asyncio.QueueEmpty):
+                dropped_key = self.send_queue.get_nowait()
+            if dropped_key is not None:
+                self.queued_keys.discard(dropped_key)
+                self.pending_frames.pop(dropped_key, None)
         with contextlib.suppress(asyncio.QueueFull):
-            self.send_queue.put_nowait(frame)
+            self.send_queue.put_nowait(key)
+            self.queued_keys.add(key)
 
     async def _send_loop(self) -> None:
         try:
             while True:
-                frame = await self.send_queue.get()
-                if frame is None:
+                key = await self.send_queue.get()
+                if key is None:
                     break
+                self.queued_keys.discard(key)
+                frame = self.pending_frames.pop(key, None)
+                if frame is None:
+                    continue
                 ok = await self.safe_send(frame)
                 if not ok:
                     break
@@ -121,6 +142,7 @@ class UdpRelayChannel:
                 await self.safe_send(Frame(UDP_RECV, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
                 return
         if session.transport is not None:
+            session.last_activity = time.monotonic()
             with contextlib.suppress(Exception):
                 session.transport.sendto(payload)
 
@@ -131,11 +153,37 @@ class UdpRelayChannel:
         self._logged_sessions.add(key)
         print(f"[relay] udp reply session={session_id} stream={stream_id} bytes={size}")
 
+    async def _cleanup_loop(self) -> None:
+        try:
+            while True:
+                await asyncio.sleep(30)
+                if self.closed:
+                    return
+                now = time.monotonic()
+                expired = [
+                    key
+                    for key, session in self.udp_sessions.items()
+                    if session.last_activity and now - session.last_activity >= 120
+                ]
+                for key in expired:
+                    session = self.udp_sessions.pop(key, None)
+                    if session and session.transport:
+                        session.transport.close()
+        except asyncio.CancelledError:
+            pass
+
     async def close(self) -> None:
         if self.closed:
             return
         self.closed = True
-        self.send_queue.put_nowait(None)
+        with contextlib.suppress(asyncio.QueueFull):
+            self.send_queue.put_nowait(None)
+        if self.cleanup_task and self.cleanup_task is not asyncio.current_task():
+            self.cleanup_task.cancel()
+            with contextlib.suppress(Exception):
+                await self.cleanup_task
+        self.pending_frames.clear()
+        self.queued_keys.clear()
         for session in self.udp_sessions.values():
             if session.transport:
                 session.transport.close()