relay_server.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import time
  5. from dataclasses import dataclass, field
  6. from typing import Dict
  7. from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, encode_json, read_frame, write_frame
  8. @dataclass
  9. class TcpSession:
  10. session_id: int
  11. stream_id: int
  12. writer: asyncio.StreamWriter
  13. task: asyncio.Task
  14. @dataclass
  15. class UdpSession:
  16. session_id: int
  17. stream_id: int
  18. transport: asyncio.DatagramTransport | None = None
  19. protocol: "RelayUdpProtocol | None" = None
  20. host: str = ""
  21. port: int = 0
  22. family: int = 0
  23. class RelayUdpProtocol(asyncio.DatagramProtocol):
  24. def __init__(self, channel: "RelayChannel", session_id: int, stream_id: int) -> None:
  25. self.channel = channel
  26. self.session_id = session_id
  27. self.stream_id = stream_id
  28. def datagram_received(self, data: bytes, _addr) -> None:
  29. if self.channel.closed:
  30. return
  31. asyncio.create_task(self.channel.safe_send(Frame(UDP_RECV, self.session_id, self.stream_id, 0, 0, data)))
  32. @dataclass
  33. class RelayChannel:
  34. reader: asyncio.StreamReader
  35. writer: asyncio.StreamWriter
  36. token: str
  37. tcp_sessions: Dict[tuple[int, int], TcpSession] = field(default_factory=dict)
  38. udp_sessions: Dict[tuple[int, int], UdpSession] = field(default_factory=dict)
  39. closed: bool = False
  40. send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  41. authed_at: float = 0.0
  42. frame_count: int = 0
  43. authed_kind: str = "normal"
  44. def _should_log_lifecycle(self, lived: float) -> bool:
  45. if self.authed_kind == "probe":
  46. return False
  47. return lived >= 15 or self.frame_count > 20
  48. async def run(self) -> None:
  49. peer = self.writer.get_extra_info("peername")
  50. authed = False
  51. try:
  52. auth = await read_frame(self.reader)
  53. if auth.kind != AUTH:
  54. return
  55. try:
  56. payload = decode_json(auth.payload) if auth.payload else {}
  57. except Exception:
  58. return
  59. if payload.get("token") != self.token:
  60. return
  61. authed = True
  62. self.authed_at = time.monotonic()
  63. self.authed_kind = payload.get("purpose", "normal")
  64. ack_payload = {"status": "ok", "kind": self.authed_kind}
  65. await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json(ack_payload)))
  66. while True:
  67. frame = await read_frame(self.reader)
  68. self.frame_count += 1
  69. await self.handle(frame)
  70. except asyncio.IncompleteReadError:
  71. if authed:
  72. lived = time.monotonic() - self.authed_at if self.authed_at else 0.0
  73. if self._should_log_lifecycle(lived):
  74. print(f"[relay] session closed peer={peer} kind={self.authed_kind} lived={lived:.1f}s frames={self.frame_count}")
  75. except asyncio.CancelledError:
  76. pass
  77. except Exception as exc:
  78. if authed:
  79. lived = time.monotonic() - self.authed_at if self.authed_at else 0.0
  80. if self._should_log_lifecycle(lived):
  81. print(f"[relay] session error peer={peer} kind={self.authed_kind} lived={lived:.1f}s frames={self.frame_count} error={exc!r}")
  82. finally:
  83. await self.close()
  84. async def safe_send(self, frame: Frame) -> bool:
  85. if self.closed:
  86. return False
  87. try:
  88. async with self.send_lock:
  89. if self.closed:
  90. return False
  91. await write_frame(self.writer, frame)
  92. return True
  93. except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError):
  94. return False
  95. async def handle(self, frame: Frame) -> None:
  96. key = (frame.session_id, frame.stream_id)
  97. if frame.kind == PING:
  98. await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong"))
  99. return
  100. if frame.kind == AUTH:
  101. return
  102. if frame.kind == TCP_OPEN:
  103. try:
  104. meta = decode_json(frame.payload) if frame.payload else {}
  105. family = int(meta.get("family", 0)) or 0
  106. reader, writer = await asyncio.open_connection(meta["host"], int(meta["port"]), family=family or 0)
  107. task = asyncio.create_task(self._tcp_pump(frame.session_id, frame.stream_id, reader))
  108. self.tcp_sessions[key] = TcpSession(frame.session_id, frame.stream_id, writer, task)
  109. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_OK, b"ok"))
  110. except Exception as exc:
  111. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  112. return
  113. if frame.kind == TCP_DATA:
  114. session = self.tcp_sessions.get(key)
  115. if session:
  116. try:
  117. session.writer.write(frame.payload)
  118. await session.writer.drain()
  119. except Exception:
  120. await self._close_tcp(key)
  121. return
  122. if frame.kind == TCP_CLOSE:
  123. await self._close_tcp(key)
  124. return
  125. if frame.kind == UDP_SEND:
  126. session = self.udp_sessions.get(key)
  127. meta = None
  128. payload = frame.payload
  129. if frame.packet_id > 0 and frame.packet_id <= len(frame.payload):
  130. try:
  131. meta = decode_json(frame.payload[: frame.packet_id])
  132. payload = frame.payload[frame.packet_id :]
  133. except Exception:
  134. print(
  135. f"[relay] udp meta decode failed session={frame.session_id} stream={frame.stream_id} "
  136. f"meta_len={frame.packet_id} payload_len={len(frame.payload)}"
  137. )
  138. if session is None:
  139. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, b"udp meta decode failed"))
  140. return
  141. payload = frame.payload
  142. if session is None:
  143. if meta is None:
  144. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, b"udp meta missing"))
  145. return
  146. try:
  147. family = int(meta.get("family", 0)) or 0
  148. print(
  149. f"[relay] udp session open session={frame.session_id} stream={frame.stream_id} "
  150. f"target={meta.get('host')}:{meta.get('port')} family={family} meta_len={frame.packet_id} payload_len={len(payload)}"
  151. )
  152. transport, protocol = await asyncio.get_running_loop().create_datagram_endpoint(
  153. lambda: RelayUdpProtocol(self, frame.session_id, frame.stream_id),
  154. remote_addr=(meta["host"], int(meta["port"])),
  155. family=family or 0,
  156. )
  157. session = UdpSession(frame.session_id, frame.stream_id, transport, protocol, meta["host"], int(meta["port"]), family)
  158. self.udp_sessions[key] = session
  159. except Exception as exc:
  160. print(
  161. f"[relay] udp session open failed session={frame.session_id} stream={frame.stream_id} "
  162. f"meta={meta!r} error={exc!r}"
  163. )
  164. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  165. return
  166. if session.transport is not None:
  167. with contextlib.suppress(Exception):
  168. session.transport.sendto(payload)
  169. return
  170. async def _tcp_pump(self, session_id: int, stream_id: int, reader: asyncio.StreamReader) -> None:
  171. try:
  172. while True:
  173. chunk = await reader.read(65536)
  174. if not chunk:
  175. break
  176. sent = await self.safe_send(Frame(TCP_DATA, session_id, stream_id, 0, 0, chunk))
  177. if not sent:
  178. break
  179. except asyncio.CancelledError:
  180. pass
  181. except Exception:
  182. pass
  183. finally:
  184. if not self.closed:
  185. await self.safe_send(Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
  186. await self._close_tcp((session_id, stream_id), from_task=True)
  187. async def _close_tcp(self, key: tuple[int, int], from_task: bool = False) -> None:
  188. session = self.tcp_sessions.pop(key, None)
  189. if session is None:
  190. return
  191. if not from_task and session.task is not asyncio.current_task():
  192. session.task.cancel()
  193. with contextlib.suppress(Exception):
  194. await session.task
  195. session.writer.close()
  196. with contextlib.suppress(Exception):
  197. await session.writer.wait_closed()
  198. async def close(self) -> None:
  199. if self.closed:
  200. return
  201. self.closed = True
  202. for key in list(self.tcp_sessions):
  203. await self._close_tcp(key)
  204. for session in self.udp_sessions.values():
  205. if session.transport:
  206. session.transport.close()
  207. self.udp_sessions.clear()
  208. self.writer.close()
  209. with contextlib.suppress(Exception):
  210. await self.writer.wait_closed()
  211. class RelayServer:
  212. def __init__(self, token: str) -> None:
  213. self.token = token
  214. async def start(self, host: str, port: int) -> None:
  215. server = await asyncio.start_server(self._accept, host, port)
  216. sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
  217. print(f"[relay] listening on {sockets}")
  218. async with server:
  219. await server.serve_forever()
  220. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  221. await RelayChannel(reader, writer, self.token).run()