Ver código fonte

继续修改UDP多节点

Gogs 3 dias atrás
pai
commit
82f8f7acdf
6 arquivos alterados com 163 adições e 25 exclusões
  1. 5 1
      README.md
  2. 1 0
      config.json
  3. 4 0
      config.py
  4. 9 2
      relay_server.py
  5. 140 22
      socks_edge.py
  6. 4 0
      transparent_edge.py

+ 5 - 1
README.md

@@ -291,6 +291,8 @@ sudo /home/mynetspeeder/scripts/start-transparent.sh --enable-udp --capture-uid
 {
   "udp_redundancy": 1,
   "udp_direct_redundancy": 2,
+  "udp_direct_redundancy_v4": 2,
+  "udp_direct_redundancy_v6": 2,
   "udp_always_broadcast": true,
   "udp_copy_interval_ms": 8
 }
@@ -299,7 +301,9 @@ sudo /home/mynetspeeder/scripts/start-transparent.sh --enable-udp --capture-uid
 说明:
 
 - `udp_redundancy`:每个 UDP 包额外重复发送的次数
-- `udp_direct_redundancy`:本地 direct UDP 并发副本数
+- `udp_direct_redundancy`:UDP 默认本地 direct 并发副本数
+- `udp_direct_redundancy_v4`:可单独指定 IPv4 目标的 UDP direct 副本数
+- `udp_direct_redundancy_v6`:可单独指定 IPv6 目标的 UDP direct 副本数
 - `udp_always_broadcast`:即使已有 winner,后续包仍持续并发发往所有可用路径
 - `udp_copy_interval_ms`:多副本之间的间隔,单位毫秒
 

+ 1 - 0
config.json

@@ -4,6 +4,7 @@
   "direct_redundancy": 3,
   "direct_max_redundancy": 3,
   "direct_redundancy_v6": 3,
+  "udp_direct_redundancy": 2,
   "tcp_warmup_bytes": 1048576,
   "tcp_loser_grace_ms": 1500,
   "probe_interval": 3,

+ 4 - 0
config.py

@@ -42,6 +42,8 @@ class Config:
     direct_max_redundancy: int = 3
     udp_redundancy: int = 1
     udp_direct_redundancy: int = 2
+    udp_direct_redundancy_v4: int | None = None
+    udp_direct_redundancy_v6: int | None = None
     udp_always_broadcast: bool = True
     udp_copy_interval_ms: int = 8
     socks_host: str = "127.0.0.1"
@@ -74,6 +76,8 @@ class Config:
             direct_max_redundancy=max(1, raw.get("direct_max_redundancy", 3)),
             udp_redundancy=max(0, raw.get("udp_redundancy", 1)),
             udp_direct_redundancy=max(1, raw.get("udp_direct_redundancy", 2)),
+            udp_direct_redundancy_v4=raw.get("udp_direct_redundancy_v4"),
+            udp_direct_redundancy_v6=raw.get("udp_direct_redundancy_v6"),
             udp_always_broadcast=raw.get("udp_always_broadcast", True),
             udp_copy_interval_ms=max(0, raw.get("udp_copy_interval_ms", 8)),
             socks_host=raw.get("socks_host", "127.0.0.1"),

+ 9 - 2
relay_server.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import asyncio
 import contextlib
+import time
 from dataclasses import dataclass, field
 from typing import Dict
 
@@ -48,6 +49,7 @@ class RelayChannel:
     udp_sessions: Dict[tuple[int, int], UdpSession] = field(default_factory=dict)
     closed: bool = False
     send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
+    authed_at: float = 0.0
 
     async def run(self) -> None:
         peer = self.writer.get_extra_info("peername")
@@ -57,6 +59,7 @@ class RelayChannel:
             if auth.kind != AUTH or decode_json(auth.payload).get("token") != self.token:
                 raise PermissionError("invalid token")
             authed = True
+            self.authed_at = time.monotonic()
             print(f"[relay] auth ok peer={peer}")
             await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, b"ok"))
             while True:
@@ -64,12 +67,16 @@ class RelayChannel:
                 await self.handle(frame)
         except asyncio.IncompleteReadError:
             if authed:
-                print(f"[relay] disconnected peer={peer}")
+                lived = time.monotonic() - self.authed_at if self.authed_at else 0.0
+                if lived >= 5:
+                    print(f"[relay] disconnected peer={peer} lived={lived:.1f}s")
         except asyncio.CancelledError:
             pass
         except Exception as exc:
             if authed:
