From 68dd46347bf4ce5deb33f1cbb88c2e9f03a53971 Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Sat, 6 Dec 2025 00:27:57 -0700 Subject: [PATCH] Major Progress Towards Interactive Remote Powershell --- Borealis.ps1 | 7 + .../Roles/ReverseTunnel/tunnel_Powershell.py | 119 ++- Data/Agent/Roles/role_ReverseTunnel.py | 92 ++- Data/Engine/Unit_Tests/test_reverse_tunnel.py | 90 +++ .../test_reverse_tunnel_integration.py | 101 +++ .../services/WebSocket/Agent/ReverseTunnel.py | 131 +++- Data/Engine/services/WebSocket/__init__.py | 2 +- .../src/Devices/Device_Details.jsx | 54 +- .../src/Devices/ReverseTunnel/Powershell.jsx | 704 ++++++++++++++++++ 9 files changed, 1247 insertions(+), 53 deletions(-) create mode 100644 Data/Engine/Unit_Tests/test_reverse_tunnel.py create mode 100644 Data/Engine/Unit_Tests/test_reverse_tunnel_integration.py create mode 100644 Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx diff --git a/Borealis.ps1 b/Borealis.ps1 index ab6c34e6..f3462cb9 100644 --- a/Borealis.ps1 +++ b/Borealis.ps1 @@ -995,6 +995,13 @@ function InstallOrUpdate-BorealisAgent { ) Copy-Item $coreAgentFiles -Destination $agentDestinationFolder -Recurse -Force + + # Ensure ReverseTunnel role is refreshed explicitly (covers incremental changes) + $rtSource = Join-Path $agentSourceRoot 'Roles\ReverseTunnel' + $rtDest = Join-Path $agentDestinationFolder 'Roles' + if (Test-Path $rtSource) { + Copy-Item $rtSource -Destination $rtDest -Recurse -Force + } } . (Join-Path $venvFolderPath 'Scripts\Activate') } diff --git a/Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py b/Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py index 52125db8..528ed0fc 100644 --- a/Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py +++ b/Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio import os import sys +import subprocess from typing import Any, Dict, Optional # Message types mirrored from the tunnel framing (kept local to avoid import cycles). @@ -30,6 +31,7 @@ class PowershellChannel: self._writer_task = None self._stdin_queue: asyncio.Queue = asyncio.Queue() self._pty = None + self._proc: Optional[asyncio.subprocess.Process] = None self._exit_code: Optional[int] = None self._frame_cls = getattr(role, "_frame_cls", None) @@ -62,12 +64,11 @@ class PowershellChannel: ) await self._send_frame(frame) - def _powershell_path(self) -> str: + def _powershell_argv(self) -> list: preferred = self.metadata.get("shell") if isinstance(self.metadata, dict) else None - if isinstance(preferred, str) and preferred.strip(): - return preferred.strip() - # Default to Windows PowerShell; fallback to pwsh if provided later. - return "powershell.exe" + shell = preferred.strip() if isinstance(preferred, str) and preferred.strip() else "powershell.exe" + # Keep the process alive and read commands from stdin; -Command - tells PS to consume stdin. + return [shell, "-NoLogo", "-NoProfile", "-NoExit", "-Command", "-"] def _initial_size(self) -> tuple: cols = int(self.metadata.get("cols") or self.metadata.get("columns") or 120) if isinstance(self.metadata, dict) else 120 @@ -79,30 +80,46 @@ class PowershellChannel: # ------------------------------------------------------------------ Lifecycle async def start(self) -> None: if sys.platform.lower().startswith("win") is False: + self.role._log("reverse_tunnel ps start aborted: non-windows platform", error=True) await self._send_close(CLOSE_PROTOCOL_ERROR, "windows_only") return + + argv = self._powershell_argv() + cols, rows = self._initial_size() + self.role._log(f"reverse_tunnel ps start channel={self.channel_id} argv={' '.join(argv)} cols={cols} rows={rows}") + + # Preferred: ConPTY via pywinpty. try: import pywinpty # type: ignore - except Exception as exc: # pragma: no cover - dependency guard - self.role._log(f"reverse_tunnel ps channel missing pywinpty: {exc}", error=True) - await self._send_close(CLOSE_PROTOCOL_ERROR, "pywinpty_missing") - return - shell = self._powershell_path() - cols, rows = self._initial_size() - try: self._pty = pywinpty.Process( - spawn_cmd=shell, + spawn_cmd=" ".join(argv[:-1]) if argv[-1] == "-" else " ".join(argv), dimensions=(cols, rows), ) + self._reader_task = self.loop.create_task(self._pump_pty_stdout()) + self._writer_task = self.loop.create_task(self._pump_pty_stdin()) + self.role._log(f"reverse_tunnel ps channel started (pty) argv={' '.join(argv)} cols={cols} rows={rows}") + return except Exception as exc: - self.role._log(f"reverse_tunnel ps channel failed to spawn {shell}: {exc}", error=True) + self.role._log(f"reverse_tunnel ps channel pywinpty unavailable, falling back to pipes: {exc}", error=True) + + # Fallback: subprocess pipes (no PTY). + try: + self._proc = await asyncio.create_subprocess_exec( + *argv, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0), + ) + except Exception as exc: + self.role._log(f"reverse_tunnel ps channel spawn failed argv={' '.join(argv)}: {exc}", error=True) await self._send_close(CLOSE_PROTOCOL_ERROR, "spawn_failed") return - self._reader_task = self.loop.create_task(self._pump_stdout()) - self._writer_task = self.loop.create_task(self._pump_stdin()) - self.role._log(f"reverse_tunnel ps channel started shell={shell} cols={cols} rows={rows}") + self._reader_task = self.loop.create_task(self._pump_proc_stdout()) + self._writer_task = self.loop.create_task(self._pump_proc_stdin()) + self.role._log(f"reverse_tunnel ps channel started (pipes) argv={' '.join(argv)} cols={cols} rows={rows}") async def on_frame(self, frame) -> None: if self._closed: @@ -140,6 +157,7 @@ class PowershellChannel: await self._resize(cols_int, rows_int) async def _resize(self, cols: Optional[int], rows: Optional[int]) -> None: + # Resize only applies to PTY sessions; pipe mode ignores. if self._pty is None: return try: @@ -155,17 +173,14 @@ class PowershellChannel: except Exception: self.role._log("reverse_tunnel ps channel resize failed", error=True) - async def _pump_stdout(self) -> None: + async def _pump_pty_stdout(self) -> None: loop = asyncio.get_event_loop() try: while not self._closed and self._pty: chunk = await loop.run_in_executor(None, self._pty.read, 4096) if chunk is None: break - if isinstance(chunk, str): - data = chunk.encode("utf-8", errors="replace") - else: - data = bytes(chunk) + data = chunk.encode("utf-8", errors="replace") if isinstance(chunk, str) else bytes(chunk) if not data: break frame = self._make_frame(MSG_DATA, payload=data) @@ -173,11 +188,11 @@ class PowershellChannel: except asyncio.CancelledError: pass except Exception: - self.role._log("reverse_tunnel ps stdout pump error", error=True) + self.role._log("reverse_tunnel ps pty stdout pump error", error=True) finally: await self.stop(reason="stdout_closed") - async def _pump_stdin(self) -> None: + async def _pump_pty_stdin(self) -> None: loop = asyncio.get_event_loop() try: while not self._closed and self._pty: @@ -187,10 +202,7 @@ class PowershellChannel: break if data is None: break - if isinstance(data, (bytes, bytearray)): - text = data.decode("utf-8", errors="replace") - else: - text = str(data) + text = data.decode("utf-8", errors="replace") if isinstance(data, (bytes, bytearray)) else str(data) try: await loop.run_in_executor(None, self._pty.write, text) except Exception: @@ -198,7 +210,47 @@ class PowershellChannel: except asyncio.CancelledError: pass except Exception: - self.role._log("reverse_tunnel ps stdin pump error", error=True) + self.role._log("reverse_tunnel ps pty stdin pump error", error=True) + finally: + await self.stop(reason="stdin_closed") + + # -------------------- Pipe fallback pumps -------------------- + async def _pump_proc_stdout(self) -> None: + try: + while self._proc and not self._closed: + chunk = await self._proc.stdout.read(4096) + if not chunk: + break + frame = self._make_frame(MSG_DATA, payload=bytes(chunk)) + await self._send_frame(frame) + except asyncio.CancelledError: + pass + except Exception: + self.role._log("reverse_tunnel ps pipe stdout pump error", error=True) + finally: + if self._proc and not self._closed: + try: + self._exit_code = await self._proc.wait() + except Exception: + pass + await self.stop(reason="stdout_closed") + + async def _pump_proc_stdin(self) -> None: + try: + while self._proc and not self._closed: + data = await self._stdin_queue.get() + if self._closed or not self._proc or not self._proc.stdin: + break + try: + self._proc.stdin.write(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8")) + await self._proc.stdin.drain() + except Exception: + self.role._log("reverse_tunnel ps pipe stdin pump error", error=True) + break + except asyncio.CancelledError: + pass + except Exception: + self.role._log("reverse_tunnel ps pipe stdin pump error", error=True) finally: await self.stop(reason="stdin_closed") @@ -211,6 +263,11 @@ class PowershellChannel: self._pty.terminate() except Exception: pass + if self._proc is not None: + try: + self._proc.terminate() + except Exception: + pass current = asyncio.current_task() if self._reader_task and self._reader_task is not current: try: @@ -222,5 +279,7 @@ class PowershellChannel: self._writer_task.cancel() except Exception: pass - await self._send_close(code, reason or "powershell_exit") + # Include exit code in the close reason for debugging. + exit_suffix = f" (exit={self._exit_code})" if self._exit_code is not None else "" + await self._send_close(code, (reason or "powershell_exit") + exit_suffix) self.role._log(f"reverse_tunnel ps channel stopped channel={self.channel_id} reason={reason or 'exit'}") diff --git a/Data/Agent/Roles/role_ReverseTunnel.py b/Data/Agent/Roles/role_ReverseTunnel.py index 8161b00a..bd0ceb97 100644 --- a/Data/Agent/Roles/role_ReverseTunnel.py +++ b/Data/Agent/Roles/role_ReverseTunnel.py @@ -1,10 +1,12 @@ import asyncio import base64 +import importlib.util import json import os import struct import time from dataclasses import dataclass, field +from pathlib import Path from typing import Any, Dict, Optional from urllib.parse import urlparse @@ -12,10 +14,25 @@ import aiohttp from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ed25519 +# Capture import errors for the PowerShell handler so we can report why it is missing. +PS_IMPORT_ERROR: Optional[str] = None +tunnel_Powershell = None try: - from .ReverseTunnel import tunnel_Powershell -except Exception: - tunnel_Powershell = None + from .ReverseTunnel import tunnel_Powershell # type: ignore +except Exception as exc: # pragma: no cover - best-effort logging only + PS_IMPORT_ERROR = repr(exc) + # Try manual import from file to survive non-package execution. + try: + _ps_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py" + if _ps_path.exists(): + spec = importlib.util.spec_from_file_location("tunnel_Powershell", _ps_path) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) # type: ignore + tunnel_Powershell = module + PS_IMPORT_ERROR = None + except Exception as exc2: # pragma: no cover - diagnostic only + PS_IMPORT_ERROR = f"{PS_IMPORT_ERROR} | fallback_load_failed={exc2!r}" ROLE_NAME = "reverse_tunnel" ROLE_CONTEXTS = ["interactive", "system"] @@ -178,8 +195,21 @@ class Role: self._protocol_handlers: Dict[str, Any] = {} self._frame_cls = TunnelFrame self.close_frame = close_frame - if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"): - self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel + try: + if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"): + self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel + module_path = getattr(tunnel_Powershell, "__file__", None) + self._log(f"reverse_tunnel ps handler registered (PowershellChannel) module={module_path}") + else: + hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else "" + module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py" + exists_hint = f" exists={module_path.exists()}" + self._log( + f"reverse_tunnel ps handler NOT registered (missing module/class){hint}{exists_hint}", + error=True, + ) + except Exception as exc: + self._log(f"reverse_tunnel ps handler registration failed: {exc}", error=True) # ------------------------------------------------------------------ Logging def _log(self, message: str, *, error: bool = False) -> None: @@ -359,6 +389,11 @@ class Role: idle_seconds = int(payload.get("idle_seconds") or 3600) grace_seconds = int(payload.get("grace_seconds") or 3600) signing_key_hint = _norm_text(payload.get("signing_key")) + agent_hint = _norm_text(payload.get("agent_id")) + + # Ignore broadcasts targeting other agents (Socket.IO fanout sends to both contexts). + if agent_hint and agent_hint.lower() != _norm_text(self.ctx.agent_id).lower(): + return if not token: self._log("reverse_tunnel_start rejected: missing token", error=True) @@ -378,6 +413,9 @@ class Role: signing_key_hint=signing_key_hint, ) except Exception as exc: + if str(exc) == "token_agent_mismatch": + # Broadcast hit the wrong agent context; ignore quietly. + return self._log(f"reverse_tunnel_start rejected: token validation failed ({exc})", error=True) await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_invalid"}) return @@ -439,6 +477,10 @@ class Role: async def _run_tunnel(self, tunnel: ActiveTunnel, *, host: str) -> None: ssl_ctx = self._ssl_context(host) timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=None) + self._log( + f"reverse_tunnel dialing ws url={tunnel.url} tunnel_id={tunnel.tunnel_id} " + f"agent_id={self.ctx.agent_id} ssl={'yes' if ssl_ctx else 'no'}" + ) try: tunnel.session = aiohttp.ClientSession(timeout=timeout) tunnel.websocket = await tunnel.session.ws_connect( @@ -449,6 +491,7 @@ class Role: timeout=timeout, ) self._mark_activity(tunnel) + self._log(f"reverse_tunnel connected ws tunnel_id={tunnel.tunnel_id} peer={getattr(tunnel.websocket, 'remote', None)}") await tunnel.websocket.send_bytes( TunnelFrame( msg_type=MSG_CONNECT, @@ -466,6 +509,7 @@ class Role: ).encode("utf-8"), ).encode() ) + self._log(f"reverse_tunnel CONNECT sent tunnel_id={tunnel.tunnel_id}") sender = self.loop.create_task(self._pump_sender(tunnel)) receiver = self.loop.create_task(self._pump_receiver(tunnel)) @@ -486,12 +530,18 @@ class Role: try: await tunnel.websocket.send_bytes(frame.encode()) self._mark_activity(tunnel) + self._log( + f"reverse_tunnel send frame tunnel_id={tunnel.tunnel_id} " + f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}" + ) except Exception: break except asyncio.CancelledError: pass except Exception: self._log(f"reverse_tunnel sender failed tunnel_id={tunnel.tunnel_id}", error=True) + finally: + self._log(f"reverse_tunnel sender stopped tunnel_id={tunnel.tunnel_id}") async def _pump_receiver(self, tunnel: ActiveTunnel) -> None: ws = tunnel.websocket @@ -508,21 +558,30 @@ class Role: self._mark_activity(tunnel) await self._handle_frame(tunnel, frame) elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSE): + self._log( + f"reverse_tunnel websocket closed tunnel_id={tunnel.tunnel_id} " + f"code={ws.close_code} reason={ws.close_reason}" + ) break except asyncio.CancelledError: pass except Exception: self._log(f"reverse_tunnel receiver failed tunnel_id={tunnel.tunnel_id}", error=True) + finally: + self._log(f"reverse_tunnel receiver stopped tunnel_id={tunnel.tunnel_id}") async def _heartbeat_loop(self, tunnel: ActiveTunnel) -> None: try: while tunnel.websocket and not tunnel.websocket.closed: await asyncio.sleep(tunnel.heartbeat_seconds) await self._send_frame(tunnel, heartbeat_frame()) + self._log(f"reverse_tunnel heartbeat sent tunnel_id={tunnel.tunnel_id}") except asyncio.CancelledError: pass except Exception: self._log(f"reverse_tunnel heartbeat failed tunnel_id={tunnel.tunnel_id}", error=True) + finally: + self._log(f"reverse_tunnel heartbeat loop stopped tunnel_id={tunnel.tunnel_id}") async def _watchdog(self, tunnel: ActiveTunnel) -> None: try: @@ -531,24 +590,34 @@ class Role: now = time.time() if tunnel.idle_seconds and (now - tunnel.last_activity) >= tunnel.idle_seconds: await self._send_frame(tunnel, close_frame(0, CLOSE_IDLE_TIMEOUT, "idle_timeout")) + self._log(f"reverse_tunnel watchdog idle_timeout tunnel_id={tunnel.tunnel_id}") break if tunnel.expires_at and (now - tunnel.expires_at) >= tunnel.grace_seconds: await self._send_frame(tunnel, close_frame(0, CLOSE_GRACE_EXPIRED, "grace_expired")) + self._log(f"reverse_tunnel watchdog grace_expired tunnel_id={tunnel.tunnel_id}") break except asyncio.CancelledError: pass except Exception: self._log(f"reverse_tunnel watchdog failed tunnel_id={tunnel.tunnel_id}", error=True) + finally: + self._log(f"reverse_tunnel watchdog stopped tunnel_id={tunnel.tunnel_id}") async def _handle_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None: + self._log( + f"reverse_tunnel recv frame tunnel_id={tunnel.tunnel_id} " + f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}" + ) if frame.msg_type == MSG_HEARTBEAT: if frame.flags & 0x1: + self._log(f"reverse_tunnel heartbeat ack tunnel_id={tunnel.tunnel_id}") return await self._send_frame(tunnel, heartbeat_frame(channel_id=frame.channel_id, is_ack=True)) return if frame.msg_type == MSG_CONNECT_ACK: tunnel.connected = True await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "connected"}) + self._log(f"reverse_tunnel CONNECT_ACK tunnel_id={tunnel.tunnel_id}") return if frame.msg_type == MSG_CHANNEL_OPEN: await self._handle_channel_open(tunnel, frame) @@ -584,6 +653,15 @@ class Role: await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "channel_exists")) return + if protocol.lower() == "ps" and "ps" not in self._protocol_handlers: + hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else "" + module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py" + exists_hint = f" exists={module_path.exists()}" + self._log( + f"reverse_tunnel ps handler missing; falling back to BaseChannel{hint}{exists_hint}", + error=True, + ) + handler_cls = self._protocol_handlers.get(protocol.lower()) or BaseChannel try: handler = handler_cls(self, tunnel, frame.channel_id, metadata) @@ -599,6 +677,10 @@ class Role: payload=json.dumps({"status": "ok", "protocol": protocol}, separators=(",", ":")).encode("utf-8"), ), ) + self._log( + f"reverse_tunnel channel_opened tunnel_id={tunnel.tunnel_id} channel={frame.channel_id} " + f"protocol={protocol} handler={handler.__class__.__name__} metadata={metadata}" + ) async def _send_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None: if tunnel.stopping: diff --git a/Data/Engine/Unit_Tests/test_reverse_tunnel.py b/Data/Engine/Unit_Tests/test_reverse_tunnel.py new file mode 100644 index 00000000..9f4b14f1 --- /dev/null +++ b/Data/Engine/Unit_Tests/test_reverse_tunnel.py @@ -0,0 +1,90 @@ +# ====================================================== +# Data\Engine\Unit_Tests\test_reverse_tunnel.py +# Description: Validates reverse tunnel lease API basics (allocation, token contents, and domain limit). +# +# API Endpoints (if applicable): +# - POST /api/tunnel/request +# ====================================================== + +from __future__ import annotations + +import base64 +import json + +import pytest + +from .conftest import EngineTestHarness + + +def _client_with_admin_session(harness: EngineTestHarness): + client = harness.app.test_client() + with client.session_transaction() as sess: + sess["username"] = "admin" + sess["role"] = "Admin" + return client + + +def _decode_token_segment(token: str) -> dict: + """Decode the unsigned payload segment from the tunnel token.""" + if not token: + return {} + segment = token.split(".")[0] + padding = "=" * (-len(segment) % 4) + raw = base64.urlsafe_b64decode(segment + padding) + try: + return json.loads(raw.decode("utf-8")) + except Exception: + return {} + + +@pytest.mark.parametrize("agent_id", ["test-device-agent"]) +def test_tunnel_request_happy_path(engine_harness: EngineTestHarness, agent_id: str) -> None: + client = _client_with_admin_session(engine_harness) + resp = client.post( + "/api/tunnel/request", + json={"agent_id": agent_id, "protocol": "ps", "domain": "ps"}, + ) + assert resp.status_code == 200 + payload = resp.get_json() + assert payload["agent_id"] == agent_id + assert payload["protocol"] == "ps" + assert payload["domain"] == "ps" + assert isinstance(payload["port"], int) and payload["port"] >= 30000 + assert payload.get("token") + + claims = _decode_token_segment(payload["token"]) + assert claims.get("agent_id") == agent_id + assert claims.get("protocol") == "ps" + assert claims.get("domain") == "ps" + assert claims.get("tunnel_id") == payload["tunnel_id"] + assert claims.get("assigned_port") == payload["port"] + + +def test_tunnel_request_domain_limit(engine_harness: EngineTestHarness) -> None: + client = _client_with_admin_session(engine_harness) + first = client.post( + "/api/tunnel/request", + json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"}, + ) + assert first.status_code == 200 + + second = client.post( + "/api/tunnel/request", + json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"}, + ) + assert second.status_code == 409 + data = second.get_json() + assert data.get("error") == "domain_limit" + + +def test_tunnel_request_includes_timeouts(engine_harness: EngineTestHarness) -> None: + client = _client_with_admin_session(engine_harness) + resp = client.post( + "/api/tunnel/request", + json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"}, + ) + assert resp.status_code == 200 + payload = resp.get_json() + assert payload.get("idle_seconds") and payload["idle_seconds"] > 0 + assert payload.get("grace_seconds") and payload["grace_seconds"] > 0 + assert payload.get("expires_at") and int(payload["expires_at"]) > 0 diff --git a/Data/Engine/Unit_Tests/test_reverse_tunnel_integration.py b/Data/Engine/Unit_Tests/test_reverse_tunnel_integration.py new file mode 100644 index 00000000..0837469a --- /dev/null +++ b/Data/Engine/Unit_Tests/test_reverse_tunnel_integration.py @@ -0,0 +1,101 @@ +# ====================================================== +# Data\Engine\Unit_Tests\test_reverse_tunnel_integration.py +# Description: Integration test that exercises a full reverse tunnel PowerShell round-trip +# against a running Engine + Agent (requires live services). +# +# Requirements: +# - Environment variables must be set to point at a live Engine + Agent: +# TUNNEL_TEST_HOST (e.g., https://localhost:5000) +# TUNNEL_TEST_AGENT_ID (agent_id/agent_guid for the target device) +# TUNNEL_TEST_BEARER (Authorization bearer token for an admin/operator) +# - A live Agent must be reachable and allowed to establish the reverse tunnel. +# - TLS verification can be controlled via TUNNEL_TEST_VERIFY ("false" to disable). +# +# API Endpoints (if applicable): +# - POST /api/tunnel/request +# - Socket.IO namespace /tunnel (join, ps_open, ps_send, ps_poll) +# ====================================================== + +from __future__ import annotations + +import os +import time + +import pytest +import requests +import socketio + + +HOST = os.environ.get("TUNNEL_TEST_HOST", "").strip() +AGENT_ID = os.environ.get("TUNNEL_TEST_AGENT_ID", "").strip() +BEARER = os.environ.get("TUNNEL_TEST_BEARER", "").strip() +VERIFY_ENV = os.environ.get("TUNNEL_TEST_VERIFY", "").strip().lower() +VERIFY = False if VERIFY_ENV in {"false", "0", "no"} else True + +SKIP_MSG = ( + "Live tunnel test skipped (set TUNNEL_TEST_HOST, TUNNEL_TEST_AGENT_ID, TUNNEL_TEST_BEARER to run)" +) + + +def _require_env() -> None: + if not HOST or not AGENT_ID or not BEARER: + pytest.skip(SKIP_MSG) + + +def _make_session() -> requests.Session: + sess = requests.Session() + sess.verify = VERIFY + sess.headers.update({"Authorization": f"Bearer {BEARER}"}) + return sess + + +def test_reverse_tunnel_powershell_roundtrip() -> None: + _require_env() + sess = _make_session() + + # 1) Request a tunnel lease + resp = sess.post( + f"{HOST}/api/tunnel/request", + json={"agent_id": AGENT_ID, "protocol": "ps", "domain": "ps"}, + ) + assert resp.status_code == 200, f"lease request failed: {resp.status_code} {resp.text}" + lease = resp.json() + tunnel_id = lease["tunnel_id"] + + # 2) Connect to Socket.IO /tunnel namespace + sio = socketio.Client() + sio.connect( + HOST, + namespaces=["/tunnel"], + headers={"Authorization": f"Bearer {BEARER}"}, + transports=["websocket"], + wait_timeout=10, + ) + + # 3) Join tunnel and open PS channel + join_resp = sio.call("join", {"tunnel_id": tunnel_id}, namespace="/tunnel", timeout=10) + assert join_resp.get("status") == "ok", f"join failed: {join_resp}" + + open_resp = sio.call("ps_open", {"cols": 120, "rows": 32}, namespace="/tunnel", timeout=10) + assert not open_resp.get("error"), f"ps_open failed: {open_resp}" + + # 4) Send a command + send_resp = sio.call("ps_send", {"data": 'Write-Host "Hello World"\r\n'}, namespace="/tunnel", timeout=10) + assert not send_resp.get("error"), f"ps_send failed: {send_resp}" + + # 5) Poll for output + output_text = "" + deadline = time.time() + 15 + while time.time() < deadline: + poll_resp = sio.call("ps_poll", {}, namespace="/tunnel", timeout=10) + if poll_resp.get("error"): + pytest.fail(f"ps_poll failed: {poll_resp}") + lines = poll_resp.get("output") or [] + output_text += "".join(lines) + if "Hello World" in output_text: + break + time.sleep(0.5) + + sio.disconnect() + + assert "Hello World" in output_text, f"expected command output not found; got: {output_text!r}" diff --git a/Data/Engine/services/WebSocket/Agent/ReverseTunnel.py b/Data/Engine/services/WebSocket/Agent/ReverseTunnel.py index 6beecfc2..602799ee 100644 --- a/Data/Engine/services/WebSocket/Agent/ReverseTunnel.py +++ b/Data/Engine/services/WebSocket/Agent/ReverseTunnel.py @@ -391,11 +391,13 @@ class TunnelLeaseManager: def expire_idle(self, *, now_ts: Optional[float] = None) -> List[TunnelLease]: now = now_ts or _utc_ts() expired: List[TunnelLease] = [] + pending_timeout = min(self.grace_timeout_seconds, 300) # avoid long-lived pending locks for lease in list(self._leases.values()): if lease.state == "expired": continue idle_age = now - lease.last_activity_ts + pending_age = now - lease.created_at if lease.state == "active" and idle_age >= lease.idle_timeout_seconds: lease.mark_expired() expired.append(lease) @@ -409,6 +411,14 @@ class TunnelLeaseManager: expired.append(lease) self.release(lease.tunnel_id, reason="grace_expired") continue + + if lease.state == "pending": + hard_expiry = lease.expires_at or (lease.created_at + lease.grace_timeout_seconds) + if pending_age >= pending_timeout or (hard_expiry and now >= hard_expiry): + lease.mark_expired() + expired.append(lease) + self.release(lease.tunnel_id, reason="pending_timeout") + continue return expired def all_leases(self) -> Iterable[TunnelLease]: @@ -591,6 +601,7 @@ class ReverseTunnelService: ) lease.token = self.issue_token(lease) self._spawn_port_listener(lease.assigned_port) + self._push_start_to_agent(lease) self.audit_logger.info( "lease_created tunnel_id=%s agent_id=%s domain=%s protocol=%s port=%s operator=%s", lease.tunnel_id, @@ -618,6 +629,35 @@ class ReverseTunnelService: lease.expires_at = expires_at return token + def _push_start_to_agent(self, lease: TunnelLease) -> None: + """Notify the target agent about the new lease over Socket.IO (best-effort).""" + + if not self._socketio: + return + payload = { + "tunnel_id": lease.tunnel_id, + "lease_id": lease.tunnel_id, + "agent_id": lease.agent_id, + "token": lease.token, + "port": lease.assigned_port, + "assigned_port": lease.assigned_port, + "protocol": lease.protocol, + "domain": lease.domain, + "idle_seconds": lease.idle_timeout_seconds, + "grace_seconds": lease.grace_timeout_seconds, + "heartbeat_seconds": self.heartbeat_seconds, + } + try: + self._socketio.emit("reverse_tunnel_start", payload, namespace="/") + self.audit_logger.info( + "lease_push_start tunnel_id=%s agent_id=%s port=%s", + lease.tunnel_id, + lease.agent_id, + lease.assigned_port, + ) + except Exception: + self.logger.debug("Failed to emit reverse_tunnel_start for tunnel_id=%s", lease.tunnel_id, exc_info=True) + def lease_summary(self, lease: TunnelLease) -> Dict[str, object]: return { "tunnel_id": lease.tunnel_id, @@ -874,6 +914,7 @@ class ReverseTunnelService: cert = self.context.tls_cert_path or self.context.tls_bundle_path key = self.context.tls_key_path if not cert or not key: + self.audit_logger.info("tunnel_listener_ssl_missing cert=%s key=%s", cert, key) return None try: ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) @@ -890,6 +931,7 @@ class ReverseTunnelService: if port in self._port_servers: return ssl_ctx = self._build_ssl_context() + self.audit_logger.info("tunnel_listener_start port=%s ssl=%s", port, bool(ssl_ctx)) async def _handler(websocket, path): await self._handle_agent_socket(websocket, path, port=port) @@ -897,6 +939,7 @@ class ReverseTunnelService: async def _start(): server = await ws_serve(_handler, host="0.0.0.0", port=port, ssl=ssl_ctx, max_size=None, ping_interval=None) self._port_servers[port] = server + self.audit_logger.info("tunnel_listener_bound port=%s", port) asyncio.run_coroutine_threadsafe(_start(), self._loop) @@ -904,15 +947,24 @@ class ReverseTunnelService: """Handle agent tunnel socket on assigned port.""" tunnel_id = None + sock_log = self.audit_logger.getChild("agent_socket") try: + peer = None + try: + peer = getattr(websocket, "remote_address", None) + except Exception: + peer = None + sock_log.info("agent_socket_open port=%s path=%s peer=%s", port, path, peer) raw = await asyncio.wait_for(websocket.recv(), timeout=10) frame = decode_frame(raw) if frame.msg_type != MSG_CONNECT: + sock_log.info("agent_socket_first_frame_not_connect port=%s msg_type=%s", port, frame.msg_type) await websocket.close() return try: payload = json.loads(frame.payload.decode("utf-8")) except Exception: + sock_log.info("agent_socket_connect_payload_decode_failed port=%s", port, exc_info=True) await websocket.close() return tunnel_id = str(payload.get("tunnel_id") or "").strip() @@ -920,23 +972,56 @@ class ReverseTunnelService: token = payload.get("token") or "" lease = self.lease_manager.get(tunnel_id) if lease is None or lease.assigned_port != port: + sock_log.info( + "agent_socket_unknown_lease port=%s tunnel_id=%s assigned=%s expected=%s", + port, + tunnel_id, + lease.assigned_port if lease else None, + port, + ) await websocket.close() return # Token validation - self.validate_token( - token, - agent_id=agent_id, - tunnel_id=tunnel_id, - domain=lease.domain, - protocol=lease.protocol, - ) + try: + self.validate_token( + token, + agent_id=agent_id, + tunnel_id=tunnel_id, + domain=lease.domain, + protocol=lease.protocol, + ) + sock_log.info( + "agent_socket_token_valid port=%s tunnel_id=%s agent_id=%s domain=%s protocol=%s", + port, + tunnel_id, + agent_id, + lease.domain, + lease.protocol, + ) + except Exception as exc: + sock_log.info( + "agent_socket_token_invalid port=%s tunnel_id=%s agent_id=%s error=%s", + port, + tunnel_id, + agent_id, + exc, + ) + await websocket.close() + return bridge = self.ensure_bridge(lease) bridge.attach_agent(token) self._agent_sockets[tunnel_id] = websocket await websocket.send(heartbeat_frame(channel_id=0, is_ack=True).encode()) await websocket.send(TunnelFrame(msg_type=MSG_CONNECT_ACK, channel_id=0, payload=b"").encode()) + sock_log.info( + "agent_socket_connected port=%s tunnel_id=%s agent_id=%s", + port, + tunnel_id, + agent_id, + ) async def _pump_to_operator(): + sock_log_local = sock_log.getChild("recv") while not websocket.closed: try: raw_msg = await websocket.recv() @@ -945,14 +1030,23 @@ class ReverseTunnelService: try: recv_frame = decode_frame(raw_msg) except Exception: + sock_log_local.info("agent_socket_frame_decode_failed tunnel_id=%s", tunnel_id, exc_info=True) continue self.lease_manager.touch(tunnel_id) + sock_log_local.info( + "agent_to_operator tunnel_id=%s msg_type=%s channel=%s payload_len=%s", + tunnel_id, + recv_frame.msg_type, + recv_frame.channel_id, + len(recv_frame.payload or b""), + ) try: self._dispatch_agent_frame(tunnel_id, recv_frame) except Exception: pass bridge.agent_to_operator(recv_frame) async def _pump_to_agent(): + sock_log_local = sock_log.getChild("send") while not websocket.closed: frame = bridge.next_for_agent() if frame is None: @@ -960,13 +1054,22 @@ class ReverseTunnelService: continue try: await websocket.send(frame.encode()) + sock_log_local.info( + "operator_to_agent tunnel_id=%s msg_type=%s channel=%s payload_len=%s", + tunnel_id, + frame.msg_type, + frame.channel_id, + len(frame.payload or b""), + ) except Exception: break async def _heartbeat(): + sock_log_local = sock_log.getChild("heartbeat") while not websocket.closed: try: await websocket.send(heartbeat_frame(channel_id=0).encode()) except Exception: + sock_log_local.info("heartbeat_send_failed tunnel_id=%s", tunnel_id, exc_info=True) break await asyncio.sleep(self.heartbeat_seconds) @@ -975,8 +1078,18 @@ class ReverseTunnelService: heart = asyncio.create_task(_heartbeat()) await asyncio.wait([consumer, producer, heart], return_when=asyncio.FIRST_COMPLETED) except Exception: - self.logger.debug("Agent socket handler failed on port %s", port, exc_info=True) + sock_log.info("agent_socket_handler_failed port=%s tunnel_id=%s", port, tunnel_id, exc_info=True) finally: + try: + sock_log.info( + "agent_socket_closed port=%s tunnel_id=%s code=%s reason=%s", + port, + tunnel_id, + getattr(websocket, "close_code", None), + getattr(websocket, "close_reason", None), + ) + except Exception: + pass if tunnel_id and tunnel_id in self._agent_sockets: self._agent_sockets.pop(tunnel_id, None) if tunnel_id: @@ -997,7 +1110,7 @@ class ReverseTunnelService: if server: return server lease = self.lease_manager.get(tunnel_id) - if lease is None or (lease.domain or "").lower() != "ps": + if lease is None: return None bridge = self.ensure_bridge(lease) server = PowershellChannelServer( diff --git a/Data/Engine/services/WebSocket/__init__.py b/Data/Engine/services/WebSocket/__init__.py index 632729c1..911c131d 100644 --- a/Data/Engine/services/WebSocket/__init__.py +++ b/Data/Engine/services/WebSocket/__init__.py @@ -474,7 +474,7 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None: return {"error": "ps_resize_failed"} @socket_server.on("ps_poll", namespace=tunnel_namespace) - def _ws_ps_poll() -> Any: + def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused server, tunnel_id, error = _require_ps_server() if server is None: return error diff --git a/Data/Engine/web-interface/src/Devices/Device_Details.jsx b/Data/Engine/web-interface/src/Devices/Device_Details.jsx index 716eb442..55810004 100644 --- a/Data/Engine/web-interface/src/Devices/Device_Details.jsx +++ b/Data/Engine/web-interface/src/Devices/Device_Details.jsx @@ -31,6 +31,7 @@ import "prismjs/themes/prism-okaidia.css"; import Editor from "react-simple-code-editor"; import { AgGridReact } from "ag-grid-react"; import { ModuleRegistry, AllCommunityModule, themeQuartz } from "ag-grid-community"; +import ReverseTunnelPowershell from "./ReverseTunnel/Powershell.jsx"; ModuleRegistry.registerModules([AllCommunityModule]); @@ -63,7 +64,15 @@ const SECTION_HEIGHTS = { network: 260, }; -const TOP_TABS = ["Device Summary", "Storage", "Memory", "Network", "Installed Software", "Activity History"]; +const TOP_TABS = [ + "Device Summary", + "Storage", + "Memory", + "Network", + "Installed Software", + "Activity History", + "Remote Shell", +]; const myTheme = themeQuartz.withParams({ accentColor: "#8b5cf6", @@ -727,6 +736,17 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage ); const summary = details.summary || {}; + const tunnelDevice = useMemo( + () => ({ + ...(device || {}), + ...(agent || {}), + summary, + hostname: meta.hostname || summary.hostname || device?.hostname || agent?.hostname, + agent_id: meta.agentId || summary.agent_id || agent?.agent_id || agent?.id || device?.agent_id || device?.agent_guid, + agent_guid: meta.agentGuid || summary.agent_guid || device?.agent_guid || device?.guid || agent?.agent_guid || agent?.guid, + }), + [agent, device, meta.agentGuid, meta.agentId, meta.hostname, summary] + ); // Build a best-effort CPU display from summary fields const cpuInfo = useMemo(() => { const cpu = details.cpu || summary.cpu || {}; @@ -850,16 +870,20 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage [] ); + const formatScriptType = useCallback((raw) => { + const value = String(raw || "").toLowerCase(); + if (value === "ansible") return "Ansible Playbook"; + if (value === "reverse_tunnel") return "Reverse Tunnel"; + return "Script"; + }, []); + const historyColumnDefs = useMemo( () => [ { - headerName: "Assembly", + headerName: "Activity", field: "script_type", minWidth: 180, - valueGetter: (params) => - String(params.data?.script_type || "").toLowerCase() === "ansible" - ? "Ansible Playbook" - : "Script", + valueGetter: (params) => formatScriptType(params.data?.script_type), }, { headerName: "Task", @@ -891,7 +915,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage cellRenderer: "HistoryActionsCell", }, ], - [formatTimestamp] + [formatScriptType, formatTimestamp] ); const MetricCard = ({ icon, title, main, sub, compact = false }) => ( @@ -1265,6 +1289,19 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage ); + const renderRemoteShellTab = () => ( + + + + ); + const memoryRows = useMemo( () => (details.memory || []).map((m, idx) => ({ @@ -1523,6 +1560,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage renderNetworkTab, renderSoftware, renderHistory, + renderRemoteShellTab, ]; const tabContent = (topTabRenderers[tab] || renderDeviceSummaryTab)(); @@ -1642,7 +1680,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage anchorEl={menuAnchor} open={Boolean(menuAnchor)} onClose={() => setMenuAnchor(null)} - PaperProps={{ + PaperProps={{ sx: { bgcolor: "rgba(8,12,24,0.96)", color: "#fff", diff --git a/Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx b/Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx new file mode 100644 index 00000000..ec75dc15 --- /dev/null +++ b/Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx @@ -0,0 +1,704 @@ +import React, { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { + Box, + Typography, + Button, + Stack, + Chip, + TextField, + MenuItem, + IconButton, + Tooltip, + Alert, + LinearProgress, +} from "@mui/material"; +import { + TerminalRounded as TerminalIcon, + PlayArrowRounded as PlayIcon, + StopRounded as StopIcon, + ContentCopy as CopyIcon, + RefreshRounded as RefreshIcon, + LanRounded as PortIcon, + SensorsRounded as ActivityIcon, + LinkRounded as LinkIcon, +} from "@mui/icons-material"; +import { io } from "socket.io-client"; +import Prism from "prismjs"; +import "prismjs/components/prism-powershell"; +import "prismjs/themes/prism-okaidia.css"; +import Editor from "react-simple-code-editor"; + +const MAGIC_UI = { + panelBg: "rgba(7,11,24,0.92)", + panelBorder: "rgba(148, 163, 184, 0.35)", + textMuted: "#94a3b8", + textBright: "#e2e8f0", + accentA: "#7dd3fc", + accentB: "#c084fc", + accentC: "#34d399", +}; + +const gradientButtonSx = { + backgroundImage: "linear-gradient(135deg,#7dd3fc,#c084fc)", + color: "#0b1220", + borderRadius: 999, + textTransform: "none", + boxShadow: "0 10px 26px rgba(124,58,237,0.28)", + px: 2.2, + minWidth: 120, + "&:hover": { + backgroundImage: "linear-gradient(135deg,#86e1ff,#d1a6ff)", + boxShadow: "0 12px 34px rgba(124,58,237,0.38)", + }, +}; + +const FRAME_HEADER_BYTES = 12; // version, msg_type, flags, reserved, channel_id(u32), length(u32) +const MSG_CLOSE = 0x08; +const CLOSE_AGENT_SHUTDOWN = 6; + +const fontFamilyMono = + 'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace'; + +function normalizeText(value) { + if (value == null) return ""; + try { + return String(value).trim(); + } catch { + return ""; + } +} + +function base64FromBytes(bytes) { + let binary = ""; + bytes.forEach((b) => { + binary += String.fromCharCode(b); + }); + return btoa(binary); +} + +function buildCloseFrame(channelId = 1, code = CLOSE_AGENT_SHUTDOWN, reason = "operator_close") { + const payload = new TextEncoder().encode(JSON.stringify({ code, reason })); + const buffer = new ArrayBuffer(FRAME_HEADER_BYTES + payload.length); + const view = new DataView(buffer); + view.setUint8(0, 1); // version + view.setUint8(1, MSG_CLOSE); + view.setUint8(2, 0); // flags + view.setUint8(3, 0); // reserved + view.setUint32(4, channelId >>> 0, true); + view.setUint32(8, payload.length >>> 0, true); + new Uint8Array(buffer, FRAME_HEADER_BYTES).set(payload); + return base64FromBytes(new Uint8Array(buffer)); +} + +function highlightPs(code) { + try { + return Prism.highlight(code || "", Prism.languages.powershell, "powershell"); + } catch { + return code || ""; + } +} + +export default function ReverseTunnelPowershell({ device }) { + const [connectionType, setConnectionType] = useState("ps"); + const [tunnel, setTunnel] = useState(null); + const [sessionState, setSessionState] = useState("idle"); + const [statusMessage, setStatusMessage] = useState(""); + const [statusSeverity, setStatusSeverity] = useState("info"); + const [output, setOutput] = useState(""); + const [input, setInput] = useState(""); + const [copyFlash, setCopyFlash] = useState(false); + const [, setPolling] = useState(false); + const [psStatus, setPsStatus] = useState({}); + const socketRef = useRef(null); + const pollTimerRef = useRef(null); + const resizeTimerRef = useRef(null); + const terminalRef = useRef(null); + const joinRetryRef = useRef(null); + const joinAttemptsRef = useRef(0); + + const hostname = useMemo(() => { + return ( + normalizeText(device?.hostname) || + normalizeText(device?.summary?.hostname) || + normalizeText(device?.agent_hostname) || + "" + ); + }, [device]); + + const agentId = useMemo(() => { + return ( + normalizeText(device?.agent_id) || + normalizeText(device?.agentId) || + normalizeText(device?.agent_guid) || + normalizeText(device?.agentGuid) || + normalizeText(device?.id) || + normalizeText(device?.guid) || + normalizeText(device?.summary?.agent_id) || + "" + ); + }, [device]); + + const resetState = useCallback(() => { + setTunnel(null); + setSessionState("idle"); + setStatusMessage(""); + setStatusSeverity("info"); + setOutput(""); + setInput(""); + setPsStatus({}); + }, []); + + const disconnectSocket = useCallback(() => { + const socket = socketRef.current; + if (socket) { + socket.off(); + socket.disconnect(); + } + socketRef.current = null; + }, []); + + const stopPolling = useCallback(() => { + if (pollTimerRef.current) { + clearTimeout(pollTimerRef.current); + pollTimerRef.current = null; + } + setPolling(false); + }, []); + + useEffect(() => { + return () => { + stopPolling(); + disconnectSocket(); + if (joinRetryRef.current) { + clearTimeout(joinRetryRef.current); + joinRetryRef.current = null; + } + }; + }, [disconnectSocket, stopPolling]); + + const appendOutput = useCallback((text) => { + if (!text) return; + setOutput((prev) => { + const next = `${prev}${text}`; + const limit = 40000; + return next.length > limit ? next.slice(next.length - limit) : next; + }); + }, []); + + const scrollToBottom = useCallback(() => { + const el = terminalRef.current; + if (!el) return; + requestAnimationFrame(() => { + el.scrollTop = el.scrollHeight; + }); + }, []); + + useEffect(() => { + scrollToBottom(); + }, [output, scrollToBottom]); + + const handleCopy = async () => { + try { + await navigator.clipboard.writeText(output || ""); + setCopyFlash(true); + setTimeout(() => setCopyFlash(false), 1200); + } catch { + setCopyFlash(false); + } + }; + + const measureTerminal = useCallback(() => { + const el = terminalRef.current; + if (!el) return { cols: 120, rows: 32 }; + const width = el.clientWidth || 960; + const height = el.clientHeight || 460; + const charWidth = 8.2; + const charHeight = 18; + const cols = Math.max(20, Math.min(Math.floor(width / charWidth), 300)); + const rows = Math.max(10, Math.min(Math.floor(height / charHeight), 200)); + return { cols, rows }; + }, []); + + const emitAsync = useCallback((socket, event, payload, timeoutMs = 4000) => { + return new Promise((resolve) => { + let settled = false; + const timer = setTimeout(() => { + if (settled) return; + settled = true; + resolve({ error: "timeout" }); + }, timeoutMs); + socket.emit(event, payload, (resp) => { + if (settled) return; + settled = true; + clearTimeout(timer); + resolve(resp || {}); + }); + }); + }, []); + + const pollLoop = useCallback( + (socket, tunnelId) => { + if (!socket || !tunnelId) return; + setPolling(true); + pollTimerRef.current = setTimeout(async () => { + const resp = await emitAsync(socket, "ps_poll", {}); + if (resp?.error) { + if (resp.error === "ps_unsupported") { + setStatusSeverity("info"); + setStatusMessage("PowerShell channel warming up..."); + } else { + setStatusSeverity("warning"); + setStatusMessage(resp.error); + } + } + if (Array.isArray(resp?.output) && resp.output.length) { + appendOutput(resp.output.join("")); + } + if (resp?.status) { + setPsStatus(resp.status); + if (resp.status.closed) { + setSessionState("closed"); + setStatusSeverity("warning"); + setStatusMessage(resp.status.close_reason || "Session closed"); + stopPolling(); + return; + } + if (resp.status.ack) { + setSessionState("connected"); + } + } + pollLoop(socket, tunnelId); + }, 520); + }, + [appendOutput, emitAsync, stopPolling] + ); + + const handleDisconnect = useCallback(() => { + const socket = socketRef.current; + const tunnelId = tunnel?.tunnel_id; + if (joinRetryRef.current) { + clearTimeout(joinRetryRef.current); + joinRetryRef.current = null; + } + joinAttemptsRef.current = 0; + if (socket && tunnelId) { + const frame = buildCloseFrame(1, CLOSE_AGENT_SHUTDOWN, "operator_close"); + socket.emit("send", { frame }); + } + stopPolling(); + disconnectSocket(); + setSessionState("closed"); + setStatusSeverity("info"); + setStatusMessage("Session closed by operator."); + }, [disconnectSocket, stopPolling, tunnel?.tunnel_id]); + + const handleResize = useCallback(() => { + if (!socketRef.current || sessionState === "idle") return; + const dims = measureTerminal(); + socketRef.current.emit("ps_resize", dims); + }, [measureTerminal, sessionState]); + + useEffect(() => { + const observer = + typeof ResizeObserver !== "undefined" + ? new ResizeObserver(() => { + if (resizeTimerRef.current) clearTimeout(resizeTimerRef.current); + resizeTimerRef.current = setTimeout(() => handleResize(), 200); + }) + : null; + const el = terminalRef.current; + if (observer && el) observer.observe(el); + const onWinResize = () => handleResize(); + window.addEventListener("resize", onWinResize); + return () => { + window.removeEventListener("resize", onWinResize); + if (observer && el) observer.unobserve(el); + }; + }, [handleResize]); + + const connectSocket = useCallback( + (lease, { isRetry = false } = {}) => { + if (!lease?.tunnel_id) return; + if (joinRetryRef.current) { + clearTimeout(joinRetryRef.current); + joinRetryRef.current = null; + } + if (!isRetry) { + joinAttemptsRef.current = 0; + } + disconnectSocket(); + stopPolling(); + setSessionState("waiting"); + const socket = io(`${window.location.origin}/tunnel`, { transports: ["websocket"] }); + socketRef.current = socket; + + socket.on("connect_error", () => { + setStatusSeverity("warning"); + setStatusMessage("Tunnel namespace unavailable."); + }); + + socket.on("disconnect", () => { + stopPolling(); + if (sessionState !== "closed") { + setSessionState("disconnected"); + setStatusSeverity("warning"); + setStatusMessage("Socket disconnected."); + } + }); + + socket.on("connect", async () => { + setStatusSeverity("info"); + setStatusMessage("Joining tunnel..."); + const joinResp = await emitAsync(socket, "join", { tunnel_id: lease.tunnel_id }); + if (joinResp?.error) { + if (joinResp.error === "unknown_tunnel") { + setSessionState("waiting_agent"); + setStatusSeverity("info"); + setStatusMessage("Waiting for agent to establish tunnel..."); + joinAttemptsRef.current += 1; + const attempt = joinAttemptsRef.current; + if (attempt <= 15) { + joinRetryRef.current = setTimeout(() => connectSocket(lease, { isRetry: true }), 1000); + } else { + setSessionState("error"); + setTunnel(null); + setStatusSeverity("warning"); + setStatusMessage("Agent did not attach to tunnel (timeout). Try Connect again."); + } + } else { + setSessionState("error"); + setStatusSeverity("error"); + setStatusMessage(joinResp.error); + } + return; + } + const dims = measureTerminal(); + const openResp = await emitAsync(socket, "ps_open", dims); + if (openResp?.error && openResp.error === "ps_unsupported") { + setStatusSeverity("info"); + setStatusMessage("PowerShell channel warming up..."); + } + appendOutput(""); + setStatusMessage("Attached — waiting for agent to acknowledge..."); + setSessionState("waiting_agent"); + pollLoop(socket, lease.tunnel_id); + handleResize(); + }); + }, + [appendOutput, disconnectSocket, emitAsync, handleResize, measureTerminal, pollLoop, sessionState, stopPolling] + ); + + const requestTunnel = useCallback(async () => { + if (tunnel && sessionState !== "closed" && sessionState !== "idle") { + setStatusSeverity("info"); + setStatusMessage("Re-attaching to existing tunnel..."); + connectSocket(tunnel); + return; + } + if (!agentId) { + setStatusSeverity("warning"); + setStatusMessage("Agent ID is required to request a tunnel."); + return; + } + if (connectionType !== "ps") { + setStatusSeverity("warning"); + setStatusMessage("Only PowerShell is supported right now."); + return; + } + resetState(); + setSessionState("requesting"); + setStatusSeverity("info"); + setStatusMessage("Requesting tunnel lease..."); + try { + const resp = await fetch("/api/tunnel/request", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ agent_id: agentId, protocol: "ps", domain: "ps" }), + }); + const data = await resp.json().catch(() => ({})); + if (!resp.ok) { + const err = data?.error || `HTTP ${resp.status}`; + setSessionState("error"); + setStatusSeverity(err === "domain_limit" ? "warning" : "error"); + setStatusMessage( + err === "domain_limit" + ? "PowerShell session already active for this agent. Try again after it closes." + : err + ); + return; + } + setTunnel(data); + setStatusMessage("Lease issued. Waiting for agent to connect..."); + setSessionState("lease_issued"); + connectSocket(data); + } catch (e) { + setSessionState("error"); + setStatusSeverity("error"); + setStatusMessage(e?.message || "Failed to request tunnel"); + } + }, [agentId, connectSocket, connectionType, resetState]); + + const handleSend = useCallback( + async (text) => { + const socket = socketRef.current; + if (!socket) return; + const payload = `${text}${text.endsWith("\n") ? "" : "\r\n"}`; + appendOutput(`\nPS> ${text}\n`); + setInput(""); + const resp = await emitAsync(socket, "ps_send", { data: payload }); + if (resp?.error) { + setStatusSeverity("warning"); + setStatusMessage(resp.error); + } + }, + [appendOutput, emitAsync] + ); + + const isConnected = sessionState === "connected" || psStatus?.ack; + const isBusy = + sessionState === "requesting" || + sessionState === "waiting" || + sessionState === "waiting_agent" || + sessionState === "lease_issued"; + const canStart = Boolean(agentId) && !isBusy && !isConnected; + + const sessionChips = [ + { + label: isConnected ? "Connected" : sessionState === "idle" ? "Idle" : sessionState.replace(/_/g, " "), + color: isConnected ? MAGIC_UI.accentC : MAGIC_UI.accentA, + icon: , + }, + tunnel?.tunnel_id + ? { + label: `Tunnel ${tunnel.tunnel_id.slice(0, 8)}`, + color: MAGIC_UI.accentB, + icon: , + } + : null, + tunnel?.port + ? { + label: `Port ${tunnel.port}`, + color: MAGIC_UI.accentA, + icon: , + } + : null, + ].filter(Boolean); + + return ( + + + + + + + Remote Shell + + {hostname ? ( + + ) : null} + {agentId ? ( + + ) : null} + + + + setConnectionType(e.target.value)} + sx={{ + minWidth: 180, + "& .MuiInputBase-root": { + backgroundColor: "rgba(12,18,35,0.85)", + color: MAGIC_UI.textBright, + borderRadius: 1.5, + }, + "& fieldset": { borderColor: MAGIC_UI.panelBorder }, + "&:hover fieldset": { borderColor: MAGIC_UI.accentA }, + }} + > + PowerShell + + + + + + + + + + + {sessionChips.map((chip) => ( + + ))} + + + + {statusMessage ? ( + + {statusMessage} + + ) : null} + + + {isBusy ? : null} + + {}} + highlight={highlightPs} + padding={12} + readOnly + style={{ + minHeight: "100%", + background: "transparent", + color: "#e6edf3", + fontFamily: fontFamilyMono, + fontSize: 13, + }} + /> + + + + + + + + setOutput("")} + sx={{ color: MAGIC_UI.textMuted }} + > + + + + + + + + setInput(e.target.value)} + onKeyDown={(e) => { + if (e.key === "Enter" && !e.shiftKey) { + e.preventDefault(); + const text = input.trim(); + if (text) handleSend(text); + } + }} + InputProps={{ + sx: { + backgroundColor: "rgba(12,18,35,0.9)", + color: "#e2e8f0", + borderRadius: 2, + "& fieldset": { borderColor: "rgba(148,163,184,0.45)" }, + "&:hover fieldset": { borderColor: MAGIC_UI.accentA }, + }, + }} + /> + + + + ); +}