import asyncio import base64 import importlib.util import json import os import struct import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Optional from urllib.parse import urlparse import aiohttp from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ed25519 # Capture import errors for protocol handlers so we can report why they are missing. PS_IMPORT_ERROR: Optional[str] = None BASH_IMPORT_ERROR: Optional[str] = None tunnel_SSH = None tunnel_WinRM = None tunnel_VNC = None tunnel_RDP = None tunnel_WebRTC = None tunnel_Powershell = None tunnel_Bash = None def _load_protocol_module(module_name: str, rel_parts: list[str]) -> tuple[Optional[object], Optional[str]]: """Load a protocol handler directly from a file path to survive non-package runtimes.""" base = Path(__file__).parent path = base for part in rel_parts: path = path / part if not path.exists(): return None, f"path_missing:{path}" try: spec = importlib.util.spec_from_file_location(module_name, path) if not spec or not spec.loader: return None, "spec_failed" module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # type: ignore return module, None except Exception as exc: # pragma: no cover - defensive return None, repr(exc) try: from .Reverse_Tunnels.remote_interactive_shell.Protocols import Powershell as tunnel_Powershell # type: ignore except Exception as exc: # pragma: no cover - best-effort logging only PS_IMPORT_ERROR = repr(exc) _module, _err = _load_protocol_module( "tunnel_Powershell", ["Reverse_Tunnels", "remote_interactive_shell", "Protocols", "Powershell.py"], ) if _module: tunnel_Powershell = _module PS_IMPORT_ERROR = None else: try: from .ReverseTunnel import tunnel_Powershell # type: ignore # legacy fallback PS_IMPORT_ERROR = None except Exception as exc2: # pragma: no cover - diagnostic only PS_IMPORT_ERROR = f"{PS_IMPORT_ERROR} | legacy_fallback={exc2!r} | file_load_failed={_err}" try: from .Reverse_Tunnels.remote_interactive_shell.Protocols import Bash as tunnel_Bash # type: ignore except Exception as exc: # pragma: no cover - best-effort logging only BASH_IMPORT_ERROR = repr(exc) _module, _err = _load_protocol_module( "tunnel_Bash", ["Reverse_Tunnels", "remote_interactive_shell", "Protocols", "Bash.py"], ) if _module: tunnel_Bash = _module BASH_IMPORT_ERROR = None else: BASH_IMPORT_ERROR = f"{BASH_IMPORT_ERROR} | file_load_failed={_err}" try: from .Reverse_Tunnels.remote_management.Protocols import SSH as tunnel_SSH # type: ignore except Exception: _module, _err = _load_protocol_module( "tunnel_SSH", ["Reverse_Tunnels", "remote_management", "Protocols", "SSH.py"], ) tunnel_SSH = _module try: from .Reverse_Tunnels.remote_management.Protocols import WinRM as tunnel_WinRM # type: ignore except Exception: _module, _err = _load_protocol_module( "tunnel_WinRM", ["Reverse_Tunnels", "remote_management", "Protocols", "WinRM.py"], ) tunnel_WinRM = _module try: from .Reverse_Tunnels.remote_video.Protocols import VNC as tunnel_VNC # type: ignore except Exception: _module, _err = _load_protocol_module( "tunnel_VNC", ["Reverse_Tunnels", "remote_video", "Protocols", "VNC.py"], ) tunnel_VNC = _module try: from .Reverse_Tunnels.remote_video.Protocols import RDP as tunnel_RDP # type: ignore except Exception: _module, _err = _load_protocol_module( "tunnel_RDP", ["Reverse_Tunnels", "remote_video", "Protocols", "RDP.py"], ) tunnel_RDP = _module try: from .Reverse_Tunnels.remote_video.Protocols import WebRTC as tunnel_WebRTC # type: ignore except Exception: _module, _err = _load_protocol_module( "tunnel_WebRTC", ["Reverse_Tunnels", "remote_video", "Protocols", "WebRTC.py"], ) tunnel_WebRTC = _module ROLE_NAME = "reverse_tunnel" ROLE_CONTEXTS = ["interactive", "system"] # Message types (keep in sync with Engine service) MSG_CONNECT = 0x01 MSG_CONNECT_ACK = 0x02 MSG_CHANNEL_OPEN = 0x03 MSG_CHANNEL_ACK = 0x04 MSG_DATA = 0x05 MSG_WINDOW_UPDATE = 0x06 MSG_HEARTBEAT = 0x07 MSG_CLOSE = 0x08 MSG_CONTROL = 0x09 # Close codes CLOSE_OK = 0 CLOSE_IDLE_TIMEOUT = 1 CLOSE_GRACE_EXPIRED = 2 CLOSE_PROTOCOL_ERROR = 3 CLOSE_AUTH_FAILED = 4 CLOSE_SERVER_SHUTDOWN = 5 CLOSE_AGENT_SHUTDOWN = 6 CLOSE_DOMAIN_LIMIT = 7 CLOSE_UNEXPECTED_DISCONNECT = 8 FRAME_VERSION = 1 FRAME_HEADER_STRUCT = struct.Struct(" bytes: payload = self.payload or b"" header = FRAME_HEADER_STRUCT.pack( self.version, self.msg_type, self.flags, self.reserved, int(self.channel_id), len(payload), ) return header + payload def decode_frame(raw: bytes) -> TunnelFrame: if len(raw) < FRAME_HEADER_STRUCT.size: raise ValueError("frame_too_small") version, msg_type, flags, reserved, channel_id, length = FRAME_HEADER_STRUCT.unpack_from(raw, 0) if version != FRAME_VERSION: raise ValueError(f"unsupported_version:{version}") if length < 0 or len(raw) < FRAME_HEADER_STRUCT.size + length: raise ValueError("invalid_length") payload = raw[FRAME_HEADER_STRUCT.size : FRAME_HEADER_STRUCT.size + length] if len(payload) != length: raise ValueError("length_mismatch") return TunnelFrame(msg_type=msg_type, channel_id=channel_id, payload=payload, flags=flags, version=version, reserved=reserved) def heartbeat_frame(channel_id: int = 0, *, is_ack: bool = False) -> TunnelFrame: flags = 0x1 if is_ack else 0x0 return TunnelFrame(msg_type=MSG_HEARTBEAT, channel_id=channel_id, flags=flags, payload=b"") def close_frame(channel_id: int, code: int, reason: str = "") -> TunnelFrame: payload = json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8") return TunnelFrame(msg_type=MSG_CLOSE, channel_id=channel_id, payload=payload) def _norm_text(value: Any) -> str: if value is None: return "" try: return str(value).strip() except Exception: return "" def _is_literal_ip(value: str) -> bool: try: import ipaddress ipaddress.ip_address(value.strip().strip("[]")) return True except Exception: return False @dataclass class ActiveTunnel: tunnel_id: str domain: str protocol: str port: int token: str url: str heartbeat_seconds: int idle_seconds: int grace_seconds: int expires_at: Optional[int] signing_key_hint: Optional[str] = None session: Optional[aiohttp.ClientSession] = None websocket: Optional[aiohttp.ClientWebSocketResponse] = None tasks: list = field(default_factory=list) send_queue: asyncio.Queue = field(default_factory=asyncio.Queue) channels: Dict[int, Any] = field(default_factory=dict) last_activity: float = field(default_factory=lambda: time.time()) connected: bool = False stopping: bool = False stop_reason: Optional[str] = None stop_origin: Optional[str] = None class BaseChannel: """Placeholder channel handler; protocol-specific handlers plug in later.""" def __init__(self, role: "Role", tunnel: ActiveTunnel, channel_id: int, metadata: Optional[dict]): self.role = role self.tunnel = tunnel self.channel_id = channel_id self.metadata = metadata or {} async def start(self) -> None: # Nothing to prime for placeholder channels. return async def on_frame(self, frame: TunnelFrame) -> None: # Drop frames until protocol module is provided. return async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None: await self.role._send_frame(self.tunnel, close_frame(self.channel_id, code, reason)) class Role: def __init__(self, ctx): self.ctx = ctx self.sio = ctx.sio self.loop = ctx.loop or asyncio.get_event_loop() self.hooks = ctx.hooks or {} self._http_client_factory = self.hooks.get("http_client") self._log_hook = self.hooks.get("log_agent") self._active: Dict[str, ActiveTunnel] = {} self._domain_claims: Dict[str, str] = {} self._domain_limits: Dict[str, Optional[int]] = { "remote-interactive-shell": 2, "remote-management": 1, "remote-video": 2, # Legacy / protocol fallbacks "ps": 2, "rdp": 1, "vnc": 1, "webrtc": 2, "ssh": None, "winrm": None, } self._default_heartbeat = 20 self._protocol_handlers: Dict[str, Any] = {} self._frame_cls = TunnelFrame self.close_frame = close_frame try: if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"): self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel module_path = getattr(tunnel_Powershell, "__file__", None) self._log(f"reverse_tunnel ps handler registered (PowershellChannel) module={module_path}") else: hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else "" module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py" exists_hint = f" exists={module_path.exists()}" self._log( f"reverse_tunnel ps handler NOT registered (missing module/class){hint}{exists_hint}", error=True, ) except Exception as exc: self._log(f"reverse_tunnel ps handler registration failed: {exc}", error=True) try: if tunnel_Bash and hasattr(tunnel_Bash, "BashChannel"): self._protocol_handlers["bash"] = tunnel_Bash.BashChannel module_path = getattr(tunnel_Bash, "__file__", None) self._log(f"reverse_tunnel bash handler registered (BashChannel) module={module_path}") elif BASH_IMPORT_ERROR: self._log(f"reverse_tunnel bash handler NOT registered (missing module/class) import_error={BASH_IMPORT_ERROR}", error=True) except Exception as exc: self._log(f"reverse_tunnel bash handler registration failed: {exc}", error=True) try: if tunnel_SSH and hasattr(tunnel_SSH, "SSHChannel"): self._protocol_handlers["ssh"] = tunnel_SSH.SSHChannel if tunnel_WinRM and hasattr(tunnel_WinRM, "WinRMChannel"): self._protocol_handlers["winrm"] = tunnel_WinRM.WinRMChannel if tunnel_VNC and hasattr(tunnel_VNC, "VNCChannel"): self._protocol_handlers["vnc"] = tunnel_VNC.VNCChannel if tunnel_RDP and hasattr(tunnel_RDP, "RDPChannel"): self._protocol_handlers["rdp"] = tunnel_RDP.RDPChannel if tunnel_WebRTC and hasattr(tunnel_WebRTC, "WebRTCChannel"): self._protocol_handlers["webrtc"] = tunnel_WebRTC.WebRTCChannel except Exception as exc: self._log(f"reverse_tunnel protocol handler registration failed: {exc}", error=True) # ------------------------------------------------------------------ Logging def _log(self, message: str, *, error: bool = False) -> None: fname = "reverse_tunnel.log" try: if callable(self._log_hook): self._log_hook(message, fname=fname) if error: self._log_hook(message, fname="agent.error.log") except Exception: pass # ------------------------------------------------------------------ Event wiring def register_events(self): @self.sio.on("reverse_tunnel_start") async def _reverse_tunnel_start(payload): await self._handle_tunnel_start(payload) @self.sio.on("reverse_tunnel_stop") async def _reverse_tunnel_stop(payload): tid = "" if isinstance(payload, dict): tid = _norm_text(payload.get("tunnel_id")) await self._stop_tunnel(tid, code=CLOSE_AGENT_SHUTDOWN, reason="server_stop") # ------------------------------------------------------------------ Token helpers def _decode_token_payload(self, token: str, *, signing_key_hint: Optional[str] = None) -> Dict[str, Any]: if not token: raise ValueError("token_missing") def _b64decode(segment: str) -> bytes: padding = "=" * (-len(segment) % 4) return base64.urlsafe_b64decode(segment + padding) parts = token.split(".") payload_segment = parts[0] payload_bytes = _b64decode(payload_segment) try: payload = json.loads(payload_bytes.decode("utf-8")) except Exception as exc: raise ValueError("token_decode_error") from exc if len(parts) == 2: candidates = [] hint = _norm_text(signing_key_hint) if hint: candidates.append(hint) client = self._http_client() if client and hasattr(client, "load_server_signing_key"): try: stored = client.load_server_signing_key() except Exception: stored = None if isinstance(stored, str) and stored.strip(): candidates.append(stored.strip()) signature = _b64decode(parts[1]) verified = False for candidate in candidates: try: key_bytes = base64.b64decode(candidate, validate=True) public_key = serialization.load_der_public_key(key_bytes) except Exception: continue if not isinstance(public_key, ed25519.Ed25519PublicKey): continue try: public_key.verify(signature, payload_bytes) verified = True if client and hasattr(client, "store_server_signing_key"): try: client.store_server_signing_key(candidate) except Exception: pass break except Exception: continue if not verified: raise ValueError("token_signature_invalid") return payload def _validate_token(self, token: str, *, expected_agent: str, expected_domain: str, expected_protocol: str, expected_tunnel: str, signing_key_hint: Optional[str]) -> Dict[str, Any]: payload = self._decode_token_payload(token, signing_key_hint=signing_key_hint) def _matches(expected: str, actual: Any) -> bool: return _norm_text(expected).lower() == _norm_text(actual).lower() if expected_agent and not _matches(expected_agent, payload.get("agent_id")): raise ValueError("token_agent_mismatch") if expected_tunnel and not _matches(expected_tunnel, payload.get("tunnel_id")): raise ValueError("token_id_mismatch") if expected_domain and not _matches(expected_domain, payload.get("domain")): raise ValueError("token_domain_mismatch") if expected_protocol and not _matches(expected_protocol, payload.get("protocol")): raise ValueError("token_protocol_mismatch") expires_at = payload.get("expires_at") try: expires_ts = int(expires_at) if expires_at is not None else None except Exception: expires_ts = None if expires_ts is not None and expires_ts < int(time.time()): raise ValueError("token_expired") return payload # ------------------------------------------------------------------ Utility def _http_client(self): try: if callable(self._http_client_factory): return self._http_client_factory() except Exception: return None return None def _port_allowed(self, port: int) -> bool: return isinstance(port, int) and 1 <= port <= 65535 def _domain_allowed(self, domain: str) -> bool: limit = self._domain_limits.get(domain) if limit is None: return True active = [tid for tid, t in self._active.items() if t.domain == domain and not t.stopping] pending = [tid for dom, tid in self._domain_claims.items() if dom == domain and tid not in active] count = len(set(active + pending)) return count < limit def _build_ws_url(self, port: int) -> str: client = self._http_client() if client: try: client.refresh_base_url() except Exception: pass base_url = getattr(client, "base_url", "") or "https://localhost:5000" else: base_url = "https://localhost:5000" parsed = urlparse(base_url) host = parsed.hostname or "localhost" scheme = "wss" if (parsed.scheme or "").lower() == "https" else "ws" return f"{scheme}://{host}:{port}/" def _ssl_context(self, host: str): client = self._http_client() verify = getattr(getattr(client, "session", None), "verify", True) if verify is False: return False if isinstance(verify, str) and os.path.isfile(verify) and client and hasattr(client, "key_store"): try: ctx = client.key_store.build_ssl_context() if ctx and _is_literal_ip(host): ctx.check_hostname = False return ctx except Exception: return None return None async def _emit_status(self, payload: Dict[str, Any]) -> None: try: await self.sio.emit("reverse_tunnel_status", payload) except Exception: pass def _mark_activity(self, tunnel: ActiveTunnel) -> None: tunnel.last_activity = time.time() # ------------------------------------------------------------------ Event handlers async def _handle_tunnel_start(self, payload: Any) -> None: if not isinstance(payload, dict): self._log("reverse_tunnel_start ignored: payload not a dict", error=True) return tunnel_id = _norm_text(payload.get("tunnel_id")) token = _norm_text(payload.get("token")) port = payload.get("port") or payload.get("assigned_port") protocol = _norm_text(payload.get("protocol") or "ps").lower() or "ps" domain = _norm_text(payload.get("domain") or protocol).lower() or protocol heartbeat_seconds = int(payload.get("heartbeat_seconds") or self._default_heartbeat or 20) idle_seconds = int(payload.get("idle_seconds") or 3600) grace_seconds = int(payload.get("grace_seconds") or 3600) signing_key_hint = _norm_text(payload.get("signing_key")) agent_hint = _norm_text(payload.get("agent_id")) # Ignore broadcasts targeting other agents (Socket.IO fanout sends to both contexts). if agent_hint and agent_hint.lower() != _norm_text(self.ctx.agent_id).lower(): return if not token: self._log("reverse_tunnel_start rejected: missing token", error=True) await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_missing"}) return if not tunnel_id: tunnel_id = _norm_text(payload.get("lease_id")) # fallback alias try: claims = self._validate_token( token, expected_agent=self.ctx.agent_id, expected_domain=domain, expected_protocol=protocol, expected_tunnel=tunnel_id, signing_key_hint=signing_key_hint, ) except Exception as exc: if str(exc) == "token_agent_mismatch": # Broadcast hit the wrong agent context; ignore quietly. return self._log(f"reverse_tunnel_start rejected: token validation failed ({exc})", error=True) await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_invalid"}) return tunnel_id = _norm_text(claims.get("tunnel_id") or tunnel_id) if not tunnel_id: self._log("reverse_tunnel_start rejected: tunnel_id missing after token parse", error=True) return try: port = int(port or claims.get("assigned_port") or 0) except Exception: port = 0 if not self._port_allowed(port): self._log(f"reverse_tunnel_start rejected: invalid port {port}", error=True) await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "invalid_port"}) return domain = _norm_text(claims.get("domain") or domain).lower() protocol = _norm_text(claims.get("protocol") or protocol).lower() expires_at = claims.get("expires_at") if tunnel_id in self._active: self._log(f"reverse_tunnel_start ignored: tunnel already active tunnel_id={tunnel_id}") return if not self._domain_allowed(domain): self._log(f"reverse_tunnel_start rejected: domain limit for {domain}", error=True) await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "domain_limit"}) return url = self._build_ws_url(port) parsed = urlparse(url) heartbeat_seconds = max(5, min(heartbeat_seconds, 120)) idle_seconds = max(30, idle_seconds) grace_seconds = max(60, grace_seconds) tunnel = ActiveTunnel( tunnel_id=tunnel_id, domain=domain, protocol=protocol, port=port, token=token, url=url, heartbeat_seconds=heartbeat_seconds, idle_seconds=idle_seconds, grace_seconds=grace_seconds, expires_at=int(expires_at) if expires_at is not None else None, signing_key_hint=signing_key_hint or None, ) self._active[tunnel_id] = tunnel self._domain_claims[domain] = tunnel_id self._log(f"reverse_tunnel_start accepted tunnel_id={tunnel_id} domain={domain} protocol={protocol} url={url}") await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "connecting", "url": url}) task = self.loop.create_task(self._run_tunnel(tunnel, host=parsed.hostname or "localhost")) tunnel.tasks.append(task) # ------------------------------------------------------------------ Core tunnel handling async def _run_tunnel(self, tunnel: ActiveTunnel, *, host: str) -> None: ssl_ctx = self._ssl_context(host) timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=None) self._log( f"reverse_tunnel dialing ws url={tunnel.url} tunnel_id={tunnel.tunnel_id} " f"agent_id={self.ctx.agent_id} ssl={'yes' if ssl_ctx else 'no'}" ) try: tunnel.session = aiohttp.ClientSession(timeout=timeout) tunnel.websocket = await tunnel.session.ws_connect( tunnel.url, ssl=ssl_ctx, heartbeat=None, max_msg_size=0, timeout=timeout, ) self._mark_activity(tunnel) self._log(f"reverse_tunnel connected ws tunnel_id={tunnel.tunnel_id} peer={getattr(tunnel.websocket, 'remote', None)}") await tunnel.websocket.send_bytes( TunnelFrame( msg_type=MSG_CONNECT, channel_id=0, payload=json.dumps( { "agent_id": self.ctx.agent_id, "tunnel_id": tunnel.tunnel_id, "token": tunnel.token, "protocol": tunnel.protocol, "domain": tunnel.domain, "version": FRAME_VERSION, }, separators=(",", ":"), ).encode("utf-8"), ).encode() ) self._log(f"reverse_tunnel CONNECT sent tunnel_id={tunnel.tunnel_id}") sender = self.loop.create_task(self._pump_sender(tunnel)) receiver = self.loop.create_task(self._pump_receiver(tunnel)) heartbeats = self.loop.create_task(self._heartbeat_loop(tunnel)) watchdog = self.loop.create_task(self._watchdog(tunnel)) tunnel.tasks.extend([sender, receiver, heartbeats, watchdog]) task_labels = { sender: "sender", receiver: "receiver", heartbeats: "heartbeat", watchdog: "watchdog", } done, pending = await asyncio.wait(task_labels.keys(), return_when=asyncio.FIRST_COMPLETED) for finished in done: label = task_labels.get(finished) or "unknown" exc_text = "" try: exc_obj = finished.exception() except asyncio.CancelledError: exc_obj = None exc_text = " (cancelled)" except Exception as exc: # pragma: no cover - defensive logging exc_obj = exc if exc_obj: exc_text = f" (exc={exc_obj!r})" if not tunnel.stop_reason: tunnel.stop_reason = f"{label}_stopped{exc_text}" if not tunnel.stop_origin: tunnel.stop_origin = label self._log( f"reverse_tunnel task completed tunnel_id={tunnel.tunnel_id} task={label} stop_reason={tunnel.stop_reason}{exc_text}" ) if pending: try: self._log( "reverse_tunnel pending tasks after first completion tunnel_id=%s pending=%s", # Represent pending tasks by label for debugging. ) except Exception: pass except Exception as exc: self._log(f"reverse_tunnel connection failed tunnel_id={tunnel.tunnel_id}: {exc}", error=True) await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "connect_failed"}) finally: try: ws = tunnel.websocket self._log( f"reverse_tunnel ws closing tunnel_id={tunnel.tunnel_id} " f"code={getattr(ws, 'close_code', None)} reason={getattr(ws, 'close_reason', None)}" ) except Exception: pass await self._shutdown_tunnel(tunnel) async def _pump_sender(self, tunnel: ActiveTunnel) -> None: try: while tunnel.websocket and not tunnel.websocket.closed: frame: TunnelFrame = await tunnel.send_queue.get() try: await tunnel.websocket.send_bytes(frame.encode()) self._mark_activity(tunnel) self._log( f"reverse_tunnel send frame tunnel_id={tunnel.tunnel_id} " f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}" ) except Exception: if not tunnel.stop_reason: tunnel.stop_reason = "sender_error" break except asyncio.CancelledError: pass except Exception: if not tunnel.stop_reason: tunnel.stop_reason = "sender_failed" self._log(f"reverse_tunnel sender failed tunnel_id={tunnel.tunnel_id}", error=True) finally: self._log(f"reverse_tunnel sender stopped tunnel_id={tunnel.tunnel_id}") async def _pump_receiver(self, tunnel: ActiveTunnel) -> None: ws = tunnel.websocket if ws is None: return try: async for msg in ws: if msg.type == aiohttp.WSMsgType.BINARY: try: frame = decode_frame(msg.data) except Exception: self._log(f"reverse_tunnel frame decode failed tunnel_id={tunnel.tunnel_id}", error=True) continue self._mark_activity(tunnel) await self._handle_frame(tunnel, frame) elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSE): self._log( f"reverse_tunnel websocket closed tunnel_id={tunnel.tunnel_id} " f"code={ws.close_code} reason={ws.close_reason}" ) tunnel.stop_reason = ws.close_reason or "ws_closed" break except asyncio.CancelledError: pass except Exception: if not tunnel.stop_reason: tunnel.stop_reason = "receiver_failed" self._log(f"reverse_tunnel receiver failed tunnel_id={tunnel.tunnel_id}", error=True) finally: self._log(f"reverse_tunnel receiver stopped tunnel_id={tunnel.tunnel_id}") # If no stop_reason was set, emit a CLOSE so engine/UI see a reason. if not tunnel.stop_reason: try: await self._send_frame(tunnel, close_frame(0, CLOSE_UNEXPECTED_DISCONNECT, "receiver_stop")) tunnel.stop_reason = "receiver_stop" except Exception: pass async def _heartbeat_loop(self, tunnel: ActiveTunnel) -> None: try: while tunnel.websocket and not tunnel.websocket.closed: await asyncio.sleep(tunnel.heartbeat_seconds) await self._send_frame(tunnel, heartbeat_frame()) self._log(f"reverse_tunnel heartbeat sent tunnel_id={tunnel.tunnel_id}") except asyncio.CancelledError: pass except Exception: if not tunnel.stop_reason: tunnel.stop_reason = "heartbeat_failed" self._log(f"reverse_tunnel heartbeat failed tunnel_id={tunnel.tunnel_id}", error=True) finally: self._log(f"reverse_tunnel heartbeat loop stopped tunnel_id={tunnel.tunnel_id}") async def _watchdog(self, tunnel: ActiveTunnel) -> None: try: while tunnel.websocket and not tunnel.websocket.closed: await asyncio.sleep(10) now = time.time() if tunnel.idle_seconds and (now - tunnel.last_activity) >= tunnel.idle_seconds: await self._send_frame(tunnel, close_frame(0, CLOSE_IDLE_TIMEOUT, "idle_timeout")) tunnel.stop_reason = tunnel.stop_reason or "idle_timeout" self._log(f"reverse_tunnel watchdog idle_timeout tunnel_id={tunnel.tunnel_id}") break if tunnel.expires_at and (now - tunnel.expires_at) >= tunnel.grace_seconds: await self._send_frame(tunnel, close_frame(0, CLOSE_GRACE_EXPIRED, "grace_expired")) tunnel.stop_reason = tunnel.stop_reason or "grace_expired" self._log(f"reverse_tunnel watchdog grace_expired tunnel_id={tunnel.tunnel_id}") break except asyncio.CancelledError: pass except Exception: if not tunnel.stop_reason: tunnel.stop_reason = "watchdog_failed" self._log(f"reverse_tunnel watchdog failed tunnel_id={tunnel.tunnel_id}", error=True) finally: self._log(f"reverse_tunnel watchdog stopped tunnel_id={tunnel.tunnel_id}") async def _handle_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None: self._log( f"reverse_tunnel recv frame tunnel_id={tunnel.tunnel_id} " f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}" ) if frame.msg_type == MSG_HEARTBEAT: if frame.flags & 0x1: self._log(f"reverse_tunnel heartbeat ack tunnel_id={tunnel.tunnel_id}") return await self._send_frame(tunnel, heartbeat_frame(channel_id=frame.channel_id, is_ack=True)) return if frame.msg_type == MSG_CONNECT_ACK: tunnel.connected = True await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "connected"}) self._log(f"reverse_tunnel CONNECT_ACK tunnel_id={tunnel.tunnel_id}") return if frame.msg_type == MSG_CHANNEL_OPEN: await self._handle_channel_open(tunnel, frame) return if frame.msg_type == MSG_CLOSE: try: reason = json.loads(frame.payload.decode("utf-8")) except Exception: reason = {"code": CLOSE_UNEXPECTED_DISCONNECT, "reason": "close"} tunnel.stop_reason = reason.get("reason") if isinstance(reason, dict) else None await self._shutdown_tunnel(tunnel, send_close=False) return handler = tunnel.channels.get(frame.channel_id) if handler: try: await handler.on_frame(frame) except Exception: self._log(f"reverse_tunnel channel handler failed tunnel_id={tunnel.tunnel_id} channel={frame.channel_id}", error=True) else: if frame.msg_type in (MSG_DATA, MSG_CONTROL, MSG_WINDOW_UPDATE): await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "unknown_channel")) async def _handle_channel_open(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None: try: payload = json.loads(frame.payload.decode("utf-8")) except Exception: payload = {} protocol = _norm_text(payload.get("protocol") or tunnel.protocol) metadata = payload.get("metadata") if isinstance(payload, dict) else {} if frame.channel_id in tunnel.channels: await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "channel_exists")) return if protocol.lower() == "ps" and "ps" not in self._protocol_handlers: hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else "" module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py" exists_hint = f" exists={module_path.exists()}" self._log( f"reverse_tunnel ps handler missing; falling back to BaseChannel{hint}{exists_hint}", error=True, ) handler_cls = self._protocol_handlers.get(protocol.lower()) or BaseChannel try: handler = handler_cls(self, tunnel, frame.channel_id, metadata) except Exception: self._log(f"reverse_tunnel channel handler fallback to BaseChannel protocol={protocol}", error=True) handler = BaseChannel(self, tunnel, frame.channel_id, metadata) tunnel.channels[frame.channel_id] = handler await handler.start() await self._send_frame( tunnel, TunnelFrame( msg_type=MSG_CHANNEL_ACK, channel_id=frame.channel_id, payload=json.dumps({"status": "ok", "protocol": protocol}, separators=(",", ":")).encode("utf-8"), ), ) self._log( f"reverse_tunnel channel_opened tunnel_id={tunnel.tunnel_id} channel={frame.channel_id} " f"protocol={protocol} handler={handler.__class__.__name__} metadata={metadata}" ) async def _send_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None: if tunnel.stopping and getattr(frame, "msg_type", None) != MSG_CLOSE: return try: tunnel.send_queue.put_nowait(frame) except Exception: await tunnel.send_queue.put(frame) async def _stop_tunnel(self, tunnel_id: str, *, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "requested") -> None: tunnel = self._active.get(tunnel_id) if not tunnel: return if not tunnel.stop_origin: tunnel.stop_origin = "stop_tunnel" self._log(f"reverse_tunnel stop_tunnel requested tunnel_id={tunnel_id} code={code} reason={reason}") if not tunnel.stop_reason: tunnel.stop_reason = reason or "requested" await self._send_frame(tunnel, close_frame(0, code, reason)) await self._shutdown_tunnel(tunnel, send_close=False) async def _shutdown_tunnel(self, tunnel: ActiveTunnel, *, send_close: bool = True) -> None: if tunnel.stopping: return tunnel.stopping = True reason_text = tunnel.stop_reason or "closed" if not tunnel.stop_reason: tunnel.stop_reason = reason_text if not tunnel.stop_origin: tunnel.stop_origin = "shutdown" self._log( f"reverse_tunnel shutdown start tunnel_id={tunnel.tunnel_id} stop_reason={tunnel.stop_reason} " f"stop_origin={tunnel.stop_origin} ws_closed={getattr(tunnel.websocket, 'closed', None)}" ) # Stop all channels first so CLOSE frames (with reasons) are sent upstream. for handler in list(tunnel.channels.values()): try: await handler.stop(code=CLOSE_UNEXPECTED_DISCONNECT, reason=reason_text or "tunnel_shutdown") except Exception: pass if send_close: close_payload = close_frame(0, CLOSE_AGENT_SHUTDOWN, reason_text or "agent_shutdown") try: await self._send_frame(tunnel, close_payload) # Give the sender loop a brief window to flush the CLOSE upstream. await asyncio.sleep(0.05) except Exception: pass # Fallback: if sender task died, try sending directly on the websocket. try: if tunnel.websocket and not tunnel.websocket.closed: await tunnel.websocket.send_bytes(close_payload.encode()) except Exception: pass for task in list(tunnel.tasks): try: task.cancel() except Exception: pass if tunnel.websocket is not None: try: message = (reason_text or "agent_shutdown").encode("utf-8", "ignore")[:120] await tunnel.websocket.close(message=message) except Exception: pass if tunnel.session is not None: try: await tunnel.session.close() except Exception: pass self._active.pop(tunnel.tunnel_id, None) if tunnel.domain in self._domain_claims and self._domain_claims.get(tunnel.domain) == tunnel.tunnel_id: self._domain_claims.pop(tunnel.domain, None) self._log(f"reverse_tunnel stopped tunnel_id={tunnel.tunnel_id} reason={tunnel.stop_reason or 'closed'}") await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "closed", "reason": tunnel.stop_reason or "closed"}) # ------------------------------------------------------------------ Lifecycle def stop_all(self): for tunnel_id in list(self._active.keys()): try: self.loop.create_task(self._stop_tunnel(tunnel_id, code=CLOSE_AGENT_SHUTDOWN, reason="agent_shutdown")) except Exception: pass