relay_server.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. from dataclasses import dataclass, field
  5. from typing import Dict
  6. from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, read_frame, write_frame
  7. @dataclass
  8. class TcpSession:
  9. session_id: int
  10. stream_id: int
  11. writer: asyncio.StreamWriter
  12. task: asyncio.Task
  13. @dataclass
  14. class UdpSession:
  15. session_id: int
  16. stream_id: int
  17. transport: asyncio.DatagramTransport | None = None
  18. protocol: "RelayUdpProtocol | None" = None
  19. class RelayUdpProtocol(asyncio.DatagramProtocol):
  20. def __init__(self, channel: "RelayChannel", session_id: int, stream_id: int) -> None:
  21. self.channel = channel
  22. self.session_id = session_id
  23. self.stream_id = stream_id
  24. def datagram_received(self, data: bytes, _addr) -> None:
  25. if self.channel.closed:
  26. return
  27. asyncio.create_task(self.channel.safe_send(Frame(UDP_RECV, self.session_id, self.stream_id, 0, 0, data)))
  28. @dataclass
  29. class RelayChannel:
  30. reader: asyncio.StreamReader
  31. writer: asyncio.StreamWriter
  32. token: str
  33. tcp_sessions: Dict[tuple[int, int], TcpSession] = field(default_factory=dict)
  34. udp_sessions: Dict[tuple[int, int], UdpSession] = field(default_factory=dict)
  35. closed: bool = False
  36. send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  37. async def run(self) -> None:
  38. peer = self.writer.get_extra_info("peername")
  39. authed = False
  40. try:
  41. auth = await read_frame(self.reader)
  42. if auth.kind != AUTH or decode_json(auth.payload).get("token") != self.token:
  43. raise PermissionError("invalid token")
  44. authed = True
  45. print(f"[relay] auth ok peer={peer}")
  46. await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, b"ok"))
  47. while True:
  48. frame = await read_frame(self.reader)
  49. await self.handle(frame)
  50. except asyncio.IncompleteReadError:
  51. if authed:
  52. print(f"[relay] disconnected peer={peer}")
  53. except asyncio.CancelledError:
  54. pass
  55. except Exception as exc:
  56. if authed:
  57. print(f"[relay] channel error peer={peer} error={exc!r}")
  58. finally:
  59. await self.close()
  60. async def safe_send(self, frame: Frame) -> bool:
  61. if self.closed:
  62. return False
  63. try:
  64. async with self.send_lock:
  65. if self.closed:
  66. return False
  67. await write_frame(self.writer, frame)
  68. return True
  69. except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError):
  70. return False
  71. async def handle(self, frame: Frame) -> None:
  72. key = (frame.session_id, frame.stream_id)
  73. if frame.kind == PING:
  74. await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong"))
  75. return
  76. if frame.kind == TCP_OPEN:
  77. meta = decode_json(frame.payload)
  78. family = int(meta.get("family", 0)) or 0
  79. try:
  80. reader, writer = await asyncio.open_connection(meta["host"], int(meta["port"]), family=family or 0)
  81. task = asyncio.create_task(self._tcp_pump(frame.session_id, frame.stream_id, reader))
  82. self.tcp_sessions[key] = TcpSession(frame.session_id, frame.stream_id, writer, task)
  83. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_OK, b"ok"))
  84. except Exception as exc:
  85. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  86. return
  87. if frame.kind == TCP_DATA:
  88. session = self.tcp_sessions.get(key)
  89. if session:
  90. try:
  91. session.writer.write(frame.payload)
  92. await session.writer.drain()
  93. except Exception:
  94. await self._close_tcp(key)
  95. return
  96. if frame.kind == TCP_CLOSE:
  97. await self._close_tcp(key)
  98. return
  99. if frame.kind == UDP_SEND:
  100. meta = decode_json(frame.payload[: frame.packet_id])
  101. payload = frame.payload[frame.packet_id :]
  102. session = self.udp_sessions.get(key)
  103. if session is None:
  104. family = int(meta.get("family", 0)) or 0
  105. transport, protocol = await asyncio.get_running_loop().create_datagram_endpoint(
  106. lambda: RelayUdpProtocol(self, frame.session_id, frame.stream_id),
  107. remote_addr=(meta["host"], int(meta["port"])),
  108. family=family or 0,
  109. )
  110. session = UdpSession(frame.session_id, frame.stream_id, transport, protocol)
  111. self.udp_sessions[key] = session
  112. with contextlib.suppress(Exception):
  113. session.transport.sendto(payload)
  114. return
  115. async def _tcp_pump(self, session_id: int, stream_id: int, reader: asyncio.StreamReader) -> None:
  116. try:
  117. while True:
  118. chunk = await reader.read(65536)
  119. if not chunk:
  120. break
  121. sent = await self.safe_send(Frame(TCP_DATA, session_id, stream_id, 0, 0, chunk))
  122. if not sent:
  123. break
  124. except asyncio.CancelledError:
  125. pass
  126. except Exception:
  127. pass
  128. finally:
  129. if not self.closed:
  130. await self.safe_send(Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
  131. await self._close_tcp((session_id, stream_id), from_task=True)
  132. async def _close_tcp(self, key: tuple[int, int], from_task: bool = False) -> None:
  133. session = self.tcp_sessions.pop(key, None)
  134. if session is None:
  135. return
  136. if not from_task and session.task is not asyncio.current_task():
  137. session.task.cancel()
  138. with contextlib.suppress(Exception):
  139. await session.task
  140. session.writer.close()
  141. with contextlib.suppress(Exception):
  142. await session.writer.wait_closed()
  143. async def close(self) -> None:
  144. if self.closed:
  145. return
  146. self.closed = True
  147. for key in list(self.tcp_sessions):
  148. await self._close_tcp(key)
  149. for session in self.udp_sessions.values():
  150. if session.transport:
  151. session.transport.close()
  152. self.udp_sessions.clear()
  153. self.writer.close()
  154. with contextlib.suppress(Exception):
  155. await self.writer.wait_closed()
  156. class RelayServer:
  157. def __init__(self, token: str) -> None:
  158. self.token = token
  159. async def start(self, host: str, port: int) -> None:
  160. server = await asyncio.start_server(self._accept, host, port)
  161. sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
  162. print(f"[relay] listening on {sockets}")
  163. async with server:
  164. await server.serve_forever()
  165. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  166. await RelayChannel(reader, writer, self.token).run()