relay_server.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. from dataclasses import dataclass, field
  5. from typing import Dict
  6. from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, read_frame, write_frame
  7. @dataclass
  8. class TcpSession:
  9. session_id: int
  10. stream_id: int
  11. writer: asyncio.StreamWriter
  12. task: asyncio.Task
  13. @dataclass
  14. class UdpSession:
  15. session_id: int
  16. stream_id: int
  17. transport: asyncio.DatagramTransport | None = None
  18. protocol: "RelayUdpProtocol | None" = None
  19. host: str = ""
  20. port: int = 0
  21. family: int = 0
  22. class RelayUdpProtocol(asyncio.DatagramProtocol):
  23. def __init__(self, channel: "RelayChannel", session_id: int, stream_id: int) -> None:
  24. self.channel = channel
  25. self.session_id = session_id
  26. self.stream_id = stream_id
  27. def datagram_received(self, data: bytes, _addr) -> None:
  28. if self.channel.closed:
  29. return
  30. asyncio.create_task(self.channel.safe_send(Frame(UDP_RECV, self.session_id, self.stream_id, 0, 0, data)))
  31. @dataclass
  32. class RelayChannel:
  33. reader: asyncio.StreamReader
  34. writer: asyncio.StreamWriter
  35. token: str
  36. tcp_sessions: Dict[tuple[int, int], TcpSession] = field(default_factory=dict)
  37. udp_sessions: Dict[tuple[int, int], UdpSession] = field(default_factory=dict)
  38. closed: bool = False
  39. send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  40. async def run(self) -> None:
  41. peer = self.writer.get_extra_info("peername")
  42. authed = False
  43. try:
  44. auth = await read_frame(self.reader)
  45. if auth.kind != AUTH or decode_json(auth.payload).get("token") != self.token:
  46. raise PermissionError("invalid token")
  47. authed = True
  48. print(f"[relay] auth ok peer={peer}")
  49. await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, b"ok"))
  50. while True:
  51. frame = await read_frame(self.reader)
  52. await self.handle(frame)
  53. except asyncio.IncompleteReadError:
  54. if authed:
  55. print(f"[relay] disconnected peer={peer}")
  56. except asyncio.CancelledError:
  57. pass
  58. except Exception as exc:
  59. if authed:
  60. print(f"[relay] channel error peer={peer} error={exc!r}")
  61. finally:
  62. await self.close()
  63. async def safe_send(self, frame: Frame) -> bool:
  64. if self.closed:
  65. return False
  66. try:
  67. async with self.send_lock:
  68. if self.closed:
  69. return False
  70. await write_frame(self.writer, frame)
  71. return True
  72. except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError):
  73. return False
  74. async def handle(self, frame: Frame) -> None:
  75. key = (frame.session_id, frame.stream_id)
  76. if frame.kind == PING:
  77. await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong"))
  78. return
  79. if frame.kind == TCP_OPEN:
  80. meta = decode_json(frame.payload)
  81. family = int(meta.get("family", 0)) or 0
  82. try:
  83. reader, writer = await asyncio.open_connection(meta["host"], int(meta["port"]), family=family or 0)
  84. task = asyncio.create_task(self._tcp_pump(frame.session_id, frame.stream_id, reader))
  85. self.tcp_sessions[key] = TcpSession(frame.session_id, frame.stream_id, writer, task)
  86. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_OK, b"ok"))
  87. except Exception as exc:
  88. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  89. return
  90. if frame.kind == TCP_DATA:
  91. session = self.tcp_sessions.get(key)
  92. if session:
  93. try:
  94. session.writer.write(frame.payload)
  95. await session.writer.drain()
  96. except Exception:
  97. await self._close_tcp(key)
  98. return
  99. if frame.kind == TCP_CLOSE:
  100. await self._close_tcp(key)
  101. return
  102. if frame.kind == UDP_SEND:
  103. session = self.udp_sessions.get(key)
  104. meta = None
  105. payload = frame.payload
  106. if frame.packet_id > 0:
  107. meta = decode_json(frame.payload[: frame.packet_id])
  108. payload = frame.payload[frame.packet_id :]
  109. if session is None:
  110. if meta is None:
  111. return
  112. family = int(meta.get("family", 0)) or 0
  113. transport, protocol = await asyncio.get_running_loop().create_datagram_endpoint(
  114. lambda: RelayUdpProtocol(self, frame.session_id, frame.stream_id),
  115. remote_addr=(meta["host"], int(meta["port"])),
  116. family=family or 0,
  117. )
  118. session = UdpSession(frame.session_id, frame.stream_id, transport, protocol, meta["host"], int(meta["port"]), family)
  119. self.udp_sessions[key] = session
  120. with contextlib.suppress(Exception):
  121. session.transport.sendto(payload)
  122. return
  123. async def _tcp_pump(self, session_id: int, stream_id: int, reader: asyncio.StreamReader) -> None:
  124. try:
  125. while True:
  126. chunk = await reader.read(65536)
  127. if not chunk:
  128. break
  129. sent = await self.safe_send(Frame(TCP_DATA, session_id, stream_id, 0, 0, chunk))
  130. if not sent:
  131. break
  132. except asyncio.CancelledError:
  133. pass
  134. except Exception:
  135. pass
  136. finally:
  137. if not self.closed:
  138. await self.safe_send(Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
  139. await self._close_tcp((session_id, stream_id), from_task=True)
  140. async def _close_tcp(self, key: tuple[int, int], from_task: bool = False) -> None:
  141. session = self.tcp_sessions.pop(key, None)
  142. if session is None:
  143. return
  144. if not from_task and session.task is not asyncio.current_task():
  145. session.task.cancel()
  146. with contextlib.suppress(Exception):
  147. await session.task
  148. session.writer.close()
  149. with contextlib.suppress(Exception):
  150. await session.writer.wait_closed()
  151. async def close(self) -> None:
  152. if self.closed:
  153. return
  154. self.closed = True
  155. for key in list(self.tcp_sessions):
  156. await self._close_tcp(key)
  157. for session in self.udp_sessions.values():
  158. if session.transport:
  159. session.transport.close()
  160. self.udp_sessions.clear()
  161. self.writer.close()
  162. with contextlib.suppress(Exception):
  163. await self.writer.wait_closed()
  164. class RelayServer:
  165. def __init__(self, token: str) -> None:
  166. self.token = token
  167. async def start(self, host: str, port: int) -> None:
  168. server = await asyncio.start_server(self._accept, host, port)
  169. sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
  170. print(f"[relay] listening on {sockets}")
  171. async with server:
  172. await server.serve_forever()
  173. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  174. await RelayChannel(reader, writer, self.token).run()