-                print(f"[relay] channel error peer={peer} error={exc!r}")
+                lived = time.monotonic() - self.authed_at if self.authed_at else 0.0
+                if lived >= 5:
+                    print(f"[relay] channel error peer={peer} lived={lived:.1f}s error={exc!r}")
         finally:
             await self.close()
 

+ 140 - 22
socks_edge.py

@@ -81,6 +81,9 @@ class UdpFlowState:
     candidate_names: tuple[str, ...] = ()
     link_streams: dict[str, int] = field(default_factory=dict)
     initialized_links: set[str] = field(default_factory=set)
+    direct_sockets: dict[str, socket.socket] = field(default_factory=dict)
+    direct_tasks: dict[str, asyncio.Task] = field(default_factory=dict)
+    direct_failures: set[str] = field(default_factory=set)
 
     def touch(self, now: float) -> None:
         self.last_activity = now
@@ -192,6 +195,7 @@ class UdpAssociateServer(asyncio.DatagramProtocol):
         self.edge = edge
         self.transport: asyncio.DatagramTransport | None = None
         self.client_addr = None
+        self.associate_peer = None
         self.packet_counter = itertools.count(1)
         self.client_flows: dict[tuple[tuple[str, int], str, int], UdpFlowState] = {}
         self.flow_counter = itertools.count(1)
@@ -201,14 +205,21 @@ class UdpAssociateServer(asyncio.DatagramProtocol):
     def connection_made(self, transport) -> None:
         self.transport = transport
 
+    def register_associate(self, peer) -> None:
+        peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer)
+        if self.associate_peer != peer_text:
+            print(f"[edge] udp associate peer={peer_text}")
+        self.associate_peer = peer_text
+
     def datagram_received(self, data: bytes, addr) -> None:
         if len(data) < 10:
             return
         if self.client_addr is None:
             self.client_addr = addr
             print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}")
-        if addr != self.client_addr:
-            return
+        elif addr != self.client_addr:
+            print(f"[edge] udp client rebound old={self.client_addr[0]}:{self.client_addr[1]} new={addr[0]}:{addr[1]}")
+            self._reset_client_state(addr)
         host, port, payload = self._parse_socks_udp(data)
         loop = asyncio.get_running_loop()
         now = loop.time()
@@ -231,38 +242,60 @@ class UdpAssociateServer(asyncio.DatagramProtocol):
         asyncio.create_task(self.edge.forward_udp(flow, payload, packet_id, self))
         self._log_udp_summary()
 
+    def _reset_client_state(self, addr) -> None:
+        for flow in list(self.client_flows.values()):
+            for task in list(flow.direct_tasks.values()):
+                task.cancel()
+            for sock in list(flow.direct_sockets.values()):
+                with contextlib.suppress(Exception):
+                    sock.close()
+            for stream_id in list(flow.link_streams.values()):
+                self.edge.udp_flow_sessions.pop((flow.flow_id, stream_id), None)
+        self.client_flows.clear()
+        self.client_addr = addr
+        self.win_counts.clear()
+        print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}")
+
     async def handle_from_relay(self, frame: Frame, link: RelayLink) -> None:
         if self.transport is None or self.client_addr is None:
             return
         flow = self.edge.udp_flow_sessions.get((frame.session_id, frame.stream_id))
         if flow is None:
             return
-        flow_id = flow.flow_id
-        host = flow.target_host
-        port = flow.target_port
-        packet = self._build_socks_udp(host, port, frame.payload)
+        await self._deliver_flow_packet(flow, frame.packet_id, frame.payload, link.node.name)
+
+    async def handle_from_direct(self, flow: UdpFlowState, path_name: str, payload: bytes) -> None:
+        if self.transport is None or self.client_addr is None:
+            return
+        await self._deliver_flow_packet(flow, 0, payload, path_name)
+
+    async def _deliver_flow_packet(self, flow: UdpFlowState, packet_id: int, payload: bytes, source_name: str) -> None:
+        if self.transport is None or self.client_addr is None:
+            return
+        packet = self._build_socks_udp(flow.target_host, flow.target_port, payload)
         winner_log = ""
         now = asyncio.get_running_loop().time()
         flow.touch(now)
         flow.packets_received += 1
         if flow.winner_name is None:
