浏览代码

修改协议不一致问题

Gogs 3 天之前
父节点
当前提交
2a01b7bedc
共有 2 个文件被更改,包括 11 次插入8 次删除
  1. 1 1
      relay_client.py
  2. 10 7
      relay_server.py

+ 1 - 1
relay_client.py

@@ -50,10 +50,10 @@ class RelayConnection:
         frame = await read_frame(self.reader)
         if frame.kind != AUTH or frame.packet_id != STATUS_OK:
             raise ConnectionError(f"relay auth failed: {self.node.name}")
-        print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port}")
         self.last_pong_at = time.monotonic()
         self.keepalive_task = asyncio.create_task(self._keepalive())
         self.pump_task = asyncio.create_task(self._pump())
+        print(f"[edge] relay connected name={self.node.name} addr={self.node.host}:{self.node.port}")
 
     async def _keepalive(self) -> None:
         try:

+ 10 - 7
relay_server.py

@@ -58,16 +58,19 @@ class RelayChannel:
         authed = False
         try:
             auth = await read_frame(self.reader)
-            payload = decode_json(auth.payload)
-            if auth.kind != AUTH or payload.get("token") != self.token:
+            if auth.kind != AUTH:
+                raise PermissionError("invalid handshake kind")
+            try:
+                payload = decode_json(auth.payload) if auth.payload else {}
+            except Exception as exc:
+                raise PermissionError(f"invalid auth payload: {exc!r}") from exc
+            if payload.get("token") != self.token:
                 raise PermissionError("invalid token")
             authed = True
             self.authed_at = time.monotonic()
             self.authed_kind = payload.get("purpose", "normal")
-            if self.authed_kind != "probe":
-                await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, b"ok"))
-            else:
-                await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json({"status": "ok", "kind": "probe"})))
+            ack_payload = {"status": "ok", "kind": self.authed_kind}
+            await self.safe_send(Frame(AUTH, 0, 0, 0, STATUS_OK, encode_json(ack_payload)))
             while True:
                 frame = await read_frame(self.reader)
                 self.frame_count += 1
@@ -80,7 +83,7 @@ class RelayChannel:
         except asyncio.CancelledError:
             pass
         except Exception as exc:
-            if authed:
+            if authed and self.authed_kind != "probe":
                 lived = time.monotonic() - self.authed_at if self.authed_at else 0.0
                 print(f"[relay] session error peer={peer} kind={self.authed_kind} lived={lived:.1f}s frames={self.frame_count} error={exc!r}")
         finally: