|
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
import asyncio
|
|
|
import contextlib
|
|
|
import itertools
|
|
|
+from collections import deque
|
|
|
import socket
|
|
|
import struct
|
|
|
from dataclasses import dataclass, field
|
|
|
@@ -13,6 +14,9 @@ from .protocol import AUTH, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS
|
|
|
from .scheduler import Scheduler
|
|
|
|
|
|
SOCKS_VERSION = 5
|
|
|
+UDP_WARMUP_BROADCAST_PACKETS = 6
|
|
|
+UDP_SHADOW_PROBE_INTERVAL_SEC = 0.25
|
|
|
+UDP_FAST_FAILOVER_MISSES = 3
|
|
|
|
|
|
|
|
|
async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes:
|
|
|
@@ -102,6 +106,10 @@ class UdpFlowState:
|
|
|
relay_failures: dict[str, int] = field(default_factory=dict)
|
|
|
relay_error_seen: set[str] = field(default_factory=set)
|
|
|
path_last_seen: dict[str, float] = field(default_factory=dict)
|
|
|
+ packet_client_addrs: dict[int, tuple[str, int]] = field(default_factory=dict)
|
|
|
+ direct_pending_clients: dict[str, deque[tuple[int, tuple[str, int]]]] = field(default_factory=dict)
|
|
|
+ last_probe_at: float = 0.0
|
|
|
+ winner_miss_streak: int = 0
|
|
|
|
|
|
def touch(self, now: float) -> None:
|
|
|
self.last_activity = now
|
|
|
@@ -263,9 +271,11 @@ class UdpAssociateServer(asyncio.DatagramProtocol):
|
|
|
)
|
|
|
self.client_flows[flow_key] = flow
|
|
|
flow.touch(now)
|
|
|
+ flow.client_addr = (addr[0], addr[1])
|
|
|
flow.packets_sent += 1
|
|
|
packet_id = next(self.packet_counter)
|
|
|
- asyncio.create_task(self.edge.forward_udp(flow, payload, packet_id, self))
|
|
|
+ flow.packet_client_addrs[packet_id] = (addr[0], addr[1])
|
|
|
+ asyncio.create_task(self.edge.forward_udp(flow, payload, packet_id, (addr[0], addr[1]), self))
|
|
|
self._log_udp_summary()
|
|
|
|
|
|
def _reset_client_state(self, addr) -> None:
|
|
|
@@ -286,12 +296,12 @@ class UdpAssociateServer(asyncio.DatagramProtocol):
|
|
|
return
|
|
|
await self._deliver_flow_packet(flow, frame.packet_id, frame.payload, link.node.name)
|
|
|
|
|
|
- async def handle_from_direct(self, flow: UdpFlowState, path_name: str, payload: bytes) -> None:
|
|
|
+ async def handle_from_direct(self, flow: UdpFlowState, path_name: str, payload: bytes, packet_id: int = 0, client_addr: tuple[str, int] | None = None) -> None:
|
|
|
if self.transport is None or self.client_addr is None:
|
|
|
return
|
|
|
- await self._deliver_flow_packet(flow, 0, payload, path_name)
|
|
|
+ await self._deliver_flow_packet(flow, packet_id, payload, path_name, client_addr)
|
|
|
|
|
|
- async def _deliver_flow_packet(self, flow: UdpFlowState, packet_id: int, payload: bytes, source_name: str) -> None:
|
|
|
+ async def _deliver_flow_packet(self, flow: UdpFlowState, packet_id: int, payload: bytes, source_name: str, client_addr: tuple[str, int] | None = None) -> None:
|
|
|
if self.transport is None or self.client_addr is None:
|
|
|
return
|
|
|
packet = self._build_socks_udp(flow.target_host, flow.target_port, payload)
|
|
|
@@ -299,8 +309,10 @@ class UdpAssociateServer(asyncio.DatagramProtocol):
|
|
|
flow.touch(now)
|
|
|
flow.path_last_seen[source_name] = now
|
|
|
flow.packets_received += 1
|
|
|
+ target_addr = client_addr or flow.packet_client_addrs.pop(packet_id, None) or flow.client_addr
|
|
|
if flow.winner_name is None:
|
|
|
flow.winner_name = source_name
|
|
|
+ flow.winner_miss_streak = 0
|
|
|
self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1
|
|
|
self._log_udp_summary(force=True)
|
|
|
elif flow.winner_name != source_name:
|
|
|
@@ -308,10 +320,14 @@ class UdpAssociateServer(asyncio.DatagramProtocol):
|
|
|
winner_last_seen = flow.path_last_seen.get(flow.winner_name, 0.0)
|
|
|
if winner_last_seen and now - winner_last_seen >= (self.edge.config.udp_failover_idle_ms / 1000):
|
|
|
flow.winner_name = source_name
|
|
|
+ flow.winner_miss_streak = 0
|
|
|
self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1
|
|
|
self._log_udp_summary(force=True)
|
|
|
+ else:
|
|
|
+ flow.winner_miss_streak = 0
|
|
|
if flow.winner_name == source_name:
|
|
|
- self.transport.sendto(packet, self.client_addr)
|
|
|
+ if target_addr is not None:
|
|
|
+ self.transport.sendto(packet, target_addr)
|
|
|
|
|
|
def set_flow_candidates(self, flow: UdpFlowState, candidate_names: tuple[str, ...]) -> None:
|
|
|
if not flow.candidate_names:
|
|
|
@@ -506,7 +522,12 @@ class SocksEdge:
|
|
|
data = await loop.sock_recv(sock, 65535)
|
|
|
if not data:
|
|
|
break
|
|
|
- await udp_server.handle_from_direct(flow, path_name, data)
|
|
|
+ pending = flow.direct_pending_clients.get(path_name)
|
|
|
+ packet_id = 0
|
|
|
+ client_addr = flow.client_addr
|
|
|
+ if pending:
|
|
|
+ packet_id, client_addr = pending.popleft()
|
|
|
+ await udp_server.handle_from_direct(flow, path_name, data, packet_id, client_addr)
|
|
|
except Exception:
|
|
|
pass
|
|
|
finally:
|
|
|
@@ -515,7 +536,7 @@ class SocksEdge:
|
|
|
with contextlib.suppress(Exception):
|
|
|
sock.close()
|
|
|
|
|
|
- async def forward_udp(self, flow: UdpFlowState, payload: bytes, packet_id: int, udp_server: UdpAssociateServer) -> None:
|
|
|
+ async def forward_udp(self, flow: UdpFlowState, payload: bytes, packet_id: int, client_addr: tuple[str, int], udp_server: UdpAssociateServer) -> None:
|
|
|
await self._ensure_udp_direct_paths(flow, udp_server)
|
|
|
meta = encode_json({"host": flow.target_host, "port": flow.target_port})
|
|
|
links = self._selected_udp_links()
|
|
|
@@ -528,12 +549,27 @@ class SocksEdge:
|
|
|
return
|
|
|
active_direct_names = list(direct_names)
|
|
|
active_links = links
|
|
|
- if not (self.config.udp_always_broadcast or flow.winner_name is None):
|
|
|
+ now = asyncio.get_running_loop().time()
|
|
|
+ warmup_mode = flow.packets_sent <= UDP_WARMUP_BROADCAST_PACKETS
|
|
|
+ shadow_probe = (
|
|
|
+ flow.winner_name is not None
|
|
|
+ and now - flow.last_probe_at >= UDP_SHADOW_PROBE_INTERVAL_SEC
|
|
|
+ )
|
|
|
+ if shadow_probe:
|
|
|
+ flow.last_probe_at = now
|
|
|
+ broadcast_mode = self.config.udp_always_broadcast or flow.winner_name is None or warmup_mode or shadow_probe
|
|
|
+ if not broadcast_mode:
|
|
|
winner_last_seen = flow.path_last_seen.get(flow.winner_name, 0.0) if flow.winner_name else 0.0
|
|
|
- if winner_last_seen and asyncio.get_running_loop().time() - winner_last_seen >= (self.config.udp_failover_idle_ms / 1000):
|
|
|
+ winner_stale = bool(winner_last_seen and now - winner_last_seen >= (self.config.udp_failover_idle_ms / 1000))
|
|
|
+ if not winner_stale:
|
|
|
+ flow.winner_miss_streak += 1
|
|
|
+ if winner_stale or flow.winner_miss_streak >= UDP_FAST_FAILOVER_MISSES:
|
|
|
flow.winner_name = None
|
|
|
- active_direct_names = [name for name in active_direct_names if name == flow.winner_name]
|
|
|
- active_links = [link for link in active_links if link.node.name == flow.winner_name]
|
|
|
+ flow.winner_miss_streak = 0
|
|
|
+ broadcast_mode = True
|
|
|
+ else:
|
|
|
+ active_direct_names = [name for name in active_direct_names if name == flow.winner_name]
|
|
|
+ active_links = [link for link in active_links if link.node.name == flow.winner_name]
|
|
|
if not active_direct_names and not active_links:
|
|
|
if direct_names:
|
|
|
active_direct_names = [direct_names[0]]
|
|
|
@@ -547,9 +583,14 @@ class SocksEdge:
|
|
|
if sock is None:
|
|
|
continue
|
|
|
try:
|
|
|
+ flow.direct_pending_clients.setdefault(path_name, deque()).append((packet_id, client_addr))
|
|
|
await asyncio.get_running_loop().sock_sendall(sock, payload)
|
|
|
sent_any = True
|
|
|
except Exception as exc:
|
|
|
+ pending = flow.direct_pending_clients.get(path_name)
|
|
|
+ if pending:
|
|
|
+ with contextlib.suppress(Exception):
|
|
|
+ pending.pop()
|
|
|
flow.direct_failures.add(path_name)
|
|
|
flow.direct_sockets.pop(path_name, None)
|
|
|
task = flow.direct_tasks.pop(path_name, None)
|