-            flow.winner_name = link.node.name
-            self.win_counts[link.node.name] = self.win_counts.get(link.node.name, 0) + 1
+            flow.winner_name = source_name
+            self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1
             relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.win_counts.items())) or "none"
             mode = "redundant" if self.edge.config.udp_redundancy > 0 else "single"
             print(
-                f"[edge] udp flow={flow.flow_id} winner={link.node.name} "
-                f"target={flow.target_host}:{flow.target_port} mode={mode} candidates={len(flow.candidate_names) or len(self.edge.links)}"
+                f"[edge] udp flow={flow.flow_id} winner={source_name} "
+                f"target={flow.target_host}:{flow.target_port} mode={mode} candidates={len(flow.candidate_names)}"
             )
             print(f"[edge] udp win relay_breakdown={relay_detail}")
-        elif flow.winner_name != link.node.name:
+        elif flow.winner_name != source_name:
             flow.duplicate_responses += 1
-            winner_log = f" duplicate=1 winner={flow.winner_name} from={link.node.name}"
+            winner_log = f" duplicate=1 winner={flow.winner_name} from={source_name}"
         print(
-            f"[edge] udp send flow={flow_id or 'unknown'} packet_id={frame.packet_id} "
-            f"target={host}:{port} size={len(frame.payload)} relay={link.node.name}{winner_log}"
+            f"[edge] udp send flow={flow.flow_id} packet_id={packet_id or 'direct'} "
+            f"target={flow.target_host}:{flow.target_port} size={len(payload)} relay={source_name}{winner_log}"
         )
-        self.transport.sendto(packet, self.client_addr)
+        if flow.winner_name == source_name:
+            self.transport.sendto(packet, self.client_addr)
         self._log_udp_summary()
 
     def set_flow_candidates(self, flow: UdpFlowState, candidate_names: tuple[str, ...]) -> None:
@@ -387,19 +420,92 @@ class SocksEdge:
         ordered = sorted(online, key=lambda link: self.scheduler.scores.get(link.node.name).score if link.node.name in self.scheduler.scores else 999999.0)
         return ordered
 
+    def _udp_direct_redundancy_for_target(self, target_host: str) -> int:
+        base = self.config.udp_direct_redundancy
+        if ":" in target_host and self.config.udp_direct_redundancy_v6 is not None:
+            base = self.config.udp_direct_redundancy_v6
+        elif ":" not in target_host and self.config.udp_direct_redundancy_v4 is not None:
+            base = self.config.udp_direct_redundancy_v4
+        return max(1, base)
+
+    async def _ensure_udp_direct_paths(self, flow: UdpFlowState, udp_server: UdpAssociateServer) -> None:
+        target_count = self._udp_direct_redundancy_for_target(flow.target_host)
+        for index in range(target_count):
+            name = f"direct-{index + 1}" if target_count > 1 else "direct"
+            if name in flow.direct_sockets or name in flow.direct_failures:
+                continue
+            try:
+                family = socket.AF_INET6 if ":" in flow.target_host else socket.AF_INET
+                sock = socket.socket(family, socket.SOCK_DGRAM)
+                sock.setblocking(False)
+                await asyncio.get_running_loop().sock_connect(sock, (flow.target_host, flow.target_port))
+                flow.direct_sockets[name] = sock
+                flow.direct_tasks[name] = asyncio.create_task(self._pump_udp_direct(flow, name, sock, udp_server))
+            except Exception as exc:
+                flow.direct_failures.add(name)
+                print(f"[edge] udp direct open error flow={flow.flow_id} path={name} target={flow.target_host}:{flow.target_port} error={exc!r}")
+
+    async def _pump_udp_direct(self, flow: UdpFlowState, path_name: str, sock: socket.socket, udp_server: UdpAssociateServer) -> None:
+        loop = asyncio.get_running_loop()
+        try:
+            while True:
+                data = await loop.sock_recv(sock, 65535)
+                if not data:
+                    break
+                await udp_server.handle_from_direct(flow, path_name, data)
+        except Exception:
+            pass
+        finally:
+            flow.direct_tasks.pop(path_name, None)
+            flow.direct_sockets.pop(path_name, None)
+            with contextlib.suppress(Exception):
+                sock.close()
+
     async def forward_udp(self, flow: UdpFlowState, payload: bytes, packet_id: int, udp_server: UdpAssociateServer) -> None:
+        await self._ensure_udp_direct_paths(flow, udp_server)
         meta = encode_json({"host": flow.target_host, "port": flow.target_port})
         links = self._selected_udp_links()
-        link_names = ",".join(link.node.name for link in links) or "none"
-        udp_server.set_flow_candidates(flow, tuple(link.node.name for link in links))
+        direct_names = tuple(name for name in sorted(flow.direct_sockets))
+        relay_names = tuple(link.node.name for link in links)
+        candidate_names = direct_names + relay_names
+        link_names = ",".join(candidate_names) or "none"
+        udp_server.set_flow_candidates(flow, candidate_names)
         print(f"[edge] udp forward packet_id={packet_id} target={flow.target_host}:{flow.target_port} size={len(payload)} links={link_names}")
