|
|
@@ -81,6 +81,9 @@ class UdpFlowState:
|
|
|
candidate_names: tuple[str, ...] = ()
|
|
|
link_streams: dict[str, int] = field(default_factory=dict)
|
|
|
initialized_links: set[str] = field(default_factory=set)
|
|
|
+ direct_sockets: dict[str, socket.socket] = field(default_factory=dict)
|
|
|
+ direct_tasks: dict[str, asyncio.Task] = field(default_factory=dict)
|
|
|
+ direct_failures: set[str] = field(default_factory=set)
|
|
|
|
|
|
def touch(self, now: float) -> None:
|
|
|
self.last_activity = now
|
|
|
@@ -192,6 +195,7 @@ class UdpAssociateServer(asyncio.DatagramProtocol):
|
|
|
self.edge = edge
|
|
|
self.transport: asyncio.DatagramTransport | None = None
|
|
|
self.client_addr = None
|
|
|
+ self.associate_peer = None
|
|
|
self.packet_counter = itertools.count(1)
|
|
|
self.client_flows: dict[tuple[tuple[str, int], str, int], UdpFlowState] = {}
|
|
|
self.flow_counter = itertools.count(1)
|
|
|
@@ -201,14 +205,21 @@ class UdpAssociateServer(asyncio.DatagramProtocol):
|
|
|
def connection_made(self, transport) -> None:
|
|
|
self.transport = transport
|
|
|
|
|
|
+ def register_associate(self, peer) -> None:
|
|
|
+ peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer)
|
|
|
+ if self.associate_peer != peer_text:
|
|
|
+ print(f"[edge] udp associate peer={peer_text}")
|
|
|
+ self.associate_peer = peer_text
|
|
|
+
|
|
|
def datagram_received(self, data: bytes, addr) -> None:
|
|
|
if len(data) < 10:
|
|
|
return
|
|
|
if self.client_addr is None:
|
|
|
self.client_addr = addr
|
|
|
print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}")
|
|
|
- if addr != self.client_addr:
|
|
|
- return
|
|
|
+ elif addr != self.client_addr:
|
|
|
+ print(f"[edge] udp client rebound old={self.client_addr[0]}:{self.client_addr[1]} new={addr[0]}:{addr[1]}")
|
|
|
+ self._reset_client_state(addr)
|
|
|
host, port, payload = self._parse_socks_udp(data)
|
|
|
loop = asyncio.get_running_loop()
|
|
|
now = loop.time()
|
|
|
@@ -231,38 +242,60 @@ class UdpAssociateServer(asyncio.DatagramProtocol):
|
|
|
asyncio.create_task(self.edge.forward_udp(flow, payload, packet_id, self))
|
|
|
self._log_udp_summary()
|
|
|
|
|
|
+ def _reset_client_state(self, addr) -> None:
|
|
|
+ for flow in list(self.client_flows.values()):
|
|
|
+ for task in list(flow.direct_tasks.values()):
|
|
|
+ task.cancel()
|
|
|
+ for sock in list(flow.direct_sockets.values()):
|
|
|
+ with contextlib.suppress(Exception):
|
|
|
+ sock.close()
|
|
|
+ for stream_id in list(flow.link_streams.values()):
|
|
|
+ self.edge.udp_flow_sessions.pop((flow.flow_id, stream_id), None)
|
|
|
+ self.client_flows.clear()
|
|
|
+ self.client_addr = addr
|
|
|
+ self.win_counts.clear()
|
|
|
+ print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}")
|
|
|
+
|
|
|
async def handle_from_relay(self, frame: Frame, link: RelayLink) -> None:
|
|
|
if self.transport is None or self.client_addr is None:
|
|
|
return
|
|
|
flow = self.edge.udp_flow_sessions.get((frame.session_id, frame.stream_id))
|
|
|
if flow is None:
|
|
|
return
|
|
|
- flow_id = flow.flow_id
|
|
|
- host = flow.target_host
|
|
|
- port = flow.target_port
|
|
|
- packet = self._build_socks_udp(host, port, frame.payload)
|
|
|
+ await self._deliver_flow_packet(flow, frame.packet_id, frame.payload, link.node.name)
|
|
|
+
|
|
|
+ async def handle_from_direct(self, flow: UdpFlowState, path_name: str, payload: bytes) -> None:
|
|
|
+ if self.transport is None or self.client_addr is None:
|
|
|
+ return
|
|
|
+ await self._deliver_flow_packet(flow, 0, payload, path_name)
|
|
|
+
|
|
|
+ async def _deliver_flow_packet(self, flow: UdpFlowState, packet_id: int, payload: bytes, source_name: str) -> None:
|
|
|
+ if self.transport is None or self.client_addr is None:
|
|
|
+ return
|
|
|
+ packet = self._build_socks_udp(flow.target_host, flow.target_port, payload)
|
|
|
winner_log = ""
|
|
|
now = asyncio.get_running_loop().time()
|
|
|
flow.touch(now)
|
|
|
flow.packets_received += 1
|
|
|
if flow.winner_name is None:
|
|
|
- flow.winner_name = link.node.name
|
|
|
- self.win_counts[link.node.name] = self.win_counts.get(link.node.name, 0) + 1
|
|
|
+ flow.winner_name = source_name
|
|
|
+ self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1
|
|
|
relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.win_counts.items())) or "none"
|
|
|
mode = "redundant" if self.edge.config.udp_redundancy > 0 else "single"
|
|
|
print(
|
|
|
- f"[edge] udp flow={flow.flow_id} winner={link.node.name} "
|
|
|
- f"target={flow.target_host}:{flow.target_port} mode={mode} candidates={len(flow.candidate_names) or len(self.edge.links)}"
|
|
|
+ f"[edge] udp flow={flow.flow_id} winner={source_name} "
|
|
|
+ f"target={flow.target_host}:{flow.target_port} mode={mode} candidates={len(flow.candidate_names)}"
|
|
|
)
|
|
|
print(f"[edge] udp win relay_breakdown={relay_detail}")
|
|
|
- elif flow.winner_name != link.node.name:
|
|
|
+ elif flow.winner_name != source_name:
|
|
|
flow.duplicate_responses += 1
|
|
|
- winner_log = f" duplicate=1 winner={flow.winner_name} from={link.node.name}"
|
|
|
+ winner_log = f" duplicate=1 winner={flow.winner_name} from={source_name}"
|
|
|
print(
|
|
|
- f"[edge] udp send flow={flow_id or 'unknown'} packet_id={frame.packet_id} "
|
|
|
- f"target={host}:{port} size={len(frame.payload)} relay={link.node.name}{winner_log}"
|
|
|
+ f"[edge] udp send flow={flow.flow_id} packet_id={packet_id or 'direct'} "
|
|
|
+ f"target={flow.target_host}:{flow.target_port} size={len(payload)} relay={source_name}{winner_log}"
|
|
|
)
|
|
|
- self.transport.sendto(packet, self.client_addr)
|
|
|
+ if flow.winner_name == source_name:
|
|
|
+ self.transport.sendto(packet, self.client_addr)
|
|
|
self._log_udp_summary()
|
|
|
|
|
|
def set_flow_candidates(self, flow: UdpFlowState, candidate_names: tuple[str, ...]) -> None:
|
|
|
@@ -387,19 +420,92 @@ class SocksEdge:
|
|
|
ordered = sorted(online, key=lambda link: self.scheduler.scores.get(link.node.name).score if link.node.name in self.scheduler.scores else 999999.0)
|
|
|
return ordered
|
|
|
|
|
|
+ def _udp_direct_redundancy_for_target(self, target_host: str) -> int:
|
|
|
+ base = self.config.udp_direct_redundancy
|
|
|
+ if ":" in target_host and self.config.udp_direct_redundancy_v6 is not None:
|
|
|
+ base = self.config.udp_direct_redundancy_v6
|
|
|
+ elif ":" not in target_host and self.config.udp_direct_redundancy_v4 is not None:
|
|
|
+ base = self.config.udp_direct_redundancy_v4
|
|
|
+ return max(1, base)
|
|
|
+
|
|
|
+ async def _ensure_udp_direct_paths(self, flow: UdpFlowState, udp_server: UdpAssociateServer) -> None:
|
|
|
+ target_count = self._udp_direct_redundancy_for_target(flow.target_host)
|
|
|
+ for index in range(target_count):
|
|
|
+ name = f"direct-{index + 1}" if target_count > 1 else "direct"
|
|
|
+ if name in flow.direct_sockets or name in flow.direct_failures:
|
|
|
+ continue
|
|
|
+ try:
|
|
|
+ family = socket.AF_INET6 if ":" in flow.target_host else socket.AF_INET
|
|
|
+ sock = socket.socket(family, socket.SOCK_DGRAM)
|
|
|
+ sock.setblocking(False)
|
|
|
+ await asyncio.get_running_loop().sock_connect(sock, (flow.target_host, flow.target_port))
|
|
|
+ flow.direct_sockets[name] = sock
|
|
|
+ flow.direct_tasks[name] = asyncio.create_task(self._pump_udp_direct(flow, name, sock, udp_server))
|
|
|
+ except Exception as exc:
|
|
|
+ flow.direct_failures.add(name)
|
|
|
+ print(f"[edge] udp direct open error flow={flow.flow_id} path={name} target={flow.target_host}:{flow.target_port} error={exc!r}")
|
|
|
+
|
|
|
+ async def _pump_udp_direct(self, flow: UdpFlowState, path_name: str, sock: socket.socket, udp_server: UdpAssociateServer) -> None:
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ try:
|
|
|
+ while True:
|
|
|
+ data = await loop.sock_recv(sock, 65535)
|
|
|
+ if not data:
|
|
|
+ break
|
|
|
+ await udp_server.handle_from_direct(flow, path_name, data)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ finally:
|
|
|
+ flow.direct_tasks.pop(path_name, None)
|
|
|
+ flow.direct_sockets.pop(path_name, None)
|
|
|
+ with contextlib.suppress(Exception):
|
|
|
+ sock.close()
|
|
|
+
|
|
|
async def forward_udp(self, flow: UdpFlowState, payload: bytes, packet_id: int, udp_server: UdpAssociateServer) -> None:
|
|
|
+ await self._ensure_udp_direct_paths(flow, udp_server)
|
|
|
meta = encode_json({"host": flow.target_host, "port": flow.target_port})
|
|
|
links = self._selected_udp_links()
|
|
|
- link_names = ",".join(link.node.name for link in links) or "none"
|
|
|
- udp_server.set_flow_candidates(flow, tuple(link.node.name for link in links))
|
|
|
+ direct_names = tuple(name for name in sorted(flow.direct_sockets))
|
|
|
+ relay_names = tuple(link.node.name for link in links)
|
|
|
+ candidate_names = direct_names + relay_names
|
|
|
+ link_names = ",".join(candidate_names) or "none"
|
|
|
+ udp_server.set_flow_candidates(flow, candidate_names)
|
|
|
print(f"[edge] udp forward packet_id={packet_id} target={flow.target_host}:{flow.target_port} size={len(payload)} links={link_names}")
|
|
|
- if not links:
|
|
|
+ if not candidate_names:
|
|
|
udp_server.note_unsent(flow, packet_id)
|
|
|
return
|
|
|
- active_links = links if self.config.udp_always_broadcast or flow.winner_name is None else [link for link in links if link.node.name == flow.winner_name]
|
|
|
- active_links = active_links or links[:1]
|
|
|
+ active_direct_names = list(direct_names)
|
|
|
+ active_links = links
|
|
|
+ if not (self.config.udp_always_broadcast or flow.winner_name is None):
|
|
|
+ active_direct_names = [name for name in active_direct_names if name == flow.winner_name]
|
|
|
+ active_links = [link for link in active_links if link.node.name == flow.winner_name]
|
|
|
+ if not active_direct_names and not active_links:
|
|
|
+ if direct_names:
|
|
|
+ active_direct_names = [direct_names[0]]
|
|
|
+ elif links:
|
|
|
+ active_links = links[:1]
|
|
|
copies = max(1, self.config.udp_redundancy + 1)
|
|
|
+ sent_any = False
|
|
|
for attempt in range(copies):
|
|
|
+ for path_name in active_direct_names:
|
|
|
+ sock = flow.direct_sockets.get(path_name)
|
|
|
+ if sock is None:
|
|
|
+ continue
|
|
|
+ try:
|
|
|
+ await asyncio.get_running_loop().sock_sendall(sock, payload)
|
|
|
+ sent_any = True
|
|
|
+ except Exception as exc:
|
|
|
+ flow.direct_failures.add(path_name)
|
|
|
+ flow.direct_sockets.pop(path_name, None)
|
|
|
+ task = flow.direct_tasks.pop(path_name, None)
|
|
|
+ if task is not None:
|
|
|
+ task.cancel()
|
|
|
+ with contextlib.suppress(Exception):
|
|
|
+ sock.close()
|
|
|
+ print(
|
|
|
+ f"[edge] udp send error flow={flow.flow_id} packet_id={packet_id} "
|
|
|
+ f"relay={path_name} error={exc!r}"
|
|
|
+ )
|
|
|
for link in active_links:
|
|
|
stream_id = flow.link_streams.get(link.node.name)
|
|
|
if stream_id is None:
|
|
|
@@ -409,10 +515,21 @@ class SocksEdge:
|
|
|
include_meta = link.node.name not in flow.initialized_links
|
|
|
body = (meta + payload) if include_meta else payload
|
|
|
meta_len = len(meta) if include_meta else 0
|
|
|
- await link.send(Frame(UDP_SEND, flow.flow_id, stream_id, 0, meta_len, body))
|
|
|
- flow.initialized_links.add(link.node.name)
|
|
|
+ try:
|
|
|
+ await link.send(Frame(UDP_SEND, flow.flow_id, stream_id, 0, meta_len, body))
|
|
|
+ flow.initialized_links.add(link.node.name)
|
|
|
+ sent_any = True
|
|
|
+ except Exception as exc:
|
|
|
+ flow.link_streams.pop(link.node.name, None)
|
|
|
+ self.udp_flow_sessions.pop((flow.flow_id, stream_id), None)
|
|
|
+ print(
|
|
|
+ f"[edge] udp send error flow={flow.flow_id} packet_id={packet_id} "
|
|
|
+ f"relay={link.node.name} error={exc!r}"
|
|
|
+ )
|
|
|
if attempt + 1 < copies and self.config.udp_copy_interval_ms > 0:
|
|
|
await asyncio.sleep(self.config.udp_copy_interval_ms / 1000)
|
|
|
+ if not sent_any:
|
|
|
+ udp_server.note_unsent(flow, packet_id)
|
|
|
|
|
|
async def _handshake(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer) -> tuple[str, int, bool]:
|
|
|
version, methods_len = (await read_exact(reader, 2))
|
|
|
@@ -440,6 +557,7 @@ class SocksEdge:
|
|
|
return host, port, False
|
|
|
if command == 3 and self.udp_server and self.udp_server.transport:
|
|
|
bind_host, bind_port = self.udp_server.transport.get_extra_info("sockname")[:2]
|
|
|
+ self.udp_server.register_associate(peer)
|
|
|
print(f"[edge] socks handshake peer={peer_text} command=udp_associate target={host}:{port} bind={bind_host}:{bind_port}")
|
|
|
writer.write(b"\x05\x00\x00\x01" + socket.inet_aton(bind_host) + struct.pack("!H", bind_port))
|
|
|
await writer.drain()
|