| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- from __future__ import annotations
- import asyncio
- import contextlib
- import time
- from dataclasses import dataclass, field
- from typing import Dict
- from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, read_frame, write_frame
- @dataclass
- class TcpSession:
- session_id: int
- stream_id: int
- writer: asyncio.StreamWriter
- task: asyncio.Task
- @dataclass
- class UdpSession:
- session_id: int
- stream_id: int
- transport: asyncio.DatagramTransport | None = None
- protocol: "RelayUdpProtocol | None" = None
- host: str = ""
- port: int = 0
- family: int = 0
- class RelayUdpProtocol(asyncio.DatagramProtocol):
- def __init__(self, channel: "RelayChannel", session_id: int, stream_id: int) -> None:
- self.channel = channel
- self.session_id = session_id
- self.stream_id = stream_id
- def datagram_received(self, data: bytes, _addr) -> None:
- if self.channel.closed:
- return
- asyncio.create_task(self.channel.safe_send(Frame(UDP_RECV, self.session_id, self.stream_id, 0, 0, data)))
- @dataclass
- class RelayChannel:
- reader: asyncio.StreamReader
- writer: asyncio.StreamWriter
- token: str
- tcp_sessions: Dict[tuple[int, int], TcpSession] = field(default_factory=dict)
- udp_sessions: Dict[tuple[int, int], UdpSession] = field(default_factory=dict)
- closed: bool = False
- send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
- authed_at: float = 0.0
- frame_count: int = 0
- authed_kind: str = "normal"
- async def run(self) -> None:
- peer = self.writer.get_extra_info("peername")
- authed = False
- try:
- auth = await read_frame(self.reader)
- if auth.kind != AUTH:
- raise PermissionError("invalid handshake kind")
- try:
- payload = decode_json(auth.payload) if auth.payload else {}
- except Exception as exc:
- raise PermissionError(f"invalid auth payload: {exc!r}") from exc
- if payload.get("token") != self.token:
- raise PermissionError("invalid token")
- authed = True
- self.authed_at = time.monotonic()
- self.authed_kind = payload.get("purpose", "normal")
- ack_payload = {"status": "ok", "kind": self.authed_kind}
- await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json(ack_payload)))
- while True:
- frame = await read_frame(self.reader)
- self.frame_count += 1
- await self.handle(frame)
- except asyncio.IncompleteReadError:
- if authed and self.authed_kind != "probe":
- lived = time.monotonic() - self.authed_at if self.authed_at else 0.0
- if lived >= 15 or self.frame_count > 20:
- print(f"[relay] session closed peer={peer} kind={self.authed_kind} lived={lived:.1f}s frames={self.frame_count}")
- except asyncio.CancelledError:
- pass
- except Exception as exc:
- if authed and self.authed_kind != "probe":
- lived = time.monotonic() - self.authed_at if self.authed_at else 0.0
- print(f"[relay] session error peer={peer} kind={self.authed_kind} lived={lived:.1f}s frames={self.frame_count} error={exc!r}")
- finally:
- await self.close()
- async def safe_send(self, frame: Frame) -> bool:
- if self.closed:
- return False
- try:
- async with self.send_lock:
- if self.closed:
- return False
- await write_frame(self.writer, frame)
- return True
- except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError):
- return False
- async def handle(self, frame: Frame) -> None:
- key = (frame.session_id, frame.stream_id)
- if frame.kind == PING:
- await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong"))
- return
- if frame.kind == TCP_OPEN:
- meta = decode_json(frame.payload)
- family = int(meta.get("family", 0)) or 0
- try:
- reader, writer = await asyncio.open_connection(meta["host"], int(meta["port"]), family=family or 0)
- task = asyncio.create_task(self._tcp_pump(frame.session_id, frame.stream_id, reader))
- self.tcp_sessions[key] = TcpSession(frame.session_id, frame.stream_id, writer, task)
- await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_OK, b"ok"))
- except Exception as exc:
- await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
- return
- if frame.kind == TCP_DATA:
- session = self.tcp_sessions.get(key)
- if session:
- try:
- session.writer.write(frame.payload)
- await session.writer.drain()
- except Exception:
- await self._close_tcp(key)
- return
- if frame.kind == TCP_CLOSE:
- await self._close_tcp(key)
- return
- if frame.kind == UDP_SEND:
- session = self.udp_sessions.get(key)
- meta = None
- payload = frame.payload
- if frame.packet_id > 0:
- meta = decode_json(frame.payload[: frame.packet_id])
- payload = frame.payload[frame.packet_id :]
- if session is None:
- if meta is None:
- return
- family = int(meta.get("family", 0)) or 0
- transport, protocol = await asyncio.get_running_loop().create_datagram_endpoint(
- lambda: RelayUdpProtocol(self, frame.session_id, frame.stream_id),
- remote_addr=(meta["host"], int(meta["port"])),
- family=family or 0,
- )
- session = UdpSession(frame.session_id, frame.stream_id, transport, protocol, meta["host"], int(meta["port"]), family)
- self.udp_sessions[key] = session
- with contextlib.suppress(Exception):
- session.transport.sendto(payload)
- return
- async def _tcp_pump(self, session_id: int, stream_id: int, reader: asyncio.StreamReader) -> None:
- try:
- while True:
- chunk = await reader.read(65536)
- if not chunk:
- break
- sent = await self.safe_send(Frame(TCP_DATA, session_id, stream_id, 0, 0, chunk))
- if not sent:
- break
- except asyncio.CancelledError:
- pass
- except Exception:
- pass
- finally:
- if not self.closed:
- await self.safe_send(Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
- await self._close_tcp((session_id, stream_id), from_task=True)
- async def _close_tcp(self, key: tuple[int, int], from_task: bool = False) -> None:
- session = self.tcp_sessions.pop(key, None)
- if session is None:
- return
- if not from_task and session.task is not asyncio.current_task():
- session.task.cancel()
- with contextlib.suppress(Exception):
- await session.task
- session.writer.close()
- with contextlib.suppress(Exception):
- await session.writer.wait_closed()
- async def close(self) -> None:
- if self.closed:
- return
- self.closed = True
- for key in list(self.tcp_sessions):
- await self._close_tcp(key)
- for session in self.udp_sessions.values():
- if session.transport:
- session.transport.close()
- self.udp_sessions.clear()
- self.writer.close()
- with contextlib.suppress(Exception):
- await self.writer.wait_closed()
- class RelayServer:
- def __init__(self, token: str) -> None:
- self.token = token
- async def start(self, host: str, port: int) -> None:
- server = await asyncio.start_server(self._accept, host, port)
- sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
- print(f"[relay] listening on {sockets}")
- async with server:
- await server.serve_forever()
- async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
- await RelayChannel(reader, writer, self.token).run()
|