Major Progress Towards Interactive Remote Powershell

This commit is contained in:
2025-12-06 00:27:57 -07:00
parent 52e40c3753
commit 68dd46347b
9 changed files with 1247 additions and 53 deletions

View File

@@ -995,6 +995,13 @@ function InstallOrUpdate-BorealisAgent {
)
Copy-Item $coreAgentFiles -Destination $agentDestinationFolder -Recurse -Force
# Ensure ReverseTunnel role is refreshed explicitly (covers incremental changes)
$rtSource = Join-Path $agentSourceRoot 'Roles\ReverseTunnel'
$rtDest = Join-Path $agentDestinationFolder 'Roles'
if (Test-Path $rtSource) {
Copy-Item $rtSource -Destination $rtDest -Recurse -Force
}
}
. (Join-Path $venvFolderPath 'Scripts\Activate')
}

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import os
import sys
import subprocess
from typing import Any, Dict, Optional
# Message types mirrored from the tunnel framing (kept local to avoid import cycles).
@@ -30,6 +31,7 @@ class PowershellChannel:
self._writer_task = None
self._stdin_queue: asyncio.Queue = asyncio.Queue()
self._pty = None
self._proc: Optional[asyncio.subprocess.Process] = None
self._exit_code: Optional[int] = None
self._frame_cls = getattr(role, "_frame_cls", None)
@@ -62,12 +64,11 @@ class PowershellChannel:
)
await self._send_frame(frame)
def _powershell_path(self) -> str:
def _powershell_argv(self) -> list:
preferred = self.metadata.get("shell") if isinstance(self.metadata, dict) else None
if isinstance(preferred, str) and preferred.strip():
return preferred.strip()
# Default to Windows PowerShell; fallback to pwsh if provided later.
return "powershell.exe"
shell = preferred.strip() if isinstance(preferred, str) and preferred.strip() else "powershell.exe"
# Keep the process alive and read commands from stdin; -Command - tells PS to consume stdin.
return [shell, "-NoLogo", "-NoProfile", "-NoExit", "-Command", "-"]
def _initial_size(self) -> tuple:
cols = int(self.metadata.get("cols") or self.metadata.get("columns") or 120) if isinstance(self.metadata, dict) else 120
@@ -79,30 +80,46 @@ class PowershellChannel:
# ------------------------------------------------------------------ Lifecycle
async def start(self) -> None:
if sys.platform.lower().startswith("win") is False:
self.role._log("reverse_tunnel ps start aborted: non-windows platform", error=True)
await self._send_close(CLOSE_PROTOCOL_ERROR, "windows_only")
return
argv = self._powershell_argv()
cols, rows = self._initial_size()
self.role._log(f"reverse_tunnel ps start channel={self.channel_id} argv={' '.join(argv)} cols={cols} rows={rows}")
# Preferred: ConPTY via pywinpty.
try:
import pywinpty # type: ignore
except Exception as exc: # pragma: no cover - dependency guard
self.role._log(f"reverse_tunnel ps channel missing pywinpty: {exc}", error=True)
await self._send_close(CLOSE_PROTOCOL_ERROR, "pywinpty_missing")
return
shell = self._powershell_path()
cols, rows = self._initial_size()
try:
self._pty = pywinpty.Process(
spawn_cmd=shell,
spawn_cmd=" ".join(argv[:-1]) if argv[-1] == "-" else " ".join(argv),
dimensions=(cols, rows),
)
self._reader_task = self.loop.create_task(self._pump_pty_stdout())
self._writer_task = self.loop.create_task(self._pump_pty_stdin())
self.role._log(f"reverse_tunnel ps channel started (pty) argv={' '.join(argv)} cols={cols} rows={rows}")
return
except Exception as exc:
self.role._log(f"reverse_tunnel ps channel failed to spawn {shell}: {exc}", error=True)
self.role._log(f"reverse_tunnel ps channel pywinpty unavailable, falling back to pipes: {exc}", error=True)
# Fallback: subprocess pipes (no PTY).
try:
self._proc = await asyncio.create_subprocess_exec(
*argv,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0),
)
except Exception as exc:
self.role._log(f"reverse_tunnel ps channel spawn failed argv={' '.join(argv)}: {exc}", error=True)
await self._send_close(CLOSE_PROTOCOL_ERROR, "spawn_failed")
return
self._reader_task = self.loop.create_task(self._pump_stdout())
self._writer_task = self.loop.create_task(self._pump_stdin())
self.role._log(f"reverse_tunnel ps channel started shell={shell} cols={cols} rows={rows}")
self._reader_task = self.loop.create_task(self._pump_proc_stdout())
self._writer_task = self.loop.create_task(self._pump_proc_stdin())
self.role._log(f"reverse_tunnel ps channel started (pipes) argv={' '.join(argv)} cols={cols} rows={rows}")
async def on_frame(self, frame) -> None:
if self._closed:
@@ -140,6 +157,7 @@ class PowershellChannel:
await self._resize(cols_int, rows_int)
async def _resize(self, cols: Optional[int], rows: Optional[int]) -> None:
# Resize only applies to PTY sessions; pipe mode ignores.
if self._pty is None:
return
try:
@@ -155,17 +173,14 @@ class PowershellChannel:
except Exception:
self.role._log("reverse_tunnel ps channel resize failed", error=True)
async def _pump_stdout(self) -> None:
async def _pump_pty_stdout(self) -> None:
loop = asyncio.get_event_loop()
try:
while not self._closed and self._pty:
chunk = await loop.run_in_executor(None, self._pty.read, 4096)
if chunk is None:
break
if isinstance(chunk, str):
data = chunk.encode("utf-8", errors="replace")
else:
data = bytes(chunk)
data = chunk.encode("utf-8", errors="replace") if isinstance(chunk, str) else bytes(chunk)
if not data:
break
frame = self._make_frame(MSG_DATA, payload=data)
@@ -173,11 +188,11 @@ class PowershellChannel:
except asyncio.CancelledError:
pass
except Exception:
self.role._log("reverse_tunnel ps stdout pump error", error=True)
self.role._log("reverse_tunnel ps pty stdout pump error", error=True)
finally:
await self.stop(reason="stdout_closed")
async def _pump_stdin(self) -> None:
async def _pump_pty_stdin(self) -> None:
loop = asyncio.get_event_loop()
try:
while not self._closed and self._pty:
@@ -187,10 +202,7 @@ class PowershellChannel:
break
if data is None:
break
if isinstance(data, (bytes, bytearray)):
text = data.decode("utf-8", errors="replace")
else:
text = str(data)
text = data.decode("utf-8", errors="replace") if isinstance(data, (bytes, bytearray)) else str(data)
try:
await loop.run_in_executor(None, self._pty.write, text)
except Exception:
@@ -198,7 +210,47 @@ class PowershellChannel:
except asyncio.CancelledError:
pass
except Exception:
self.role._log("reverse_tunnel ps stdin pump error", error=True)
self.role._log("reverse_tunnel ps pty stdin pump error", error=True)
finally:
await self.stop(reason="stdin_closed")
# -------------------- Pipe fallback pumps --------------------
async def _pump_proc_stdout(self) -> None:
try:
while self._proc and not self._closed:
chunk = await self._proc.stdout.read(4096)
if not chunk:
break
frame = self._make_frame(MSG_DATA, payload=bytes(chunk))
await self._send_frame(frame)
except asyncio.CancelledError:
pass
except Exception:
self.role._log("reverse_tunnel ps pipe stdout pump error", error=True)
finally:
if self._proc and not self._closed:
try:
self._exit_code = await self._proc.wait()
except Exception:
pass
await self.stop(reason="stdout_closed")
async def _pump_proc_stdin(self) -> None:
try:
while self._proc and not self._closed:
data = await self._stdin_queue.get()
if self._closed or not self._proc or not self._proc.stdin:
break
try:
self._proc.stdin.write(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
await self._proc.stdin.drain()
except Exception:
self.role._log("reverse_tunnel ps pipe stdin pump error", error=True)
break
except asyncio.CancelledError:
pass
except Exception:
self.role._log("reverse_tunnel ps pipe stdin pump error", error=True)
finally:
await self.stop(reason="stdin_closed")
@@ -211,6 +263,11 @@ class PowershellChannel:
self._pty.terminate()
except Exception:
pass
if self._proc is not None:
try:
self._proc.terminate()
except Exception:
pass
current = asyncio.current_task()
if self._reader_task and self._reader_task is not current:
try:
@@ -222,5 +279,7 @@ class PowershellChannel:
self._writer_task.cancel()
except Exception:
pass
await self._send_close(code, reason or "powershell_exit")
# Include exit code in the close reason for debugging.
exit_suffix = f" (exit={self._exit_code})" if self._exit_code is not None else ""
await self._send_close(code, (reason or "powershell_exit") + exit_suffix)
self.role._log(f"reverse_tunnel ps channel stopped channel={self.channel_id} reason={reason or 'exit'}")

View File

@@ -1,10 +1,12 @@
import asyncio
import base64
import importlib.util
import json
import os
import struct
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional
from urllib.parse import urlparse
@@ -12,10 +14,25 @@ import aiohttp
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
# Capture import errors for the PowerShell handler so we can report why it is missing.
PS_IMPORT_ERROR: Optional[str] = None
tunnel_Powershell = None
try:
from .ReverseTunnel import tunnel_Powershell
except Exception:
tunnel_Powershell = None
from .ReverseTunnel import tunnel_Powershell # type: ignore
except Exception as exc: # pragma: no cover - best-effort logging only
PS_IMPORT_ERROR = repr(exc)
# Try manual import from file to survive non-package execution.
try:
_ps_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py"
if _ps_path.exists():
spec = importlib.util.spec_from_file_location("tunnel_Powershell", _ps_path)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
tunnel_Powershell = module
PS_IMPORT_ERROR = None
except Exception as exc2: # pragma: no cover - diagnostic only
PS_IMPORT_ERROR = f"{PS_IMPORT_ERROR} | fallback_load_failed={exc2!r}"
ROLE_NAME = "reverse_tunnel"
ROLE_CONTEXTS = ["interactive", "system"]
@@ -178,8 +195,21 @@ class Role:
self._protocol_handlers: Dict[str, Any] = {}
self._frame_cls = TunnelFrame
self.close_frame = close_frame
if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"):
self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel
try:
if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"):
self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel
module_path = getattr(tunnel_Powershell, "__file__", None)
self._log(f"reverse_tunnel ps handler registered (PowershellChannel) module={module_path}")
else:
hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else ""
module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py"
exists_hint = f" exists={module_path.exists()}"
self._log(
f"reverse_tunnel ps handler NOT registered (missing module/class){hint}{exists_hint}",
error=True,
)
except Exception as exc:
self._log(f"reverse_tunnel ps handler registration failed: {exc}", error=True)
# ------------------------------------------------------------------ Logging
def _log(self, message: str, *, error: bool = False) -> None:
@@ -359,6 +389,11 @@ class Role:
idle_seconds = int(payload.get("idle_seconds") or 3600)
grace_seconds = int(payload.get("grace_seconds") or 3600)
signing_key_hint = _norm_text(payload.get("signing_key"))
agent_hint = _norm_text(payload.get("agent_id"))
# Ignore broadcasts targeting other agents (Socket.IO fanout sends to both contexts).
if agent_hint and agent_hint.lower() != _norm_text(self.ctx.agent_id).lower():
return
if not token:
self._log("reverse_tunnel_start rejected: missing token", error=True)
@@ -378,6 +413,9 @@ class Role:
signing_key_hint=signing_key_hint,
)
except Exception as exc:
if str(exc) == "token_agent_mismatch":
# Broadcast hit the wrong agent context; ignore quietly.
return
self._log(f"reverse_tunnel_start rejected: token validation failed ({exc})", error=True)
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_invalid"})
return
@@ -439,6 +477,10 @@ class Role:
async def _run_tunnel(self, tunnel: ActiveTunnel, *, host: str) -> None:
ssl_ctx = self._ssl_context(host)
timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=None)
self._log(
f"reverse_tunnel dialing ws url={tunnel.url} tunnel_id={tunnel.tunnel_id} "
f"agent_id={self.ctx.agent_id} ssl={'yes' if ssl_ctx else 'no'}"
)
try:
tunnel.session = aiohttp.ClientSession(timeout=timeout)
tunnel.websocket = await tunnel.session.ws_connect(
@@ -449,6 +491,7 @@ class Role:
timeout=timeout,
)
self._mark_activity(tunnel)
self._log(f"reverse_tunnel connected ws tunnel_id={tunnel.tunnel_id} peer={getattr(tunnel.websocket, 'remote', None)}")
await tunnel.websocket.send_bytes(
TunnelFrame(
msg_type=MSG_CONNECT,
@@ -466,6 +509,7 @@ class Role:
).encode("utf-8"),
).encode()
)
self._log(f"reverse_tunnel CONNECT sent tunnel_id={tunnel.tunnel_id}")
sender = self.loop.create_task(self._pump_sender(tunnel))
receiver = self.loop.create_task(self._pump_receiver(tunnel))
@@ -486,12 +530,18 @@ class Role:
try:
await tunnel.websocket.send_bytes(frame.encode())
self._mark_activity(tunnel)
self._log(
f"reverse_tunnel send frame tunnel_id={tunnel.tunnel_id} "
f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}"
)
except Exception:
break
except asyncio.CancelledError:
pass
except Exception:
self._log(f"reverse_tunnel sender failed tunnel_id={tunnel.tunnel_id}", error=True)
finally:
self._log(f"reverse_tunnel sender stopped tunnel_id={tunnel.tunnel_id}")
async def _pump_receiver(self, tunnel: ActiveTunnel) -> None:
ws = tunnel.websocket
@@ -508,21 +558,30 @@ class Role:
self._mark_activity(tunnel)
await self._handle_frame(tunnel, frame)
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSE):
self._log(
f"reverse_tunnel websocket closed tunnel_id={tunnel.tunnel_id} "
f"code={ws.close_code} reason={ws.close_reason}"
)
break
except asyncio.CancelledError:
pass
except Exception:
self._log(f"reverse_tunnel receiver failed tunnel_id={tunnel.tunnel_id}", error=True)
finally:
self._log(f"reverse_tunnel receiver stopped tunnel_id={tunnel.tunnel_id}")
async def _heartbeat_loop(self, tunnel: ActiveTunnel) -> None:
try:
while tunnel.websocket and not tunnel.websocket.closed:
await asyncio.sleep(tunnel.heartbeat_seconds)
await self._send_frame(tunnel, heartbeat_frame())
self._log(f"reverse_tunnel heartbeat sent tunnel_id={tunnel.tunnel_id}")
except asyncio.CancelledError:
pass
except Exception:
self._log(f"reverse_tunnel heartbeat failed tunnel_id={tunnel.tunnel_id}", error=True)
finally:
self._log(f"reverse_tunnel heartbeat loop stopped tunnel_id={tunnel.tunnel_id}")
async def _watchdog(self, tunnel: ActiveTunnel) -> None:
try:
@@ -531,24 +590,34 @@ class Role:
now = time.time()
if tunnel.idle_seconds and (now - tunnel.last_activity) >= tunnel.idle_seconds:
await self._send_frame(tunnel, close_frame(0, CLOSE_IDLE_TIMEOUT, "idle_timeout"))
self._log(f"reverse_tunnel watchdog idle_timeout tunnel_id={tunnel.tunnel_id}")
break
if tunnel.expires_at and (now - tunnel.expires_at) >= tunnel.grace_seconds:
await self._send_frame(tunnel, close_frame(0, CLOSE_GRACE_EXPIRED, "grace_expired"))
self._log(f"reverse_tunnel watchdog grace_expired tunnel_id={tunnel.tunnel_id}")
break
except asyncio.CancelledError:
pass
except Exception:
self._log(f"reverse_tunnel watchdog failed tunnel_id={tunnel.tunnel_id}", error=True)
finally:
self._log(f"reverse_tunnel watchdog stopped tunnel_id={tunnel.tunnel_id}")
async def _handle_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
self._log(
f"reverse_tunnel recv frame tunnel_id={tunnel.tunnel_id} "
f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}"
)
if frame.msg_type == MSG_HEARTBEAT:
if frame.flags & 0x1:
self._log(f"reverse_tunnel heartbeat ack tunnel_id={tunnel.tunnel_id}")
return
await self._send_frame(tunnel, heartbeat_frame(channel_id=frame.channel_id, is_ack=True))
return
if frame.msg_type == MSG_CONNECT_ACK:
tunnel.connected = True
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "connected"})
self._log(f"reverse_tunnel CONNECT_ACK tunnel_id={tunnel.tunnel_id}")
return
if frame.msg_type == MSG_CHANNEL_OPEN:
await self._handle_channel_open(tunnel, frame)
@@ -584,6 +653,15 @@ class Role:
await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "channel_exists"))
return
if protocol.lower() == "ps" and "ps" not in self._protocol_handlers:
hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else ""
module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py"
exists_hint = f" exists={module_path.exists()}"
self._log(
f"reverse_tunnel ps handler missing; falling back to BaseChannel{hint}{exists_hint}",
error=True,
)
handler_cls = self._protocol_handlers.get(protocol.lower()) or BaseChannel
try:
handler = handler_cls(self, tunnel, frame.channel_id, metadata)
@@ -599,6 +677,10 @@ class Role:
payload=json.dumps({"status": "ok", "protocol": protocol}, separators=(",", ":")).encode("utf-8"),
),
)
self._log(
f"reverse_tunnel channel_opened tunnel_id={tunnel.tunnel_id} channel={frame.channel_id} "
f"protocol={protocol} handler={handler.__class__.__name__} metadata={metadata}"
)
async def _send_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
if tunnel.stopping:

View File

@@ -0,0 +1,90 @@
# ======================================================
# Data\Engine\Unit_Tests\test_reverse_tunnel.py
# Description: Validates reverse tunnel lease API basics (allocation, token contents, and domain limit).
#
# API Endpoints (if applicable):
# - POST /api/tunnel/request
# ======================================================
from __future__ import annotations
import base64
import json
import pytest
from .conftest import EngineTestHarness
def _client_with_admin_session(harness: EngineTestHarness):
client = harness.app.test_client()
with client.session_transaction() as sess:
sess["username"] = "admin"
sess["role"] = "Admin"
return client
def _decode_token_segment(token: str) -> dict:
"""Decode the unsigned payload segment from the tunnel token."""
if not token:
return {}
segment = token.split(".")[0]
padding = "=" * (-len(segment) % 4)
raw = base64.urlsafe_b64decode(segment + padding)
try:
return json.loads(raw.decode("utf-8"))
except Exception:
return {}
@pytest.mark.parametrize("agent_id", ["test-device-agent"])
def test_tunnel_request_happy_path(engine_harness: EngineTestHarness, agent_id: str) -> None:
client = _client_with_admin_session(engine_harness)
resp = client.post(
"/api/tunnel/request",
json={"agent_id": agent_id, "protocol": "ps", "domain": "ps"},
)
assert resp.status_code == 200
payload = resp.get_json()
assert payload["agent_id"] == agent_id
assert payload["protocol"] == "ps"
assert payload["domain"] == "ps"
assert isinstance(payload["port"], int) and payload["port"] >= 30000
assert payload.get("token")
claims = _decode_token_segment(payload["token"])
assert claims.get("agent_id") == agent_id
assert claims.get("protocol") == "ps"
assert claims.get("domain") == "ps"
assert claims.get("tunnel_id") == payload["tunnel_id"]
assert claims.get("assigned_port") == payload["port"]
def test_tunnel_request_domain_limit(engine_harness: EngineTestHarness) -> None:
client = _client_with_admin_session(engine_harness)
first = client.post(
"/api/tunnel/request",
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
)
assert first.status_code == 200
second = client.post(
"/api/tunnel/request",
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
)
assert second.status_code == 409
data = second.get_json()
assert data.get("error") == "domain_limit"
def test_tunnel_request_includes_timeouts(engine_harness: EngineTestHarness) -> None:
client = _client_with_admin_session(engine_harness)
resp = client.post(
"/api/tunnel/request",
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
)
assert resp.status_code == 200
payload = resp.get_json()
assert payload.get("idle_seconds") and payload["idle_seconds"] > 0
assert payload.get("grace_seconds") and payload["grace_seconds"] > 0
assert payload.get("expires_at") and int(payload["expires_at"]) > 0

