|
@@ -169,6 +169,7 @@ class TransparentSession:
|
|
|
writer: asyncio.StreamWriter
|
|
writer: asyncio.StreamWriter
|
|
|
paths: list[BasePath]
|
|
paths: list[BasePath]
|
|
|
warmup_bytes: int
|
|
warmup_bytes: int
|
|
|
|
|
+ loser_grace_ms: int
|
|
|
stats: dict[str, int]
|
|
stats: dict[str, int]
|
|
|
target_stats: dict[tuple[str, int], dict[str, int]]
|
|
target_stats: dict[tuple[str, int], dict[str, int]]
|
|
|
opened_count: int = 0
|
|
opened_count: int = 0
|
|
@@ -180,6 +181,7 @@ class TransparentSession:
|
|
|
winner_event: asyncio.Event = field(default_factory=asyncio.Event)
|
|
winner_event: asyncio.Event = field(default_factory=asyncio.Event)
|
|
|
closed: bool = False
|
|
closed: bool = False
|
|
|
pump_task: asyncio.Task | None = None
|
|
pump_task: asyncio.Task | None = None
|
|
|
|
|
+ loser_close_task: asyncio.Task | None = None
|
|
|
|
|
|
|
|
def _record_win(self, winner: BasePath) -> None:
|
|
def _record_win(self, winner: BasePath) -> None:
|
|
|
self.stats[winner.name] = self.stats.get(winner.name, 0) + 1
|
|
self.stats[winner.name] = self.stats.get(winner.name, 0) + 1
|
|
@@ -212,7 +214,7 @@ class TransparentSession:
|
|
|
active = [path for path in self.paths if path.opened and not path.closed]
|
|
active = [path for path in self.paths if path.opened and not path.closed]
|
|
|
if not active:
|
|
if not active:
|
|
|
break
|
|
break
|
|
|
- if self.winner is None and self.uplink_bytes <= self.warmup_bytes:
|
|
|
|
|
|
|
+ if self.uplink_bytes <= self.warmup_bytes:
|
|
|
await asyncio.gather(*(path.send(chunk) for path in active), return_exceptions=True)
|
|
await asyncio.gather(*(path.send(chunk) for path in active), return_exceptions=True)
|
|
|
else:
|
|
else:
|
|
|
if self.winner is None:
|
|
if self.winner is None:
|
|
@@ -241,7 +243,10 @@ class TransparentSession:
|
|
|
self.winner = path
|
|
self.winner = path
|
|
|
self._record_win(path)
|
|
self._record_win(path)
|
|
|
self.winner_event.set()
|
|
self.winner_event.set()
|
|
|
- await self._close_losers(path)
|
|
|
|
|
|
|
+ if self.loser_grace_ms > 0:
|
|
|
|
|
+ self.loser_close_task = asyncio.create_task(self._close_losers_after_grace(path))
|
|
|
|
|
+ else:
|
|
|
|
|
+ await self._close_losers(path)
|
|
|
if path is self.winner and payload is not None:
|
|
if path is self.winner and payload is not None:
|
|
|
self.writer.write(payload)
|
|
self.writer.write(payload)
|
|
|
await self.writer.drain()
|
|
await self.writer.drain()
|
|
@@ -258,6 +263,11 @@ class TransparentSession:
|
|
|
async def _close_losers(self, winner: BasePath) -> None:
|
|
async def _close_losers(self, winner: BasePath) -> None:
|
|
|
await asyncio.gather(*(path.close() for path in self.paths if path is not winner), return_exceptions=True)
|
|
await asyncio.gather(*(path.close() for path in self.paths if path is not winner), return_exceptions=True)
|
|
|
|
|
|
|
|
|
|
+ async def _close_losers_after_grace(self, winner: BasePath) -> None:
|
|
|
|
|
+ await asyncio.sleep(self.loser_grace_ms / 1000)
|
|
|
|
|
+ if not self.closed:
|
|
|
|
|
+ await self._close_losers(winner)
|
|
|
|
|
+
|
|
|
async def close(self) -> None:
|
|
async def close(self) -> None:
|
|
|
if self.closed:
|
|
if self.closed:
|
|
|
return
|
|
return
|
|
@@ -267,6 +277,10 @@ class TransparentSession:
|
|
|
self.pump_task.cancel()
|
|
self.pump_task.cancel()
|
|
|
with contextlib.suppress(Exception):
|
|
with contextlib.suppress(Exception):
|
|
|
await self.pump_task
|
|
await self.pump_task
|
|
|
|
|
+ if self.loser_close_task and self.loser_close_task is not asyncio.current_task():
|
|
|
|
|
+ self.loser_close_task.cancel()
|
|
|
|
|
+ with contextlib.suppress(Exception):
|
|
|
|
|
+ await self.loser_close_task
|
|
|
await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True)
|
|
await asyncio.gather(*(path.close() for path in self.paths), return_exceptions=True)
|
|
|
self.writer.close()
|
|
self.writer.close()
|
|
|
with contextlib.suppress(Exception):
|
|
with contextlib.suppress(Exception):
|
|
@@ -425,8 +439,7 @@ class TransparentUdpListener:
|
|
|
data, ancdata, _flags, src = self.socket.recvmsg(65535, 512)
|
|
data, ancdata, _flags, src = self.socket.recvmsg(65535, 512)
|
|
|
except BlockingIOError:
|
|
except BlockingIOError:
|
|
|
return
|
|
return
|
|
|
- except Exception as exc:
|
|
|
|
|
- print(f"[edge] udp recv failed family={self.family} error={exc!r}")
|
|
|
|
|
|
|
+ except Exception:
|
|
|
return
|
|
return
|
|
|
original = None
|
|
original = None
|
|
|
for level, ctype, cdata in ancdata:
|
|
for level, ctype, cdata in ancdata:
|
|
@@ -437,12 +450,13 @@ class TransparentUdpListener:
|
|
|
original = parse_sockaddr(cdata)
|
|
original = parse_sockaddr(cdata)
|
|
|
break
|
|
break
|
|
|
if original is None:
|
|
if original is None:
|
|
|
- print(f"[edge] udp missing original dst family={self.family} src={src}")
|
|
|
|
|
return
|
|
return
|
|
|
if self.family == socket.AF_INET:
|
|
if self.family == socket.AF_INET:
|
|
|
source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET)
|
|
source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET)
|
|
|
else:
|
|
else:
|
|
|
source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET6)
|
|
source = PeerAddress(host=src[0], port=src[1], family=socket.AF_INET6)
|
|
|
|
|
+ if original.port == self.port and (original.host in ("127.0.0.1", "::1") or original.host == self.bind_host):
|
|
|
|
|
+ return
|
|
|
asyncio.create_task(self.edge.handle_udp_datagram(source, original, data, self))
|
|
asyncio.create_task(self.edge.handle_udp_datagram(source, original, data, self))
|
|
|
|
|
|
|
|
async def send_response(self, source: PeerAddress, payload: bytes) -> None:
|
|
async def send_response(self, source: PeerAddress, payload: bytes) -> None:
|
|
@@ -461,10 +475,11 @@ class TransparentUdpListener:
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransparentEdge:
|
|
class TransparentEdge:
|
|
|
- def __init__(self, listen_host: str, listen_port: int, config: Config) -> None:
|
|
|
|
|
|
|
+ def __init__(self, listen_host: str, listen_port: int, config: Config, enable_udp: bool = False) -> None:
|
|
|
self.listen_host = listen_host
|
|
self.listen_host = listen_host
|
|
|
self.listen_port = listen_port
|
|
self.listen_port = listen_port
|
|
|
self.config = config
|
|
self.config = config
|
|
|
|
|
+ self.enable_udp = enable_udp
|
|
|
self.manager = RelayManager(config)
|
|
self.manager = RelayManager(config)
|
|
|
self.session_ids = itertools.count(1)
|
|
self.session_ids = itertools.count(1)
|
|
|
self.stream_ids = itertools.count(1)
|
|
self.stream_ids = itertools.count(1)
|
|
@@ -488,7 +503,8 @@ class TransparentEdge:
|
|
|
sockets.extend(str(sock.getsockname()) for sock in server6.sockets or [])
|
|
sockets.extend(str(sock.getsockname()) for sock in server6.sockets or [])
|
|
|
except Exception as exc:
|
|
except Exception as exc:
|
|
|
print(f"[edge] ipv6 tcp listener skipped: {exc!r}")
|
|
print(f"[edge] ipv6 tcp listener skipped: {exc!r}")
|
|
|
- self._start_udp_listeners()
|
|
|
|
|
|
|
+ if self.enable_udp:
|
|
|
|
|
+ self._start_udp_listeners()
|
|
|
self.udp_gc_task = asyncio.create_task(self._gc_udp_flows())
|
|
self.udp_gc_task = asyncio.create_task(self._gc_udp_flows())
|
|
|
print(f"[edge] transparent tcp listening on {', '.join(sockets)}")
|
|
print(f"[edge] transparent tcp listening on {', '.join(sockets)}")
|
|
|
if server6 is None:
|
|
if server6 is None:
|
|
@@ -520,7 +536,7 @@ class TransparentEdge:
|
|
|
try:
|
|
try:
|
|
|
target = self._get_original_dst(writer)
|
|
target = self._get_original_dst(writer)
|
|
|
session_id = next(self.session_ids)
|
|
session_id = next(self.session_ids)
|
|
|
- session = TransparentSession(session_id=session_id, target=target, reader=reader, writer=writer, paths=[], warmup_bytes=self.config.tcp_warmup_bytes, stats=self.tcp_win_counts, target_stats=self.tcp_target_wins)
|
|
|
|
|
|
|
+ session = TransparentSession(session_id=session_id, target=target, reader=reader, writer=writer, paths=[], warmup_bytes=self.config.tcp_warmup_bytes, loser_grace_ms=self.config.tcp_loser_grace_ms, stats=self.tcp_win_counts, target_stats=self.tcp_target_wins)
|
|
|
paths: list[BasePath] = [DirectTcpPath(name="direct", on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload))]
|
|
paths: list[BasePath] = [DirectTcpPath(name="direct", on_frame=lambda path, event, payload, s=session: self._handle_tcp_session(s, path, event, payload))]
|
|
|
for connection in self.manager.available():
|
|
for connection in self.manager.available():
|
|
|
stream_id = next(self.stream_ids)
|
|
stream_id = next(self.stream_ids)
|
|
@@ -551,6 +567,10 @@ class TransparentEdge:
|
|
|
raise RuntimeError(f"unsupported socket family={family}")
|
|
raise RuntimeError(f"unsupported socket family={family}")
|
|
|
|
|
|
|
|
async def handle_udp_datagram(self, source: PeerAddress, target: TargetAddress, payload: bytes, listener: TransparentUdpListener) -> None:
|
|
async def handle_udp_datagram(self, source: PeerAddress, target: TargetAddress, payload: bytes, listener: TransparentUdpListener) -> None:
|
|
|
|
|
+ if not self.enable_udp:
|
|
|
|
|
+ return
|
|
|
|
|
+ if target.port == self.listen_port and target.host in ("127.0.0.1", "::1", self.listen_host):
|
|
|
|
|
+ return
|
|
|
key = (source, target)
|
|
key = (source, target)
|
|
|
flow = self.udp_flows.get(key)
|
|
flow = self.udp_flows.get(key)
|
|
|
if flow is None:
|
|
if flow is None:
|