relay_server_tcp.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. from dataclasses import dataclass, field
  5. from typing import Dict
  6. from .protocol import AUTH, PING, PONG, STATUS_ERR, STATUS_OK, TCP_CLOSE, TCP_DATA, TCP_OPEN, TCP_STATUS, Frame, decode_json, encode_json, read_frame, write_frame
  7. @dataclass
  8. class TcpSession:
  9. session_id: int
  10. stream_id: int
  11. writer: asyncio.StreamWriter
  12. task: asyncio.Task
  13. @dataclass
  14. class TcpRelayChannel:
  15. reader: asyncio.StreamReader
  16. writer: asyncio.StreamWriter
  17. token: str
  18. tcp_sessions: Dict[tuple[int, int], TcpSession] = field(default_factory=dict)
  19. closed: bool = False
  20. send_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  21. async def run(self) -> None:
  22. try:
  23. auth = await read_frame(self.reader)
  24. if auth.kind != AUTH:
  25. return
  26. payload = decode_json(auth.payload) if auth.payload else {}
  27. if payload.get("token") != self.token:
  28. return
  29. await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json({"status": "ok", "kind": payload.get("purpose", "normal"), "udp_only": False})))
  30. while True:
  31. frame = await read_frame(self.reader)
  32. await self.handle(frame)
  33. finally:
  34. await self.close()
  35. async def safe_send(self, frame: Frame) -> bool:
  36. if self.closed:
  37. return False
  38. try:
  39. async with self.send_lock:
  40. if self.closed:
  41. return False
  42. await write_frame(self.writer, frame)
  43. return True
  44. except Exception:
  45. return False
  46. async def handle(self, frame: Frame) -> None:
  47. key = (frame.session_id, frame.stream_id)
  48. if frame.kind == PING:
  49. await self.safe_send(Frame(PONG, 0, 0, frame.seq, 0, b"pong"))
  50. return
  51. if frame.kind == TCP_OPEN:
  52. try:
  53. meta = decode_json(frame.payload) if frame.payload else {}
  54. family = int(meta.get("family", 0)) or 0
  55. reader, writer = await asyncio.open_connection(meta["host"], int(meta["port"]), family=family or 0)
  56. task = asyncio.create_task(self._tcp_pump(frame.session_id, frame.stream_id, reader))
  57. self.tcp_sessions[key] = TcpSession(frame.session_id, frame.stream_id, writer, task)
  58. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_OK, b"ok"))
  59. except Exception as exc:
  60. await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
  61. return
  62. if frame.kind == TCP_DATA:
  63. session = self.tcp_sessions.get(key)
  64. if session:
  65. try:
  66. session.writer.write(frame.payload)
  67. await session.writer.drain()
  68. except Exception:
  69. await self._close_tcp(key)
  70. return
  71. if frame.kind == TCP_CLOSE:
  72. await self._close_tcp(key)
  73. async def _tcp_pump(self, session_id: int, stream_id: int, reader: asyncio.StreamReader) -> None:
  74. try:
  75. while True:
  76. chunk = await reader.read(65536)
  77. if not chunk:
  78. break
  79. sent = await self.safe_send(Frame(TCP_DATA, session_id, stream_id, 0, 0, chunk))
  80. if not sent:
  81. break
  82. finally:
  83. if not self.closed:
  84. await self.safe_send(Frame(TCP_CLOSE, session_id, stream_id, 0, 0, b""))
  85. await self._close_tcp((session_id, stream_id), from_task=True)
  86. async def _close_tcp(self, key: tuple[int, int], from_task: bool = False) -> None:
  87. session = self.tcp_sessions.pop(key, None)
  88. if session is None:
  89. return
  90. if not from_task and session.task is not asyncio.current_task():
  91. session.task.cancel()
  92. with contextlib.suppress(Exception):
  93. await session.task
  94. session.writer.close()
  95. with contextlib.suppress(Exception):
  96. await session.writer.wait_closed()
  97. async def close(self) -> None:
  98. if self.closed:
  99. return
  100. self.closed = True
  101. for key in list(self.tcp_sessions):
  102. await self._close_tcp(key)
  103. self.writer.close()
  104. with contextlib.suppress(Exception):
  105. await self.writer.wait_closed()
  106. class TcpRelayServer:
  107. def __init__(self, token: str) -> None:
  108. self.token = token
  109. async def start(self, host: str, port: int) -> None:
  110. async def accept(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  111. await TcpRelayChannel(reader, writer, self.token).run()
  112. server = await asyncio.start_server(accept, host, port)
  113. async with server:
  114. await server.serve_forever()