View File

@@ -0,0 +1,101 @@
# ======================================================
# Data\Engine\Unit_Tests\test_reverse_tunnel_integration.py
# Description: Integration test that exercises a full reverse tunnel PowerShell round-trip
# against a running Engine + Agent (requires live services).
#
# Requirements:
# - Environment variables must be set to point at a live Engine + Agent:
# TUNNEL_TEST_HOST (e.g., https://localhost:5000)
# TUNNEL_TEST_AGENT_ID (agent_id/agent_guid for the target device)
# TUNNEL_TEST_BEARER (Authorization bearer token for an admin/operator)
# - A live Agent must be reachable and allowed to establish the reverse tunnel.
# - TLS verification can be controlled via TUNNEL_TEST_VERIFY ("false" to disable).
#
# API Endpoints (if applicable):
# - POST /api/tunnel/request
# - Socket.IO namespace /tunnel (join, ps_open, ps_send, ps_poll)
# ======================================================
from __future__ import annotations
import os
import time
import pytest
import requests
import socketio
HOST = os.environ.get("TUNNEL_TEST_HOST", "").strip()
AGENT_ID = os.environ.get("TUNNEL_TEST_AGENT_ID", "").strip()
BEARER = os.environ.get("TUNNEL_TEST_BEARER", "").strip()
VERIFY_ENV = os.environ.get("TUNNEL_TEST_VERIFY", "").strip().lower()
VERIFY = False if VERIFY_ENV in {"false", "0", "no"} else True
SKIP_MSG = (
"Live tunnel test skipped (set TUNNEL_TEST_HOST, TUNNEL_TEST_AGENT_ID, TUNNEL_TEST_BEARER to run)"
)
def _require_env() -> None:
if not HOST or not AGENT_ID or not BEARER:
pytest.skip(SKIP_MSG)
def _make_session() -> requests.Session:
sess = requests.Session()
sess.verify = VERIFY
sess.headers.update({"Authorization": f"Bearer {BEARER}"})
return sess
def test_reverse_tunnel_powershell_roundtrip() -> None:
_require_env()
sess = _make_session()
# 1) Request a tunnel lease
resp = sess.post(
f"{HOST}/api/tunnel/request",
json={"agent_id": AGENT_ID, "protocol": "ps", "domain": "ps"},
)
assert resp.status_code == 200, f"lease request failed: {resp.status_code} {resp.text}"
lease = resp.json()
tunnel_id = lease["tunnel_id"]
# 2) Connect to Socket.IO /tunnel namespace
sio = socketio.Client()
sio.connect(
HOST,
namespaces=["/tunnel"],
headers={"Authorization": f"Bearer {BEARER}"},
transports=["websocket"],
wait_timeout=10,
)
# 3) Join tunnel and open PS channel
join_resp = sio.call("join", {"tunnel_id": tunnel_id}, namespace="/tunnel", timeout=10)
assert join_resp.get("status") == "ok", f"join failed: {join_resp}"
open_resp = sio.call("ps_open", {"cols": 120, "rows": 32}, namespace="/tunnel", timeout=10)
assert not open_resp.get("error"), f"ps_open failed: {open_resp}"
# 4) Send a command
send_resp = sio.call("ps_send", {"data": 'Write-Host "Hello World"\r\n'}, namespace="/tunnel", timeout=10)
assert not send_resp.get("error"), f"ps_send failed: {send_resp}"
# 5) Poll for output
output_text = ""
deadline = time.time() + 15
while time.time() < deadline:
poll_resp = sio.call("ps_poll", {}, namespace="/tunnel", timeout=10)
if poll_resp.get("error"):
pytest.fail(f"ps_poll failed: {poll_resp}")
lines = poll_resp.get("output") or []
output_text += "".join(lines)
if "Hello World" in output_text:
break
time.sleep(0.5)
sio.disconnect()
assert "Hello World" in output_text, f"expected command output not found; got: {output_text!r}"

