Major Progress Towards Interactive Remote Powershell

This commit is contained in:
2025-12-06 00:27:57 -07:00
parent 52e40c3753
commit 68dd46347b
9 changed files with 1247 additions and 53 deletions

View File

@@ -0,0 +1,90 @@
# ======================================================
# Data\Engine\Unit_Tests\test_reverse_tunnel.py
# Description: Validates reverse tunnel lease API basics (allocation, token contents, and domain limit).
#
# API Endpoints (if applicable):
# - POST /api/tunnel/request
# ======================================================
from __future__ import annotations
import base64
import json
import pytest
from .conftest import EngineTestHarness
def _client_with_admin_session(harness: EngineTestHarness):
client = harness.app.test_client()
with client.session_transaction() as sess:
sess["username"] = "admin"
sess["role"] = "Admin"
return client
def _decode_token_segment(token: str) -> dict:
"""Decode the unsigned payload segment from the tunnel token."""
if not token:
return {}
segment = token.split(".")[0]
padding = "=" * (-len(segment) % 4)
raw = base64.urlsafe_b64decode(segment + padding)
try:
return json.loads(raw.decode("utf-8"))
except Exception:
return {}
@pytest.mark.parametrize("agent_id", ["test-device-agent"])
def test_tunnel_request_happy_path(engine_harness: EngineTestHarness, agent_id: str) -> None:
client = _client_with_admin_session(engine_harness)
resp = client.post(
"/api/tunnel/request",
json={"agent_id": agent_id, "protocol": "ps", "domain": "ps"},
)
assert resp.status_code == 200
payload = resp.get_json()
assert payload["agent_id"] == agent_id
assert payload["protocol"] == "ps"
assert payload["domain"] == "ps"
assert isinstance(payload["port"], int) and payload["port"] >= 30000
assert payload.get("token")
claims = _decode_token_segment(payload["token"])
assert claims.get("agent_id") == agent_id
assert claims.get("protocol") == "ps"
assert claims.get("domain") == "ps"
assert claims.get("tunnel_id") == payload["tunnel_id"]
assert claims.get("assigned_port") == payload["port"]
def test_tunnel_request_domain_limit(engine_harness: EngineTestHarness) -> None:
client = _client_with_admin_session(engine_harness)
first = client.post(
"/api/tunnel/request",
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
)
assert first.status_code == 200
second = client.post(
"/api/tunnel/request",
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
)
assert second.status_code == 409
data = second.get_json()
assert data.get("error") == "domain_limit"
def test_tunnel_request_includes_timeouts(engine_harness: EngineTestHarness) -> None:
client = _client_with_admin_session(engine_harness)
resp = client.post(
"/api/tunnel/request",
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
)
assert resp.status_code == 200
payload = resp.get_json()
assert payload.get("idle_seconds") and payload["idle_seconds"] > 0
assert payload.get("grace_seconds") and payload["grace_seconds"] > 0
assert payload.get("expires_at") and int(payload["expires_at"]) > 0

View File

@@ -0,0 +1,101 @@
# ======================================================
# Data\Engine\Unit_Tests\test_reverse_tunnel_integration.py
# Description: Integration test that exercises a full reverse tunnel PowerShell round-trip
# against a running Engine + Agent (requires live services).
#
# Requirements:
# - Environment variables must be set to point at a live Engine + Agent:
# TUNNEL_TEST_HOST (e.g., https://localhost:5000)
# TUNNEL_TEST_AGENT_ID (agent_id/agent_guid for the target device)
# TUNNEL_TEST_BEARER (Authorization bearer token for an admin/operator)
# - A live Agent must be reachable and allowed to establish the reverse tunnel.
# - TLS verification can be controlled via TUNNEL_TEST_VERIFY ("false" to disable).
#
# API Endpoints (if applicable):
# - POST /api/tunnel/request
# - Socket.IO namespace /tunnel (join, ps_open, ps_send, ps_poll)
# ======================================================
from __future__ import annotations
import os
import time
import pytest
import requests
import socketio
HOST = os.environ.get("TUNNEL_TEST_HOST", "").strip()
AGENT_ID = os.environ.get("TUNNEL_TEST_AGENT_ID", "").strip()
BEARER = os.environ.get("TUNNEL_TEST_BEARER", "").strip()
VERIFY_ENV = os.environ.get("TUNNEL_TEST_VERIFY", "").strip().lower()
VERIFY = False if VERIFY_ENV in {"false", "0", "no"} else True
SKIP_MSG = (
"Live tunnel test skipped (set TUNNEL_TEST_HOST, TUNNEL_TEST_AGENT_ID, TUNNEL_TEST_BEARER to run)"
)
def _require_env() -> None:
if not HOST or not AGENT_ID or not BEARER:
pytest.skip(SKIP_MSG)
def _make_session() -> requests.Session:
sess = requests.Session()
sess.verify = VERIFY
sess.headers.update({"Authorization": f"Bearer {BEARER}"})
return sess
def test_reverse_tunnel_powershell_roundtrip() -> None:
_require_env()
sess = _make_session()
# 1) Request a tunnel lease
resp = sess.post(
f"{HOST}/api/tunnel/request",
json={"agent_id": AGENT_ID, "protocol": "ps", "domain": "ps"},
)
assert resp.status_code == 200, f"lease request failed: {resp.status_code} {resp.text}"
lease = resp.json()
tunnel_id = lease["tunnel_id"]
# 2) Connect to Socket.IO /tunnel namespace
sio = socketio.Client()
sio.connect(
HOST,
namespaces=["/tunnel"],
headers={"Authorization": f"Bearer {BEARER}"},
transports=["websocket"],
wait_timeout=10,
)
# 3) Join tunnel and open PS channel
join_resp = sio.call("join", {"tunnel_id": tunnel_id}, namespace="/tunnel", timeout=10)
assert join_resp.get("status") == "ok", f"join failed: {join_resp}"
open_resp = sio.call("ps_open", {"cols": 120, "rows": 32}, namespace="/tunnel", timeout=10)
assert not open_resp.get("error"), f"ps_open failed: {open_resp}"
# 4) Send a command
send_resp = sio.call("ps_send", {"data": 'Write-Host "Hello World"\r\n'}, namespace="/tunnel", timeout=10)
assert not send_resp.get("error"), f"ps_send failed: {send_resp}"
# 5) Poll for output
output_text = ""
deadline = time.time() + 15
while time.time() < deadline:
poll_resp = sio.call("ps_poll", {}, namespace="/tunnel", timeout=10)
if poll_resp.get("error"):
pytest.fail(f"ps_poll failed: {poll_resp}")
lines = poll_resp.get("output") or []
output_text += "".join(lines)
if "Hello World" in output_text:
break
time.sleep(0.5)
sio.disconnect()
assert "Hello World" in output_text, f"expected command output not found; got: {output_text!r}"

View File

@@ -391,11 +391,13 @@ class TunnelLeaseManager:
def expire_idle(self, *, now_ts: Optional[float] = None) -> List[TunnelLease]:
now = now_ts or _utc_ts()
expired: List[TunnelLease] = []
pending_timeout = min(self.grace_timeout_seconds, 300) # avoid long-lived pending locks
for lease in list(self._leases.values()):
if lease.state == "expired":
continue
idle_age = now - lease.last_activity_ts
pending_age = now - lease.created_at
if lease.state == "active" and idle_age >= lease.idle_timeout_seconds:
lease.mark_expired()
expired.append(lease)
@@ -409,6 +411,14 @@ class TunnelLeaseManager:
expired.append(lease)
self.release(lease.tunnel_id, reason="grace_expired")
continue
if lease.state == "pending":
hard_expiry = lease.expires_at or (lease.created_at + lease.grace_timeout_seconds)
if pending_age >= pending_timeout or (hard_expiry and now >= hard_expiry):
lease.mark_expired()
expired.append(lease)
self.release(lease.tunnel_id, reason="pending_timeout")
continue
return expired
def all_leases(self) -> Iterable[TunnelLease]:
@@ -591,6 +601,7 @@ class ReverseTunnelService:
)
lease.token = self.issue_token(lease)
self._spawn_port_listener(lease.assigned_port)
self._push_start_to_agent(lease)
self.audit_logger.info(
"lease_created tunnel_id=%s agent_id=%s domain=%s protocol=%s port=%s operator=%s",
lease.tunnel_id,
@@ -618,6 +629,35 @@ class ReverseTunnelService:
lease.expires_at = expires_at
return token
def _push_start_to_agent(self, lease: TunnelLease) -> None:
"""Notify the target agent about the new lease over Socket.IO (best-effort)."""
if not self._socketio:
return
payload = {
"tunnel_id": lease.tunnel_id,
"lease_id": lease.tunnel_id,
"agent_id": lease.agent_id,
"token": lease.token,
"port": lease.assigned_port,
"assigned_port": lease.assigned_port,
"protocol": lease.protocol,
"domain": lease.domain,
"idle_seconds": lease.idle_timeout_seconds,
"grace_seconds": lease.grace_timeout_seconds,
"heartbeat_seconds": self.heartbeat_seconds,
}
try:
self._socketio.emit("reverse_tunnel_start", payload, namespace="/")
self.audit_logger.info(
"lease_push_start tunnel_id=%s agent_id=%s port=%s",
lease.tunnel_id,
lease.agent_id,
lease.assigned_port,
)
except Exception:
self.logger.debug("Failed to emit reverse_tunnel_start for tunnel_id=%s", lease.tunnel_id, exc_info=True)
def lease_summary(self, lease: TunnelLease) -> Dict[str, object]:
return {
"tunnel_id": lease.tunnel_id,
@@ -874,6 +914,7 @@ class ReverseTunnelService:
cert = self.context.tls_cert_path or self.context.tls_bundle_path
key = self.context.tls_key_path
if not cert or not key:
self.audit_logger.info("tunnel_listener_ssl_missing cert=%s key=%s", cert, key)
return None
try:
ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
@@ -890,6 +931,7 @@ class ReverseTunnelService:
if port in self._port_servers:
return
ssl_ctx = self._build_ssl_context()
self.audit_logger.info("tunnel_listener_start port=%s ssl=%s", port, bool(ssl_ctx))
async def _handler(websocket, path):
await self._handle_agent_socket(websocket, path, port=port)
@@ -897,6 +939,7 @@ class ReverseTunnelService:
async def _start():
server = await ws_serve(_handler, host="0.0.0.0", port=port, ssl=ssl_ctx, max_size=None, ping_interval=None)
self._port_servers[port] = server
self.audit_logger.info("tunnel_listener_bound port=%s", port)
asyncio.run_coroutine_threadsafe(_start(), self._loop)
@@ -904,15 +947,24 @@ class ReverseTunnelService:
"""Handle agent tunnel socket on assigned port."""
tunnel_id = None
sock_log = self.audit_logger.getChild("agent_socket")
try:
peer = None
try:
peer = getattr(websocket, "remote_address", None)
except Exception:
peer = None
sock_log.info("agent_socket_open port=%s path=%s peer=%s", port, path, peer)
raw = await asyncio.wait_for(websocket.recv(), timeout=10)
frame = decode_frame(raw)
if frame.msg_type != MSG_CONNECT:
sock_log.info("agent_socket_first_frame_not_connect port=%s msg_type=%s", port, frame.msg_type)
await websocket.close()
return
try:
payload = json.loads(frame.payload.decode("utf-8"))
except Exception:
sock_log.info("agent_socket_connect_payload_decode_failed port=%s", port, exc_info=True)
await websocket.close()
return
tunnel_id = str(payload.get("tunnel_id") or "").strip()
@@ -920,23 +972,56 @@ class ReverseTunnelService:
token = payload.get("token") or ""
lease = self.lease_manager.get(tunnel_id)
if lease is None or lease.assigned_port != port:
sock_log.info(
"agent_socket_unknown_lease port=%s tunnel_id=%s assigned=%s expected=%s",
port,
tunnel_id,
lease.assigned_port if lease else None,
port,
)
await websocket.close()
return
# Token validation
self.validate_token(
token,
agent_id=agent_id,
tunnel_id=tunnel_id,
domain=lease.domain,
protocol=lease.protocol,
)
try:
self.validate_token(
token,
agent_id=agent_id,
tunnel_id=tunnel_id,
domain=lease.domain,
protocol=lease.protocol,
)
sock_log.info(
"agent_socket_token_valid port=%s tunnel_id=%s agent_id=%s domain=%s protocol=%s",
port,
tunnel_id,
agent_id,
lease.domain,
lease.protocol,
)
except Exception as exc:
sock_log.info(
"agent_socket_token_invalid port=%s tunnel_id=%s agent_id=%s error=%s",
port,
tunnel_id,
agent_id,
exc,
)
await websocket.close()
return
bridge = self.ensure_bridge(lease)
bridge.attach_agent(token)
self._agent_sockets[tunnel_id] = websocket
await websocket.send(heartbeat_frame(channel_id=0, is_ack=True).encode())
await websocket.send(TunnelFrame(msg_type=MSG_CONNECT_ACK, channel_id=0, payload=b"").encode())
sock_log.info(
"agent_socket_connected port=%s tunnel_id=%s agent_id=%s",
port,
tunnel_id,
agent_id,
)
async def _pump_to_operator():
sock_log_local = sock_log.getChild("recv")
while not websocket.closed:
try:
raw_msg = await websocket.recv()
@@ -945,14 +1030,23 @@ class ReverseTunnelService:
try:
recv_frame = decode_frame(raw_msg)
except Exception:
sock_log_local.info("agent_socket_frame_decode_failed tunnel_id=%s", tunnel_id, exc_info=True)
continue
self.lease_manager.touch(tunnel_id)
sock_log_local.info(
"agent_to_operator tunnel_id=%s msg_type=%s channel=%s payload_len=%s",
tunnel_id,
recv_frame.msg_type,
recv_frame.channel_id,
len(recv_frame.payload or b""),
)
try:
self._dispatch_agent_frame(tunnel_id, recv_frame)
except Exception:
pass
bridge.agent_to_operator(recv_frame)
async def _pump_to_agent():
sock_log_local = sock_log.getChild("send")
while not websocket.closed:
frame = bridge.next_for_agent()
if frame is None:
@@ -960,13 +1054,22 @@ class ReverseTunnelService:
continue
try:
await websocket.send(frame.encode())
sock_log_local.info(
"operator_to_agent tunnel_id=%s msg_type=%s channel=%s payload_len=%s",
tunnel_id,
frame.msg_type,
frame.channel_id,
len(frame.payload or b""),
)
except Exception:
break
async def _heartbeat():
sock_log_local = sock_log.getChild("heartbeat")
while not websocket.closed:
try:
await websocket.send(heartbeat_frame(channel_id=0).encode())
except Exception:
sock_log_local.info("heartbeat_send_failed tunnel_id=%s", tunnel_id, exc_info=True)
break
await asyncio.sleep(self.heartbeat_seconds)
@@ -975,8 +1078,18 @@ class ReverseTunnelService:
heart = asyncio.create_task(_heartbeat())
await asyncio.wait([consumer, producer, heart], return_when=asyncio.FIRST_COMPLETED)
except Exception:
self.logger.debug("Agent socket handler failed on port %s", port, exc_info=True)
sock_log.info("agent_socket_handler_failed port=%s tunnel_id=%s", port, tunnel_id, exc_info=True)
finally:
try:
sock_log.info(
"agent_socket_closed port=%s tunnel_id=%s code=%s reason=%s",
port,
tunnel_id,
getattr(websocket, "close_code", None),
getattr(websocket, "close_reason", None),
)
except Exception:
pass
if tunnel_id and tunnel_id in self._agent_sockets:
self._agent_sockets.pop(tunnel_id, None)
if tunnel_id:
@@ -997,7 +1110,7 @@ class ReverseTunnelService:
if server:
return server
lease = self.lease_manager.get(tunnel_id)
if lease is None or (lease.domain or "").lower() != "ps":
if lease is None:
return None
bridge = self.ensure_bridge(lease)
server = PowershellChannelServer(

View File

@@ -474,7 +474,7 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
return {"error": "ps_resize_failed"}
@socket_server.on("ps_poll", namespace=tunnel_namespace)
def _ws_ps_poll() -> Any:
def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused
server, tunnel_id, error = _require_ps_server()
if server is None:
return error

View File

@@ -31,6 +31,7 @@ import "prismjs/themes/prism-okaidia.css";
import Editor from "react-simple-code-editor";
import { AgGridReact } from "ag-grid-react";
import { ModuleRegistry, AllCommunityModule, themeQuartz } from "ag-grid-community";
import ReverseTunnelPowershell from "./ReverseTunnel/Powershell.jsx";
ModuleRegistry.registerModules([AllCommunityModule]);
@@ -63,7 +64,15 @@ const SECTION_HEIGHTS = {
network: 260,
};
const TOP_TABS = ["Device Summary", "Storage", "Memory", "Network", "Installed Software", "Activity History"];
const TOP_TABS = [
"Device Summary",
"Storage",
"Memory",
"Network",
"Installed Software",
"Activity History",
"Remote Shell",
];
const myTheme = themeQuartz.withParams({
accentColor: "#8b5cf6",
@@ -727,6 +736,17 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
);
const summary = details.summary || {};
const tunnelDevice = useMemo(
() => ({
...(device || {}),
...(agent || {}),
summary,
hostname: meta.hostname || summary.hostname || device?.hostname || agent?.hostname,
agent_id: meta.agentId || summary.agent_id || agent?.agent_id || agent?.id || device?.agent_id || device?.agent_guid,
agent_guid: meta.agentGuid || summary.agent_guid || device?.agent_guid || device?.guid || agent?.agent_guid || agent?.guid,
}),
[agent, device, meta.agentGuid, meta.agentId, meta.hostname, summary]
);
// Build a best-effort CPU display from summary fields
const cpuInfo = useMemo(() => {
const cpu = details.cpu || summary.cpu || {};
@@ -850,16 +870,20 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
[]
);
const formatScriptType = useCallback((raw) => {
const value = String(raw || "").toLowerCase();
if (value === "ansible") return "Ansible Playbook";
if (value === "reverse_tunnel") return "Reverse Tunnel";
return "Script";
}, []);
const historyColumnDefs = useMemo(
() => [
{
headerName: "Assembly",
headerName: "Activity",
field: "script_type",
minWidth: 180,
valueGetter: (params) =>
String(params.data?.script_type || "").toLowerCase() === "ansible"
? "Ansible Playbook"
: "Script",
valueGetter: (params) => formatScriptType(params.data?.script_type),
},
{
headerName: "Task",
@@ -891,7 +915,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
cellRenderer: "HistoryActionsCell",
},
],
[formatTimestamp]
[formatScriptType, formatTimestamp]
);
const MetricCard = ({ icon, title, main, sub, compact = false }) => (
@@ -1265,6 +1289,19 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
</Box>
);
const renderRemoteShellTab = () => (
<Box
sx={{
display: "flex",
flexDirection: "column",
flexGrow: 1,
minHeight: 0,
}}
>
<ReverseTunnelPowershell device={tunnelDevice} />
</Box>
);
const memoryRows = useMemo(
() =>
(details.memory || []).map((m, idx) => ({
@@ -1523,6 +1560,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
renderNetworkTab,
renderSoftware,
renderHistory,
renderRemoteShellTab,
];
const tabContent = (topTabRenderers[tab] || renderDeviceSummaryTab)();
@@ -1642,7 +1680,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
anchorEl={menuAnchor}
open={Boolean(menuAnchor)}
onClose={() => setMenuAnchor(null)}
PaperProps={{
PaperProps={{
sx: {
bgcolor: "rgba(8,12,24,0.96)",
color: "#fff",

View File

@@ -0,0 +1,704 @@
import React, { useCallback, useEffect, useMemo, useRef, useState } from "react";
import {
Box,
Typography,
Button,
Stack,
Chip,
TextField,
MenuItem,
IconButton,
Tooltip,
Alert,
LinearProgress,
} from "@mui/material";
import {
TerminalRounded as TerminalIcon,
PlayArrowRounded as PlayIcon,
StopRounded as StopIcon,
ContentCopy as CopyIcon,
RefreshRounded as RefreshIcon,
LanRounded as PortIcon,
SensorsRounded as ActivityIcon,
LinkRounded as LinkIcon,
} from "@mui/icons-material";
import { io } from "socket.io-client";
import Prism from "prismjs";
import "prismjs/components/prism-powershell";
import "prismjs/themes/prism-okaidia.css";
import Editor from "react-simple-code-editor";
const MAGIC_UI = {
panelBg: "rgba(7,11,24,0.92)",
panelBorder: "rgba(148, 163, 184, 0.35)",
textMuted: "#94a3b8",
textBright: "#e2e8f0",
accentA: "#7dd3fc",
accentB: "#c084fc",
accentC: "#34d399",
};
const gradientButtonSx = {
backgroundImage: "linear-gradient(135deg,#7dd3fc,#c084fc)",
color: "#0b1220",
borderRadius: 999,
textTransform: "none",
boxShadow: "0 10px 26px rgba(124,58,237,0.28)",
px: 2.2,
minWidth: 120,
"&:hover": {
backgroundImage: "linear-gradient(135deg,#86e1ff,#d1a6ff)",
boxShadow: "0 12px 34px rgba(124,58,237,0.38)",
},
};
const FRAME_HEADER_BYTES = 12; // version, msg_type, flags, reserved, channel_id(u32), length(u32)
const MSG_CLOSE = 0x08;
const CLOSE_AGENT_SHUTDOWN = 6;
const fontFamilyMono =
'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace';
function normalizeText(value) {
if (value == null) return "";
try {
return String(value).trim();
} catch {
return "";
}
}
function base64FromBytes(bytes) {
let binary = "";
bytes.forEach((b) => {
binary += String.fromCharCode(b);
});
return btoa(binary);
}
function buildCloseFrame(channelId = 1, code = CLOSE_AGENT_SHUTDOWN, reason = "operator_close") {
const payload = new TextEncoder().encode(JSON.stringify({ code, reason }));
const buffer = new ArrayBuffer(FRAME_HEADER_BYTES + payload.length);
const view = new DataView(buffer);
view.setUint8(0, 1); // version
view.setUint8(1, MSG_CLOSE);
view.setUint8(2, 0); // flags
view.setUint8(3, 0); // reserved
view.setUint32(4, channelId >>> 0, true);
view.setUint32(8, payload.length >>> 0, true);
new Uint8Array(buffer, FRAME_HEADER_BYTES).set(payload);
return base64FromBytes(new Uint8Array(buffer));
}
function highlightPs(code) {
try {
return Prism.highlight(code || "", Prism.languages.powershell, "powershell");
} catch {
return code || "";
}
}
export default function ReverseTunnelPowershell({ device }) {
const [connectionType, setConnectionType] = useState("ps");
const [tunnel, setTunnel] = useState(null);
const [sessionState, setSessionState] = useState("idle");
const [statusMessage, setStatusMessage] = useState("");
const [statusSeverity, setStatusSeverity] = useState("info");
const [output, setOutput] = useState("");
const [input, setInput] = useState("");
const [copyFlash, setCopyFlash] = useState(false);
const [, setPolling] = useState(false);
const [psStatus, setPsStatus] = useState({});
const socketRef = useRef(null);
const pollTimerRef = useRef(null);
const resizeTimerRef = useRef(null);
const terminalRef = useRef(null);
const joinRetryRef = useRef(null);
const joinAttemptsRef = useRef(0);
const hostname = useMemo(() => {
return (
normalizeText(device?.hostname) ||
normalizeText(device?.summary?.hostname) ||
normalizeText(device?.agent_hostname) ||
""
);
}, [device]);
const agentId = useMemo(() => {
return (
normalizeText(device?.agent_id) ||
normalizeText(device?.agentId) ||
normalizeText(device?.agent_guid) ||
normalizeText(device?.agentGuid) ||
normalizeText(device?.id) ||
normalizeText(device?.guid) ||
normalizeText(device?.summary?.agent_id) ||
""
);
}, [device]);
const resetState = useCallback(() => {
setTunnel(null);
setSessionState("idle");
setStatusMessage("");
setStatusSeverity("info");
setOutput("");
setInput("");
setPsStatus({});
}, []);
const disconnectSocket = useCallback(() => {
const socket = socketRef.current;
if (socket) {
socket.off();
socket.disconnect();
}
socketRef.current = null;
}, []);
const stopPolling = useCallback(() => {
if (pollTimerRef.current) {
clearTimeout(pollTimerRef.current);
pollTimerRef.current = null;
}
setPolling(false);
}, []);
useEffect(() => {
return () => {
stopPolling();
disconnectSocket();
if (joinRetryRef.current) {
clearTimeout(joinRetryRef.current);
joinRetryRef.current = null;
}
};
}, [disconnectSocket, stopPolling]);
const appendOutput = useCallback((text) => {
if (!text) return;
setOutput((prev) => {
const next = `${prev}${text}`;
const limit = 40000;
return next.length > limit ? next.slice(next.length - limit) : next;
});
}, []);
const scrollToBottom = useCallback(() => {
const el = terminalRef.current;
if (!el) return;
requestAnimationFrame(() => {
el.scrollTop = el.scrollHeight;
});
}, []);
useEffect(() => {
scrollToBottom();
}, [output, scrollToBottom]);
const handleCopy = async () => {
try {
await navigator.clipboard.writeText(output || "");
setCopyFlash(true);
setTimeout(() => setCopyFlash(false), 1200);
} catch {
setCopyFlash(false);
}
};
const measureTerminal = useCallback(() => {
const el = terminalRef.current;
if (!el) return { cols: 120, rows: 32 };
const width = el.clientWidth || 960;
const height = el.clientHeight || 460;
const charWidth = 8.2;
const charHeight = 18;
const cols = Math.max(20, Math.min(Math.floor(width / charWidth), 300));
const rows = Math.max(10, Math.min(Math.floor(height / charHeight), 200));
return { cols, rows };
}, []);
const emitAsync = useCallback((socket, event, payload, timeoutMs = 4000) => {
return new Promise((resolve) => {
let settled = false;
const timer = setTimeout(() => {
if (settled) return;
settled = true;
resolve({ error: "timeout" });
}, timeoutMs);
socket.emit(event, payload, (resp) => {
if (settled) return;
settled = true;
clearTimeout(timer);
resolve(resp || {});
});
});
}, []);
const pollLoop = useCallback(
(socket, tunnelId) => {
if (!socket || !tunnelId) return;
setPolling(true);
pollTimerRef.current = setTimeout(async () => {
const resp = await emitAsync(socket, "ps_poll", {});
if (resp?.error) {
if (resp.error === "ps_unsupported") {
setStatusSeverity("info");
setStatusMessage("PowerShell channel warming up...");
} else {
setStatusSeverity("warning");
setStatusMessage(resp.error);
}
}
if (Array.isArray(resp?.output) && resp.output.length) {
appendOutput(resp.output.join(""));
}
if (resp?.status) {
setPsStatus(resp.status);
if (resp.status.closed) {
setSessionState("closed");
setStatusSeverity("warning");
setStatusMessage(resp.status.close_reason || "Session closed");
stopPolling();
return;
}
if (resp.status.ack) {
setSessionState("connected");
}
}
pollLoop(socket, tunnelId);
}, 520);
},
[appendOutput, emitAsync, stopPolling]
);
const handleDisconnect = useCallback(() => {
const socket = socketRef.current;
const tunnelId = tunnel?.tunnel_id;
if (joinRetryRef.current) {
clearTimeout(joinRetryRef.current);
joinRetryRef.current = null;
}
joinAttemptsRef.current = 0;
if (socket && tunnelId) {
const frame = buildCloseFrame(1, CLOSE_AGENT_SHUTDOWN, "operator_close");
socket.emit("send", { frame });
}
stopPolling();
disconnectSocket();
setSessionState("closed");
setStatusSeverity("info");
setStatusMessage("Session closed by operator.");
}, [disconnectSocket, stopPolling, tunnel?.tunnel_id]);
const handleResize = useCallback(() => {
if (!socketRef.current || sessionState === "idle") return;
const dims = measureTerminal();
socketRef.current.emit("ps_resize", dims);
}, [measureTerminal, sessionState]);
useEffect(() => {
const observer =
typeof ResizeObserver !== "undefined"
? new ResizeObserver(() => {
if (resizeTimerRef.current) clearTimeout(resizeTimerRef.current);
resizeTimerRef.current = setTimeout(() => handleResize(), 200);
})
: null;
const el = terminalRef.current;
if (observer && el) observer.observe(el);
const onWinResize = () => handleResize();
window.addEventListener("resize", onWinResize);
return () => {
window.removeEventListener("resize", onWinResize);
if (observer && el) observer.unobserve(el);
};
}, [handleResize]);
const connectSocket = useCallback(
(lease, { isRetry = false } = {}) => {
if (!lease?.tunnel_id) return;
if (joinRetryRef.current) {
clearTimeout(joinRetryRef.current);
joinRetryRef.current = null;
}
if (!isRetry) {
joinAttemptsRef.current = 0;
}
disconnectSocket();
stopPolling();
setSessionState("waiting");
const socket = io(`${window.location.origin}/tunnel`, { transports: ["websocket"] });
socketRef.current = socket;
socket.on("connect_error", () => {
setStatusSeverity("warning");
setStatusMessage("Tunnel namespace unavailable.");
});
socket.on("disconnect", () => {
stopPolling();
if (sessionState !== "closed") {
setSessionState("disconnected");
setStatusSeverity("warning");
setStatusMessage("Socket disconnected.");
}
});
socket.on("connect", async () => {
setStatusSeverity("info");
setStatusMessage("Joining tunnel...");
const joinResp = await emitAsync(socket, "join", { tunnel_id: lease.tunnel_id });
if (joinResp?.error) {
if (joinResp.error === "unknown_tunnel") {
setSessionState("waiting_agent");
setStatusSeverity("info");
setStatusMessage("Waiting for agent to establish tunnel...");
joinAttemptsRef.current += 1;
const attempt = joinAttemptsRef.current;
if (attempt <= 15) {
joinRetryRef.current = setTimeout(() => connectSocket(lease, { isRetry: true }), 1000);
} else {
setSessionState("error");
setTunnel(null);
setStatusSeverity("warning");
setStatusMessage("Agent did not attach to tunnel (timeout). Try Connect again.");
}
} else {
setSessionState("error");
setStatusSeverity("error");
setStatusMessage(joinResp.error);
}
return;
}
const dims = measureTerminal();
const openResp = await emitAsync(socket, "ps_open", dims);
if (openResp?.error && openResp.error === "ps_unsupported") {
setStatusSeverity("info");
setStatusMessage("PowerShell channel warming up...");
}
appendOutput("");
setStatusMessage("Attached — waiting for agent to acknowledge...");
setSessionState("waiting_agent");
pollLoop(socket, lease.tunnel_id);
handleResize();
});
},
[appendOutput, disconnectSocket, emitAsync, handleResize, measureTerminal, pollLoop, sessionState, stopPolling]
);
const requestTunnel = useCallback(async () => {
if (tunnel && sessionState !== "closed" && sessionState !== "idle") {
setStatusSeverity("info");
setStatusMessage("Re-attaching to existing tunnel...");
connectSocket(tunnel);
return;
}
if (!agentId) {
setStatusSeverity("warning");
setStatusMessage("Agent ID is required to request a tunnel.");
return;
}
if (connectionType !== "ps") {
setStatusSeverity("warning");
setStatusMessage("Only PowerShell is supported right now.");
return;
}
resetState();
setSessionState("requesting");
setStatusSeverity("info");
setStatusMessage("Requesting tunnel lease...");
try {
const resp = await fetch("/api/tunnel/request", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ agent_id: agentId, protocol: "ps", domain: "ps" }),
});
const data = await resp.json().catch(() => ({}));
if (!resp.ok) {
const err = data?.error || `HTTP ${resp.status}`;
setSessionState("error");
setStatusSeverity(err === "domain_limit" ? "warning" : "error");
setStatusMessage(
err === "domain_limit"
? "PowerShell session already active for this agent. Try again after it closes."
: err
);
return;
}
setTunnel(data);
setStatusMessage("Lease issued. Waiting for agent to connect...");
setSessionState("lease_issued");
connectSocket(data);
} catch (e) {
setSessionState("error");
setStatusSeverity("error");
setStatusMessage(e?.message || "Failed to request tunnel");
}
}, [agentId, connectSocket, connectionType, resetState]);
const handleSend = useCallback(
async (text) => {
const socket = socketRef.current;
if (!socket) return;
const payload = `${text}${text.endsWith("\n") ? "" : "\r\n"}`;
appendOutput(`\nPS> ${text}\n`);
setInput("");
const resp = await emitAsync(socket, "ps_send", { data: payload });
if (resp?.error) {
setStatusSeverity("warning");
setStatusMessage(resp.error);
}
},
[appendOutput, emitAsync]
);
const isConnected = sessionState === "connected" || psStatus?.ack;
const isBusy =
sessionState === "requesting" ||
sessionState === "waiting" ||
sessionState === "waiting_agent" ||
sessionState === "lease_issued";
const canStart = Boolean(agentId) && !isBusy && !isConnected;
const sessionChips = [
{
label: isConnected ? "Connected" : sessionState === "idle" ? "Idle" : sessionState.replace(/_/g, " "),
color: isConnected ? MAGIC_UI.accentC : MAGIC_UI.accentA,
icon: <ActivityIcon sx={{ fontSize: 18 }} />,
},
tunnel?.tunnel_id
? {
label: `Tunnel ${tunnel.tunnel_id.slice(0, 8)}`,
color: MAGIC_UI.accentB,
icon: <LinkIcon sx={{ fontSize: 18 }} />,
}
: null,
tunnel?.port
? {
label: `Port ${tunnel.port}`,
color: MAGIC_UI.accentA,
icon: <PortIcon sx={{ fontSize: 18 }} />,
}
: null,
].filter(Boolean);
return (
<Box sx={{ display: "flex", flexDirection: "column", gap: 1.5, flexGrow: 1, minHeight: 0 }}>
<Box
sx={{
border: `1px solid ${MAGIC_UI.panelBorder}`,
borderRadius: 2,
background: "linear-gradient(120deg, rgba(12,18,35,0.9), rgba(8,12,24,0.92))",
boxShadow: "0 18px 45px rgba(2,6,23,0.6)",
p: { xs: 1.5, md: 2 },
}}
>
<Stack
direction={{ xs: "column", md: "row" }}
spacing={1.5}
alignItems={{ xs: "flex-start", md: "center" }}
justifyContent="space-between"
>
<Stack direction="row" spacing={1} alignItems="center" flexWrap="wrap">
<TerminalIcon sx={{ fontSize: 22, color: MAGIC_UI.accentA }} />
<Typography variant="h6" sx={{ fontWeight: 700, letterSpacing: 0.3 }}>
Remote Shell
</Typography>
{hostname ? (
<Chip
size="small"
label={hostname}
sx={{
background: "rgba(12,18,35,0.8)",
color: MAGIC_UI.textBright,
border: `1px solid ${MAGIC_UI.panelBorder}`,
}}
/>
) : null}
{agentId ? (
<Chip
size="small"
label={`Agent ${agentId}`}
sx={{
background: "rgba(12,18,35,0.8)",
color: MAGIC_UI.textMuted,
border: `1px solid ${MAGIC_UI.panelBorder}`,
}}
/>
) : null}
</Stack>
<Stack direction={{ xs: "column", sm: "row" }} spacing={1} alignItems="center">
<TextField
select
label="Connection Type"
size="small"
value={connectionType}
onChange={(e) => setConnectionType(e.target.value)}
sx={{
minWidth: 180,
"& .MuiInputBase-root": {
backgroundColor: "rgba(12,18,35,0.85)",
color: MAGIC_UI.textBright,
borderRadius: 1.5,
},
"& fieldset": { borderColor: MAGIC_UI.panelBorder },
"&:hover fieldset": { borderColor: MAGIC_UI.accentA },
}}
>
<MenuItem value="ps">PowerShell</MenuItem>
</TextField>
<Tooltip title={isConnected ? "Disconnect session" : "Connect to agent"}>
<span>
<Button
size="small"
startIcon={isConnected ? <StopIcon /> : <PlayIcon />}
sx={gradientButtonSx}
disabled={!isConnected && !canStart}
onClick={isConnected ? handleDisconnect : requestTunnel}
>
{isConnected ? "Disconnect" : "Connect"}
</Button>
</span>
</Tooltip>
</Stack>
</Stack>
<Stack direction={{ xs: "column", md: "row" }} spacing={1} sx={{ mt: 1, flexWrap: "wrap" }}>
{sessionChips.map((chip) => (
<Chip
key={chip.label}
icon={chip.icon}
label={chip.label}
sx={{
background: "rgba(12,18,35,0.85)",
color: chip.color,
border: `1px solid ${chip.color}44`,
fontWeight: 600,
}}
/>
))}
</Stack>
</Box>
{statusMessage ? (
<Alert
severity={statusSeverity}
sx={{
borderRadius: 2,
backgroundColor: "rgba(8,12,24,0.9)",
border: `1px solid ${MAGIC_UI.panelBorder}`,
color: MAGIC_UI.textBright,
}}
>
{statusMessage}
</Alert>
) : null}
<Box
sx={{
flexGrow: 1,
minHeight: 320,
display: "flex",
flexDirection: "column",
borderRadius: 3,
border: `1px solid ${MAGIC_UI.panelBorder}`,
background:
"linear-gradient(145deg, rgba(8,12,24,0.94), rgba(10,16,30,0.9)), radial-gradient(circle at 20% 20%, rgba(125,211,252,0.08), transparent 35%)",
boxShadow: "0 25px 80px rgba(2,6,23,0.85)",
overflow: "hidden",
}}
>
{isBusy ? <LinearProgress color="info" sx={{ height: 3 }} /> : null}
<Box
ref={terminalRef}
sx={{
flexGrow: 1,
minHeight: 240,
maxHeight: "100%",
overflow: "auto",
position: "relative",
p: 2,
"& pre": {
margin: 0,
fontFamily: fontFamilyMono,
fontSize: 13,
lineHeight: 1.5,
color: "#e6edf3",
},
}}
>
<Editor
value={output}
onValueChange={() => {}}
highlight={highlightPs}
padding={12}
readOnly
style={{
minHeight: "100%",
background: "transparent",
color: "#e6edf3",
fontFamily: fontFamilyMono,
fontSize: 13,
}}
/>
<Box sx={{ position: "absolute", top: 8, right: 8, display: "flex", gap: 0.5 }}>
<Tooltip title="Copy output">
<IconButton size="small" onClick={handleCopy} sx={{ color: copyFlash ? MAGIC_UI.accentC : MAGIC_UI.textMuted }}>
<CopyIcon fontSize="small" />
</IconButton>
</Tooltip>
<Tooltip title="Clear output">
<IconButton
size="small"
onClick={() => setOutput("")}
sx={{ color: MAGIC_UI.textMuted }}
>
<RefreshIcon fontSize="small" />
</IconButton>
</Tooltip>
</Box>
</Box>
<Box
sx={{
borderTop: `1px solid ${MAGIC_UI.panelBorder}`,
p: 1.5,
background: "rgba(6,10,20,0.92)",
}}
>
<TextField
fullWidth
size="small"
value={input}
disabled={!isConnected}
placeholder={
isConnected
? "Enter PowerShell command and press Enter"
: "Connect to start sending commands"
}
onChange={(e) => setInput(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
const text = input.trim();
if (text) handleSend(text);
}
}}
InputProps={{
sx: {
backgroundColor: "rgba(12,18,35,0.9)",
color: "#e2e8f0",
borderRadius: 2,
"& fieldset": { borderColor: "rgba(148,163,184,0.45)" },
"&:hover fieldset": { borderColor: MAGIC_UI.accentA },
},
}}
/>
</Box>
</Box>
</Box>
);
}