relay_server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import time
  5. from dataclasses import dataclass, field
  6. from typing import Dict
  7. from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame
  8. @dataclass
  9. class TcpSession:
  10. session_id: int
  11. stream_id: int
  12. writer: asyncio.StreamWriter
  13. task: asyncio.Task
  14. @dataclass
  15. class UdpSession:
  16. session_id: int
  17. stream_id: int
  18. transport: asyncio.DatagramTransport | None = None
  19. protocol: "RelayUdpProtocol | None" = None
  20. host: str = ""
  21. port: int = 0
  22. family: int = 0
  23. created_at: float = 0.0
  24. last_seen_at: float = 0.0
  25. packets_sent: int = 0
  26. packets_received: int = 0
  27. first_seen_logged: bool = False
  28. class RelayUdpProtocol(asyncio.DatagramProtocol):
  29. def __init__(self, channel: "RelayChannel", session_id: int, stream_id: int) -> None:
  30. self.channel = channel
  31. self.session_id = session_id
  32. self.stream_id = stream_id
  33. def datagram_received(self, data: bytes, _addr) -> None:
  34. if self.channel.closed:
  35. return
  36. session = self.channel.udp_sessions.get((self.session_id, self.stream_id))
  37. if session is not None:
  38. session.packets_received += 1
  39. session.last_seen_at = time.monotonic()
  40. if not session.first_seen_logged:
  41. session.first_seen_logged = True
  42. lived = session.last_seen_at - session.created_at if session.created_at else 0.0
  43. print(
  44. f"[relay] udp session first_packet peer={self.channel.writer.get_extra_info('peername')} "
  45. f"target={session.host}:{session.port} session={session.session_id} stream={session.stream_id} "
  46. f"family={session.family} lived={lived:.1f}s recv={session.packets_received} sent={session.packets_sent}"
  47. )
  48. asyncio.create_task(self.channel.safe_send(Frame(UDP_RECV, self.session_id, self.stream_id, 0, 0, data)))
  49. @dataclass
  50. class RelayChannel:
  51. reader: asyncio.StreamReader
  52. writer: asyncio.StreamWriter
  53. token: str
  54. tcp_sessions: Dict[tuple[int, int], TcpSession] = field(default_factory=dict)
  55. udp_sessions: Dict[tuple[int, int], UdpSession] = field(default_factory=dict)
  56. closed: bool = False
  57. send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  58. authed_at: float = 0.0
  59. frame_count: int = 0
  60. authed_kind: str = "normal"
  61. udp_only: bool = False
  62. async def run(self) -> None:
  63. peer = self.writer.get_extra_info("peername")
  64. authed = False
  65. try:
  66. auth = await read_frame(self.reader)
  67. if auth.kind != AUTH:
  68. return
  69. try:
  70. payload = decode_json(auth.payload) if auth.payload else {}
  71. except Exception:
  72. return
  73. if payload.get("token") != self.token:
  74. return
  75. authed = True
  76. self.authed_at = time.monotonic()
  77. self.authed_kind = payload.get("purpose", "normal")
  78. ack_payload = {"status": "ok", "kind": self.authed_kind, "udp_only": self.udp_only}
  79. await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json(ack_payload)))
  80. while True:
  81. frame = await read_frame(self.reader)
  82. self.frame_count += 1
  83. await self.handle(frame)
  84. except asyncio.IncompleteReadError:
  85. if authed and self.authed_kind != "probe":
  86. lived = time.monotonic() - self.authed_at if self.authed_at else 0.0
  87. if lived >= 15 or self.frame_count > 20:
  88. print(f"[relay] session closed peer={peer} kind={self.authed_kind} lived={lived:.1f}s frames={self.frame_count}")
  89. except asyncio.CancelledError:
  90. pass
  91. except Exception as exc:
  92. if authed and self.authed_kind != "probe":
  93. lived = time.monotonic() - self.authed_at if self.authed_at else 0.0
  94. print(f"[relay] session error peer={peer} kind={self.authed_kind} lived={lived:.1f}s frames={self.frame_count} error={exc!r}")
  95. finally:
  96. await self.close()
  97. async def safe_send(self, frame: Frame) -> bool:
  98. if self.closed:
  99. return False
  100. try:
  101. async with self.send_lock:
  102. if self.closed:
  103. return False
  104. await write_frame(self.writer, frame)
  105. return True
  106. except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError):
  107. return False
  108. async def handle(self, frame: Frame) -> None:
  109. key = (frame.session_id, frame.stream_id)
  110. if frame.kind == PING:
  111. await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong"))
  112. return
  113. if frame.kind == AUTH:
  114. return
  115. if frame.kind == TCP_OPEN:
  116. if self.udp_only:
  117. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, b"tcp disabled on udp-only relay"))
  118. return
  119. try:
  120. meta = decode_json(frame.payload) if frame.payload else {}
  121. family = int(meta.get("family", 0)) or 0
  122. reader, writer = await asyncio.open_connection(meta["host"], int(meta["port"]), family=family or 0)
  123. task = asyncio.create_task(self._tcp_pump(frame.session_id, frame.stream_id, reader))
  124. self.tcp_sessions[key] = TcpSession(frame.session_id, frame.stream_id, writer, task)
  125. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_OK, b"ok"))
  126. except Exception as exc:
  127. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  128. return
  129. if frame.kind == TCP_DATA:
  130. if self.udp_only:
  131. return
  132. session = self.tcp_sessions.get(key)
  133. if session:
  134. try:
  135. session.writer.write(frame.payload)
  136. await session.writer.drain()
  137. except Exception:
  138. await self._close_tcp(key)
  139. return
  140. if frame.kind == TCP_CLOSE:
  141. if self.udp_only:
  142. return
  143. await self._close_tcp(key)
  144. return
  145. if frame.kind == UDP_SEND:
  146. session = self.udp_sessions.get(key)
  147. meta = None
  148. payload = frame.payload
  149. if frame.packet_id > 0 and frame.packet_id <= len(frame.payload):
  150. try:
  151. meta = decode_json(frame.payload[: frame.packet_id])
  152. payload = frame.payload[frame.packet_id :]
  153. except Exception:
  154. if session is None:
  155. return
  156. payload = frame.payload
  157. if session is None:
  158. if meta is None:
  159. return
  160. try:
  161. family = int(meta.get("family", 0)) or 0
  162. transport, protocol = await asyncio.get_running_loop().create_datagram_endpoint(
  163. lambda: RelayUdpProtocol(self, frame.session_id, frame.stream_id),
  164. remote_addr=(meta["host"], int(meta["port"])),
  165. family=family or 0,
  166. )
  167. session = UdpSession(
  168. frame.session_id,
  169. frame.stream_id,
  170. transport,
  171. protocol,
  172. meta["host"],
  173. int(meta["port"]),
  174. family,
  175. created_at=time.monotonic(),
  176. )
  177. self.udp_sessions[key] = session
  178. print(
  179. f"[relay] udp session open peer={peer} target={session.host}:{session.port} "
  180. f"session={session.session_id} stream={session.stream_id} family={session.family}"
  181. )
  182. except Exception as exc:
  183. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  184. return
  185. if session.transport is not None:
  186. with contextlib.suppress(Exception):
  187. session.packets_sent += 1
  188. session.last_seen_at = time.monotonic()
  189. session.transport.sendto(payload)
  190. return
  191. async def _tcp_pump(self, session_id: int, stream_id: int, reader: asyncio.StreamReader) -> None:
  192. try:
  193. while True:
  194. chunk = await reader.read(65536)
  195. if not chunk:
  196. break
  197. sent = await self.safe_send(Frame(TCP_DATA, session_id, stream_id, 0, 0, chunk))
  198. if not sent:
  199. break
  200. except asyncio.CancelledError:
  201. pass
  202. except Exception:
  203. pass
  204. finally:
  205. if not self.closed:
  206. await self.safe_send(Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
  207. await self._close_tcp((session_id, stream_id), from_task=True)
  208. async def _close_tcp(self, key: tuple[int, int], from_task: bool = False) -> None:
  209. session = self.tcp_sessions.pop(key, None)
  210. if session is None:
  211. return
  212. if not from_task and session.task is not asyncio.current_task():
  213. session.task.cancel()
  214. with contextlib.suppress(Exception):
  215. await session.task
  216. session.writer.close()
  217. with contextlib.suppress(Exception):
  218. await session.writer.wait_closed()
  219. async def close(self) -> None:
  220. if self.closed:
  221. return
  222. self.closed = True
  223. for key in list(self.tcp_sessions):
  224. await self._close_tcp(key)
  225. for session in self.udp_sessions.values():
  226. if session.transport:
  227. lived = time.monotonic() - session.created_at if session.created_at else 0.0
  228. print(
  229. f"[relay] udp session closed target={session.host}:{session.port} "
  230. f"session={session.session_id} stream={session.stream_id} family={session.family} "
  231. f"lived={lived:.1f}s sent={session.packets_sent} recv={session.packets_received}"
  232. )
  233. session.transport.close()
  234. self.udp_sessions.clear()
  235. self.writer.close()
  236. with contextlib.suppress(Exception):
  237. await self.writer.wait_closed()
  238. class RelayServer:
  239. def __init__(self, token: str) -> None:
  240. self.token = token
  241. self.udp_only = False
  242. async def start(self, host: str, port: int, udp_only: bool = False) -> None:
  243. self.udp_only = udp_only
  244. server = await asyncio.start_server(self._accept, host, port)
  245. sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
  246. mode = "udp-only" if udp_only else "normal"
  247. print(f"[relay] listening on {sockets} mode={mode}")
  248. async with server:
  249. await server.serve_forever()
  250. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  251. await RelayChannel(reader, writer, self.token, udp_only=self.udp_only).run()