|
@@ -69,6 +69,7 @@ class RelayChannel:
|
|
|
authed_at: float = 0.0
|
|
authed_at: float = 0.0
|
|
|
frame_count: int = 0
|
|
frame_count: int = 0
|
|
|
authed_kind: str = "normal"
|
|
authed_kind: str = "normal"
|
|
|
|
|
+ udp_only: bool = False
|
|
|
|
|
|
|
|
async def run(self) -> None:
|
|
async def run(self) -> None:
|
|
|
peer = self.writer.get_extra_info("peername")
|
|
peer = self.writer.get_extra_info("peername")
|
|
@@ -86,7 +87,7 @@ class RelayChannel:
|
|
|
authed = True
|
|
authed = True
|
|
|
self.authed_at = time.monotonic()
|
|
self.authed_at = time.monotonic()
|
|
|
self.authed_kind = payload.get("purpose", "normal")
|
|
self.authed_kind = payload.get("purpose", "normal")
|
|
|
- ack_payload = {"status": "ok", "kind": self.authed_kind}
|
|
|
|
|
|
|
+ ack_payload = {"status": "ok", "kind": self.authed_kind, "udp_only": self.udp_only}
|
|
|
await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json(ack_payload)))
|
|
await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json(ack_payload)))
|
|
|
while True:
|
|
while True:
|
|
|
frame = await read_frame(self.reader)
|
|
frame = await read_frame(self.reader)
|
|
@@ -126,6 +127,9 @@ class RelayChannel:
|
|
|
if frame.kind == AUTH:
|
|
if frame.kind == AUTH:
|
|
|
return
|
|
return
|
|
|
if frame.kind == TCP_OPEN:
|
|
if frame.kind == TCP_OPEN:
|
|
|
|
|
+ if self.udp_only:
|
|
|
|
|
+ await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, b"tcp disabled on udp-only relay"))
|
|
|
|
|
+ return
|
|
|
try:
|
|
try:
|
|
|
meta = decode_json(frame.payload) if frame.payload else {}
|
|
meta = decode_json(frame.payload) if frame.payload else {}
|
|
|
family = int(meta.get("family", 0)) or 0
|
|
family = int(meta.get("family", 0)) or 0
|
|
@@ -137,6 +141,8 @@ class RelayChannel:
|
|
|
await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
|
|
await self.safe_send(Frame(TCP_STATUS, frame.session_id, frame.stream_id, 0, STATUS_ERR, str(exc).encode()))
|
|
|
return
|
|
return
|
|
|
if frame.kind == TCP_DATA:
|
|
if frame.kind == TCP_DATA:
|
|
|
|
|
+ if self.udp_only:
|
|
|
|
|
+ return
|
|
|
session = self.tcp_sessions.get(key)
|
|
session = self.tcp_sessions.get(key)
|
|
|
if session:
|
|
if session:
|
|
|
try:
|
|
try:
|
|
@@ -146,6 +152,8 @@ class RelayChannel:
|
|
|
await self._close_tcp(key)
|
|
await self._close_tcp(key)
|
|
|
return
|
|
return
|
|
|
if frame.kind == TCP_CLOSE:
|
|
if frame.kind == TCP_CLOSE:
|
|
|
|
|
+ if self.udp_only:
|
|
|
|
|
+ return
|
|
|
await self._close_tcp(key)
|
|
await self._close_tcp(key)
|
|
|
return
|
|
return
|
|
|
if frame.kind == UDP_SEND:
|
|
if frame.kind == UDP_SEND:
|
|
@@ -249,13 +257,16 @@ class RelayChannel:
|
|
|
class RelayServer:
|
|
class RelayServer:
|
|
|
def __init__(self, token: str) -> None:
|
|
def __init__(self, token: str) -> None:
|
|
|
self.token = token
|
|
self.token = token
|
|
|
|
|
+ self.udp_only = False
|
|
|
|
|
|
|
|
- async def start(self, host: str, port: int) -> None:
|
|
|
|
|
|
|
+ async def start(self, host: str, port: int, udp_only: bool = False) -> None:
|
|
|
|
|
+ self.udp_only = udp_only
|
|
|
server = await asyncio.start_server(self._accept, host, port)
|
|
server = await asyncio.start_server(self._accept, host, port)
|
|
|
sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
|
|
sockets = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
|
|
|
- print(f"[relay] listening on {sockets}")
|
|
|
|
|
|
|
+ mode = "udp-only" if udp_only else "normal"
|
|
|
|
|
+ print(f"[relay] listening on {sockets} mode={mode}")
|
|
|
async with server:
|
|
async with server:
|
|
|
await server.serve_forever()
|
|
await server.serve_forever()
|
|
|
|
|
|
|
|
async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
|
async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
|
|
- await RelayChannel(reader, writer, self.token).run()
|
|
|
|
|
|
|
+ await RelayChannel(reader, writer, self.token, udp_only=self.udp_only).run()
|