relay_server.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import json
  5. import time
  6. from dataclasses import dataclass, field
  7. from typing import Dict
  8. from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame
  9. @dataclass
  10. class TcpSession:
  11. session_id: int
  12. stream_id: int
  13. writer: asyncio.StreamWriter
  14. task: asyncio.Task
  15. @dataclass
  16. class UdpSession:
  17. session_id: int
  18. stream_id: int
  19. transport: asyncio.DatagramTransport | None = None
  20. protocol: "RelayUdpProtocol | None" = None
  21. host: str = ""
  22. port: int = 0
  23. family: int = 0
  24. class RelayUdpProtocol(asyncio.DatagramProtocol):
  25. def __init__(self, channel: "RelayChannel", session_id: int, stream_id: int) -> None:
  26. self.channel = channel
  27. self.session_id = session_id
  28. self.stream_id = stream_id
  29. def datagram_received(self, data: bytes, _addr) -> None:
  30. if self.channel.closed:
  31. return
  32. asyncio.create_task(self.channel.safe_send(Frame(UDP_RECV, self.session_id, self.stream_id, 0, 0, data)))
  33. @dataclass
  34. class RelayChannel:
  35. reader: asyncio.StreamReader
  36. writer: asyncio.StreamWriter
  37. token: str
  38. tcp_sessions: Dict[tuple[int, int], TcpSession] = field(default_factory=dict)
  39. udp_sessions: Dict[tuple[int, int], UdpSession] = field(default_factory=dict)
  40. closed: bool = False
  41. send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  42. authed_at: float = 0.0
  43. frame_count: int = 0
  44. authed_kind: str = "normal"
  45. def _should_log_lifecycle(self, lived: float) -> bool:
  46. if self.authed_kind == "probe":
  47. return False
  48. return lived >= 15 or self.frame_count > 20
  49. def _should_log_auth(self) -> bool:
  50. return self.authed_kind not in {"probe", "normal"}
  51. async def run(self) -> None:
  52. peer = self.writer.get_extra_info("peername")
  53. authed = False
  54. try:
  55. auth = await read_frame(self.reader)
  56. if auth.kind != AUTH:
  57. return
  58. try:
  59. payload = decode_json(auth.payload) if auth.payload else {}
  60. except json.JSONDecodeError:
  61. return
  62. except Exception:
  63. return
  64. if payload.get("token") != self.token:
  65. return
  66. authed = True
  67. self.authed_at = time.monotonic()
  68. self.authed_kind = payload.get("purpose", "normal")
  69. if self._should_log_auth():
  70. print(f"[relay] auth ok peer={peer} kind={self.authed_kind}")
  71. ack_payload = {"status": "ok", "kind": self.authed_kind}
  72. await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json(ack_payload)))
  73. while True:
  74. frame = await read_frame(self.reader)
  75. self.frame_count += 1
  76. await self.handle(frame)
  77. except asyncio.IncompleteReadError:
  78. if authed:
  79. lived = time.monotonic() - self.authed_at if self.authed_at else 0.0
  80. if self._should_log_lifecycle(lived):
  81. print(f"[relay] session closed peer={peer} kind={self.authed_kind} lived={lived:.1f}s frames={self.frame_count}")
  82. except asyncio.CancelledError:
  83. pass
  84. except Exception as exc:
  85. if authed:
  86. lived = time.monotonic() - self.authed_at if self.authed_at else 0.0
  87. if self._should_log_lifecycle(lived):
  88. print(f"[relay] session error peer={peer} kind={self.authed_kind} lived={lived:.1f}s frames={self.frame_count} error={exc!r}")
  89. finally:
  90. await self.close()
  91. async def safe_send(self, frame: Frame) -> bool:
  92. if self.closed:
  93. return False
  94. try:
  95. async with self.send_lock:
  96. if self.closed:
  97. return False
  98. await write_frame(self.writer, frame)
  99. return True
  100. except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError):
  101. return False
  102. async def handle(self, frame: Frame) -> None:
  103. key = (frame.session_id, frame.stream_id)
  104. if frame.kind == PING:
  105. await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong"))
  106. return
  107. if frame.kind == AUTH:
  108. return
  109. if frame.kind == TCP_OPEN:
  110. try:
  111. meta = decode_json(frame.payload) if frame.payload else {}
  112. family = int(meta.get("family", 0)) or 0
  113. reader, writer = await asyncio.open_connection(meta["host"], int(meta["port"]), family=family or 0)
  114. task = asyncio.create_task(self._tcp_pump(frame.session_id, frame.stream_id, reader))
  115. self.tcp_sessions[key] = TcpSession(frame.session_id, frame.stream_id, writer, task)
  116. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_OK, b"ok"))
  117. except Exception as exc:
  118. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  119. return
  120. if frame.kind == TCP_DATA:
  121. session = self.tcp_sessions.get(key)
  122. if session:
  123. try:
  124. session.writer.write(frame.payload)
  125. await session.writer.drain()
  126. except Exception:
  127. await self._close_tcp(key)
  128. return
  129. if frame.kind == TCP_CLOSE:
  130. await self._close_tcp(key)
  131. return
  132. if frame.kind == UDP_SEND:
  133. session = self.udp_sessions.get(key)
  134. meta = None
  135. payload = frame.payload
  136. if frame.packet_id > 0 and frame.packet_id <= len(frame.payload):
  137. try:
  138. meta = decode_json(frame.payload[: frame.packet_id])
  139. payload = frame.payload[frame.packet_id :]
  140. except Exception:
  141. print(
  142. f"[relay] udp meta decode failed session={frame.session_id} stream={frame.stream_id} "
  143. f"meta_len={frame.packet_id} payload_len={len(frame.payload)}"
  144. )
  145. if session is None:
  146. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, b"udp meta decode failed"))
  147. return
  148. payload = frame.payload
  149. if session is None:
  150. if meta is None:
  151. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, b"udp meta missing"))
  152. return
  153. try:
  154. family = int(meta.get("family", 0)) or 0
  155. print(
  156. f"[relay] udp session open session={frame.session_id} stream={frame.stream_id} "
  157. f"target={meta.get('host')}:{meta.get('port')} family={family} meta_len={frame.packet_id} payload_len={len(payload)}"
  158. )
  159. transport, protocol = await asyncio.get_running_loop().create_datagram_endpoint(
  160. lambda: RelayUdpProtocol(self, frame.session_id, frame.stream_id),
  161. remote_addr=(meta["host"], int(meta["port"])),
  162. family=family or 0,
  163. )
  164. session = UdpSession(frame.session_id, frame.stream_id, transport, protocol, meta["host"], int(meta["port"]), family)
  165. self.udp_sessions[key] = session
  166. except Exception as exc:
  167. print(
  168. f"[relay] udp session open failed session={frame.session_id} stream={frame.stream_id} "
  169. f"meta={meta!r} error={exc!r}"
  170. )
  171. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  172. return
  173. if session.transport is not None:
  174. with contextlib.suppress(Exception):
  175. session.transport.sendto(payload)
  176. return
  177. async def _tcp_pump(self, session_id: int, stream_id: int, reader: asyncio.StreamReader) -> None:
  178. try:
  179. while True:
  180. chunk = await reader.read(65536)
  181. if not chunk:
  182. break
  183. sent = await self.safe_send(Frame(TCP_DATA, session_id, stream_id, 0, 0, chunk))
  184. if not sent:
  185. break
  186. except asyncio.CancelledError:
  187. pass
  188. except Exception:
  189. pass
  190. finally:
  191. if not self.closed:
  192. await self.safe_send(Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
  193. await self._close_tcp((session_id, stream_id), from_task=True)
  194. async def _close_tcp(self, key: tuple[int, int], from_task: bool = False) -> None:
  195. session = self.tcp_sessions.pop(key, None)
  196. if session is None:
  197. return
  198. if not from_task and session.task is not asyncio.current_task():
  199. session.task.cancel()
  200. with contextlib.suppress(Exception):
  201. await session.task
  202. session.writer.close()
  203. with contextlib.suppress(Exception):
  204. await session.writer.wait_closed()
  205. async def close(self) -> None:
  206. if self.closed:
  207. return
  208. self.closed = True
  209. for key in list(self.tcp_sessions):
  210. await self._close_tcp(key)
  211. for session in self.udp_sessions.values():
  212. if session.transport:
  213. session.transport.close()
  214. self.udp_sessions.clear()
  215. self.writer.close()
  216. with contextlib.suppress(Exception):
  217. await self.writer.wait_closed()
  218. class RelayServer:
  219. def __init__(self, token: str) -> None:
  220. self.token = token
  221. async def start(self, host: str, port: int) -> None:
  222. server = await asyncio.start_server(self._accept, host, port)
  223. sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
  224. print(f"[relay] listening on {sockets}")
  225. async with server:
  226. await server.serve_forever()
  227. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  228. await RelayChannel(reader, writer, self.token).run()