relay_server_tcp.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. from dataclasses import dataclass, field
  5. from typing import Dict
  6. from .logging_utils import log_print as print
  7. from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, Frame, decode_json, encode_json, read_frame, write_frame
  8. @dataclass
  9. class TcpSession:
  10. session_id: int
  11. stream_id: int
  12. writer: asyncio.StreamWriter
  13. task: asyncio.Task
  14. @dataclass
  15. class TcpRelayChannel:
  16. reader: asyncio.StreamReader
  17. writer: asyncio.StreamWriter
  18. token: str
  19. tcp_sessions: Dict[tuple[int, int], TcpSession] = field(default_factory=dict)
  20. closed: bool = False
  21. send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  22. async def run(self) -> None:
  23. try:
  24. auth = await read_frame(self.reader)
  25. if auth.kind != AUTH:
  26. return
  27. payload = decode_json(auth.payload) if auth.payload else {}
  28. if payload.get("token") != self.token:
  29. return
  30. await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json({"status": "ok", "kind": payload.get("purpose", "normal"), "udp_only": False})))
  31. while True:
  32. frame = await read_frame(self.reader)
  33. await self.handle(frame)
  34. finally:
  35. await self.close()
  36. async def safe_send(self, frame: Frame) -> bool:
  37. if self.closed:
  38. return False
  39. try:
  40. async with self.send_lock:
  41. if self.closed:
  42. return False
  43. await write_frame(self.writer, frame)
  44. return True
  45. except Exception:
  46. return False
  47. async def handle(self, frame: Frame) -> None:
  48. key = (frame.session_id, frame.stream_id)
  49. if frame.kind == PING:
  50. await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong"))
  51. return
  52. if frame.kind == TCP_OPEN:
  53. try:
  54. meta = decode_json(frame.payload) if frame.payload else {}
  55. family = int(meta.get("family", 0)) or 0
  56. reader, writer = await asyncio.open_connection(meta["host"], int(meta["port"]), family=family or 0)
  57. task = asyncio.create_task(self._tcp_pump(frame.session_id, frame.stream_id, reader))
  58. self.tcp_sessions[key] = TcpSession(frame.session_id, frame.stream_id, writer, task)
  59. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_OK, b"ok"))
  60. except Exception as exc:
  61. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  62. return
  63. if frame.kind == TCP_DATA:
  64. session = self.tcp_sessions.get(key)
  65. if session:
  66. try:
  67. session.writer.write(frame.payload)
  68. await session.writer.drain()
  69. except Exception:
  70. await self._close_tcp(key)
  71. return
  72. if frame.kind == TCP_CLOSE:
  73. await self._close_tcp(key)
  74. async def _tcp_pump(self, session_id: int, stream_id: int, reader: asyncio.StreamReader) -> None:
  75. try:
  76. while True:
  77. chunk = await reader.read(65536)
  78. if not chunk:
  79. break
  80. sent = await self.safe_send(Frame(TCP_DATA, session_id, stream_id, 0, 0, chunk))
  81. if not sent:
  82. break
  83. finally:
  84. if not self.closed:
  85. await self.safe_send(Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
  86. await self._close_tcp((session_id, stream_id), from_task=True)
  87. async def _close_tcp(self, key: tuple[int, int], from_task: bool = False) -> None:
  88. session = self.tcp_sessions.pop(key, None)
  89. if session is None:
  90. return
  91. if not from_task and session.task is not asyncio.current_task():
  92. session.task.cancel()
  93. with contextlib.suppress(Exception):
  94. await session.task
  95. session.writer.close()
  96. with contextlib.suppress(Exception):
  97. await session.writer.wait_closed()
  98. async def close(self) -> None:
  99. if self.closed:
  100. return
  101. self.closed = True
  102. for key in list(self.tcp_sessions):
  103. await self._close_tcp(key)
  104. self.writer.close()
  105. with contextlib.suppress(Exception):
  106. await self.writer.wait_closed()
  107. class TcpRelayServer:
  108. def __init__(self, token: str) -> None:
  109. self.token = token
  110. async def start(self, host: str, port: int) -> None:
  111. async def accept(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  112. await TcpRelayChannel(reader, writer, self.token).run()
  113. server = await asyncio.start_server(accept, host, port)
  114. async with server:
  115. await server.serve_forever()