| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589 |
- from __future__ import annotations
- import asyncio
- import contextlib
- import itertools
- import socket
- import struct
- from dataclasses import dataclass, field
- from typing import Dict
- from .config import Config, RelayNode
- from .protocol import AUTH, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame
- from .scheduler import Scheduler
- SOCKS_VERSION = 5
- async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes:
- return await reader.readexactly(size)
- @dataclass(eq=False)
- class RelayLink:
- node: RelayNode
- reader: asyncio.StreamReader
- writer: asyncio.StreamWriter
- pump: asyncio.Task | None = None
- closed_event: asyncio.Event = field(default_factory=asyncio.Event)
- maintain_task: asyncio.Task | None = None
- tcp_sessions: Dict[tuple[int, int], "TcpRaceSession"] = field(default_factory=dict)
- udp_server: "UdpAssociateServer | None" = None
- closed: bool = False
- async def start(self) -> None:
- await write_frame(self.writer, Frame(AUTH, 0, 0, 0, 0, encode_json({"token": self.node.token})))
- frame = await read_frame(self.reader)
- if frame.kind != AUTH or frame.packet_id != STATUS_OK:
- raise ConnectionError(f"relay auth failed: {self.node.name}")
- self.closed = False
- self.closed_event.clear()
- self.pump = asyncio.create_task(self._pump())
- async def _pump(self) -> None:
- try:
- while True:
- frame = await read_frame(self.reader)
- key = (frame.session_id, frame.stream_id)
- if frame.kind in (TCP_STATUS, TCP_DATA, TCP_CLOSE):
- session = self.tcp_sessions.get(key)
- if session:
- await session.handle_frame(self, frame)
- elif frame.kind == UDP_RECV and self.udp_server:
- await self.udp_server.handle_from_relay(frame, self)
- except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, OSError):
- pass
- except Exception:
- pass
- finally:
- await self.close()
- async def send(self, frame: Frame) -> None:
- if self.closed:
- raise ConnectionError(f"relay closed: {self.node.name}")
- try:
- await write_frame(self.writer, frame)
- except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError) as exc:
- await self.close()
- raise ConnectionError(f"relay closed: {self.node.name}") from exc
- async def close(self) -> None:
- if self.closed:
- return
- self.closed = True
- self.closed_event.set()
- if self.pump and self.pump is not asyncio.current_task():
- self.pump.cancel()
- with contextlib.suppress(Exception):
- await self.pump
- self.writer.close()
- with contextlib.suppress(Exception):
- await self.writer.wait_closed()
- @dataclass
- class UdpFlowState:
- flow_id: int
- client_addr: tuple[str, int]
- target_host: str
- target_port: int
- created_at: float
- last_activity: float
- packets_sent: int = 0
- packets_received: int = 0
- duplicate_responses: int = 0
- winner_name: str | None = None
- candidate_names: tuple[str, ...] = ()
- link_streams: dict[str, int] = field(default_factory=dict)
- initialized_links: set[str] = field(default_factory=set)
- direct_sockets: dict[str, socket.socket] = field(default_factory=dict)
- direct_tasks: dict[str, asyncio.Task] = field(default_factory=dict)
- direct_failures: set[str] = field(default_factory=set)
- relay_failures: dict[str, int] = field(default_factory=dict)
- relay_error_seen: set[str] = field(default_factory=set)
- def touch(self, now: float) -> None:
- self.last_activity = now
- @dataclass
- class TcpRaceSession:
- session_id: int
- stream_id: int
- target_host: str
- target_port: int
- local_reader: asyncio.StreamReader
- local_writer: asyncio.StreamWriter
- links: list[RelayLink]
- warmup_bytes: int
- winning_link: RelayLink | None = None
- winner_name: str | None = None
- opened: int = 0
- open_errors: list[str] = field(default_factory=list)
- uplink_bytes: int = 0
- closed: bool = False
- open_event: asyncio.Event = field(default_factory=asyncio.Event)
- winner_event: asyncio.Event = field(default_factory=asyncio.Event)
- pump_task: asyncio.Task | None = None
- win_counts: Dict[str, int] = field(default_factory=dict)
- async def start(self) -> None:
- meta = encode_json({"host": self.target_host, "port": self.target_port})
- for link in self.links:
- link.tcp_sessions[(self.session_id, self.stream_id)] = self
- await link.send(Frame(TCP_OPEN, self.session_id, self.stream_id, 0, 0, meta))
- await asyncio.wait_for(self.open_event.wait(), timeout=10)
- if self.opened == 0:
- raise ConnectionError(self.open_errors[0] if self.open_errors else "all relays failed")
- self.pump_task = asyncio.create_task(self._pump_local())
- async def _pump_local(self) -> None:
- try:
- while True:
- chunk = await self.local_reader.read(65536)
- if not chunk:
- break
- self.uplink_bytes += len(chunk)
- if self.winning_link is None and self.uplink_bytes <= self.warmup_bytes:
- await asyncio.gather(*(link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk)) for link in self.links if not link.closed), return_exceptions=True)
- else:
- if self.winning_link is None:
- await self.winner_event.wait()
- if self.winning_link:
- await self.winning_link.send(Frame(TCP_DATA, self.session_id, self.stream_id, 0, 0, chunk))
- except Exception:
- pass
- finally:
- await self.close()
- async def handle_frame(self, link: RelayLink, frame: Frame) -> None:
- if self.closed:
- return
- if frame.kind == TCP_STATUS:
- if frame.packet_id == STATUS_OK:
- self.opened += 1
- else:
- self.open_errors.append(frame.payload.decode("utf-8", errors="replace"))
- if self.opened > 0 or len(self.open_errors) == len(self.links):
- self.open_event.set()
- return
- if frame.kind == TCP_DATA:
- if self.winning_link is None:
- self.winning_link = link
- self.winner_name = link.node.name
- self.win_counts[link.node.name] = self.win_counts.get(link.node.name, 0) + 1
- node_total = self.win_counts[link.node.name]
- relay_detail = ", ".join(f"{name}={count}" for name, count in sorted(self.win_counts.items())) or "none"
- print(f"[edge] tcp win session={self.session_id} target={self.target_host}:{self.target_port} winner={link.node.name} node_total={node_total} win_breakdown={relay_detail}")
- self.winner_event.set()
- await self._close_losers(except_link=link)
- if link is self.winning_link:
- self.local_writer.write(frame.payload)
- await self.local_writer.drain()
- return
- if frame.kind == TCP_CLOSE:
- if self.winning_link is None:
- self.winning_link = link
- self.winner_event.set()
- if link is self.winning_link:
- await self.close()
- async def _close_losers(self, except_link: RelayLink) -> None:
- await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if link is not except_link and not link.closed), return_exceptions=True)
- async def close(self) -> None:
- if self.closed:
- return
- self.closed = True
- if self.pump_task and self.pump_task is not asyncio.current_task():
- self.pump_task.cancel()
- with contextlib.suppress(Exception):
- await self.pump_task
- await asyncio.gather(*(link.send(Frame(TCP_CLOSE, self.session_id, self.stream_id, 0, 0, b"")) for link in self.links if not link.closed), return_exceptions=True)
- for link in self.links:
- link.tcp_sessions.pop((self.session_id, self.stream_id), None)
- self.local_writer.close()
- with contextlib.suppress(Exception):
- await self.local_writer.wait_closed()
- class UdpAssociateServer(asyncio.DatagramProtocol):
- def __init__(self, edge: "SocksEdge") -> None:
- self.edge = edge
- self.transport: asyncio.DatagramTransport | None = None
- self.client_addr = None
- self.associate_peer = None
- self.packet_counter = itertools.count(1)
- self.client_flows: dict[tuple[tuple[str, int], str, int], UdpFlowState] = {}
- self.flow_counter = itertools.count(1)
- self.last_summary_at = 0.0
- self.win_counts: Dict[str, int] = {}
- self.relay_error_counts: Dict[str, int] = {}
- def connection_made(self, transport) -> None:
- self.transport = transport
- def register_associate(self, peer) -> None:
- peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer)
- if self.associate_peer != peer_text:
- print(f"[edge] udp associate peer={peer_text}")
- self.associate_peer = peer_text
- def datagram_received(self, data: bytes, addr) -> None:
- if len(data) < 10:
- return
- if self.client_addr is None:
- self.client_addr = addr
- print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}")
- elif addr != self.client_addr:
- print(f"[edge] udp client rebound old={self.client_addr[0]}:{self.client_addr[1]} new={addr[0]}:{addr[1]}")
- self._reset_client_state(addr)
- host, port, payload = self._parse_socks_udp(data)
- loop = asyncio.get_running_loop()
- now = loop.time()
- flow_key = ((addr[0], addr[1]), host, port)
- flow = self.client_flows.get(flow_key)
- if flow is None:
- flow = UdpFlowState(
- flow_id=next(self.flow_counter),
- client_addr=(addr[0], addr[1]),
- target_host=host,
- target_port=port,
- created_at=now,
- last_activity=now,
- )
- self.client_flows[flow_key] = flow
- flow.touch(now)
- flow.packets_sent += 1
- packet_id = next(self.packet_counter)
- asyncio.create_task(self.edge.forward_udp(flow, payload, packet_id, self))
- self._log_udp_summary()
- def _reset_client_state(self, addr) -> None:
- for flow in list(self.client_flows.values()):
- for task in list(flow.direct_tasks.values()):
- task.cancel()
- for sock in list(flow.direct_sockets.values()):
- with contextlib.suppress(Exception):
- sock.close()
- for stream_id in list(flow.link_streams.values()):
- self.edge.udp_flow_sessions.pop((flow.flow_id, stream_id), None)
- self.client_flows.clear()
- self.client_addr = addr
- self.win_counts.clear()
- print(f"[edge] udp client bound addr={addr[0]}:{addr[1]}")
- async def handle_from_relay(self, frame: Frame, link: RelayLink) -> None:
- if self.transport is None or self.client_addr is None:
- return
- flow = self.edge.udp_flow_sessions.get((frame.session_id, frame.stream_id))
- if flow is None:
- return
- await self._deliver_flow_packet(flow, frame.packet_id, frame.payload, link.node.name)
- async def handle_from_direct(self, flow: UdpFlowState, path_name: str, payload: bytes) -> None:
- if self.transport is None or self.client_addr is None:
- return
- await self._deliver_flow_packet(flow, 0, payload, path_name)
- async def _deliver_flow_packet(self, flow: UdpFlowState, packet_id: int, payload: bytes, source_name: str) -> None:
- if self.transport is None or self.client_addr is None:
- return
- packet = self._build_socks_udp(flow.target_host, flow.target_port, payload)
- now = asyncio.get_running_loop().time()
- flow.touch(now)
- flow.packets_received += 1
- if flow.winner_name is None:
- flow.winner_name = source_name
- self.win_counts[source_name] = self.win_counts.get(source_name, 0) + 1
- self._log_udp_summary(force=True)
- elif flow.winner_name != source_name:
- flow.duplicate_responses += 1
- if flow.winner_name == source_name:
- self.transport.sendto(packet, self.client_addr)
- def set_flow_candidates(self, flow: UdpFlowState, candidate_names: tuple[str, ...]) -> None:
- if not flow.candidate_names:
- flow.candidate_names = candidate_names
- def note_unsent(self, flow: UdpFlowState, packet_id: int) -> None:
- flow.touch(asyncio.get_running_loop().time())
- flow.relay_failures["unsent"] = flow.relay_failures.get("unsent", 0) + 1
- self._log_udp_summary(force=True)
- def _log_udp_summary(self, force: bool = False) -> None:
- now = asyncio.get_running_loop().time()
- if not force and now - self.last_summary_at < 10:
- return
- self.last_summary_at = now
- active_flows = len(self.client_flows)
- winners = sum(1 for flow in self.client_flows.values() if flow.winner_name)
- packets_sent = sum(flow.packets_sent for flow in self.client_flows.values())
- packets_received = sum(flow.packets_received for flow in self.client_flows.values())
- duplicates = sum(flow.duplicate_responses for flow in self.client_flows.values())
- direct_paths = sum(len(flow.direct_sockets) for flow in self.client_flows.values())
- relay_candidates = sum(len(flow.link_streams) for flow in self.client_flows.values())
- winner_detail = ", ".join(f"{flow.flow_id}:{flow.winner_name}" for flow in self.client_flows.values() if flow.winner_name) or "none"
- relay_errors: list[str] = []
- for flow in self.client_flows.values():
- for name, count in flow.relay_failures.items():
- relay_errors.append(f"{name}={count}")
- relay_error_detail = ", ".join(sorted(relay_errors)) or "none"
- if self.client_addr:
- print(
- f"[edge] udp summary bind={self.client_addr[0]}:{self.client_addr[1]} active_flows={active_flows} "
- f"winner_flows={winners} winner_detail={winner_detail} packets_sent={packets_sent} packets_received={packets_received} dup={duplicates} "
- f"direct_paths={direct_paths} relay_paths={relay_candidates} relay_errors={relay_error_detail}"
- )
- else:
- print(
- f"[edge] udp summary bind=unbound active_flows={active_flows} winner_flows={winners} winner_detail={winner_detail} packets_sent={packets_sent} packets_received={packets_received} dup={duplicates} "
- f"direct_paths={direct_paths} relay_paths={relay_candidates} relay_errors={relay_error_detail}"
- )
- def _parse_socks_udp(self, packet: bytes) -> tuple[str, int, bytes]:
- atyp = packet[3]
- offset = 4
- if atyp == 1:
- host = socket.inet_ntoa(packet[offset:offset + 4])
- offset += 4
- elif atyp == 3:
- size = packet[offset]
- offset += 1
- host = packet[offset:offset + size].decode()
- offset += size
- else:
- raise ValueError("unsupported udp atyp")
- port = struct.unpack("!H", packet[offset:offset + 2])[0]
- offset += 2
- return host, port, packet[offset:]
- def _build_socks_udp(self, host: str, port: int, payload: bytes) -> bytes:
- try:
- addr = socket.inet_aton(host)
- header = b"\x00\x00\x00\x01" + addr + struct.pack("!H", port)
- except OSError:
- raw = host.encode()
- header = b"\x00\x00\x00\x03" + bytes([len(raw)]) + raw + struct.pack("!H", port)
- return header + payload
- class SocksEdge:
- def __init__(self, listen_host: str, listen_port: int, config: Config) -> None:
- self.listen_host = listen_host
- self.listen_port = listen_port
- self.config = config
- self.scheduler = Scheduler(config)
- self.links: list[RelayLink] = []
- self.session_ids = itertools.count(1)
- self.udp_stream_ids = itertools.count(1)
- self.udp_flow_sessions: dict[tuple[int, int], UdpFlowState] = {}
- self.udp_server: UdpAssociateServer | None = None
- async def start(self) -> None:
- await self.scheduler.start()
- await self._connect_relays()
- server = await asyncio.start_server(self._accept, self.listen_host, self.listen_port)
- sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
- print(f"[edge] socks5 listening on {sockets}")
- async with server:
- await server.serve_forever()
- async def _connect_relays(self) -> None:
- loop = asyncio.get_running_loop()
- transport, protocol = await loop.create_datagram_endpoint(lambda: UdpAssociateServer(self), local_addr=(self.listen_host, 0))
- self.udp_server = protocol
- self.udp_transport = transport
- for node in self.config.relays:
- link = RelayLink(node=node, reader=None, writer=None) # type: ignore[arg-type]
- link.udp_server = protocol
- self.links.append(link)
- link.maintain_task = asyncio.create_task(self._maintain_link(link))
- async def _maintain_link(self, link: RelayLink) -> None:
- backoff = 1.0
- while True:
- try:
- reader, writer = await asyncio.open_connection(link.node.host, link.node.port)
- sock = writer.get_extra_info("socket")
- if sock is not None:
- with contextlib.suppress(OSError):
- sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
- link.reader = reader
- link.writer = writer
- await link.start()
- backoff = 1.0
- await link.closed_event.wait()
- except asyncio.CancelledError:
- raise
- except Exception:
- await asyncio.sleep(backoff)
- backoff = min(10.0, backoff * 2)
- async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
- try:
- peer = writer.get_extra_info("peername")
- _host, _port, udp_mode = await self._handshake(reader, writer, peer)
- if udp_mode:
- return
- except Exception:
- writer.close()
- with contextlib.suppress(Exception):
- await writer.wait_closed()
- def _selected_links(self) -> list[RelayLink]:
- chosen = {node.name for node in self.scheduler.choose()}
- links = [link for link in self.links if link.node.name in chosen and not link.closed]
- return links or [link for link in self.links if not link.closed][:1]
- def _selected_udp_links(self) -> list[RelayLink]:
- online = [link for link in self.links if not link.closed and link.writer is not None]
- if not online:
- return []
- ordered = sorted(online, key=lambda link: self.scheduler.scores.get(link.node.name).score if link.node.name in self.scheduler.scores else 999999.0)
- return ordered
- def _udp_direct_redundancy_for_target(self, target_host: str) -> int:
- base = self.config.udp_direct_redundancy
- if ":" in target_host and self.config.udp_direct_redundancy_v6 is not None:
- base = self.config.udp_direct_redundancy_v6
- elif ":" not in target_host and self.config.udp_direct_redundancy_v4 is not None:
- base = self.config.udp_direct_redundancy_v4
- return max(1, base)
- async def _ensure_udp_direct_paths(self, flow: UdpFlowState, udp_server: UdpAssociateServer) -> None:
- target_count = self._udp_direct_redundancy_for_target(flow.target_host)
- for index in range(target_count):
- name = f"direct-{index + 1}" if target_count > 1 else "direct"
- if name in flow.direct_sockets or name in flow.direct_failures:
- continue
- try:
- family = socket.AF_INET6 if ":" in flow.target_host else socket.AF_INET
- sock = socket.socket(family, socket.SOCK_DGRAM)
- sock.setblocking(False)
- await asyncio.get_running_loop().sock_connect(sock, (flow.target_host, flow.target_port))
- flow.direct_sockets[name] = sock
- flow.direct_tasks[name] = asyncio.create_task(self._pump_udp_direct(flow, name, sock, udp_server))
- except Exception as exc:
- flow.direct_failures.add(name)
- print(f"[edge] udp direct open error flow={flow.flow_id} path={name} target={flow.target_host}:{flow.target_port} error={exc!r}")
- async def _pump_udp_direct(self, flow: UdpFlowState, path_name: str, sock: socket.socket, udp_server: UdpAssociateServer) -> None:
- loop = asyncio.get_running_loop()
- try:
- while True:
- data = await loop.sock_recv(sock, 65535)
- if not data:
- break
- await udp_server.handle_from_direct(flow, path_name, data)
- except Exception:
- pass
- finally:
- flow.direct_tasks.pop(path_name, None)
- flow.direct_sockets.pop(path_name, None)
- with contextlib.suppress(Exception):
- sock.close()
- async def forward_udp(self, flow: UdpFlowState, payload: bytes, packet_id: int, udp_server: UdpAssociateServer) -> None:
- await self._ensure_udp_direct_paths(flow, udp_server)
- meta = encode_json({"host": flow.target_host, "port": flow.target_port})
- links = self._selected_udp_links()
- direct_names = tuple(name for name in sorted(flow.direct_sockets))
- relay_names = tuple(link.node.name for link in links)
- candidate_names = direct_names + relay_names
- udp_server.set_flow_candidates(flow, candidate_names)
- if not candidate_names:
- udp_server.note_unsent(flow, packet_id)
- return
- active_direct_names = list(direct_names)
- active_links = links
- if not (self.config.udp_always_broadcast or flow.winner_name is None):
- active_direct_names = [name for name in active_direct_names if name == flow.winner_name]
- active_links = [link for link in active_links if link.node.name == flow.winner_name]
- if not active_direct_names and not active_links:
- if direct_names:
- active_direct_names = [direct_names[0]]
- elif links:
- active_links = links[:1]
- copies = max(1, self.config.udp_redundancy + 1)
- sent_any = False
- for attempt in range(copies):
- for path_name in active_direct_names:
- sock = flow.direct_sockets.get(path_name)
- if sock is None:
- continue
- try:
- await asyncio.get_running_loop().sock_sendall(sock, payload)
- sent_any = True
- except Exception as exc:
- flow.direct_failures.add(path_name)
- flow.direct_sockets.pop(path_name, None)
- task = flow.direct_tasks.pop(path_name, None)
- if task is not None:
- task.cancel()
- with contextlib.suppress(Exception):
- sock.close()
- flow.relay_failures[path_name] = flow.relay_failures.get(path_name, 0) + 1
- if path_name not in flow.relay_error_seen:
- flow.relay_error_seen.add(path_name)
- print(
- f"[edge] udp relay error flow={flow.flow_id} relay={path_name} error={exc!r}"
- )
- for link in active_links:
- stream_id = flow.link_streams.get(link.node.name)
- if stream_id is None:
- stream_id = next(self.udp_stream_ids)
- flow.link_streams[link.node.name] = stream_id
- self.udp_flow_sessions[(flow.flow_id, stream_id)] = flow
- include_meta = link.node.name not in flow.initialized_links
- body = (meta + payload) if include_meta else payload
- meta_len = len(meta) if include_meta else 0
- try:
- await link.send(Frame(UDP_SEND, flow.flow_id, stream_id, 0, meta_len, body))
- flow.initialized_links.add(link.node.name)
- sent_any = True
- except Exception as exc:
- flow.link_streams.pop(link.node.name, None)
- self.udp_flow_sessions.pop((flow.flow_id, stream_id), None)
- flow.relay_failures[link.node.name] = flow.relay_failures.get(link.node.name, 0) + 1
- if link.node.name not in flow.relay_error_seen:
- flow.relay_error_seen.add(link.node.name)
- print(
- f"[edge] udp relay error flow={flow.flow_id} relay={link.node.name} error={exc!r}"
- )
- if attempt + 1 < copies and self.config.udp_copy_interval_ms > 0:
- await asyncio.sleep(self.config.udp_copy_interval_ms / 1000)
- if not sent_any:
- udp_server.note_unsent(flow, packet_id)
- async def _handshake(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer) -> tuple[str, int, bool]:
- version, methods_len = (await read_exact(reader, 2))
- if version != SOCKS_VERSION:
- raise ValueError("unsupported socks version")
- await read_exact(reader, methods_len)
- writer.write(b"\x05\x00")
- await writer.drain()
- version, command, _, atyp = await read_exact(reader, 4)
- if version != SOCKS_VERSION:
- raise ValueError("unsupported socks version")
- if atyp == 1:
- host = socket.inet_ntoa(await read_exact(reader, 4))
- elif atyp == 3:
- size = (await read_exact(reader, 1))[0]
- host = (await read_exact(reader, size)).decode()
- else:
- raise ValueError("unsupported atyp")
- port = struct.unpack("!H", await read_exact(reader, 2))[0]
- peer_text = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) and len(peer) >= 2 else str(peer)
- if command == 1:
- writer.write(b"\x05\x07\x00\x01\x00\x00\x00\x00\x00\x00")
- await writer.drain()
- raise ValueError("tcp connect disabled in socks udp-only mode")
- if command == 3 and self.udp_server and self.udp_server.transport:
- bind_host, bind_port = self.udp_server.transport.get_extra_info("sockname")[:2]
- self.udp_server.register_associate(peer)
- print(f"[edge] socks handshake peer={peer_text} command=udp_associate target={host}:{port} bind={bind_host}:{bind_port}")
- writer.write(b"\x05\x00\x00\x01" + socket.inet_aton(bind_host) + struct.pack("!H", bind_port))
- await writer.drain()
- return host, port, True
- raise ValueError("unsupported socks command")
|