From ec85896e009041de54c7398c18ba17c7e270d89d Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Sun, 11 Jan 2026 22:06:05 -0700 Subject: [PATCH] Tunnel Functionality Validated (Initial) --- Data/Engine/services/API/devices/tunnel.py | 10 +++- Data/Engine/services/WebSocket/__init__.py | 18 ++++++ Data/Engine/services/WebSocket/vpn_shell.py | 27 +++++++-- .../src/Devices/ReverseTunnel/Powershell.jsx | 56 +++++++++++++++---- 4 files changed, 93 insertions(+), 18 deletions(-) diff --git a/Data/Engine/services/API/devices/tunnel.py b/Data/Engine/services/API/devices/tunnel.py index 515389a0..16947cce 100644 --- a/Data/Engine/services/API/devices/tunnel.py +++ b/Data/Engine/services/API/devices/tunnel.py @@ -201,6 +201,13 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None: tunnel_service = _get_tunnel_service(adapters) payload = tunnel_service.status(agent_id) + agent_socket = False + registry = getattr(adapters.context, "agent_socket_registry", None) + if registry and hasattr(registry, "is_registered"): + try: + agent_socket = bool(registry.is_registered(agent_id)) + except Exception: + agent_socket = False bump = _normalize_text(request.args.get("bump") or "") _service_log_event( "vpn_api_status_request agent_id={0} bump={1} remote={2}".format( @@ -213,8 +220,9 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None: _service_log_event( "vpn_api_status_response agent_id={0} status=down".format(agent_id) ) - return jsonify({"status": "down", "agent_id": agent_id}), 200 + return jsonify({"status": "down", "agent_id": agent_id, "agent_socket": agent_socket}), 200 payload["status"] = "up" + payload["agent_socket"] = agent_socket if bump: tunnel_service.bump_activity(agent_id) _service_log_event( diff --git a/Data/Engine/services/WebSocket/__init__.py b/Data/Engine/services/WebSocket/__init__.py index 62c89ccf..1849a4d0 100644 --- a/Data/Engine/services/WebSocket/__init__.py +++ b/Data/Engine/services/WebSocket/__init__.py @@ -85,6 +85,9 @@ class AgentSocketRegistry: self._sid_by_agent.pop(agent_id, None) return agent_id + def is_registered(self, agent_id: str) -> bool: + return bool(self._sid_by_agent.get(agent_id)) + def emit(self, agent_id: str, event: str, payload: Any) -> bool: sid = self._sid_by_agent.get(agent_id) if not sid: @@ -105,6 +108,7 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None: agent_logger = context.logger.getChild("realtime.agents") shell_bridge = VpnShellBridge(socket_server, context, adapters.service_log) agent_registry = AgentSocketRegistry(socket_server, agent_logger) + setattr(context, "agent_socket_registry", agent_registry) def _emit_agent_event(agent_id: str, event: str, payload: Any) -> bool: return agent_registry.emit(agent_id, event, payload) @@ -370,6 +374,20 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None: level="WARNING", ) return {"error": "tunnel_down"} + registry = getattr(context, "agent_socket_registry", None) + if registry and hasattr(registry, "is_registered"): + try: + if not registry.is_registered(agent_id): + _shell_log( + "vpn_shell_open_failed agent_id={0} sid={1} reason=agent_socket_missing".format( + agent_id, + request.sid, + ), + level="WARNING", + ) + return {"error": "agent_socket_missing"} + except Exception: + agent_logger.debug("agent_socket_registry lookup failed for agent_id=%s", agent_id, exc_info=True) session = shell_bridge.open_session(request.sid, agent_id) if session is None: diff --git a/Data/Engine/services/WebSocket/vpn_shell.py b/Data/Engine/services/WebSocket/vpn_shell.py index e9b56e82..ac10f4c2 100644 --- a/Data/Engine/services/WebSocket/vpn_shell.py +++ b/Data/Engine/services/WebSocket/vpn_shell.py @@ -40,9 +40,13 @@ class ShellSession: _reader: Optional[threading.Thread] = None def start_reader(self) -> None: - t = threading.Thread(target=self._read_loop, daemon=True) - t.start() - self._reader = t + starter = getattr(self.socketio, "start_background_task", None) + if callable(starter): + self._reader = starter(self._read_loop) + else: + t = threading.Thread(target=self._read_loop, daemon=True) + t.start() + self._reader = t def _service_log_event(self, message: str, *, level: str = "INFO") -> None: if not callable(self.service_log): @@ -171,6 +175,16 @@ class VpnShellBridge: service = getattr(self.context, "vpn_tunnel_service", None) if service is None: return None + existing = self._sessions.pop(sid, None) + if existing: + self._service_log_event( + "vpn_shell_replace_session agent_id={0} sid={1}".format( + existing.agent_id, + sid, + ), + level="WARNING", + ) + existing.close() status = service.status(agent_id) if not status: return None @@ -178,7 +192,8 @@ class VpnShellBridge: port = int(self.context.wireguard_shell_port) tcp = None last_error: Optional[Exception] = None - for attempt in range(3): + connect_timeout = 2.0 + for attempt in range(2): self._service_log_event( "vpn_shell_connect_attempt agent_id={0} sid={1} host={2} port={3} attempt={4}".format( agent_id, @@ -189,7 +204,7 @@ class VpnShellBridge: ) ) try: - tcp = socket.create_connection((host, port), timeout=5) + tcp = socket.create_connection((host, port), timeout=connect_timeout) break except Exception as exc: last_error = exc @@ -205,7 +220,7 @@ class VpnShellBridge: "vpn_shell_agent_start_failed agent_id={0} sid={1}".format(agent_id, sid), level="WARNING", ) - time.sleep(1) + time.sleep(0.5) if tcp is None: self._service_log_event( "vpn_shell_connect_failed agent_id={0} sid={1} host={2} port={3} error={4}".format( diff --git a/Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx b/Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx index c0023230..3c6d33a0 100644 --- a/Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx +++ b/Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx @@ -64,6 +64,8 @@ const emitAsync = (socket, event, payload, timeoutMs = 4000) => }); }); +const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); + function normalizeText(value) { if (value == null) return ""; try { @@ -245,20 +247,52 @@ export default function ReverseTunnelPowershell({ device }) { const detail = data?.detail ? `: ${data.detail}` : ""; throw new Error(`${data?.error || `HTTP ${resp.status}`}${detail}`); } - const statusResp = await fetch( - `/api/tunnel/connect/status?agent_id=${encodeURIComponent(agentId)}&bump=1` - ); - const statusData = await statusResp.json().catch(() => ({})); - if (!statusResp.ok || statusData?.status !== "up") { - throw new Error(statusData?.error || "Tunnel not ready"); - } + const waitForTunnelReady = async () => { + const deadline = Date.now() + 60000; + let lastError = ""; + while (Date.now() < deadline) { + const statusResp = await fetch( + `/api/tunnel/connect/status?agent_id=${encodeURIComponent(agentId)}&bump=1` + ); + const statusData = await statusResp.json().catch(() => ({})); + if (statusResp.ok && statusData?.status === "up") { + const agentSocket = statusData?.agent_socket; + const agentReady = agentSocket === undefined ? true : Boolean(agentSocket); + if (agentReady) { + return statusData; + } + setStatusMessage("Waiting for agent VPN socket to register..."); + } else if (statusData?.error) { + lastError = statusData.error; + } + await sleep(2000); + } + throw new Error(lastError || "Tunnel not ready"); + }; + + const statusData = await waitForTunnelReady(); setTunnel({ ...data, ...statusData }); const socket = ensureSocket(); - const openResp = await emitAsync(socket, "vpn_shell_open", { agent_id: agentId }, 6000); - if (openResp?.error) { - throw new Error(openResp.error); - } + const openShellWithRetry = async () => { + const deadline = Date.now() + 30000; + let lastError = ""; + let attempt = 0; + while (Date.now() < deadline) { + attempt += 1; + const openResp = await emitAsync(socket, "vpn_shell_open", { agent_id: agentId }, 6000); + if (!openResp?.error) { + return openResp; + } + lastError = openResp.error; + setStatusMessage(`Waiting for PowerShell shell (${attempt})...`); + await sleep(2000); + } + throw new Error(lastError || "shell_connect_failed"); + }; + + await openShellWithRetry(); + setStatusMessage(""); setSessionState("connected"); setShellState("connected"); } catch (err) {