|
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
import contextlib
|
|
|
+import time
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
from .logging_utils import log_print as print
|
|
|
@@ -27,6 +28,7 @@ class UdpRelaySession:
|
|
|
host: str = ""
|
|
|
port: int = 0
|
|
|
family: int = 0
|
|
|
+ last_activity: float = 0.0
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
@@ -37,13 +39,17 @@ class UdpRelayChannel:
|
|
|
udp_sessions: dict[tuple[int, int], UdpRelaySession] = field(default_factory=dict)
|
|
|
closed: bool = False
|
|
|
send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
|
|
- send_queue: asyncio.Queue[Frame | None] = field(default_factory=asyncio.Queue)
|
|
|
+ send_queue: asyncio.Queue[tuple[int, int] | None] = field(default_factory=lambda: asyncio.Queue(maxsize=1024))
|
|
|
+ pending_frames: dict[tuple[int, int], Frame] = field(default_factory=dict)
|
|
|
+ queued_keys: set[tuple[int, int]] = field(default_factory=set)
|
|
|
send_task: asyncio.Task | None = None
|
|
|
+ cleanup_task: asyncio.Task | None = None
|
|
|
_logged_sessions: set[tuple[int, int]] = field(default_factory=set)
|
|
|
|
|
|
async def run(self) -> None:
|
|
|
try:
|
|
|
self.send_task = asyncio.create_task(self._send_loop())
|
|
|
+ self.cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
|
auth = await read_frame(self.reader)
|
|
|
if auth.kind != AUTH:
|
|
|
return
|
|
|
@@ -60,15 +66,30 @@ class UdpRelayChannel:
|
|
|
def enqueue_send(self, frame: Frame) -> None:
|
|
|
if self.closed:
|
|
|
return
|
|
|
+ key = (frame.session_id, frame.stream_id)
|
|
|
+ self.pending_frames[key] = frame
|
|
|
+ if key in self.queued_keys:
|
|
|
+ return
|
|
|
+ if self.send_queue.full():
|
|
|
+ with contextlib.suppress(asyncio.QueueEmpty):
|
|
|
+ dropped_key = self.send_queue.get_nowait()
|
|
|
+ if dropped_key is not None:
|
|
|
+ self.queued_keys.discard(dropped_key)
|
|
|
+ self.pending_frames.pop(dropped_key, None)
|
|
|
with contextlib.suppress(asyncio.QueueFull):
|
|
|
- self.send_queue.put_nowait(frame)
|
|
|
+ self.send_queue.put_nowait(key)
|
|
|
+ self.queued_keys.add(key)
|
|
|
|
|
|
async def _send_loop(self) -> None:
|
|
|
try:
|
|
|
while True:
|
|
|
- frame = await self.send_queue.get()
|
|
|
- if frame is None:
|
|
|
+ key = await self.send_queue.get()
|
|
|
+ if key is None:
|
|
|
break
|
|
|
+ self.queued_keys.discard(key)
|
|
|
+ frame = self.pending_frames.pop(key, None)
|
|
|
+ if frame is None:
|
|
|
+ continue
|
|
|
ok = await self.safe_send(frame)
|
|
|
if not ok:
|
|
|
break
|
|
|
@@ -121,6 +142,7 @@ class UdpRelayChannel:
|
|
|
await self.safe_send(Frame(UDP_RECV, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
|
|
|
return
|
|
|
if session.transport is not None:
|
|
|
+ session.last_activity = time.monotonic()
|
|
|
with contextlib.suppress(Exception):
|
|
|
session.transport.sendto(payload)
|
|
|
|
|
|
@@ -131,11 +153,37 @@ class UdpRelayChannel:
|
|
|
self._logged_sessions.add(key)
|
|
|
print(f"[relay] udp reply session={session_id} stream={stream_id} bytes={size}")
|
|
|
|
|
|
+ async def _cleanup_loop(self) -> None:
|
|
|
+ try:
|
|
|
+ while True:
|
|
|
+ await asyncio.sleep(30)
|
|
|
+ if self.closed:
|
|
|
+ return
|
|
|
+ now = time.monotonic()
|
|
|
+ expired = [
|
|
|
+ key
|
|
|
+ for key, session in self.udp_sessions.items()
|
|
|
+ if session.last_activity and now - session.last_activity >= 120
|
|
|
+ ]
|
|
|
+ for key in expired:
|
|
|
+ session = self.udp_sessions.pop(key, None)
|
|
|
+ if session and session.transport:
|
|
|
+ session.transport.close()
|
|
|
+ except asyncio.CancelledError:
|
|
|
+ pass
|
|
|
+
|
|
|
async def close(self) -> None:
|
|
|
if self.closed:
|
|
|
return
|
|
|
self.closed = True
|
|
|
- self.send_queue.put_nowait(None)
|
|
|
+ with contextlib.suppress(asyncio.QueueFull):
|
|
|
+ self.send_queue.put_nowait(None)
|
|
|
+ if self.cleanup_task and self.cleanup_task is not asyncio.current_task():
|
|
|
+ self.cleanup_task.cancel()
|
|
|
+ with contextlib.suppress(Exception):
|
|
|
+ await self.cleanup_task
|
|
|
+ self.pending_frames.clear()
|
|
|
+ self.queued_keys.clear()
|
|
|
for session in self.udp_sessions.values():
|
|
|
if session.transport:
|
|
|
session.transport.close()
|