-        if not links:
+        if not candidate_names:
             udp_server.note_unsent(flow, packet_id)
             return
-        active_links = links if self.config.udp_always_broadcast or flow.winner_name is None else [link for link in links if link.node.name == flow.winner_name]
-        active_links = active_links or links[:1]
+        active_direct_names = list(direct_names)
+        active_links = links
+        if not (self.config.udp_always_broadcast or flow.winner_name is None):
+            active_direct_names = [name for name in active_direct_names if name == flow.winner_name]
+            active_links = [link for link in active_links if link.node.name == flow.winner_name]
+        if not active_direct_names and not active_links:
+            if direct_names:
+                active_direct_names = [direct_names[0]]
+            elif links:
+                active_links = links[:1]
         copies = max(1, self.config.udp_redundancy + 1)
+        sent_any = False
         for attempt in range(copies):
+            for path_name in active_direct_names:
+                sock = flow.direct_sockets.get(path_name)
+                if sock is None:
+                    continue
+                try:
+                    await asyncio.get_running_loop().sock_sendall(sock, payload)
+                    sent_any = True
+                except Exception as exc:
+                    flow.direct_failures.add(path_name)
+                    flow.direct_sockets.pop(path_name, None)
+                    task = flow.direct_tasks.pop(path_name, None)
+                    if task is not None:
+                        task.cancel()
+                    with contextlib.suppress(Exception):
+                        sock.close()
+                    print(
+                        f"[edge] udp send error flow={flow.flow_id} packet_id={packet_id} "
+                        f"relay={path_name} error={exc!r}"
+                    )
             for link in active_links:
                 stream_id = flow.link_streams.get(link.node.name)
                 if stream_id is None:
@@ -409,10 +515,21 @@ class SocksEdge:
                 include_meta = link.node.name not in flow.initialized_links
                 body = (meta + payload) if include_meta else payload
                 meta_len = len(meta) if include_meta else 0
-                await link.send(Frame(UDP_SEND, flow.flow_id, stream_id, 0, meta_len, body))
-                flow.initialized_links.add(link.node.name)
+                try:
+                    await link.send(Frame(UDP_SEND, flow.flow_id, stream_id, 0, meta_len, body))
+                    flow.initialized_links.add(link.node.name)
+                    sent_any = True
+                except Exception as exc:
+                    flow.link_streams.pop(link.node.name, None)
+                    self.udp_flow_sessions.pop((flow.flow_id, stream_id), None)
+                    print(
+                        f"[edge] udp send error flow={flow.flow_id} packet_id={packet_id} "
+                        f"relay={link.node.name} error={exc!r}"
+                    )
             if attempt + 1 < copies and self.config.udp_copy_interval_ms > 0:
                 await asyncio.sleep(self.config.udp_copy_interval_ms / 1000)
+        if not sent_any:
+            udp_server.note_unsent(flow, packet_id)
 
     async def _handshake(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer) -> tuple[str, int, bool]:
         version, methods_len = (await read_exact(reader, 2))
@@ -440,6 +557,7 @@ class SocksEdge:
             return host, port, False
         if command == 3 and self.udp_server and self.udp_server.transport:
             bind_host, bind_port = self.udp_server.transport.get_extra_info("sockname")[:2]
+            self.udp_server.register_associate(peer)
             print(f"[edge] socks handshake peer={peer_text} command=udp_associate target={host}:{port} bind={bind_host}:{bind_port}")
             writer.write(b"\x05\x00\x00\x01" + socket.inet_aton(bind_host) + struct.pack("!H", bind_port))
             await writer.drain()

+ 4 - 0
transparent_edge.py

@@ -687,6 +687,10 @@ class TransparentEdge:
 
     def _build_udp_direct_paths(self, target: TargetAddress, flow_id: int) -> list[BasePath]:
         count = max(1, self.config.udp_direct_redundancy)
+        if target.family == socket.AF_INET6 and self.config.udp_direct_redundancy_v6 is not None:
+            count = max(1, self.config.udp_direct_redundancy_v6)
+        elif target.family == socket.AF_INET and self.config.udp_direct_redundancy_v4 is not None:
+            count = max(1, self.config.udp_direct_redundancy_v4)
         return [
             DirectUdpPath(
                 name=f"direct-{index + 1}" if count > 1 else "direct",