View File

@@ -391,11 +391,13 @@ class TunnelLeaseManager:
def expire_idle(self, *, now_ts: Optional[float] = None) -> List[TunnelLease]:
now = now_ts or _utc_ts()
expired: List[TunnelLease] = []
pending_timeout = min(self.grace_timeout_seconds, 300) # avoid long-lived pending locks
for lease in list(self._leases.values()):
if lease.state == "expired":
continue
idle_age = now - lease.last_activity_ts
pending_age = now - lease.created_at
if lease.state == "active" and idle_age >= lease.idle_timeout_seconds:
lease.mark_expired()
expired.append(lease)
@@ -409,6 +411,14 @@ class TunnelLeaseManager:
expired.append(lease)
self.release(lease.tunnel_id, reason="grace_expired")
continue
if lease.state == "pending":
hard_expiry = lease.expires_at or (lease.created_at + lease.grace_timeout_seconds)
if pending_age >= pending_timeout or (hard_expiry and now >= hard_expiry):
lease.mark_expired()
expired.append(lease)
self.release(lease.tunnel_id, reason="pending_timeout")
continue
return expired
def all_leases(self) -> Iterable[TunnelLease]:
@@ -591,6 +601,7 @@ class ReverseTunnelService:
)
lease.token = self.issue_token(lease)
self._spawn_port_listener(lease.assigned_port)
self._push_start_to_agent(lease)
self.audit_logger.info(
"lease_created tunnel_id=%s agent_id=%s domain=%s protocol=%s port=%s operator=%s",
lease.tunnel_id,
@@ -618,6 +629,35 @@ class ReverseTunnelService:
lease.expires_at = expires_at
return token
def _push_start_to_agent(self, lease: TunnelLease) -> None:
"""Notify the target agent about the new lease over Socket.IO (best-effort)."""
if not self._socketio:
return
payload = {
"tunnel_id": lease.tunnel_id,
"lease_id": lease.tunnel_id,
"agent_id": lease.agent_id,
"token": lease.token,
"port": lease.assigned_port,
"assigned_port": lease.assigned_port,
"protocol": lease.protocol,
"domain": lease.domain,
"idle_seconds": lease.idle_timeout_seconds,
"grace_seconds": lease.grace_timeout_seconds,
"heartbeat_seconds": self.heartbeat_seconds,
}
try:
self._socketio.emit("reverse_tunnel_start", payload, namespace="/")
self.audit_logger.info(
"lease_push_start tunnel_id=%s agent_id=%s port=%s",
lease.tunnel_id,
lease.agent_id,
lease.assigned_port,
)
except Exception:
self.logger.debug("Failed to emit reverse_tunnel_start for tunnel_id=%s", lease.tunnel_id, exc_info=True)
def lease_summary(self, lease: TunnelLease) -> Dict[str, object]:
return {
"tunnel_id": lease.tunnel_id,
@@ -874,6 +914,7 @@ class ReverseTunnelService:
cert = self.context.tls_cert_path or self.context.tls_bundle_path
key = self.context.tls_key_path
if not cert or not key:
self.audit_logger.info("tunnel_listener_ssl_missing cert=%s key=%s", cert, key)
return None
try:
ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
@@ -890,6 +931,7 @@ class ReverseTunnelService:
if port in self._port_servers:
return
ssl_ctx = self._build_ssl_context()
self.audit_logger.info("tunnel_listener_start port=%s ssl=%s", port, bool(ssl_ctx))
async def _handler(websocket, path):
await self._handle_agent_socket(websocket, path, port=port)
@@ -897,6 +939,7 @@ class ReverseTunnelService:
async def _start():
server = await ws_serve(_handler, host="0.0.0.0", port=port, ssl=ssl_ctx, max_size=None, ping_interval=None)
self._port_servers[port] = server
self.audit_logger.info("tunnel_listener_bound port=%s", port)
asyncio.run_coroutine_threadsafe(_start(), self._loop)
@@ -904,15 +947,24 @@ class ReverseTunnelService:
"""Handle agent tunnel socket on assigned port."""
tunnel_id = None
sock_log = self.audit_logger.getChild("agent_socket")
try:
peer = None
try:
peer = getattr(websocket, "remote_address", None)
except Exception:
peer = None
sock_log.info("agent_socket_open port=%s path=%s peer=%s", port, path, peer)
raw = await asyncio.wait_for(websocket.recv(), timeout=10)
frame = decode_frame(raw)
if frame.msg_type != MSG_CONNECT:
sock_log.info("agent_socket_first_frame_not_connect port=%s msg_type=%s", port, frame.msg_type)
await websocket.close()
return
try:
payload = json.loads(frame.payload.decode("utf-8"))
except Exception:
sock_log.info("agent_socket_connect_payload_decode_failed port=%s", port, exc_info=True)
await websocket.close()
return
tunnel_id = str(payload.get("tunnel_id") or "").strip()
@@ -920,23 +972,56 @@ class ReverseTunnelService:
token = payload.get("token") or ""
lease = self.lease_manager.get(tunnel_id)
if lease is None or lease.assigned_port != port:
sock_log.info(
"agent_socket_unknown_lease port=%s tunnel_id=%s assigned=%s expected=%s",
port,
tunnel_id,
lease.assigned_port if lease else None,
port,
)
await websocket.close()
return
# Token validation
self.validate_token(
token,
agent_id=agent_id,
tunnel_id=tunnel_id,
domain=lease.domain,
protocol=lease.protocol,
)
try:
self.validate_token(
token,
agent_id=agent_id,
tunnel_id=tunnel_id,
domain=lease.domain,
protocol=lease.protocol,
)
sock_log.info(
"agent_socket_token_valid port=%s tunnel_id=%s agent_id=%s domain=%s protocol=%s",
port,
tunnel_id,
agent_id,
lease.domain,
lease.protocol,
)
except Exception as exc:
sock_log.info(
"agent_socket_token_invalid port=%s tunnel_id=%s agent_id=%s error=%s",
port,
tunnel_id,
agent_id,
exc,
)
await websocket.close()
return
bridge = self.ensure_bridge(lease)
bridge.attach_agent(token)
self._agent_sockets[tunnel_id] = websocket
await websocket.send(heartbeat_frame(channel_id=0, is_ack=True).encode())
await websocket.send(TunnelFrame(msg_type=MSG_CONNECT_ACK, channel_id=0, payload=b"").encode())
sock_log.info(
"agent_socket_connected port=%s tunnel_id=%s agent_id=%s",
port,
tunnel_id,
agent_id,
)
async def _pump_to_operator():
sock_log_local = sock_log.getChild("recv")
while not websocket.closed:
try:
raw_msg = await websocket.recv()
@@ -945,14 +1030,23 @@ class ReverseTunnelService:
try:
recv_frame = decode_frame(raw_msg)
except Exception:
sock_log_local.info("agent_socket_frame_decode_failed tunnel_id=%s", tunnel_id, exc_info=True)
continue
self.lease_manager.touch(tunnel_id)
sock_log_local.info(
"agent_to_operator tunnel_id=%s msg_type=%s channel=%s payload_len=%s",
tunnel_id,
recv_frame.msg_type,
recv_frame.channel_id,
len(recv_frame.payload or b""),
)
try:
self._dispatch_agent_frame(tunnel_id, recv_frame)
except Exception:
pass
bridge.agent_to_operator(recv_frame)
async def _pump_to_agent():
sock_log_local = sock_log.getChild("send")
while not websocket.closed:
frame = bridge.next_for_agent()
if frame is None:
@@ -960,13 +1054,22 @@ class ReverseTunnelService:
continue
try:
await websocket.send(frame.encode())
sock_log_local.info(
"operator_to_agent tunnel_id=%s msg_type=%s channel=%s payload_len=%s",
tunnel_id,
frame.msg_type,
frame.channel_id,
len(frame.payload or b""),
)
except Exception:
break
async def _heartbeat():
sock_log_local = sock_log.getChild("heartbeat")
while not websocket.closed:
try:
await websocket.send(heartbeat_frame(channel_id=0).encode())
except Exception:
sock_log_local.info("heartbeat_send_failed tunnel_id=%s", tunnel_id, exc_info=True)
break
await asyncio.sleep(self.heartbeat_seconds)
@@ -975,8 +1078,18 @@ class ReverseTunnelService:
heart = asyncio.create_task(_heartbeat())
await asyncio.wait([consumer, producer, heart], return_when=asyncio.FIRST_COMPLETED)
except Exception:
self.logger.debug("Agent socket handler failed on port %s", port, exc_info=True)
sock_log.info("agent_socket_handler_failed port=%s tunnel_id=%s", port, tunnel_id, exc_info=True)
finally:
try:
sock_log.info(
"agent_socket_closed port=%s tunnel_id=%s code=%s reason=%s",
port,
tunnel_id,
getattr(websocket, "close_code", None),
getattr(websocket, "close_reason", None),
)
except Exception:
pass
if tunnel_id and tunnel_id in self._agent_sockets:
self._agent_sockets.pop(tunnel_id, None)
if tunnel_id:
@@ -997,7 +1110,7 @@ class ReverseTunnelService:
if server:
return server
lease = self.lease_manager.get(tunnel_id)
if lease is None or (lease.domain or "").lower() != "ps":
if lease is None:
return None
bridge = self.ensure_bridge(lease)
server = PowershellChannelServer(

View File

@@ -474,7 +474,7 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
return {"error": "ps_resize_failed"}
@socket_server.on("ps_poll", namespace=tunnel_namespace)
def _ws_ps_poll() -> Any:
def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused
server, tunnel_id, error = _require_ps_server()
if server is None:
return error

View File

@@ -31,6 +31,7 @@ import "prismjs/themes/prism-okaidia.css";
import Editor from "react-simple-code-editor";
import { AgGridReact } from "ag-grid-react";
import { ModuleRegistry, AllCommunityModule, themeQuartz } from "ag-grid-community";
import ReverseTunnelPowershell from "./ReverseTunnel/Powershell.jsx";
ModuleRegistry.registerModules([AllCommunityModule]);
@@ -63,7 +64,15 @@ const SECTION_HEIGHTS = {
network: 260,
};
const TOP_TABS = ["Device Summary", "Storage", "Memory", "Network", "Installed Software", "Activity History"];
const TOP_TABS = [
"Device Summary",
"Storage",
"Memory",
"Network",
"Installed Software",
"Activity History",
"Remote Shell",
];
const myTheme = themeQuartz.withParams({
accentColor: "#8b5cf6",
@@ -727,6 +736,17 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
);
const summary = details.summary || {};
const tunnelDevice = useMemo(
() => ({
...(device || {}),
...(agent || {}),
summary,
hostname: meta.hostname || summary.hostname || device?.hostname || agent?.hostname,
agent_id: meta.agentId || summary.agent_id || agent?.agent_id || agent?.id || device?.agent_id || device?.agent_guid,
agent_guid: meta.agentGuid || summary.agent_guid || device?.agent_guid || device?.guid || agent?.agent_guid || agent?.guid,
}),
[agent, device, meta.agentGuid, meta.agentId, meta.hostname, summary]
);
// Build a best-effort CPU display from summary fields
const cpuInfo = useMemo(() => {
const cpu = details.cpu || summary.cpu || {};
@@ -850,16 +870,20 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
[]
);
const formatScriptType = useCallback((raw) => {
const value = String(raw || "").toLowerCase();
if (value === "ansible") return "Ansible Playbook";
if (value === "reverse_tunnel") return "Reverse Tunnel";
return "Script";
}, []);
const historyColumnDefs = useMemo(
() => [
{
headerName: "Assembly",
headerName: "Activity",
field: "script_type",
minWidth: 180,
valueGetter: (params) =>
String(params.data?.script_type || "").toLowerCase() === "ansible"
? "Ansible Playbook"
: "Script",
valueGetter: (params) => formatScriptType(params.data?.script_type),
},
{
headerName: "Task",
@@ -891,7 +915,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
cellRenderer: "HistoryActionsCell",
},
],
[formatTimestamp]
[formatScriptType, formatTimestamp]
);
const MetricCard = ({ icon, title, main, sub, compact = false }) => (
@@ -1265,6 +1289,19 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
</Box>
);
const renderRemoteShellTab = () => (
<Box
sx={{
display: "flex",
flexDirection: "column",
flexGrow: 1,
minHeight: 0,
}}
>
<ReverseTunnelPowershell device={tunnelDevice} />
</Box>
);
const memoryRows = useMemo(
() =>
(details.memory || []).map((m, idx) => ({
@@ -1523,6 +1560,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
renderNetworkTab,
renderSoftware,
renderHistory,
renderRemoteShellTab,
];
const tabContent = (topTabRenderers[tab] || renderDeviceSummaryTab)();
@@ -1642,7 +1680,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
anchorEl={menuAnchor}
open={Boolean(menuAnchor)}
onClose={() => setMenuAnchor(null)}
PaperProps={{
PaperProps={{
sx: {
bgcolor: "rgba(8,12,24,0.96)",
color: "#fff",

View File

@@ -0,0 +1,704 @@
import React, { useCallback, useEffect, useMemo, useRef, useState } from "react";
import {
Box,
Typography,
Button,
Stack,
Chip,
TextField,
MenuItem,
IconButton,
Tooltip,
Alert,
LinearProgress,
} from "@mui/material";
import {
TerminalRounded as TerminalIcon,
PlayArrowRounded as PlayIcon,
StopRounded as StopIcon,
ContentCopy as CopyIcon,
RefreshRounded as RefreshIcon,
LanRounded as PortIcon,
SensorsRounded as ActivityIcon,
LinkRounded as LinkIcon,
} from "@mui/icons-material";
import { io } from "socket.io-client";
import Prism from "prismjs";
import "prismjs/components/prism-powershell";
import "prismjs/themes/prism-okaidia.css";
import Editor from "react-simple-code-editor";
const MAGIC_UI = {
panelBg: "rgba(7,11,24,0.92)",
panelBorder: "rgba(148, 163, 184, 0.35)",
textMuted: "#94a3b8",
textBright: "#e2e8f0",
accentA: "#7dd3fc",
accentB: "#c084fc",
accentC: "#34d399",
};
const gradientButtonSx = {
backgroundImage: "linear-gradient(135deg,#7dd3fc,#c084fc)",
color: "#0b1220",
borderRadius: 999,
textTransform: "none",
boxShadow: "0 10px 26px rgba(124,58,237,0.28)",
px: 2.2,
minWidth: 120,
"&:hover": {
backgroundImage: "linear-gradient(135deg,#86e1ff,#d1a6ff)",
boxShadow: "0 12px 34px rgba(124,58,237,0.38)",
},
};
const FRAME_HEADER_BYTES = 12; // version, msg_type, flags, reserved, channel_id(u32), length(u32)
const MSG_CLOSE = 0x08;
const CLOSE_AGENT_SHUTDOWN = 6;
const fontFamilyMono =
'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace';
function normalizeText(value) {
if (value == null) return "";
try {
return String(value).trim();
} catch {
return "";
}
}
function base64FromBytes(bytes) {
let binary = "";
bytes.forEach((b) => {
binary += String.fromCharCode(b);
});
return btoa(binary);
}
function buildCloseFrame(channelId = 1, code = CLOSE_AGENT_SHUTDOWN, reason = "operator_close") {
const payload = new TextEncoder().encode(JSON.stringify({ code, reason }));
const buffer = new ArrayBuffer(FRAME_HEADER_BYTES + payload.length);
const view = new DataView(buffer);
view.setUint8(0, 1); // version
view.setUint8(1, MSG_CLOSE);
view.setUint8(2, 0); // flags
view.setUint8(3, 0); // reserved
view.setUint32(4, channelId >>> 0, true);
view.setUint32(8, payload.length >>> 0, true);
new Uint8Array(buffer, FRAME_HEADER_BYTES).set(payload);
return base64FromBytes(new Uint8Array(buffer));
}
function highlightPs(code) {
try {
return Prism.highlight(code || "", Prism.languages.powershell, "powershell");
} catch {
return code || "";
}
}
export default function ReverseTunnelPowershell({ device }) {
const [connectionType, setConnectionType] = useState("ps");
const [tunnel, setTunnel] = useState(null);
const [sessionState, setSessionState] = useState("idle");
const [statusMessage, setStatusMessage] = useState("");
const [statusSeverity, setStatusSeverity] = useState("info");
const [output, setOutput] = useState("");
const [input, setInput] = useState("");
const [copyFlash, setCopyFlash] = useState(false);
const [, setPolling] = useState(false);
const [psStatus, setPsStatus] = useState({});
const socketRef = useRef(null);
const pollTimerRef = useRef(null);
const resizeTimerRef = useRef(null);
const terminalRef = useRef(null);
const joinRetryRef = useRef(null);
const joinAttemptsRef = useRef(0);
const hostname = useMemo(() => {
return (
normalizeText(device?.hostname) ||
normalizeText(device?.summary?.hostname) ||
normalizeText(device?.agent_hostname) ||
""
);
}, [device]);
const agentId = useMemo(() => {
return (
normalizeText(device?.agent_id) ||
normalizeText(device?.agentId) ||
normalizeText(device?.agent_guid) ||
normalizeText(device?.agentGuid) ||
normalizeText(device?.id) ||
normalizeText(device?.guid) ||
normalizeText(device?.summary?.agent_id) ||
""
);
}, [device]);
const resetState = useCallback(() => {
setTunnel(null);
setSessionState("idle");
setStatusMessage("");
setStatusSeverity("info");
setOutput("");
setInput("");
setPsStatus({});
}, []);
const disconnectSocket = useCallback(() => {
const socket = socketRef.current;
if (socket) {
socket.off();
socket.disconnect();
}
socketRef.current = null;
}, []);
const stopPolling = useCallback(() => {
if (pollTimerRef.current) {
clearTimeout(pollTimerRef.current);
pollTimerRef.current = null;
}
setPolling(false);
}, []);
useEffect(() => {
return () => {
stopPolling();
disconnectSocket();
if (joinRetryRef.current) {
clearTimeout(joinRetryRef.current);
joinRetryRef.current = null;
}
};
}, [disconnectSocket, stopPolling]);
const appendOutput = useCallback((text) => {
if (!text) return;
setOutput((prev) => {
const next = `${prev}${text}`;
const limit = 40000;
return next.length > limit ? next.slice(next.length - limit) : next;
});
}, []);
const scrollToBottom = useCallback(() => {
const el = terminalRef.current;
if (!el) return;
requestAnimationFrame(() => {
el.scrollTop = el.scrollHeight;
});
}, []);
useEffect(() => {
scrollToBottom();
}, [output, scrollToBottom]);
const handleCopy = async () => {
try {
await navigator.clipboard.writeText(output || "");
setCopyFlash(true);
setTimeout(() => setCopyFlash(false), 1200);
} catch {
setCopyFlash(false);
}
};
const measureTerminal = useCallback(() => {
const el = terminalRef.current;
if (!el) return { cols: 120, rows: 32 };
const width = el.clientWidth || 960;
const height = el.clientHeight || 460;
const charWidth = 8.2;
const charHeight = 18;
const cols = Math.max(20, Math.min(Math.floor(width / charWidth), 300));
const rows = Math.max(10, Math.min(Math.floor(height / charHeight), 200));
return { cols, rows };
}, []);
const emitAsync = useCallback((socket, event, payload, timeoutMs = 4000) => {
return new Promise((resolve) => {
let settled = false;
const timer = setTimeout(() => {
if (settled) return;
settled = true;
resolve({ error: "timeout" });
}, timeoutMs);
socket.emit(event, payload, (resp) => {
if (settled) return;
settled = true;
clearTimeout(timer);
resolve(resp || {});
});
});
}, []);
const pollLoop = useCallback(
(socket, tunnelId) => {
if (!socket || !tunnelId) return;
setPolling(true);
pollTimerRef.current = setTimeout(async () => {
const resp = await emitAsync(socket, "ps_poll", {});
if (resp?.error) {
if (resp.error === "ps_unsupported") {
setStatusSeverity("info");
setStatusMessage("PowerShell channel warming up...");
} else {
setStatusSeverity("warning");
setStatusMessage(resp.error);
}
}
if (Array.isArray(resp?.output) && resp.output.length) {
appendOutput(resp.output.join(""));
}
if (resp?.status) {
setPsStatus(resp.status);
if (resp.status.closed) {
setSessionState("closed");
setStatusSeverity("warning");
setStatusMessage(resp.status.close_reason || "Session closed");
stopPolling();
return;
}
if (resp.status.ack) {
setSessionState("connected");
}
}
pollLoop(socket, tunnelId);
}, 520);
},
[appendOutput, emitAsync, stopPolling]
);
const handleDisconnect = useCallback(() => {
const socket = socketRef.current;
const tunnelId = tunnel?.tunnel_id;
if (joinRetryRef.current) {
clearTimeout(joinRetryRef.current);
joinRetryRef.current = null;
}
joinAttemptsRef.current = 0;
if (socket && tunnelId) {
const frame = buildCloseFrame(1, CLOSE_AGENT_SHUTDOWN, "operator_close");
socket.emit("send", { frame });
}
stopPolling();
disconnectSocket();
setSessionState("closed");
setStatusSeverity("info");
setStatusMessage("Session closed by operator.");
}, [disconnectSocket, stopPolling, tunnel?.tunnel_id]);
const handleResize = useCallback(() => {
if (!socketRef.current || sessionState === "idle") return;
const dims = measureTerminal();
socketRef.current.emit("ps_resize", dims);
}, [measureTerminal, sessionState]);
useEffect(() => {
const observer =
typeof ResizeObserver !== "undefined"
? new ResizeObserver(() => {
if (resizeTimerRef.current) clearTimeout(resizeTimerRef.current);
resizeTimerRef.current = setTimeout(() => handleResize(), 200);
})
: null;
const el = terminalRef.current;
if (observer && el) observer.observe(el);
const onWinResize = () => handleResize();
window.addEventListener("resize", onWinResize);
return () => {
window.removeEventListener("resize", onWinResize);
if (observer && el) observer.unobserve(el);
};
}, [handleResize]);
const connectSocket = useCallback(
(lease, { isRetry = false } = {}) => {
if (!lease?.tunnel_id) return;
if (joinRetryRef.current) {
clearTimeout(joinRetryRef.current);
joinRetryRef.current = null;
}
if (!isRetry) {
joinAttemptsRef.current = 0;
}
disconnectSocket();
stopPolling();
setSessionState("waiting");
const socket = io(`${window.location.origin}/tunnel`, { transports: ["websocket"] });
socketRef.current = socket;
socket.on("connect_error", () => {
setStatusSeverity("warning");
setStatusMessage("Tunnel namespace unavailable.");
});
socket.on("disconnect", () => {
stopPolling();
if (sessionState !== "closed") {
setSessionState("disconnected");
setStatusSeverity("warning");
setStatusMessage("Socket disconnected.");
}
});
socket.on("connect", async () => {
setStatusSeverity("info");
setStatusMessage("Joining tunnel...");
const joinResp = await emitAsync(socket, "join", { tunnel_id: lease.tunnel_id });
if (joinResp?.error) {
if (joinResp.error === "unknown_tunnel") {
setSessionState("waiting_agent");
setStatusSeverity("info");
setStatusMessage("Waiting for agent to establish tunnel...");
joinAttemptsRef.current += 1;
const attempt = joinAttemptsRef.current;
if (attempt <= 15) {
joinRetryRef.current = setTimeout(() => connectSocket(lease, { isRetry: true }), 1000);
} else {
setSessionState("error");
setTunnel(null);
setStatusSeverity("warning");
setStatusMessage("Agent did not attach to tunnel (timeout). Try Connect again.");
}
} else {
setSessionState("error");
setStatusSeverity("error");
setStatusMessage(joinResp.error);
}
return;
}
const dims = measureTerminal();
const openResp = await emitAsync(socket, "ps_open", dims);
if (openResp?.error && openResp.error === "ps_unsupported") {
setStatusSeverity("info");
setStatusMessage("PowerShell channel warming up...");
}
appendOutput("");
setStatusMessage("Attached — waiting for agent to acknowledge...");
setSessionState("waiting_agent");
pollLoop(socket, lease.tunnel_id);
handleResize();
});
},
[appendOutput, disconnectSocket, emitAsync, handleResize, measureTerminal, pollLoop, sessionState, stopPolling]
);
const requestTunnel = useCallback(async () => {
if (tunnel && sessionState !== "closed" && sessionState !== "idle") {
setStatusSeverity("info");
setStatusMessage("Re-attaching to existing tunnel...");
connectSocket(tunnel);
return;
}
if (!agentId) {
setStatusSeverity("warning");
setStatusMessage("Agent ID is required to request a tunnel.");
return;
}
if (connectionType !== "ps") {
setStatusSeverity("warning");
setStatusMessage("Only PowerShell is supported right now.");
return;
}
resetState();
setSessionState("requesting");
setStatusSeverity("info");
setStatusMessage("Requesting tunnel lease...");
try {
const resp = await fetch("/api/tunnel/request", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ agent_id: agentId, protocol: "ps", domain: "ps" }),
});
const data = await resp.json().catch(() => ({}));
if (!resp.ok) {
const err = data?.error || `HTTP ${resp.status}`;
setSessionState("error");
setStatusSeverity(err === "domain_limit" ? "warning" : "error");
setStatusMessage(
err === "domain_limit"
? "PowerShell session already active for this agent. Try again after it closes."
: err
);
return;
}
setTunnel(data);
setStatusMessage("Lease issued. Waiting for agent to connect...");
setSessionState("lease_issued");
connectSocket(data);
} catch (e) {
setSessionState("error");
setStatusSeverity("error");
setStatusMessage(e?.message || "Failed to request tunnel");
}
}, [agentId, connectSocket, connectionType, resetState]);
const handleSend = useCallback(
async (text) => {
const socket = socketRef.current;
if (!socket) return;
const payload = `${text}${text.endsWith("\n") ? "" : "\r\n"}`;
appendOutput(`\nPS> ${text}\n`);
setInput("");
const resp = await emitAsync(socket, "ps_send", { data: payload });
if (resp?.error) {
setStatusSeverity("warning");
setStatusMessage(resp.error);
}
},
[appendOutput, emitAsync]
);
const isConnected = sessionState === "connected" || psStatus?.ack;
const isBusy =
sessionState === "requesting" ||
sessionState === "waiting" ||
sessionState === "waiting_agent" ||
sessionState === "lease_issued";
const canStart = Boolean(agentId) && !isBusy && !isConnected;
const sessionChips = [
{
label: isConnected ? "Connected" : sessionState === "idle" ? "Idle" : sessionState.replace(/_/g, " "),
color: isConnected ? MAGIC_UI.accentC : MAGIC_UI.accentA,
icon: <ActivityIcon sx={{ fontSize: 18 }} />,
},
tunnel?.tunnel_id
? {
label: `Tunnel ${tunnel.tunnel_id.slice(0, 8)}`,
color: MAGIC_UI.accentB,
icon: <LinkIcon sx={{ fontSize: 18 }} />,
}
: null,
tunnel?.port
? {
label: `Port ${tunnel.port}`,
color: MAGIC_UI.accentA,
icon: <PortIcon sx={{ fontSize: 18 }} />,
}
: null,
].filter(Boolean);
return (
<Box sx={{ display: "flex", flexDirection: "column", gap: 1.5, flexGrow: 1, minHeight: 0 }}>
<Box
sx={{
border: `1px solid ${MAGIC_UI.panelBorder}`,
borderRadius: 2,
background: "linear-gradient(120deg, rgba(12,18,35,0.9), rgba(8,12,24,0.92))",
boxShadow: "0 18px 45px rgba(2,6,23,0.6)",
p: { xs: 1.5, md: 2 },
}}
>
<Stack
direction={{ xs: "column", md: "row" }}
spacing={1.5}
alignItems={{ xs: "flex-start", md: "center" }}
justifyContent="space-between"
>
<Stack direction="row" spacing={1} alignItems="center" flexWrap="wrap">
<TerminalIcon sx={{ fontSize: 22, color: MAGIC_UI.accentA }} />
<Typography variant="h6" sx={{ fontWeight: 700, letterSpacing: 0.3 }}>
Remote Shell
</Typography>
{hostname ? (
<Chip
size="small"
label={hostname}
sx={{
background: "rgba(12,18,35,0.8)",
color: MAGIC_UI.textBright,
border: `1px solid ${MAGIC_UI.panelBorder}`,
}}
/>
) : null}
{agentId ? (
<Chip
size="small"
label={`Agent ${agentId}`}
sx={{
background: "rgba(12,18,35,0.8)",
color: MAGIC_UI.textMuted,
border: `1px solid ${MAGIC_UI.panelBorder}`,
}}
/>
) : null}
</Stack>
<Stack direction={{ xs: "column", sm: "row" }} spacing={1} alignItems="center">
<TextField
select
label="Connection Type"
size="small"
value={connectionType}
onChange={(e) => setConnectionType(e.target.value)}
sx={{
minWidth: 180,
"& .MuiInputBase-root": {
backgroundColor: "rgba(12,18,35,0.85)",
color: MAGIC_UI.textBright,
borderRadius: 1.5,
},
"& fieldset": { borderColor: MAGIC_UI.panelBorder },
"&:hover fieldset": { borderColor: MAGIC_UI.accentA },
}}
>
<MenuItem value="ps">PowerShell</MenuItem>
</TextField>
<Tooltip title={isConnected ? "Disconnect session" : "Connect to agent"}>
<span>
<Button
size="small"
startIcon={isConnected ? <StopIcon /> : <PlayIcon />}
sx={gradientButtonSx}
disabled={!isConnected && !canStart}
onClick={isConnected ? handleDisconnect : requestTunnel}
>
{isConnected ? "Disconnect" : "Connect"}
</Button>
</span>
</Tooltip>
</Stack>
</Stack>
<Stack direction={{ xs: "column", md: "row" }} spacing={1} sx={{ mt: 1, flexWrap: "wrap" }}>
{sessionChips.map((chip) => (
<Chip
key={chip.label}
icon={chip.icon}
label={chip.label}
sx={{
background: "rgba(12,18,35,0.85)",
color: chip.color,
border: `1px solid ${chip.color}44`,
fontWeight: 600,
}}
/>
))}
</Stack>
</Box>
{statusMessage ? (
<Alert
severity={statusSeverity}
sx={{
borderRadius: 2,
backgroundColor: "rgba(8,12,24,0.9)",
border: `1px solid ${MAGIC_UI.panelBorder}`,
color: MAGIC_UI.textBright,
}}
>
{statusMessage}
</Alert>
) : null}
<Box
sx={{
flexGrow: 1,
minHeight: 320,
display: "flex",
flexDirection: "column",
borderRadius: 3,
border: `1px solid ${MAGIC_UI.panelBorder}`,
background:
"linear-gradient(145deg, rgba(8,12,24,0.94), rgba(10,16,30,0.9)), radial-gradient(circle at 20% 20%, rgba(125,211,252,0.08), transparent 35%)",
boxShadow: "0 25px 80px rgba(2,6,23,0.85)",
overflow: "hidden",
}}
>
{isBusy ? <LinearProgress color="info" sx={{ height: 3 }} /> : null}
<Box
ref={terminalRef}
sx={{
flexGrow: 1,
minHeight: 240,
maxHeight: "100%",
overflow: "auto",
position: "relative",
p: 2,
"& pre": {
margin: 0,
fontFamily: fontFamilyMono,
fontSize: 13,
lineHeight: 1.5,
color: "#e6edf3",
},
}}
>
<Editor
value={output}
onValueChange={() => {}}
highlight={highlightPs}
padding={12}
readOnly
style={{
minHeight: "100%",
background: "transparent",
color: "#e6edf3",
fontFamily: fontFamilyMono,
fontSize: 13,
}}
/>
<Box sx={{ position: "absolute", top: 8, right: 8, display: "flex", gap: 0.5 }}>
<Tooltip title="Copy output">
<IconButton size="small" onClick={handleCopy} sx={{ color: copyFlash ? MAGIC_UI.accentC : MAGIC_UI.textMuted }}>
<CopyIcon fontSize="small" />
</IconButton>
</Tooltip>
<Tooltip title="Clear output">
<IconButton
size="small"
onClick={() => setOutput("")}
sx={{ color: MAGIC_UI.textMuted }}
>
<RefreshIcon fontSize="small" />
</IconButton>
</Tooltip>
</Box>
</Box>
<Box
sx={{
borderTop: `1px solid ${MAGIC_UI.panelBorder}`,
p: 1.5,
background: "rgba(6,10,20,0.92)",
}}
>
<TextField
fullWidth
size="small"
value={input}
disabled={!isConnected}
placeholder={
isConnected
? "Enter PowerShell command and press Enter"
: "Connect to start sending commands"
}
onChange={(e) => setInput(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
const text = input.trim();
if (text) handleSend(text);
}
}}
InputProps={{
sx: {
backgroundColor: "rgba(12,18,35,0.9)",
color: "#e2e8f0",
borderRadius: 2,
"& fieldset": { borderColor: "rgba(148,163,184,0.45)" },
"&:hover fieldset": { borderColor: MAGIC_UI.accentA },
},
}}
/>
</Box>
</Box>
</Box>
);
}