relay_server.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. from dataclasses import dataclass, field
  5. from typing import Dict
  6. from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, UDP_RECV, UDP_SEND, Frame, decode_json, read_frame, write_frame
  7. @dataclass
  8. class TcpSession:
  9. session_id: int
  10. stream_id: int
  11. writer: asyncio.StreamWriter
  12. task: asyncio.Task
  13. @dataclass
  14. class UdpSession:
  15. session_id: int
  16. stream_id: int
  17. transport: asyncio.DatagramTransport | None = None
  18. protocol: "RelayUdpProtocol | None" = None
  19. class RelayUdpProtocol(asyncio.DatagramProtocol):
  20. def __init__(self, channel: "RelayChannel", session_id: int, stream_id: int) -> None:
  21. self.channel = channel
  22. self.session_id = session_id
  23. self.stream_id = stream_id
  24. def datagram_received(self, data: bytes, _addr) -> None:
  25. if self.channel.closed:
  26. return
  27. asyncio.create_task(self.channel.safe_send(Frame(UDP_RECV, self.session_id, self.stream_id, 0, 0, data)))
  28. @dataclass
  29. class RelayChannel:
  30. reader: asyncio.StreamReader
  31. writer: asyncio.StreamWriter
  32. token: str
  33. tcp_sessions: Dict[tuple[int, int], TcpSession] = field(default_factory=dict)
  34. udp_sessions: Dict[tuple[int, int], UdpSession] = field(default_factory=dict)
  35. closed: bool = False
  36. async def run(self) -> None:
  37. peer = self.writer.get_extra_info("peername")
  38. authed = False
  39. try:
  40. auth = await read_frame(self.reader)
  41. if auth.kind != AUTH or decode_json(auth.payload).get("token") != self.token:
  42. raise PermissionError("invalid token")
  43. authed = True
  44. print(f"[relay] auth ok peer={peer}")
  45. await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, b"ok"))
  46. while True:
  47. frame = await read_frame(self.reader)
  48. await self.handle(frame)
  49. except asyncio.IncompleteReadError:
  50. if authed:
  51. print(f"[relay] disconnected peer={peer}")
  52. except asyncio.CancelledError:
  53. pass
  54. except Exception as exc:
  55. if authed:
  56. print(f"[relay] channel error peer={peer} error={exc!r}")
  57. finally:
  58. await self.close()
  59. async def safe_send(self, frame: Frame) -> bool:
  60. if self.closed:
  61. return False
  62. try:
  63. await write_frame(self.writer, frame)
  64. return True
  65. except (BrokenPipeError, ConnectionResetError, RuntimeError, OSError, asyncio.CancelledError):
  66. return False
  67. async def handle(self, frame: Frame) -> None:
  68. key = (frame.session_id, frame.stream_id)
  69. if frame.kind == PING:
  70. await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong"))
  71. return
  72. if frame.kind == TCP_OPEN:
  73. meta = decode_json(frame.payload)
  74. family = int(meta.get("family", 0)) or 0
  75. try:
  76. reader, writer = await asyncio.open_connection(meta["host"], int(meta["port"]), family=family or 0)
  77. task = asyncio.create_task(self._tcp_pump(frame.session_id, frame.stream_id, reader))
  78. self.tcp_sessions[key] = TcpSession(frame.session_id, frame.stream_id, writer, task)
  79. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_OK, b"ok"))
  80. except Exception as exc:
  81. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  82. return
  83. if frame.kind == TCP_DATA:
  84. session = self.tcp_sessions.get(key)
  85. if session:
  86. try:
  87. session.writer.write(frame.payload)
  88. await session.writer.drain()
  89. except Exception:
  90. await self._close_tcp(key)
  91. return
  92. if frame.kind == TCP_CLOSE:
  93. await self._close_tcp(key)
  94. return
  95. if frame.kind == UDP_SEND:
  96. meta = decode_json(frame.payload[: frame.packet_id])
  97. payload = frame.payload[frame.packet_id :]
  98. session = self.udp_sessions.get(key)
  99. if session is None:
  100. family = int(meta.get("family", 0)) or 0
  101. transport, protocol = await asyncio.get_running_loop().create_datagram_endpoint(
  102. lambda: RelayUdpProtocol(self, frame.session_id, frame.stream_id),
  103. remote_addr=(meta["host"], int(meta["port"])),
  104. family=family or 0,
  105. )
  106. session = UdpSession(frame.session_id, frame.stream_id, transport, protocol)
  107. self.udp_sessions[key] = session
  108. with contextlib.suppress(Exception):
  109. session.transport.sendto(payload)
  110. return
  111. async def _tcp_pump(self, session_id: int, stream_id: int, reader: asyncio.StreamReader) -> None:
  112. try:
  113. while True:
  114. chunk = await reader.read(65536)
  115. if not chunk:
  116. break
  117. sent = await self.safe_send(Frame(TCP_DATA, session_id, stream_id, 0, 0, chunk))
  118. if not sent:
  119. break
  120. except asyncio.CancelledError:
  121. pass
  122. except Exception:
  123. pass
  124. finally:
  125. if not self.closed:
  126. await self.safe_send(Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
  127. await self._close_tcp((session_id, stream_id), from_task=True)
  128. async def _close_tcp(self, key: tuple[int, int], from_task: bool = False) -> None:
  129. session = self.tcp_sessions.pop(key, None)
  130. if session is None:
  131. return
  132. if not from_task and session.task is not asyncio.current_task():
  133. session.task.cancel()
  134. with contextlib.suppress(Exception):
  135. await session.task
  136. session.writer.close()
  137. with contextlib.suppress(Exception):
  138. await session.writer.wait_closed()
  139. async def close(self) -> None:
  140. if self.closed:
  141. return
  142. self.closed = True
  143. for key in list(self.tcp_sessions):
  144. await self._close_tcp(key)
  145. for session in self.udp_sessions.values():
  146. if session.transport:
  147. session.transport.close()
  148. self.udp_sessions.clear()
  149. self.writer.close()
  150. with contextlib.suppress(Exception):
  151. await self.writer.wait_closed()
  152. class RelayServer:
  153. def __init__(self, token: str) -> None:
  154. self.token = token
  155. async def start(self, host: str, port: int) -> None:
  156. server = await asyncio.start_server(self._accept, host, port)
  157. sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
  158. print(f"[relay] listening on {sockets}")
  159. async with server:
  160. await server.serve_forever()
  161. async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  162. await RelayChannel(reader, writer, self.token).run()