Overhaul of VPN Codebase

This commit is contained in:
2025-12-18 01:35:03 -07:00
parent 2f81061a1b
commit 6ceb59f717
56 changed files with 1786 additions and 4778 deletions

View File

@@ -1,90 +0,0 @@
# ======================================================
# Data\Engine\Unit_Tests\test_reverse_tunnel.py
# Description: Validates reverse tunnel lease API basics (allocation, token contents, and domain limit).
#
# API Endpoints (if applicable):
# - POST /api/tunnel/request
# ======================================================
from __future__ import annotations
import base64
import json
import pytest
from .conftest import EngineTestHarness
def _client_with_admin_session(harness: EngineTestHarness):
client = harness.app.test_client()
with client.session_transaction() as sess:
sess["username"] = "admin"
sess["role"] = "Admin"
return client
def _decode_token_segment(token: str) -> dict:
"""Decode the unsigned payload segment from the tunnel token."""
if not token:
return {}
segment = token.split(".")[0]
padding = "=" * (-len(segment) % 4)
raw = base64.urlsafe_b64decode(segment + padding)
try:
return json.loads(raw.decode("utf-8"))
except Exception:
return {}
@pytest.mark.parametrize("agent_id", ["test-device-agent"])
def test_tunnel_request_happy_path(engine_harness: EngineTestHarness, agent_id: str) -> None:
client = _client_with_admin_session(engine_harness)
resp = client.post(
"/api/tunnel/request",
json={"agent_id": agent_id, "protocol": "ps", "domain": "ps"},
)
assert resp.status_code == 200
payload = resp.get_json()
assert payload["agent_id"] == agent_id
assert payload["protocol"] == "ps"
assert payload["domain"] == "ps"
assert isinstance(payload["port"], int) and payload["port"] >= 30000
assert payload.get("token")
claims = _decode_token_segment(payload["token"])
assert claims.get("agent_id") == agent_id
assert claims.get("protocol") == "ps"
assert claims.get("domain") == "ps"
assert claims.get("tunnel_id") == payload["tunnel_id"]
assert claims.get("assigned_port") == payload["port"]
def test_tunnel_request_domain_limit(engine_harness: EngineTestHarness) -> None:
client = _client_with_admin_session(engine_harness)
first = client.post(
"/api/tunnel/request",
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
)
assert first.status_code == 200
second = client.post(
"/api/tunnel/request",
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
)
assert second.status_code == 409
data = second.get_json()
assert data.get("error") == "domain_limit"
def test_tunnel_request_includes_timeouts(engine_harness: EngineTestHarness) -> None:
client = _client_with_admin_session(engine_harness)
resp = client.post(
"/api/tunnel/request",
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
)
assert resp.status_code == 200
payload = resp.get_json()
assert payload.get("idle_seconds") and payload["idle_seconds"] > 0
assert payload.get("grace_seconds") and payload["grace_seconds"] > 0
assert payload.get("expires_at") and int(payload["expires_at"]) > 0

View File

@@ -1,101 +0,0 @@
# ======================================================
# Data\Engine\Unit_Tests\test_reverse_tunnel_integration.py
# Description: Integration test that exercises a full reverse tunnel PowerShell round-trip
# against a running Engine + Agent (requires live services).
#
# Requirements:
# - Environment variables must be set to point at a live Engine + Agent:
# TUNNEL_TEST_HOST (e.g., https://localhost:5000)
# TUNNEL_TEST_AGENT_ID (agent_id/agent_guid for the target device)
# TUNNEL_TEST_BEARER (Authorization bearer token for an admin/operator)
# - A live Agent must be reachable and allowed to establish the reverse tunnel.
# - TLS verification can be controlled via TUNNEL_TEST_VERIFY ("false" to disable).
#
# API Endpoints (if applicable):
# - POST /api/tunnel/request
# - Socket.IO namespace /tunnel (join, ps_open, ps_send, ps_poll)
# ======================================================
from __future__ import annotations
import os
import time
import pytest
import requests
import socketio
HOST = os.environ.get("TUNNEL_TEST_HOST", "").strip()
AGENT_ID = os.environ.get("TUNNEL_TEST_AGENT_ID", "").strip()
BEARER = os.environ.get("TUNNEL_TEST_BEARER", "").strip()
VERIFY_ENV = os.environ.get("TUNNEL_TEST_VERIFY", "").strip().lower()
VERIFY = False if VERIFY_ENV in {"false", "0", "no"} else True
SKIP_MSG = (
"Live tunnel test skipped (set TUNNEL_TEST_HOST, TUNNEL_TEST_AGENT_ID, TUNNEL_TEST_BEARER to run)"
)
def _require_env() -> None:
if not HOST or not AGENT_ID or not BEARER:
pytest.skip(SKIP_MSG)
def _make_session() -> requests.Session:
sess = requests.Session()
sess.verify = VERIFY
sess.headers.update({"Authorization": f"Bearer {BEARER}"})
return sess
def test_reverse_tunnel_powershell_roundtrip() -> None:
_require_env()
sess = _make_session()
# 1) Request a tunnel lease
resp = sess.post(
f"{HOST}/api/tunnel/request",
json={"agent_id": AGENT_ID, "protocol": "ps", "domain": "remote-interactive-shell"},
)
assert resp.status_code == 200, f"lease request failed: {resp.status_code} {resp.text}"
lease = resp.json()
tunnel_id = lease["tunnel_id"]
# 2) Connect to Socket.IO /tunnel namespace
sio = socketio.Client()
sio.connect(
HOST,
namespaces=["/tunnel"],
headers={"Authorization": f"Bearer {BEARER}"},
transports=["websocket"],
wait_timeout=10,
)
# 3) Join tunnel and open PS channel
join_resp = sio.call("join", {"tunnel_id": tunnel_id}, namespace="/tunnel", timeout=10)
assert join_resp.get("status") == "ok", f"join failed: {join_resp}"
open_resp = sio.call("ps_open", {"cols": 120, "rows": 32}, namespace="/tunnel", timeout=10)
assert not open_resp.get("error"), f"ps_open failed: {open_resp}"
# 4) Send a command
send_resp = sio.call("ps_send", {"data": 'Write-Host "Hello World"\r\n'}, namespace="/tunnel", timeout=10)
assert not send_resp.get("error"), f"ps_send failed: {send_resp}"
# 5) Poll for output
output_text = ""
deadline = time.time() + 15
while time.time() < deadline:
poll_resp = sio.call("ps_poll", {}, namespace="/tunnel", timeout=10)
if poll_resp.get("error"):
pytest.fail(f"ps_poll failed: {poll_resp}")
lines = poll_resp.get("output") or []
output_text += "".join(lines)
if "Hello World" in output_text:
break
time.sleep(0.5)
sio.disconnect()
assert "Hello World" in output_text, f"expected command output not found; got: {output_text!r}"

View File

@@ -77,17 +77,12 @@ LOG_ROOT = PROJECT_ROOT / "Engine" / "Logs"
LOG_FILE_PATH = LOG_ROOT / "engine.log"
ERROR_LOG_FILE_PATH = LOG_ROOT / "error.log"
API_LOG_FILE_PATH = LOG_ROOT / "api.log"
REVERSE_TUNNEL_LOG_FILE_PATH = LOG_ROOT / "reverse_tunnel.log"
DEFAULT_TUNNEL_FIXED_PORT = 8443
DEFAULT_TUNNEL_PORT_RANGE = (30000, 40000)
DEFAULT_TUNNEL_IDLE_TIMEOUT_SECONDS = 3600
DEFAULT_TUNNEL_GRACE_TIMEOUT_SECONDS = 3600
DEFAULT_TUNNEL_HEARTBEAT_INTERVAL_SECONDS = 20
VPN_TUNNEL_LOG_FILE_PATH = LOG_ROOT / "reverse_tunnel.log"
DEFAULT_WIREGUARD_PORT = 30000
DEFAULT_WIREGUARD_ENGINE_VIRTUAL_IP = "10.255.0.1/32"
DEFAULT_WIREGUARD_PEER_NETWORK = "10.255.0.0/24"
DEFAULT_WIREGUARD_ACL_WINDOWS = (3389, 5985, 5986, 5900, 3478)
DEFAULT_WIREGUARD_SHELL_PORT = 47001
DEFAULT_WIREGUARD_ACL_WINDOWS = (3389, 5985, 5986, 5900, 3478, DEFAULT_WIREGUARD_SHELL_PORT)
VPN_SERVER_CERT_ROOT = PROJECT_ROOT / "Engine" / "Certificates" / "VPN_Server"
@@ -282,18 +277,14 @@ class EngineSettings:
error_log_file: str
api_log_file: str
api_groups: Tuple[str, ...]
reverse_tunnel_fixed_port: int
reverse_tunnel_port_range: Tuple[int, int]
reverse_tunnel_idle_timeout_seconds: int
reverse_tunnel_grace_timeout_seconds: int
reverse_tunnel_heartbeat_seconds: int
reverse_tunnel_log_file: str
vpn_tunnel_log_file: str
wireguard_port: int
wireguard_engine_virtual_ip: str
wireguard_peer_network: str
wireguard_server_private_key_path: str
wireguard_server_public_key_path: str
wireguard_acl_allowlist_windows: Tuple[int, ...]
wireguard_shell_port: int
raw: MutableMapping[str, Any] = field(default_factory=dict)
def to_flask_config(self) -> MutableMapping[str, Any]:
@@ -390,10 +381,14 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine
api_log_file = str(runtime_config.get("API_LOG_FILE") or API_LOG_FILE_PATH)
_ensure_parent(Path(api_log_file))
reverse_tunnel_log_file = str(
runtime_config.get("REVERSE_TUNNEL_LOG_FILE") or REVERSE_TUNNEL_LOG_FILE_PATH
vpn_tunnel_log_file = str(
runtime_config.get("VPN_TUNNEL_LOG_FILE")
or runtime_config.get("WIREGUARD_LOG_FILE")
or os.environ.get("BOREALIS_VPN_TUNNEL_LOG_FILE")
or os.environ.get("BOREALIS_WIREGUARD_LOG_FILE")
or VPN_TUNNEL_LOG_FILE_PATH
)
_ensure_parent(Path(reverse_tunnel_log_file))
_ensure_parent(Path(vpn_tunnel_log_file))
wireguard_port = _parse_int(
runtime_config.get("WIREGUARD_PORT") or os.environ.get("BOREALIS_WIREGUARD_PORT"),
@@ -416,6 +411,13 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine
or os.environ.get("BOREALIS_WIREGUARD_WINDOWS_ALLOWLIST"),
default=DEFAULT_WIREGUARD_ACL_WINDOWS,
)
wireguard_shell_port = _parse_int(
runtime_config.get("WIREGUARD_SHELL_PORT")
or os.environ.get("BOREALIS_WIREGUARD_SHELL_PORT"),
default=DEFAULT_WIREGUARD_SHELL_PORT,
minimum=1,
maximum=65535,
)
wireguard_key_root = Path(
runtime_config.get("WIREGUARD_KEY_ROOT")
or os.environ.get("BOREALIS_WIREGUARD_KEY_ROOT")
@@ -440,35 +442,6 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine
"scheduled_jobs",
)
tunnel_fixed_port = _parse_int(
runtime_config.get("TUNNEL_FIXED_PORT") or os.environ.get("BOREALIS_TUNNEL_FIXED_PORT"),
default=DEFAULT_TUNNEL_FIXED_PORT,
minimum=1,
maximum=65535,
)
tunnel_port_range = _parse_port_range(
runtime_config.get("TUNNEL_PORT_RANGE") or os.environ.get("BOREALIS_TUNNEL_PORT_RANGE"),
default=DEFAULT_TUNNEL_PORT_RANGE,
)
tunnel_idle_timeout_seconds = _parse_int(
runtime_config.get("TUNNEL_IDLE_TIMEOUT_SECONDS")
or os.environ.get("BOREALIS_TUNNEL_IDLE_TIMEOUT_SECONDS"),
default=DEFAULT_TUNNEL_IDLE_TIMEOUT_SECONDS,
minimum=60,
)
tunnel_grace_timeout_seconds = _parse_int(
runtime_config.get("TUNNEL_GRACE_TIMEOUT_SECONDS")
or os.environ.get("BOREALIS_TUNNEL_GRACE_TIMEOUT_SECONDS"),
default=DEFAULT_TUNNEL_GRACE_TIMEOUT_SECONDS,
minimum=60,
)
tunnel_heartbeat_seconds = _parse_int(
runtime_config.get("TUNNEL_HEARTBEAT_SECONDS")
or os.environ.get("BOREALIS_TUNNEL_HEARTBEAT_SECONDS"),
default=DEFAULT_TUNNEL_HEARTBEAT_INTERVAL_SECONDS,
minimum=5,
)
settings = EngineSettings(
database_path=database_path,
static_folder=static_folder,
@@ -484,18 +457,14 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine
error_log_file=str(error_log_file),
api_log_file=str(api_log_file),
api_groups=api_groups,
reverse_tunnel_fixed_port=tunnel_fixed_port,
reverse_tunnel_port_range=tunnel_port_range,
reverse_tunnel_idle_timeout_seconds=tunnel_idle_timeout_seconds,
reverse_tunnel_grace_timeout_seconds=tunnel_grace_timeout_seconds,
reverse_tunnel_heartbeat_seconds=tunnel_heartbeat_seconds,
reverse_tunnel_log_file=reverse_tunnel_log_file,
vpn_tunnel_log_file=vpn_tunnel_log_file,
wireguard_port=wireguard_port,
wireguard_engine_virtual_ip=wireguard_engine_virtual_ip,
wireguard_peer_network=wireguard_peer_network,
wireguard_server_private_key_path=wireguard_server_private_key_path,
wireguard_server_public_key_path=wireguard_server_public_key_path,
wireguard_acl_allowlist_windows=wireguard_acl_allowlist_windows,
wireguard_shell_port=wireguard_shell_port,
raw=runtime_config,
)
return settings

View File

@@ -2,6 +2,8 @@
# Data\Engine\database_migrations.py
# Description: Provides schema evolution helpers for the Engine sqlite
# database without importing the legacy ``Modules`` package.
#
# API Endpoints (if applicable): None
# ======================================================
"""Engine database schema migration helpers."""
@@ -24,6 +26,7 @@ def apply_all(conn: sqlite3.Connection) -> None:
_ensure_devices_table(conn)
_ensure_device_aux_tables(conn)
_ensure_device_vpn_config_table(conn)
_ensure_refresh_token_table(conn)
_ensure_install_code_table(conn)
_ensure_install_code_persistence_table(conn)
@@ -112,6 +115,20 @@ def _ensure_device_aux_tables(conn: sqlite3.Connection) -> None:
)
def _ensure_device_vpn_config_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS device_vpn_config (
agent_id TEXT PRIMARY KEY,
allowed_ports TEXT,
updated_at TEXT,
updated_by TEXT
)
"""
)
def _ensure_refresh_token_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(

View File

@@ -120,18 +120,14 @@ class EngineContext:
config: Mapping[str, Any]
api_groups: Sequence[str]
api_log_path: str
reverse_tunnel_fixed_port: int
reverse_tunnel_port_range: Tuple[int, int]
reverse_tunnel_idle_timeout_seconds: int
reverse_tunnel_grace_timeout_seconds: int
reverse_tunnel_heartbeat_seconds: int
reverse_tunnel_log_path: str
vpn_tunnel_log_path: str
wireguard_port: int
wireguard_engine_virtual_ip: str
wireguard_peer_network: str
wireguard_server_private_key_path: str
wireguard_server_public_key_path: str
wireguard_acl_allowlist_windows: Tuple[int, ...]
wireguard_shell_port: int
wireguard_server_manager: Optional[Any] = None
assembly_cache: Optional[Any] = None
@@ -151,18 +147,14 @@ def _build_engine_context(settings: EngineSettings, logger: logging.Logger) -> E
config=settings.as_dict(),
api_groups=settings.api_groups,
api_log_path=settings.api_log_file,
reverse_tunnel_fixed_port=settings.reverse_tunnel_fixed_port,
reverse_tunnel_port_range=settings.reverse_tunnel_port_range,
reverse_tunnel_idle_timeout_seconds=settings.reverse_tunnel_idle_timeout_seconds,
reverse_tunnel_grace_timeout_seconds=settings.reverse_tunnel_grace_timeout_seconds,
reverse_tunnel_heartbeat_seconds=settings.reverse_tunnel_heartbeat_seconds,
reverse_tunnel_log_path=settings.reverse_tunnel_log_file,
vpn_tunnel_log_path=settings.vpn_tunnel_log_file,
wireguard_port=settings.wireguard_port,
wireguard_engine_virtual_ip=settings.wireguard_engine_virtual_ip,
wireguard_peer_network=settings.wireguard_peer_network,
wireguard_server_private_key_path=settings.wireguard_server_private_key_path,
wireguard_server_public_key_path=settings.wireguard_server_public_key_path,
wireguard_acl_allowlist_windows=settings.wireguard_acl_allowlist_windows,
wireguard_shell_port=settings.wireguard_shell_port,
assembly_cache=None,
)
@@ -249,7 +241,7 @@ def create_app(config: Optional[Mapping[str, Any]] = None) -> Tuple[Flask, Socke
private_key_path=Path(context.wireguard_server_private_key_path),
public_key_path=Path(context.wireguard_server_public_key_path),
acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows),
log_path=Path(context.reverse_tunnel_log_path),
log_path=Path(context.vpn_tunnel_log_path),
)
context.wireguard_server_manager = WireGuardServerManager(wg_config)
except Exception:
@@ -325,7 +317,7 @@ def register_engine_api(app: Flask, *, config: Optional[Mapping[str, Any]] = Non
private_key_path=Path(context.wireguard_server_private_key_path),
public_key_path=Path(context.wireguard_server_public_key_path),
acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows),
log_path=Path(context.reverse_tunnel_log_path),
log_path=Path(context.vpn_tunnel_log_path),
)
context.wireguard_server_manager = WireGuardServerManager(wg_config)
except Exception:

View File

@@ -9,6 +9,8 @@
# - GET /api/devices/<guid> (Token Authenticated) - Retrieves a single device record by GUID, including summary fields.
# - GET /api/device/details/<hostname> (Token Authenticated) - Returns full device details keyed by hostname.
# - POST /api/device/description/<hostname> (Token Authenticated) - Updates the human-readable description for a device.
# - GET /api/device/vpn_config/<agent_id> (Token Authenticated) - Returns per-device VPN allowed port settings.
# - PUT /api/device/vpn_config/<agent_id> (Token Authenticated) - Updates per-device VPN allowed port settings.
# - GET /api/device_list_views (Token Authenticated) - Lists saved device table view definitions.
# - GET /api/device_list_views/<int:view_id> (Token Authenticated) - Retrieves a specific saved device table view definition.
# - POST /api/device_list_views (Token Authenticated) - Creates a custom device list view for the signed-in operator.
@@ -426,6 +428,131 @@ class DeviceManagementService:
return None
return None
def _parse_ports(self, raw: Any) -> List[int]:
ports: List[int] = []
if isinstance(raw, str):
parts = [part.strip() for part in raw.split(",") if part.strip()]
elif isinstance(raw, list):
parts = raw
else:
parts = []
for part in parts:
try:
value = int(part)
except Exception:
continue
if 1 <= value <= 65535:
ports.append(value)
return list(dict.fromkeys(ports))
def _default_vpn_ports(self, os_name: Optional[str]) -> List[int]:
ports = list(self.adapters.context.wireguard_acl_allowlist_windows or [])
os_text = (os_name or "").strip().lower()
if os_text and "windows" not in os_text:
baseline = {5900, 3478}
filtered = [p for p in ports if p in baseline]
return filtered or ports
return ports
def get_vpn_config(self, agent_id: str) -> Tuple[Dict[str, Any], int]:
agent_id = (agent_id or "").strip()
if not agent_id:
return {"error": "agent_id_required"}, 400
default_ports: List[int] = []
shell_port = int(self.adapters.context.wireguard_shell_port)
try:
conn = self._db_conn()
cur = conn.cursor()
os_name = ""
try:
cur.execute(
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
(agent_id,),
)
row = cur.fetchone()
if row and row[0]:
os_name = str(row[0])
except Exception:
os_name = ""
default_ports = self._default_vpn_ports(os_name)
cur.execute(
"SELECT allowed_ports, updated_at, updated_by FROM device_vpn_config WHERE agent_id=?",
(agent_id,),
)
row = cur.fetchone()
if not row:
return {
"agent_id": agent_id,
"allowed_ports": default_ports,
"default_ports": default_ports,
"shell_port": shell_port,
"source": "default",
}, 200
raw_ports = row[0] or ""
ports = []
try:
ports = json.loads(raw_ports) if raw_ports else []
except Exception:
ports = []
return {
"agent_id": agent_id,
"allowed_ports": ports or default_ports,
"default_ports": default_ports,
"shell_port": shell_port,
"updated_at": row[1],
"updated_by": row[2],
"source": "custom" if ports else "default",
}, 200
except Exception as exc:
self.logger.debug("Failed to load vpn config", exc_info=True)
return {"error": "internal_error"}, 500
finally:
try:
conn.close()
except Exception:
pass
def set_vpn_config(self, agent_id: str, payload: Dict[str, Any]) -> Tuple[Dict[str, Any], int]:
agent_id = (agent_id or "").strip()
if not agent_id:
return {"error": "agent_id_required"}, 400
ports = self._parse_ports(payload.get("allowed_ports"))
if not ports:
return {"error": "allowed_ports_required"}, 400
user = self._current_user() or {}
updated_by = user.get("username") or ""
updated_at = datetime.now(timezone.utc).isoformat()
try:
conn = self._db_conn()
cur = conn.cursor()
cur.execute(
"""
INSERT INTO device_vpn_config(agent_id, allowed_ports, updated_at, updated_by)
VALUES (?, ?, ?, ?)
ON CONFLICT(agent_id) DO UPDATE SET
allowed_ports=excluded.allowed_ports,
updated_at=excluded.updated_at,
updated_by=excluded.updated_by
""",
(agent_id, json.dumps(ports), updated_at, updated_by),
)
conn.commit()
return {
"agent_id": agent_id,
"allowed_ports": ports,
"updated_at": updated_at,
"updated_by": updated_by,
"source": "custom",
}, 200
except Exception:
self.logger.debug("Failed to save vpn config", exc_info=True)
return {"error": "internal_error"}, 500
finally:
try:
conn.close()
except Exception:
pass
def _require_login(self) -> Optional[Tuple[Dict[str, Any], int]]:
if not self._current_user():
return {"error": "unauthorized"}, 401
@@ -1793,6 +1920,19 @@ def register_management(app, adapters: "EngineServiceAdapters") -> None:
payload, status = service.set_device_description(hostname, description)
return jsonify(payload), status
@blueprint.route("/api/device/vpn_config/<agent_id>", methods=["GET", "PUT"])
def _vpn_config(agent_id: str):
requirement = service._require_login()
if requirement:
payload, status = requirement
return jsonify(payload), status
if request.method == "GET":
payload, status = service.get_vpn_config(agent_id)
else:
body = request.get_json(silent=True) or {}
payload, status = service.set_vpn_config(agent_id, body)
return jsonify(payload), status
@blueprint.route("/api/device_list_views", methods=["GET"])
def _list_views():
requirement = service._require_login()

View File

@@ -1,12 +1,14 @@
# ======================================================
# Data\Engine\services\API\devices\tunnel.py
# Description: Negotiation endpoint for reverse tunnel leases (operator-initiated; dormant until tunnel listener is wired).
# Description: WireGuard VPN tunnel API (connect/status/disconnect).
#
# API Endpoints (if applicable):
# - POST /api/tunnel/request (Token Authenticated) - Allocates a reverse tunnel lease for the requested agent/protocol.
# - POST /api/tunnel/connect (Token Authenticated) - Issues VPN session material for an agent.
# - GET /api/tunnel/status (Token Authenticated) - Returns VPN status for an agent.
# - DELETE /api/tunnel/disconnect (Token Authenticated) - Tears down VPN session for an agent.
# ======================================================
"""Reverse tunnel negotiation API (Engine side)."""
"""WireGuard VPN tunnel API (Engine side)."""
from __future__ import annotations
import os
@@ -15,15 +17,13 @@ from typing import Any, Dict, Optional, Tuple
from flask import Blueprint, jsonify, request, session
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from ...WebSocket.Agent.reverse_tunnel_orchestrator import ReverseTunnelService
from ...VPN import VpnTunnelService
if False: # pragma: no cover - import cycle hint for type checkers
from .. import EngineServiceAdapters
def _current_user(app) -> Optional[Dict[str, str]]:
"""Resolve operator identity from session or signed token."""
username = session.get("username")
role = session.get("role") or "User"
if username:
@@ -58,18 +58,22 @@ def _require_login(app) -> Optional[Tuple[Dict[str, Any], int]]:
return None
def _get_tunnel_service(adapters: "EngineServiceAdapters") -> ReverseTunnelService:
service = getattr(adapters.context, "reverse_tunnel_service", None) or getattr(adapters, "_reverse_tunnel_service", None)
def _get_tunnel_service(adapters: "EngineServiceAdapters") -> VpnTunnelService:
service = getattr(adapters.context, "vpn_tunnel_service", None) or getattr(adapters, "_vpn_tunnel_service", None)
if service is None:
service = ReverseTunnelService(
adapters.context,
signer=getattr(adapters, "script_signer", None),
manager = getattr(adapters.context, "wireguard_server_manager", None)
if manager is None:
raise RuntimeError("wireguard_manager_unavailable")
service = VpnTunnelService(
context=adapters.context,
wireguard_manager=manager,
db_conn_factory=adapters.db_conn_factory,
socketio=getattr(adapters.context, "socketio", None),
service_log=adapters.service_log,
signer=getattr(adapters, "script_signer", None),
)
service.start()
setattr(adapters, "_reverse_tunnel_service", service)
setattr(adapters.context, "reverse_tunnel_service", service)
setattr(adapters, "_vpn_tunnel_service", service)
setattr(adapters.context, "vpn_tunnel_service", service)
return service
@@ -83,14 +87,11 @@ def _normalize_text(value: Any) -> str:
def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
"""Register reverse tunnel negotiation endpoints."""
blueprint = Blueprint("vpn_tunnel", __name__)
logger = adapters.context.logger.getChild("vpn_tunnel.api")
blueprint = Blueprint("reverse_tunnel", __name__)
service_log = adapters.service_log
logger = adapters.context.logger.getChild("tunnel.api")
@blueprint.route("/api/tunnel/request", methods=["POST"])
def request_tunnel():
@blueprint.route("/api/tunnel/connect", methods=["POST"])
def connect_tunnel():
requirement = _require_login(app)
if requirement:
payload, status = requirement
@@ -101,69 +102,67 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id"))
protocol = _normalize_text(body.get("protocol") or "ps").lower() or "ps"
domain = _normalize_text(body.get("domain") or protocol).lower() or protocol
if protocol == "ps" and domain == "ps":
domain = "remote-interactive-shell"
if not agent_id:
return jsonify({"error": "agent_id_required"}), 400
tunnel_service = _get_tunnel_service(adapters)
try:
lease = tunnel_service.request_lease(
agent_id=agent_id,
protocol=protocol,
domain=domain,
operator_id=operator_id,
)
tunnel_service = _get_tunnel_service(adapters)
payload = tunnel_service.connect(agent_id=agent_id, operator_id=operator_id)
except RuntimeError as exc:
message = str(exc)
if message.startswith("domain_limit:"):
domain_name = message.split(":", 1)[-1] if ":" in message else domain
return jsonify({"error": "domain_limit", "domain": domain_name}), 409
if message == "port_pool_exhausted":
return jsonify({"error": "port_pool_exhausted"}), 503
logger.warning("tunnel lease request failed for agent_id=%s: %s", agent_id, message)
return jsonify({"error": "lease_allocation_failed"}), 500
logger.warning("vpn connect failed for agent_id=%s: %s", agent_id, exc)
return jsonify({"error": "connect_failed"}), 500
summary = tunnel_service.lease_summary(lease)
summary["fixed_port"] = tunnel_service.fixed_port
summary["heartbeat_seconds"] = tunnel_service.heartbeat_seconds
return jsonify(payload), 200
service_log(
"reverse_tunnel",
f"lease created tunnel_id={lease.tunnel_id} agent_id={lease.agent_id} domain={lease.domain} protocol={lease.protocol}",
)
return jsonify(summary), 200
@blueprint.route("/api/tunnel/<tunnel_id>", methods=["DELETE"])
def stop_tunnel(tunnel_id: str):
@blueprint.route("/api/tunnel/status", methods=["GET"])
def tunnel_status():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
tunnel_id_norm = _normalize_text(tunnel_id)
if not tunnel_id_norm:
return jsonify({"error": "tunnel_id_required"}), 400
agent_id = _normalize_text(request.args.get("agent_id") or "")
if not agent_id:
return jsonify({"error": "agent_id_required"}), 400
tunnel_service = _get_tunnel_service(adapters)
payload = tunnel_service.status(agent_id)
if not payload:
return jsonify({"status": "down", "agent_id": agent_id}), 200
payload["status"] = "up"
bump = _normalize_text(request.args.get("bump") or "")
if bump:
tunnel_service.bump_activity(agent_id)
return jsonify(payload), 200
@blueprint.route("/api/tunnel/connect/status", methods=["GET"])
def tunnel_connect_status():
return tunnel_status()
@blueprint.route("/api/tunnel/disconnect", methods=["DELETE"])
def disconnect_tunnel():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id"))
tunnel_id = _normalize_text(body.get("tunnel_id"))
reason = _normalize_text(body.get("reason") or "operator_stop")
tunnel_service = _get_tunnel_service(adapters)
stopped = False
try:
stopped = tunnel_service.stop_tunnel(tunnel_id_norm, reason=reason)
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("stop_tunnel failed tunnel_id=%s: %s", tunnel_id_norm, exc, exc_info=True)
if tunnel_id:
stopped = tunnel_service.disconnect_by_tunnel(tunnel_id, reason=reason)
elif agent_id:
stopped = tunnel_service.disconnect(agent_id, reason=reason)
else:
return jsonify({"error": "agent_id_required"}), 400
if not stopped:
return jsonify({"error": "not_found"}), 404
service_log(
"reverse_tunnel",
f"lease stopped tunnel_id={tunnel_id_norm} reason={reason or '-'}",
)
return jsonify({"status": "stopped", "tunnel_id": tunnel_id_norm}), 200
return jsonify({"status": "stopped", "reason": reason}), 200
app.register_blueprint(blueprint)

View File

@@ -8,4 +8,4 @@
"""VPN service helpers for the Engine runtime."""
from .wireguard_server import WireGuardServerConfig, WireGuardServerManager # noqa: F401
from .vpn_tunnel_service import VpnTunnelService # noqa: F401

View File

@@ -0,0 +1,473 @@
# ======================================================
# Data\Engine\services\VPN\vpn_tunnel_service.py
# Description: WireGuard tunnel orchestration (single tunnel per agent, token issuance, idle handling).
#
# API Endpoints (if applicable): None
# ======================================================
"""WireGuard tunnel orchestration helpers for the Engine runtime."""
from __future__ import annotations
import base64
import ipaddress
import json
import threading
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from .wireguard_server import WireGuardServerManager
@dataclass
class VpnSession:
tunnel_id: str
agent_id: str
virtual_ip: str
token: Dict[str, Any]
client_public_key: str
client_private_key: str
allowed_ports: Tuple[int, ...]
created_at: float
expires_at: float
last_activity: float
operator_ids: set[str] = field(default_factory=set)
firewall_rules: List[str] = field(default_factory=list)
activity_id: Optional[int] = None
hostname: Optional[str] = None
class VpnTunnelService:
def __init__(
self,
*,
context: Any,
wireguard_manager: WireGuardServerManager,
db_conn_factory,
socketio,
service_log,
signer: Optional[Any] = None,
idle_seconds: int = 900,
) -> None:
self.context = context
self.wg = wireguard_manager
self.db_conn_factory = db_conn_factory
self.socketio = socketio
self.service_log = service_log
self.signer = signer
self.logger = context.logger.getChild("vpn_tunnel")
self.activity_logger = self.wg.logger.getChild("device_activity")
self.idle_seconds = max(60, int(idle_seconds))
self._lock = threading.Lock()
self._sessions_by_agent: Dict[str, VpnSession] = {}
self._sessions_by_tunnel: Dict[str, VpnSession] = {}
self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip)
self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False)
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
self._idle_thread.start()
def _idle_loop(self) -> None:
while True:
time.sleep(10)
now = time.time()
expired: List[VpnSession] = []
with self._lock:
for session in list(self._sessions_by_agent.values()):
if session.last_activity + self.idle_seconds <= now:
expired.append(session)
for session in expired:
self.disconnect(session.agent_id, reason="idle_timeout")
def _allocate_virtual_ip(self, agent_id: str) -> str:
existing = self._sessions_by_agent.get(agent_id)
if existing:
return existing.virtual_ip
used = {s.virtual_ip for s in self._sessions_by_agent.values()}
for host in self._peer_network.hosts():
if host == self._engine_ip.ip:
continue
candidate = f"{host}/32"
if candidate not in used:
return candidate
raise RuntimeError("vpn_ip_pool_exhausted")
def _load_allowed_ports(self, agent_id: str) -> Tuple[int, ...]:
default = tuple(self.context.wireguard_acl_allowlist_windows or ())
try:
conn = self.db_conn_factory()
cur = conn.cursor()
try:
cur.execute(
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
(agent_id,),
)
row = cur.fetchone()
os_name = str(row[0]).lower() if row and row[0] else ""
except Exception:
os_name = ""
if os_name and "windows" not in os_name:
baseline = {5900, 3478}
filtered = [p for p in default if p in baseline]
if filtered:
default = tuple(filtered)
cur.execute(
"SELECT allowed_ports FROM device_vpn_config WHERE agent_id=?",
(agent_id,),
)
row = cur.fetchone()
if not row:
return default
raw = row[0] or ""
ports = json.loads(raw) if raw else []
ports = [int(p) for p in ports if isinstance(p, (int, float, str))]
ports = [p for p in ports if 1 <= p <= 65535]
return tuple(dict.fromkeys(ports)) or default
except Exception:
return default
finally:
try:
conn.close()
except Exception:
pass
def _generate_client_keys(self) -> Tuple[str, str]:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import x25519
key = x25519.X25519PrivateKey.generate()
priv = base64.b64encode(
key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
).decode("ascii").strip()
pub = base64.b64encode(
key.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
).decode("ascii").strip()
return priv, pub
def _issue_token(self, agent_id: str, tunnel_id: str, expires_at: float) -> Dict[str, Any]:
payload = {
"agent_id": agent_id,
"tunnel_id": tunnel_id,
"port": self.context.wireguard_port,
"expires_at": expires_at,
"issued_at": time.time(),
}
if not self.signer:
return dict(payload)
token = dict(payload)
try:
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
signature = self.signer.sign(payload_bytes)
token["signature"] = base64.b64encode(signature).decode("ascii")
if hasattr(self.signer, "public_base64_spki"):
token["signing_key"] = self.signer.public_base64_spki()
token["sig_alg"] = "ed25519"
except Exception:
self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True)
return token
def _refresh_listener(self) -> None:
peers: List[Mapping[str, object]] = []
for session in self._sessions_by_agent.values():
peer = self.wg.build_peer_profile(
session.agent_id,
session.virtual_ip,
allowed_ports=session.allowed_ports,
)
peer = dict(peer)
peer["public_key"] = session.client_public_key
peers.append(peer)
if not peers:
self.wg.stop_listener()
return
self.wg.start_listener(peers)
def connect(self, *, agent_id: str, operator_id: Optional[str]) -> Mapping[str, Any]:
now = time.time()
with self._lock:
existing = self._sessions_by_agent.get(agent_id)
if existing:
if operator_id:
existing.operator_ids.add(operator_id)
existing.last_activity = now
return self._session_payload(existing)
tunnel_id = uuid.uuid4().hex
virtual_ip = self._allocate_virtual_ip(agent_id)
allowed_ports = self._load_allowed_ports(agent_id)
client_private, client_public = self._generate_client_keys()
token = self._issue_token(agent_id, tunnel_id, now + 300)
self.wg.require_orchestration_token(token)
session = VpnSession(
tunnel_id=tunnel_id,
agent_id=agent_id,
virtual_ip=virtual_ip,
token=token,
client_public_key=client_public,
client_private_key=client_private,
allowed_ports=allowed_ports,
created_at=now,
expires_at=now + 300,
last_activity=now,
)
if operator_id:
session.operator_ids.add(operator_id)
self._sessions_by_agent[agent_id] = session
self._sessions_by_tunnel[tunnel_id] = session
try:
self._refresh_listener()
peer = self.wg.build_peer_profile(
agent_id,
virtual_ip,
allowed_ports=allowed_ports,
)
rule_names = self.wg.apply_firewall_rules(peer)
session.firewall_rules = rule_names
except Exception:
with self._lock:
self._sessions_by_agent.pop(agent_id, None)
self._sessions_by_tunnel.pop(tunnel_id, None)
try:
self._refresh_listener()
except Exception:
self.logger.debug("Failed to refresh WireGuard listener after connect rollback.", exc_info=True)
raise
payload = self._session_payload(session)
self._emit_start(payload)
self._log_device_activity(session, event="start")
return payload
def status(self, agent_id: str) -> Optional[Mapping[str, Any]]:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
if not session:
return None
return self._session_payload(session, include_token=False)
def bump_activity(self, agent_id: str) -> None:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
if not session:
return
session.last_activity = time.time()
try:
if self.socketio:
self.socketio.emit("vpn_tunnel_activity", {"agent_id": agent_id}, namespace="/")
except Exception:
self.logger.debug("vpn_tunnel_activity emit failed for agent_id=%s", agent_id, exc_info=True)
def disconnect(self, agent_id: str, reason: str = "operator_stop") -> bool:
with self._lock:
session = self._sessions_by_agent.pop(agent_id, None)
if not session:
return False
self._sessions_by_tunnel.pop(session.tunnel_id, None)
try:
self.wg.remove_firewall_rules(session.firewall_rules)
except Exception:
self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True)
self._refresh_listener()
self._emit_stop(session, reason)
self._log_device_activity(session, event="stop", reason=reason)
return True
def disconnect_by_tunnel(self, tunnel_id: str, reason: str = "operator_stop") -> bool:
with self._lock:
session = self._sessions_by_tunnel.get(tunnel_id)
if not session:
return False
return self.disconnect(session.agent_id, reason=reason)
def _emit_start(self, payload: Mapping[str, Any]) -> None:
if not self.socketio:
return
try:
self.socketio.emit("vpn_tunnel_start", payload, namespace="/")
except Exception:
self.logger.debug("vpn_tunnel_start emit failed", exc_info=True)
def _emit_stop(self, session: VpnSession, reason: str) -> None:
if not self.socketio:
return
try:
self.socketio.emit(
"vpn_tunnel_stop",
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
namespace="/",
)
except Exception:
self.logger.debug("vpn_tunnel_stop emit failed", exc_info=True)
def _log_device_activity(self, session: VpnSession, *, event: str, reason: Optional[str] = None) -> None:
if self.db_conn_factory is None:
self.activity_logger.info(
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s",
event,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
reason or "-",
)
return
conn = None
try:
conn = self.db_conn_factory()
cur = conn.cursor()
hostname = session.hostname
if not hostname:
try:
cur.execute(
"SELECT hostname FROM devices WHERE agent_id = ? ORDER BY last_seen DESC LIMIT 1",
(session.agent_id,),
)
row = cur.fetchone()
if row and row[0]:
hostname = str(row[0]).strip()
session.hostname = hostname
except Exception:
hostname = None
if not hostname:
self.activity_logger.info(
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s hostname=unknown",
event,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
reason or "-",
)
return
now_ts = int(time.time())
script_name = "Reverse VPN Tunnel (WireGuard)"
if event == "start":
cur.execute(
"""
INSERT INTO activity_history(hostname, script_path, script_name, script_type, ran_at, status, stdout, stderr)
VALUES(?,?,?,?,?,?,?,?)
""",
(
hostname,
session.tunnel_id,
script_name,
"vpn_tunnel",
now_ts,
"Running",
"",
"",
),
)
session.activity_id = cur.lastrowid
conn.commit()
if self.socketio:
try:
self.socketio.emit(
"device_activity_changed",
{
"hostname": hostname,
"activity_id": session.activity_id,
"change": "created",
"source": "vpn_tunnel",
},
)
except Exception:
pass
self.activity_logger.info(
"device_activity_start hostname=%s agent_id=%s tunnel_id=%s operator=%s activity_id=%s",
hostname,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
session.activity_id or "-",
)
return
if session.activity_id:
status = "Completed" if event == "stop" else "Closed"
cur.execute(
"""
UPDATE activity_history
SET status=?,
stderr=COALESCE(stderr, '') || ?
WHERE id=?
""",
(
status,
f"\nreason: {reason}" if reason else "",
session.activity_id,
),
)
conn.commit()
if self.socketio:
try:
self.socketio.emit(
"device_activity_changed",
{
"hostname": hostname,
"activity_id": session.activity_id,
"change": "updated",
"source": "vpn_tunnel",
},
)
except Exception:
pass
self.activity_logger.info(
"device_activity event=%s hostname=%s agent_id=%s tunnel_id=%s operator=%s reason=%s activity_id=%s",
event,
hostname,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
reason or "-",
session.activity_id or "-",
)
except Exception:
self.activity_logger.debug(
"device_activity logging failed for tunnel_id=%s",
session.tunnel_id,
exc_info=True,
)
finally:
if conn is not None:
try:
conn.close()
except Exception:
pass
def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]:
payload: Dict[str, Any] = {
"tunnel_id": session.tunnel_id,
"agent_id": session.agent_id,
"virtual_ip": session.virtual_ip,
"engine_virtual_ip": str(self._engine_ip.ip),
"allowed_ips": f"{self._engine_ip.ip}/32",
"endpoint": f"{self._engine_ip.ip}:{self.context.wireguard_port}",
"server_public_key": self.wg.server_public_key,
"client_public_key": session.client_public_key,
"client_private_key": session.client_private_key,
"idle_seconds": self.idle_seconds,
"allowed_ports": list(session.allowed_ports),
"connected_operators": len([o for o in session.operator_ids if o]),
}
if include_token:
payload["token"] = session.token
return payload

View File

@@ -70,7 +70,7 @@ class WireGuardServerManager:
self.logger = _build_logger(config.log_path)
self._ensure_cert_dir()
self.server_private_key, self.server_public_key = self._ensure_server_keys()
self._service_name = "BorealisWireGuard"
self._service_name = "borealis-wg"
self._temp_dir = Path(tempfile.gettempdir()) / "borealis-wg-engine"
def _ensure_cert_dir(self) -> None:
@@ -157,7 +157,7 @@ class WireGuardServerManager:
if not token:
raise ValueError("Missing orchestration token for WireGuard peer")
required_fields = ("agent_id", "tunnel_id", "expires_at")
required_fields = ("agent_id", "tunnel_id", "expires_at", "port")
missing = [field for field in required_fields if field not in token or token[field] in (None, "")]
if missing:
raise ValueError(f"Invalid orchestration token; missing {', '.join(missing)}")
@@ -167,6 +167,13 @@ class WireGuardServerManager:
except Exception:
raise ValueError("Invalid orchestration token expiry")
try:
port = int(token["port"])
except Exception:
raise ValueError("Invalid orchestration token port")
if port != int(self.config.port):
raise ValueError("Orchestration token port mismatch")
now = time.time()
if expires_at <= now:
raise ValueError("Orchestration token expired")
@@ -253,12 +260,14 @@ class WireGuardServerManager:
"host_only": True,
}
def apply_firewall_rules(self, peer: Mapping[str, object]) -> None:
def apply_firewall_rules(self, peer: Mapping[str, object]) -> List[str]:
"""Apply outbound firewall allow rules for the agent's virtual IP/ports (Windows netsh)."""
rules = self.build_firewall_rules(peer)
rule_names: List[str] = []
for idx, rule in enumerate(rules):
name = f"Borealis-WG-Agent-{peer.get('agent_id','')}-{idx}"
protocol = str(rule.get("protocol") or "TCP").upper()
args = [
"netsh",
"advfirewall",
@@ -269,7 +278,7 @@ class WireGuardServerManager:
"dir=out",
"action=allow",
f"remoteip={rule.get('remote_address','')}",
f"protocol=TCP",
f"protocol={protocol}",
f"localport={rule.get('local_port','')}",
]
code, out, err = self._run_command(args)
@@ -277,6 +286,19 @@ class WireGuardServerManager:
self.logger.warning("Failed to apply firewall rule %s code=%s err=%s", name, code, err)
else:
self.logger.info("Applied firewall rule %s", name)
rule_names.append(name)
return rule_names
def remove_firewall_rules(self, rule_names: Sequence[str]) -> None:
for name in rule_names:
if not name:
continue
args = ["netsh", "advfirewall", "firewall", "delete", "rule", f"name={name}"]
code, out, err = self._run_command(args)
if code != 0:
self.logger.warning("Failed to remove firewall rule %s code=%s err=%s", name, code, err)
else:
self.logger.info("Removed firewall rule %s", name)
def start_listener(self, peers: Sequence[Mapping[str, object]]) -> None:
"""Render a temporary WireGuard config and start the service."""
@@ -291,6 +313,9 @@ class WireGuardServerManager:
config_path.write_text(rendered, encoding="utf-8")
self.logger.info("Rendered WireGuard config to %s", config_path)
# Ensure old service is removed before re-installing.
self.stop_listener()
args = ["wireguard.exe", "/installtunnelservice", str(config_path)]
code, out, err = self._run_command(args)
if code != 0:
@@ -301,7 +326,7 @@ class WireGuardServerManager:
def stop_listener(self) -> None:
"""Stop and remove the WireGuard tunnel service."""
args = ["wireguard.exe", "/uninstalltunnelservice", "borealis-wg"]
args = ["wireguard.exe", "/uninstalltunnelservice", self._service_name]
code, out, err = self._run_command(args)
if code != 0:
self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err)
@@ -323,15 +348,17 @@ class WireGuardServerManager:
port_list = []
for port in port_list:
rules.append(
{
"direction": "outbound",
"remote_address": ip,
"local_port": port,
"action": "allow",
"description": f"WireGuard engine->agent allow port {port}",
}
)
for protocol in ("TCP", "UDP"):
rules.append(
{
"direction": "outbound",
"remote_address": ip,
"local_port": port,
"protocol": protocol,
"action": "allow",
"description": f"WireGuard engine->agent allow port {port}/{protocol}",
}
)
self.logger.info(
"Prepared firewall rule plan for agent=%s rules=%s",

View File

@@ -1,3 +0,0 @@
"""Namespace package for reverse tunnel domain handlers (Engine side)."""
__all__ = ["remote_interactive_shell", "remote_management", "remote_video"]

View File

@@ -1,78 +0,0 @@
"""Placeholder Bash channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class BashChannelServer:
"""Stub Bash handler until the agent-side channel is implemented."""
protocol_name = "bash"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"bash.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
# No-op placeholder; output collection for future Bash support.
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"bash channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
# Placeholder: no agent-side Bash yet.
self.logger.info("bash placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
# Placeholder: not implemented.
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["BashChannelServer"]

View File

@@ -1,139 +0,0 @@
"""Engine-side PowerShell tunnel channel helper (remote interactive shell domain)."""
from __future__ import annotations
import json
from collections import deque
from typing import Any, Deque, Dict, List, Optional
# Mirror framing constants to avoid circular imports.
MSG_CHANNEL_OPEN = 0x03
MSG_CHANNEL_ACK = 0x04
MSG_DATA = 0x05
MSG_CONTROL = 0x09
MSG_CLOSE = 0x08
CLOSE_OK = 0
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AGENT_SHUTDOWN = 6
class PowershellChannelServer:
"""Coordinate PowerShell channel frames over a TunnelBridge."""
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"ps.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output: Deque[str] = deque()
self._close_reason: Optional[str] = None
self._close_code: Optional[int] = None
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
# ------------------------------------------------------------------ Agent frame handling
def handle_agent_frame(self, frame) -> None:
if frame.channel_id != self.channel_id:
return
if frame.msg_type == MSG_CHANNEL_ACK:
self._ack_received = True
self.logger.info("ps channel acked tunnel_id=%s", self.bridge.lease.tunnel_id)
return
if frame.msg_type == MSG_DATA:
try:
text = frame.payload.decode("utf-8", errors="replace")
except Exception:
text = ""
if text:
self._append_output(text)
return
if frame.msg_type == MSG_CLOSE:
try:
payload = json.loads(frame.payload.decode("utf-8"))
except Exception:
payload = {}
self._closed = True
self._close_code = payload.get("code") if isinstance(payload, dict) else None
self._close_reason = payload.get("reason") if isinstance(payload, dict) else None
self.logger.info(
"ps channel closed tunnel_id=%s code=%s reason=%s",
self.bridge.lease.tunnel_id,
self._close_code,
self._close_reason or "-",
)
# ------------------------------------------------------------------ Operator actions
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
if self._open_sent:
return
payload = json.dumps(
{"protocol": "ps", "metadata": {"cols": cols, "rows": rows}},
separators=(",", ":"),
).encode("utf-8")
frame = self._frame_cls(msg_type=MSG_CHANNEL_OPEN, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
self._open_sent = True
self.logger.info(
"ps channel open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
if self._closed:
return
payload = data.encode("utf-8", errors="replace")
frame = self._frame_cls(msg_type=MSG_DATA, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
def send_resize(self, cols: int, rows: int) -> None:
if self._closed:
return
payload = json.dumps({"cols": cols, "rows": rows}, separators=(",", ":")).encode("utf-8")
frame = self._frame_cls(msg_type=MSG_CONTROL, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
def close(self, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
frame = self._close_frame_fn(self.channel_id, code, reason)
else:
frame = self._frame_cls(
msg_type=MSG_CLOSE,
channel_id=self.channel_id,
payload=json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8"),
)
self.bridge.operator_to_agent(frame)
# ------------------------------------------------------------------ Output polling
def drain_output(self) -> List[str]:
items: List[str] = []
while self._output:
items.append(self._output.popleft())
return items
def _append_output(self, text: str) -> None:
self._output.append(text)
# Cap buffer to avoid unbounded memory growth.
while len(self._output) > 500:
self._output.popleft()
# ------------------------------------------------------------------ Status helpers
def status(self) -> Dict[str, Any]:
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": self._close_reason,
"close_code": self._close_code,
}
__all__ = ["PowershellChannelServer"]

View File

@@ -1,6 +0,0 @@
"""Protocol handlers for remote interactive shell tunnels (Engine side)."""
from .Powershell import PowershellChannelServer
from .Bash import BashChannelServer
__all__ = ["PowershellChannelServer", "BashChannelServer"]

View File

@@ -1,3 +0,0 @@
"""Domain handlers for remote interactive shells (PowerShell/Bash)."""
__all__ = ["Protocols"]

View File

@@ -1,73 +0,0 @@
"""Placeholder SSH channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class SSHChannelServer:
protocol_name = "ssh"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"ssh.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"ssh channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("ssh placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["SSHChannelServer"]

View File

@@ -1,73 +0,0 @@
"""Placeholder WinRM channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class WinRMChannelServer:
protocol_name = "winrm"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"winrm.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"winrm channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("winrm placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["WinRMChannelServer"]

View File

@@ -1,6 +0,0 @@
"""Protocol handlers for remote management tunnels (Engine side)."""
from .SSH import SSHChannelServer
from .WinRM import WinRMChannelServer
__all__ = ["SSHChannelServer", "WinRMChannelServer"]

View File

@@ -1,3 +0,0 @@
"""Domain handlers for remote management tunnels (SSH/WinRM)."""
__all__ = ["Protocols"]

View File

@@ -1,73 +0,0 @@
"""Placeholder RDP channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class RDPChannelServer:
protocol_name = "rdp"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"rdp.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"rdp channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("rdp placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["RDPChannelServer"]

View File

@@ -1,73 +0,0 @@
"""Placeholder VNC channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class VNCChannelServer:
protocol_name = "vnc"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"vnc.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"vnc channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("vnc placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["VNCChannelServer"]

View File

@@ -1,73 +0,0 @@
"""Placeholder WebRTC channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class WebRTCChannelServer:
protocol_name = "webrtc"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"webrtc.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"webrtc channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("webrtc placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["WebRTCChannelServer"]

View File

@@ -1,7 +0,0 @@
"""Protocol handlers for remote video tunnels (Engine side)."""
from .WebRTC import WebRTCChannelServer
from .RDP import RDPChannelServer
from .VNC import VNCChannelServer
__all__ = ["WebRTCChannelServer", "RDPChannelServer", "VNCChannelServer"]

View File

@@ -1,3 +0,0 @@
"""Domain handlers for remote video/desktop tunnels (RDP/VNC/WebRTC)."""
__all__ = ["Protocols"]

View File

@@ -1,10 +0,0 @@
# ======================================================
# Data\Engine\services\WebSocket\Agent\__init__.py
# Description: Package marker for Agent-facing WebSocket services (reverse tunnel scaffolding).
#
# API Endpoints (if applicable): None
# ======================================================
"""Agent-facing WebSocket services for the Engine runtime."""
__all__ = []

View File

@@ -1,6 +1,6 @@
# ======================================================
# Data\Engine\services\WebSocket\__init__.py
# Description: Socket.IO handlers for Engine runtime quick job updates and realtime notifications.
# Description: Socket.IO handlers for Engine runtime quick job updates and VPN shell bridging.
#
# API Endpoints (if applicable): None
# ======================================================
@@ -8,24 +8,20 @@
"""WebSocket service registration for the Borealis Engine runtime."""
from __future__ import annotations
import base64
import sqlite3
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, Optional
from flask import session, request
from flask import request
from flask_socketio import SocketIO
from ...database import initialise_engine_database
from ...security import signing
from ...server import EngineContext
from .Agent.reverse_tunnel_orchestrator import (
ReverseTunnelService,
TunnelBridge,
decode_frame,
TunnelFrame,
)
from ..VPN import VpnTunnelService
from .vpn_shell import VpnShellBridge
def _now_ts() -> int:
@@ -70,20 +66,31 @@ class EngineRealtimeAdapters:
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
"""Register Socket.IO event handlers for the Engine runtime."""
from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load
adapters = EngineRealtimeAdapters(context)
logger = context.logger.getChild("realtime.quick_jobs")
tunnel_service = getattr(context, "reverse_tunnel_service", None)
if tunnel_service is None:
tunnel_service = ReverseTunnelService(
context,
signer=None,
shell_bridge = VpnShellBridge(socket_server, context)
def _get_tunnel_service() -> Optional[VpnTunnelService]:
service = getattr(context, "vpn_tunnel_service", None)
if service is not None:
return service
manager = getattr(context, "wireguard_server_manager", None)
if manager is None:
return None
try:
signer = signing.load_signer()
except Exception:
signer = None
service = VpnTunnelService(
context=context,
wireguard_manager=manager,
db_conn_factory=adapters.db_conn_factory,
socketio=socket_server,
service_log=adapters.service_log,
signer=signer,
)
tunnel_service.start()
setattr(context, "reverse_tunnel_service", tunnel_service)
setattr(context, "vpn_tunnel_service", service)
return service
@socket_server.on("quick_job_result")
def _handle_quick_job_result(data: Any) -> None:
@@ -246,252 +253,45 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
exc,
)
@socket_server.on("tunnel_bridge_attach")
def _tunnel_bridge_attach(data: Any) -> Any:
"""Placeholder operator bridge attach handler (no data channel yet)."""
@socket_server.on("vpn_shell_open")
def _vpn_shell_open(data: Any) -> Dict[str, Any]:
agent_id = ""
if isinstance(data, dict):
agent_id = str(data.get("agent_id") or "").strip()
elif isinstance(data, str):
agent_id = data.strip()
if not agent_id:
return {"error": "agent_id_required"}
if not isinstance(data, dict):
return {"error": "invalid_payload"}
service = _get_tunnel_service()
if service is None:
return {"error": "vpn_service_unavailable"}
if not service.status(agent_id):
return {"error": "tunnel_down"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
operator_id = str(data.get("operator_id") or "").strip() or None
if not tunnel_id:
return {"error": "tunnel_id_required"}
session = shell_bridge.open_session(request.sid, agent_id)
if session is None:
return {"error": "shell_connect_failed"}
service.bump_activity(agent_id)
return {"status": "ok"}
try:
tunnel_service.operator_attach(tunnel_id, operator_id)
except ValueError as exc:
return {"error": str(exc)}
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("tunnel_bridge_attach failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "bridge_attach_failed"}
return {"status": "ok", "tunnel_id": tunnel_id, "operator_id": operator_id or "-"}
def _encode_frame(frame: TunnelFrame) -> str:
return base64.b64encode(frame.encode()).decode("ascii")
def _decode_frame_payload(raw: Any) -> TunnelFrame:
if isinstance(raw, str):
try:
raw_bytes = base64.b64decode(raw)
except Exception:
raise ValueError("invalid_frame")
elif isinstance(raw, (bytes, bytearray)):
raw_bytes = bytes(raw)
@socket_server.on("vpn_shell_send")
def _vpn_shell_send(data: Any) -> Dict[str, Any]:
payload = None
if isinstance(data, dict):
payload = data.get("data")
else:
raise ValueError("invalid_frame")
return decode_frame(raw_bytes)
@socket_server.on("tunnel_operator_send")
def _tunnel_operator_send(data: Any) -> Any:
"""Operator -> agent frame enqueue (placeholder queue)."""
if not isinstance(data, dict):
return {"error": "invalid_payload"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
frame_raw = data.get("frame")
if not tunnel_id or frame_raw is None:
return {"error": "tunnel_id_and_frame_required"}
try:
frame = _decode_frame_payload(frame_raw)
except Exception as exc:
return {"error": str(exc)}
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
bridge.operator_to_agent(frame)
return {"status": "ok"}
@socket_server.on("tunnel_operator_poll")
def _tunnel_operator_poll(data: Any) -> Any:
"""Operator polls queued frames from agent."""
tunnel_id = ""
if isinstance(data, dict):
tunnel_id = str(data.get("tunnel_id") or "").strip()
if not tunnel_id:
return {"error": "tunnel_id_required"}
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
frames = []
while True:
frame = bridge.next_for_operator()
if frame is None:
break
frames.append(_encode_frame(frame))
return {"frames": frames}
# WebUI operator bridge namespace for browser clients
tunnel_namespace = "/tunnel"
_operator_sessions: Dict[str, str] = {}
def _current_operator() -> Optional[str]:
username = session.get("username")
if username:
return str(username)
auth_header = (request.headers.get("Authorization") or "").strip()
token = None
if auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
if not token:
token = request.cookies.get("borealis_auth")
return token or None
@socket_server.on("join", namespace=tunnel_namespace)
def _ws_tunnel_join(data: Any) -> Any:
if not isinstance(data, dict):
return {"error": "invalid_payload"}
operator_id = _current_operator()
if not operator_id:
return {"error": "unauthorized"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
if not tunnel_id:
return {"error": "tunnel_id_required"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
try:
tunnel_service.operator_attach(tunnel_id, operator_id)
except Exception as exc:
logger.debug("ws_tunnel_join failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "attach_failed"}
sid = request.sid
_operator_sessions[sid] = tunnel_id
return {"status": "ok", "tunnel_id": tunnel_id}
@socket_server.on("send", namespace=tunnel_namespace)
def _ws_tunnel_send(data: Any) -> Any:
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return {"error": "not_joined"}
if not isinstance(data, dict):
return {"error": "invalid_payload"}
frame_raw = data.get("frame")
if frame_raw is None:
return {"error": "frame_required"}
try:
frame = _decode_frame_payload(frame_raw)
except Exception:
return {"error": "invalid_frame"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
bridge.operator_to_agent(frame)
return {"status": "ok"}
@socket_server.on("poll", namespace=tunnel_namespace)
def _ws_tunnel_poll() -> Any:
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return {"error": "not_joined"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
frames = []
while True:
frame = bridge.next_for_operator()
if frame is None:
break
frames.append(_encode_frame(frame))
return {"frames": frames}
def _require_ps_server():
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return None, None, {"error": "not_joined"}
server = tunnel_service.ensure_protocol_server(tunnel_id)
if server is None or not hasattr(server, "open_channel"):
return None, tunnel_id, {"error": "ps_unsupported"}
return server, tunnel_id, None
@socket_server.on("ps_open", namespace=tunnel_namespace)
def _ws_ps_open(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = 120
rows = 32
if isinstance(data, dict):
try:
cols = int(data.get("cols", cols))
rows = int(data.get("rows", rows))
except Exception:
pass
cols = max(20, min(cols, 300))
rows = max(10, min(rows, 200))
try:
server.open_channel(cols=cols, rows=rows)
except Exception as exc:
logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_open_failed"}
return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows}
@socket_server.on("ps_send", namespace=tunnel_namespace)
def _ws_ps_send(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
if data is None:
payload = data
if payload is None:
return {"error": "payload_required"}
text = data
if isinstance(data, dict):
text = data.get("data")
if text is None:
return {"error": "payload_required"}
try:
server.send_input(str(text))
except Exception as exc:
logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_send_failed"}
shell_bridge.send(request.sid, str(payload))
return {"status": "ok"}
@socket_server.on("ps_resize", namespace=tunnel_namespace)
def _ws_ps_resize(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = None
rows = None
if isinstance(data, dict):
cols = data.get("cols")
rows = data.get("rows")
try:
cols_int = int(cols) if cols is not None else 120
rows_int = int(rows) if rows is not None else 32
cols_int = max(20, min(cols_int, 300))
rows_int = max(10, min(rows_int, 200))
server.send_resize(cols_int, rows_int)
return {"status": "ok", "cols": cols_int, "rows": rows_int}
except Exception as exc:
logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_resize_failed"}
@socket_server.on("vpn_shell_close")
def _vpn_shell_close() -> Dict[str, Any]:
shell_bridge.close(request.sid)
return {"status": "ok"}
@socket_server.on("ps_poll", namespace=tunnel_namespace)
def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
try:
output = server.drain_output()
status = server.status()
return {"output": output, "status": status}
except Exception as exc:
logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_poll_failed"}
@socket_server.on("disconnect", namespace=tunnel_namespace)
def _ws_tunnel_disconnect():
sid = request.sid
tunnel_id = _operator_sessions.pop(sid, None)
if tunnel_id and tunnel_id not in _operator_sessions.values():
try:
tunnel_service.stop_tunnel(tunnel_id, reason="operator_socket_disconnect")
except Exception as exc:
logger.debug("ws_tunnel_disconnect stop_tunnel failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
@socket_server.on("disconnect")
def _ws_disconnect() -> None:
shell_bridge.close(request.sid)

View File

@@ -0,0 +1,127 @@
# ======================================================
# Data\Engine\services\WebSocket\vpn_shell.py
# Description: Socket.IO handlers bridging UI shell to agent TCP server over WireGuard.
#
# API Endpoints (if applicable): None
# ======================================================
"""WireGuard VPN PowerShell bridge (Engine side)."""
from __future__ import annotations
import base64
import json
import socket
import threading
from dataclasses import dataclass
from typing import Any, Dict, Optional
def _b64encode(data: bytes) -> str:
return base64.b64encode(data).decode("ascii").strip()
def _b64decode(value: str) -> bytes:
return base64.b64decode(value.encode("ascii"))
@dataclass
class ShellSession:
sid: str
agent_id: str
socketio: Any
tcp: socket.socket
_reader: Optional[threading.Thread] = None
def start_reader(self) -> None:
t = threading.Thread(target=self._read_loop, daemon=True)
t.start()
self._reader = t
def _read_loop(self) -> None:
buffer = b""
try:
while True:
data = self.tcp.recv(4096)
if not data:
break
buffer += data
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
if not line:
continue
try:
msg = json.loads(line.decode("utf-8"))
except Exception:
continue
if msg.get("type") == "stdout":
payload = msg.get("data") or ""
try:
decoded = _b64decode(str(payload)).decode("utf-8", errors="replace")
except Exception:
decoded = ""
self.socketio.emit("vpn_shell_output", {"data": decoded}, to=self.sid)
finally:
self.socketio.emit("vpn_shell_closed", {"agent_id": self.agent_id}, to=self.sid)
try:
self.tcp.close()
except Exception:
pass
def send(self, payload: str) -> None:
data = json.dumps({"type": "stdin", "data": _b64encode(payload.encode("utf-8"))})
self.tcp.sendall(data.encode("utf-8") + b"\n")
def close(self) -> None:
try:
data = json.dumps({"type": "close"})
self.tcp.sendall(data.encode("utf-8") + b"\n")
except Exception:
pass
try:
self.tcp.close()
except Exception:
pass
class VpnShellBridge:
def __init__(self, socketio, context) -> None:
self.socketio = socketio
self.context = context
self._sessions: Dict[str, ShellSession] = {}
self.logger = context.logger.getChild("vpn_shell")
def open_session(self, sid: str, agent_id: str) -> Optional[ShellSession]:
service = getattr(self.context, "vpn_tunnel_service", None)
if service is None:
return None
status = service.status(agent_id)
if not status:
return None
host = str(status.get("virtual_ip") or "").split("/")[0]
port = int(self.context.wireguard_shell_port)
try:
tcp = socket.create_connection((host, port), timeout=5)
except Exception:
self.logger.debug("Failed to connect vpn shell to %s:%s", host, port, exc_info=True)
return None
session = ShellSession(sid=sid, agent_id=agent_id, socketio=self.socketio, tcp=tcp)
self._sessions[sid] = session
session.start_reader()
return session
def send(self, sid: str, payload: str) -> None:
session = self._sessions.get(sid)
if not session:
return
session.send(payload)
service = getattr(self.context, "vpn_tunnel_service", None)
if service:
service.bump_activity(session.agent_id)
def close(self, sid: str) -> None:
session = self._sessions.pop(sid, None)
if not session:
return
session.close()

View File

@@ -8,13 +8,17 @@ import {
Tab,
Typography,
Button,
Switch,
Chip,
Divider,
Menu,
MenuItem,
TextField,
Dialog,
DialogTitle,
DialogContent,
DialogActions
DialogActions,
LinearProgress
} from "@mui/material";
import InfoOutlinedIcon from "@mui/icons-material/InfoOutlined";
import StorageRoundedIcon from "@mui/icons-material/StorageRounded";
@@ -23,6 +27,7 @@ import LanRoundedIcon from "@mui/icons-material/LanRounded";
import AppsRoundedIcon from "@mui/icons-material/AppsRounded";
import ListAltRoundedIcon from "@mui/icons-material/ListAltRounded";
import TerminalRoundedIcon from "@mui/icons-material/TerminalRounded";
import TuneRoundedIcon from "@mui/icons-material/TuneRounded";
import SpeedRoundedIcon from "@mui/icons-material/SpeedRounded";
import DeveloperBoardRoundedIcon from "@mui/icons-material/DeveloperBoardRounded";
import MoreHorizIcon from "@mui/icons-material/MoreHoriz";
@@ -69,14 +74,51 @@ const SECTION_HEIGHTS = {
network: 260,
};
const buildVpnGroups = (shellPort) => {
const normalizedShell = Number(shellPort) || 47001;
return [
{
key: "shell",
label: "Borealis PowerShell",
description: "Web terminal access over the VPN tunnel.",
ports: [normalizedShell],
},
{
key: "rdp",
label: "RDP",
description: "Remote Desktop (TCP 3389).",
ports: [3389],
},
{
key: "winrm",
label: "WinRM",
description: "PowerShell/WinRM management (TCP 5985/5986).",
ports: [5985, 5986],
},
{
key: "vnc",
label: "VNC",
description: "Remote desktop streaming (TCP 5900).",
ports: [5900],
},
{
key: "webrtc",
label: "WebRTC",
description: "Real-time comms (UDP 3478).",
ports: [3478],
},
];
};
const TOP_TABS = [
{ label: "Device Summary", icon: InfoOutlinedIcon },
{ label: "Storage", icon: StorageRoundedIcon },
{ label: "Memory", icon: MemoryRoundedIcon },
{ label: "Network", icon: LanRoundedIcon },
{ label: "Installed Software", icon: AppsRoundedIcon },
{ label: "Activity History", icon: ListAltRoundedIcon },
{ label: "Remote Shell", icon: TerminalRoundedIcon },
{ key: "summary", label: "Device Summary", icon: InfoOutlinedIcon },
{ key: "storage", label: "Storage", icon: StorageRoundedIcon },
{ key: "memory", label: "Memory", icon: MemoryRoundedIcon },
{ key: "network", label: "Network", icon: LanRoundedIcon },
{ key: "software", label: "Installed Software", icon: AppsRoundedIcon },
{ key: "activity", label: "Activity History", icon: ListAltRoundedIcon },
{ key: "advanced", label: "Advanced Config", icon: TuneRoundedIcon },
{ key: "shell", label: "Remote Shell", icon: TerminalRoundedIcon },
];
const myTheme = themeQuartz.withParams({
@@ -286,6 +328,15 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
const [menuAnchor, setMenuAnchor] = useState(null);
const [clearDialogOpen, setClearDialogOpen] = useState(false);
const [assemblyNameMap, setAssemblyNameMap] = useState({});
const [vpnLoading, setVpnLoading] = useState(false);
const [vpnSaving, setVpnSaving] = useState(false);
const [vpnError, setVpnError] = useState("");
const [vpnSource, setVpnSource] = useState("default");
const [vpnToggles, setVpnToggles] = useState({});
const [vpnCustomPorts, setVpnCustomPorts] = useState([]);
const [vpnDefaultPorts, setVpnDefaultPorts] = useState([]);
const [vpnShellPort, setVpnShellPort] = useState(47001);
const [vpnLoadedFor, setVpnLoadedFor] = useState("");
// Snapshotted status for the lifetime of this page
const [lockedStatus, setLockedStatus] = useState(() => {
// Prefer status provided by the device list row if available
@@ -655,6 +706,104 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
};
}, [activityHostname, loadHistory]);
const applyVpnPorts = useCallback((ports, defaults, shellPort, source) => {
const normalized = Array.isArray(ports) ? ports : [];
const normalizedDefaults = Array.isArray(defaults) ? defaults : [];
const numericPorts = normalized
.map((p) => Number(p))
.filter((p) => Number.isFinite(p) && p > 0);
const numericDefaults = normalizedDefaults
.map((p) => Number(p))
.filter((p) => Number.isFinite(p) && p > 0);
const effectiveShell = Number(shellPort) || 47001;
const groups = buildVpnGroups(effectiveShell);
const knownPorts = new Set(groups.flatMap((group) => group.ports));
const allowedSet = new Set(numericPorts);
const nextToggles = {};
groups.forEach((group) => {
nextToggles[group.key] = group.ports.every((port) => allowedSet.has(port));
});
const customPorts = numericPorts.filter((port) => !knownPorts.has(port));
setVpnShellPort(effectiveShell);
setVpnDefaultPorts(numericDefaults);
setVpnCustomPorts(customPorts);
setVpnToggles(nextToggles);
setVpnSource(source || "default");
}, []);
const loadVpnConfig = useCallback(async () => {
if (!vpnAgentId) return;
setVpnLoading(true);
setVpnError("");
setVpnLoadedFor(vpnAgentId);
try {
const resp = await fetch(`/api/device/vpn_config/${encodeURIComponent(vpnAgentId)}`);
const data = await resp.json().catch(() => ({}));
if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`);
const allowedPorts = Array.isArray(data?.allowed_ports) ? data.allowed_ports : [];
const defaultPorts = Array.isArray(data?.default_ports) ? data.default_ports : [];
const shellPort = data?.shell_port;
applyVpnPorts(allowedPorts.length ? allowedPorts : defaultPorts, defaultPorts, shellPort, data?.source);
setVpnLoadedFor(vpnAgentId);
} catch (err) {
setVpnError(String(err.message || err));
} finally {
setVpnLoading(false);
}
}, [applyVpnPorts, vpnAgentId]);
const saveVpnConfig = useCallback(async () => {
if (!vpnAgentId) return;
const ports = [];
vpnPortGroups.forEach((group) => {
if (vpnToggles[group.key]) {
ports.push(...group.ports);
}
});
vpnCustomPorts.forEach((port) => ports.push(port));
const uniquePorts = Array.from(new Set(ports)).filter((p) => p > 0);
if (!uniquePorts.length) {
setVpnError("Enable at least one port before saving.");
return;
}
setVpnSaving(true);
setVpnError("");
try {
const resp = await fetch(`/api/device/vpn_config/${encodeURIComponent(vpnAgentId)}`, {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ allowed_ports: uniquePorts }),
});
const data = await resp.json().catch(() => ({}));
if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`);
const allowedPorts = Array.isArray(data?.allowed_ports) ? data.allowed_ports : uniquePorts;
const defaultPorts = Array.isArray(data?.default_ports) ? data.default_ports : vpnDefaultPorts;
applyVpnPorts(allowedPorts, defaultPorts, data?.shell_port || vpnShellPort, data?.source || "custom");
} catch (err) {
setVpnError(String(err.message || err));
} finally {
setVpnSaving(false);
}
}, [applyVpnPorts, vpnAgentId, vpnCustomPorts, vpnDefaultPorts, vpnPortGroups, vpnShellPort, vpnToggles]);
const resetVpnConfig = useCallback(() => {
if (!vpnDefaultPorts.length) {
setVpnError("No default ports available to reset.");
return;
}
setVpnError("");
applyVpnPorts(vpnDefaultPorts, vpnDefaultPorts, vpnShellPort, "default");
}, [applyVpnPorts, vpnDefaultPorts, vpnShellPort]);
useEffect(() => {
const advancedIndex = TOP_TABS.findIndex((item) => item.key === "advanced");
if (advancedIndex < 0) return;
if (tab !== advancedIndex) return;
if (!vpnAgentId) return;
if (vpnLoadedFor === vpnAgentId) return;
loadVpnConfig();
}, [loadVpnConfig, tab, vpnAgentId, vpnLoadedFor]);
// No explicit live recap tab; recaps are recorded into Activity History
const clearHistory = async () => {
@@ -739,6 +888,19 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
);
const summary = details.summary || {};
const vpnAgentId = useMemo(() => {
return (
meta.agentId ||
summary.agent_id ||
agent?.agent_id ||
agent?.id ||
device?.agent_id ||
device?.agent_guid ||
device?.id ||
""
);
}, [agent?.agent_id, agent?.id, device?.agent_guid, device?.agent_id, device?.id, meta.agentId, summary.agent_id]);
const vpnPortGroups = useMemo(() => buildVpnGroups(vpnShellPort), [vpnShellPort]);
const tunnelDevice = useMemo(
() => ({
...(device || {}),
@@ -876,7 +1038,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
const formatScriptType = useCallback((raw) => {
const value = String(raw || "").toLowerCase();
if (value === "ansible") return "Ansible Playbook";
if (value === "reverse_tunnel") return "Reverse Tunnel";
if (value === "reverse_tunnel" || value === "vpn_tunnel") return "Reverse VPN Tunnel";
return "Script";
}, []);
@@ -1368,6 +1530,150 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
</Box>
);
const handleVpnToggle = useCallback((key, checked) => {
setVpnToggles((prev) => ({ ...(prev || {}), [key]: checked }));
setVpnSource("custom");
}, []);
const renderAdvancedConfigTab = () => {
const sourceLabel = vpnSource === "custom" ? "Custom overrides" : "Defaults";
const showProgress = vpnLoading || vpnSaving;
return (
<Box sx={{ display: "flex", flexDirection: "column", gap: 2, flexGrow: 1, minHeight: 0 }}>
<Box
sx={{
borderRadius: 3,
border: `1px solid ${MAGIC_UI.panelBorder}`,
background:
"linear-gradient(160deg, rgba(8,12,24,0.94), rgba(10,16,30,0.9)), radial-gradient(circle at 20% 10%, rgba(125,211,252,0.08), transparent 40%)",
boxShadow: "0 25px 80px rgba(2,6,23,0.65)",
p: { xs: 2, md: 3 },
}}
>
{showProgress ? <LinearProgress color="info" sx={{ height: 3, mb: 2 }} /> : null}
<Stack direction={{ xs: "column", md: "row" }} spacing={2} alignItems={{ xs: "flex-start", md: "center" }}>
<Box sx={{ flexGrow: 1 }}>
<Typography variant="h6" sx={{ color: MAGIC_UI.textBright, fontWeight: 700 }}>
Reverse VPN Tunnel - Allowed Ports
</Typography>
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted, mt: 0.5 }}>
Toggle which services the Engine can reach over the WireGuard tunnel for this device.
</Typography>
</Box>
<Chip
label={sourceLabel}
sx={{
borderRadius: 999,
fontWeight: 600,
letterSpacing: 0.2,
color: vpnSource === "custom" ? MAGIC_UI.accentA : MAGIC_UI.textMuted,
border: `1px solid ${MAGIC_UI.panelBorder}`,
backgroundColor: "rgba(8,12,24,0.75)",
}}
/>
</Stack>
<Divider sx={{ my: 2, borderColor: "rgba(148,163,184,0.2)" }} />
<Stack spacing={1.5}>
{vpnPortGroups.map((group) => (
<Box
key={group.key}
sx={{
display: "flex",
alignItems: { xs: "flex-start", md: "center" },
justifyContent: "space-between",
gap: 2,
p: 2,
borderRadius: 2,
border: `1px solid ${MAGIC_UI.panelBorder}`,
background: "rgba(6,10,20,0.7)",
}}
>
<Box sx={{ flexGrow: 1 }}>
<Typography sx={{ color: MAGIC_UI.textBright, fontWeight: 600 }}>
{group.label}
</Typography>
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted, mt: 0.35 }}>
{group.description}
</Typography>
<Stack direction="row" spacing={0.75} sx={{ mt: 0.8, flexWrap: "wrap" }}>
{group.ports.map((port) => (
<Chip
key={`${group.key}-${port}`}
label={`TCP ${port}`}
size="small"
sx={{
borderRadius: 999,
backgroundColor: "rgba(15,23,42,0.65)",
color: MAGIC_UI.textMuted,
border: `1px solid rgba(148,163,184,0.25)`,
}}
/>
))}
</Stack>
</Box>
<Switch
checked={Boolean(vpnToggles[group.key])}
onChange={(event) => handleVpnToggle(group.key, event.target.checked)}
color="info"
disabled={vpnLoading || vpnSaving}
/>
</Box>
))}
</Stack>
{vpnCustomPorts.length ? (
<Box sx={{ mt: 2 }}>
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted }}>
Custom ports preserved: {vpnCustomPorts.join(", ")}
</Typography>
</Box>
) : null}
{vpnError ? (
<Typography variant="body2" sx={{ color: "#ff7b89", mt: 1 }}>
{vpnError}
</Typography>
) : null}
<Stack direction="row" spacing={1.25} sx={{ mt: 2 }}>
<Button
size="small"
disabled={!vpnAgentId || vpnSaving || vpnLoading}
onClick={saveVpnConfig}
sx={{
backgroundImage: "linear-gradient(135deg,#7dd3fc,#c084fc)",
color: "#0b1220",
borderRadius: 999,
textTransform: "none",
px: 2.4,
"&:hover": {
backgroundImage: "linear-gradient(135deg,#86e1ff,#d1a6ff)",
},
}}
>
Save Config
</Button>
<Button
size="small"
disabled={!vpnDefaultPorts.length || vpnSaving || vpnLoading}
onClick={resetVpnConfig}
sx={{
borderRadius: 999,
textTransform: "none",
px: 2.4,
color: MAGIC_UI.textBright,
border: `1px solid ${MAGIC_UI.panelBorder}`,
backgroundColor: "rgba(8,12,24,0.6)",
"&:hover": {
backgroundColor: "rgba(12,18,35,0.8)",
},
}}
>
Reset Defaults
</Button>
</Stack>
</Box>
</Box>
);
};
const memoryRows = useMemo(
() =>
(details.memory || []).map((m, idx) => ({
@@ -1618,6 +1924,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
renderNetworkTab,
renderSoftware,
renderHistory,
renderAdvancedConfigTab,
renderRemoteShellTab,
];
const tabContent = (topTabRenderers[tab] || renderDeviceSummaryTab)();
@@ -1742,7 +2049,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
>
{TOP_TABS.map((tabDef) => (
<Tab
key={tabDef.label}
key={tabDef.key || tabDef.label}
label={tabDef.label}
icon={<tabDef.icon sx={{ fontSize: 18 }} />}
iconPosition="start"

View File

@@ -5,17 +5,17 @@ import {
Button,
Stack,
TextField,
MenuItem,
IconButton,
Tooltip,
LinearProgress,
Chip,
} from "@mui/material";
import {
PlayArrowRounded as PlayIcon,
StopRounded as StopIcon,
ContentCopy as CopyIcon,
RefreshRounded as RefreshIcon,
LanRounded as PortIcon,
LanRounded as IpIcon,
LinkRounded as LinkIcon,
} from "@mui/icons-material";
import { io } from "socket.io-client";
@@ -24,18 +24,7 @@ import "prismjs/components/prism-powershell";
import "prismjs/themes/prism-okaidia.css";
import Editor from "react-simple-code-editor";
// Console diagnostics for troubleshooting the connect/disconnect flow.
const debugLog = (...args) => {
try {
// eslint-disable-next-line no-console
console.error("[ReverseTunnel][PS]", ...args);
} catch {
// ignore
}
};
const MAGIC_UI = {
panelBg: "rgba(7,11,24,0.92)",
panelBorder: "rgba(148, 163, 184, 0.35)",
textMuted: "#94a3b8",
textBright: "#e2e8f0",
@@ -56,13 +45,25 @@ const gradientButtonSx = {
},
};
const FRAME_HEADER_BYTES = 12; // version, msg_type, flags, reserved, channel_id(u32), length(u32)
const MSG_CLOSE = 0x08;
const CLOSE_AGENT_SHUTDOWN = 6;
const fontFamilyMono =
'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace';
const emitAsync = (socket, event, payload, timeoutMs = 4000) =>
new Promise((resolve) => {
let settled = false;
const timer = setTimeout(() => {
if (settled) return;
settled = true;
resolve({ error: "timeout" });
}, timeoutMs);
socket.emit(event, payload, (resp) => {
if (settled) return;
settled = true;
clearTimeout(timer);
resolve(resp || {});
});
});
function normalizeText(value) {
if (value == null) return "";
try {
@@ -72,28 +73,6 @@ function normalizeText(value) {
}
}
function base64FromBytes(bytes) {
let binary = "";
bytes.forEach((b) => {
binary += String.fromCharCode(b);
});
return btoa(binary);
}
function buildCloseFrame(channelId = 1, code = CLOSE_AGENT_SHUTDOWN, reason = "operator_close") {
const payload = new TextEncoder().encode(JSON.stringify({ code, reason }));
const buffer = new ArrayBuffer(FRAME_HEADER_BYTES + payload.length);
const view = new DataView(buffer);
view.setUint8(0, 1); // version
view.setUint8(1, MSG_CLOSE);
view.setUint8(2, 0); // flags
view.setUint8(3, 0); // reserved
view.setUint32(4, channelId >>> 0, true);
view.setUint32(8, payload.length >>> 0, true);
new Uint8Array(buffer, FRAME_HEADER_BYTES).set(payload);
return base64FromBytes(new Uint8Array(buffer));
}
function highlightPs(code) {
try {
return Prism.highlight(code || "", Prism.languages.powershell, "powershell");
@@ -102,52 +81,18 @@ function highlightPs(code) {
}
}
const INITIAL_MILESTONES = {
tunnelReady: false,
operatorAttached: false,
shellEstablished: false,
};
const INITIAL_STATUS_CHAIN = ["Offline"];
export default function ReverseTunnelPowershell({ device }) {
const [connectionType, setConnectionType] = useState("ps");
const [tunnel, setTunnel] = useState(null);
const [sessionState, setSessionState] = useState("idle");
const [, setStatusMessage] = useState("");
const [, setStatusSeverity] = useState("info");
const [shellState, setShellState] = useState("idle");
const [tunnel, setTunnel] = useState(null);
const [output, setOutput] = useState("");
const [input, setInput] = useState("");
const [statusMessage, setStatusMessage] = useState("");
const [copyFlash, setCopyFlash] = useState(false);
const [, setPolling] = useState(false);
const [psStatus, setPsStatus] = useState({});
const [milestones, setMilestones] = useState(() => ({ ...INITIAL_MILESTONES }));
const [tunnelSteps, setTunnelSteps] = useState(() => [...INITIAL_STATUS_CHAIN]);
const [websocketSteps, setWebsocketSteps] = useState(() => [...INITIAL_STATUS_CHAIN]);
const [shellSteps, setShellSteps] = useState(() => [...INITIAL_STATUS_CHAIN]);
const [loading, setLoading] = useState(false);
const socketRef = useRef(null);
const pollTimerRef = useRef(null);
const resizeTimerRef = useRef(null);
const localSocketRef = useRef(false);
const terminalRef = useRef(null);
const joinRetryRef = useRef(null);
const joinAttemptsRef = useRef(0);
const tunnelRef = useRef(null);
const shellFlagsRef = useRef({ openSent: false, ack: false });
const DOMAIN_REMOTE_SHELL = "remote-interactive-shell";
useEffect(() => {
debugLog("component mount", { hostname: device?.hostname, agentId });
return () => debugLog("component unmount");
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const hostname = useMemo(() => {
return (
normalizeText(device?.hostname) ||
normalizeText(device?.summary?.hostname) ||
normalizeText(device?.agent_hostname) ||
""
);
}, [device]);
const agentId = useMemo(() => {
return (
@@ -162,78 +107,18 @@ export default function ReverseTunnelPowershell({ device }) {
);
}, [device]);
const appendStatus = useCallback((setter, label) => {
if (!label) return;
setter((prev) => {
const next = [...prev, label];
const cap = 6;
return next.length > cap ? next.slice(next.length - cap) : next;
});
}, []);
const resetState = useCallback(() => {
debugLog("resetState invoked");
setTunnel(null);
setSessionState("idle");
setStatusMessage("");
setStatusSeverity("info");
setOutput("");
setInput("");
setPsStatus({});
setMilestones({ ...INITIAL_MILESTONES });
setTunnelSteps([...INITIAL_STATUS_CHAIN]);
setWebsocketSteps([...INITIAL_STATUS_CHAIN]);
setShellSteps([...INITIAL_STATUS_CHAIN]);
shellFlagsRef.current = { openSent: false, ack: false };
}, []);
useEffect(() => {
tunnelRef.current = tunnel?.tunnel_id || null;
}, [tunnel?.tunnel_id]);
const disconnectSocket = useCallback(() => {
const socket = socketRef.current;
if (socket) {
socket.off();
socket.disconnect();
const ensureSocket = useCallback(() => {
if (socketRef.current) return socketRef.current;
const existing = typeof window !== "undefined" ? window.BorealisSocket : null;
if (existing) {
socketRef.current = existing;
localSocketRef.current = false;
return existing;
}
socketRef.current = null;
}, []);
const stopPolling = useCallback(() => {
if (pollTimerRef.current) {
clearTimeout(pollTimerRef.current);
pollTimerRef.current = null;
}
setPolling(false);
}, []);
const stopTunnel = useCallback(async (reason = "operator_disconnect", tunnelIdOverride = null) => {
const tunnelId = tunnelIdOverride || tunnelRef.current;
if (!tunnelId) return;
try {
await fetch(`/api/tunnel/${tunnelId}`, {
method: "DELETE",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ reason }),
});
} catch (err) {
// best-effort; socket close frame acts as fallback
}
}, []);
useEffect(() => {
return () => {
debugLog("cleanup on unmount", { tunnelId: tunnelRef.current });
stopPolling();
disconnectSocket();
if (joinRetryRef.current) {
clearTimeout(joinRetryRef.current);
joinRetryRef.current = null;
}
stopTunnel("component_unmount", tunnelRef.current);
};
// eslint-disable-next-line react-hooks/exhaustive-deps
const socket = io(window.location.origin, { transports: ["websocket"] });
socketRef.current = socket;
localSocketRef.current = true;
return socket;
}, []);
const appendOutput = useCallback((text) => {
@@ -257,6 +142,137 @@ export default function ReverseTunnelPowershell({ device }) {
scrollToBottom();
}, [output, scrollToBottom]);
const stopTunnel = useCallback(
async (reason = "operator_disconnect") => {
if (!agentId) return;
try {
await fetch("/api/tunnel/disconnect", {
method: "DELETE",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ agent_id: agentId, tunnel_id: tunnel?.tunnel_id, reason }),
});
} catch {
// best-effort
}
},
[agentId, tunnel?.tunnel_id]
);
const closeShell = useCallback(async () => {
const socket = ensureSocket();
await emitAsync(socket, "vpn_shell_close", {});
}, [ensureSocket]);
const handleDisconnect = useCallback(async () => {
setLoading(true);
setStatusMessage("");
try {
await closeShell();
await stopTunnel("operator_disconnect");
} finally {
setTunnel(null);
setShellState("closed");
setSessionState("idle");
setLoading(false);
}
}, [closeShell, stopTunnel]);
useEffect(() => {
const socket = ensureSocket();
const handleDisconnectEvent = () => {
if (sessionState === "connected") {
setShellState("closed");
setSessionState("idle");
setStatusMessage("Socket disconnected.");
}
};
const handleOutput = (payload) => {
appendOutput(payload?.data || "");
};
const handleClosed = () => {
setShellState("closed");
setSessionState("idle");
setStatusMessage("Shell closed.");
};
socket.on("disconnect", handleDisconnectEvent);
socket.on("vpn_shell_output", handleOutput);
socket.on("vpn_shell_closed", handleClosed);
return () => {
socket.off("disconnect", handleDisconnectEvent);
socket.off("vpn_shell_output", handleOutput);
socket.off("vpn_shell_closed", handleClosed);
if (localSocketRef.current) {
socket.disconnect();
}
};
}, [appendOutput, ensureSocket, sessionState]);
useEffect(() => {
return () => {
closeShell();
stopTunnel("component_unmount");
};
}, [closeShell, stopTunnel]);
const requestTunnel = useCallback(async () => {
if (!agentId) {
setStatusMessage("Agent ID is required to connect.");
return;
}
setLoading(true);
setStatusMessage("");
setSessionState("connecting");
setShellState("opening");
try {
const resp = await fetch("/api/tunnel/connect", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ agent_id: agentId }),
});
const data = await resp.json().catch(() => ({}));
if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`);
const statusResp = await fetch(
`/api/tunnel/connect/status?agent_id=${encodeURIComponent(agentId)}&bump=1`
);
const statusData = await statusResp.json().catch(() => ({}));
if (!statusResp.ok || statusData?.status !== "up") {
throw new Error(statusData?.error || "Tunnel not ready");
}
setTunnel({ ...data, ...statusData });
const socket = ensureSocket();
const openResp = await emitAsync(socket, "vpn_shell_open", { agent_id: agentId }, 6000);
if (openResp?.error) {
throw new Error(openResp.error);
}
setSessionState("connected");
setShellState("connected");
} catch (err) {
setSessionState("error");
setShellState("closed");
setStatusMessage(String(err.message || err));
} finally {
setLoading(false);
}
}, [agentId, ensureSocket]);
const handleSend = useCallback(
async (text) => {
const socket = ensureSocket();
if (!socket || sessionState !== "connected") return;
const payload = `${text}${text.endsWith("\n") ? "" : "\r\n"}`;
appendOutput(`\nPS> ${text}\n`);
setInput("");
const resp = await emitAsync(socket, "vpn_shell_send", { data: payload });
if (resp?.error) {
setStatusMessage("Send failed.");
}
},
[appendOutput, ensureSocket, sessionState]
);
const handleCopy = async () => {
try {
await navigator.clipboard.writeText(output || "");
@@ -267,329 +283,7 @@ export default function ReverseTunnelPowershell({ device }) {
}
};
const measureTerminal = useCallback(() => {
const el = terminalRef.current;
if (!el) return { cols: 120, rows: 32 };
const width = el.clientWidth || 960;
const height = el.clientHeight || 460;
const charWidth = 8.2;
const charHeight = 18;
const cols = Math.max(20, Math.min(Math.floor(width / charWidth), 300));
const rows = Math.max(10, Math.min(Math.floor(height / charHeight), 200));
return { cols, rows };
}, []);
const emitAsync = useCallback((socket, event, payload, timeoutMs = 4000) => {
return new Promise((resolve) => {
let settled = false;
const timer = setTimeout(() => {
if (settled) return;
settled = true;
resolve({ error: "timeout" });
}, timeoutMs);
socket.emit(event, payload, (resp) => {
if (settled) return;
settled = true;
clearTimeout(timer);
resolve(resp || {});
});
});
}, []);
const pollLoop = useCallback(
(socket, tunnelId) => {
if (!socket || !tunnelId) return;
debugLog("pollLoop tick", { tunnelId });
setPolling(true);
pollTimerRef.current = setTimeout(async () => {
const resp = await emitAsync(socket, "ps_poll", {});
if (resp?.error) {
debugLog("pollLoop error", resp);
stopPolling();
disconnectSocket();
setPsStatus({});
setTunnel(null);
setSessionState("error");
return;
}
if (Array.isArray(resp?.output) && resp.output.length) {
appendOutput(resp.output.join(""));
}
if (resp?.status) {
setPsStatus(resp.status);
debugLog("pollLoop status", resp.status);
if (resp.status.closed) {
setSessionState("closed");
setTunnel(null);
setMilestones({ ...INITIAL_MILESTONES });
appendStatus(setShellSteps, "Shell closed");
appendStatus(setTunnelSteps, "Stopped");
appendStatus(setWebsocketSteps, "Relay stopped");
shellFlagsRef.current = { openSent: false, ack: false };
stopPolling();
return;
}
if (resp.status.open_sent && !shellFlagsRef.current.openSent) {
appendStatus(setShellSteps, "Opening remote shell");
shellFlagsRef.current.openSent = true;
}
if (resp.status.ack && !shellFlagsRef.current.ack) {
setSessionState("connected");
setMilestones((prev) => ({ ...prev, shellEstablished: true }));
appendStatus(setShellSteps, "Remote shell established");
shellFlagsRef.current.ack = true;
}
}
pollLoop(socket, tunnelId);
}, 520);
},
[appendOutput, emitAsync, stopPolling, disconnectSocket, appendStatus]
);
const handleDisconnect = useCallback(
async (reason = "operator_disconnect") => {
debugLog("handleDisconnect begin", { reason, tunnelId: tunnel?.tunnel_id, psStatus, sessionState });
setPsStatus({});
const socket = socketRef.current;
const tunnelId = tunnel?.tunnel_id;
if (joinRetryRef.current) {
clearTimeout(joinRetryRef.current);
joinRetryRef.current = null;
}
joinAttemptsRef.current = 0;
if (socket && tunnelId) {
const frame = buildCloseFrame(1, CLOSE_AGENT_SHUTDOWN, "operator_close");
debugLog("emit CLOSE", { tunnelId });
socket.emit("send", { frame });
}
await stopTunnel(reason);
debugLog("stopTunnel issued", { tunnelId });
stopPolling();
disconnectSocket();
setTunnel(null);
setSessionState("closed");
setMilestones({ ...INITIAL_MILESTONES });
appendStatus(setTunnelSteps, "Stopped");
appendStatus(setWebsocketSteps, "Relay closed");
appendStatus(setShellSteps, "Shell closed");
shellFlagsRef.current = { openSent: false, ack: false };
debugLog("handleDisconnect finished", { tunnelId });
},
[appendStatus, disconnectSocket, stopPolling, stopTunnel, tunnel?.tunnel_id]
);
const handleResize = useCallback(() => {
if (!socketRef.current || sessionState === "idle") return;
const dims = measureTerminal();
socketRef.current.emit("ps_resize", dims);
}, [measureTerminal, sessionState]);
useEffect(() => {
const observer =
typeof ResizeObserver !== "undefined"
? new ResizeObserver(() => {
if (resizeTimerRef.current) clearTimeout(resizeTimerRef.current);
resizeTimerRef.current = setTimeout(() => handleResize(), 200);
})
: null;
const el = terminalRef.current;
if (observer && el) observer.observe(el);
const onWinResize = () => handleResize();
window.addEventListener("resize", onWinResize);
return () => {
window.removeEventListener("resize", onWinResize);
if (observer && el) observer.unobserve(el);
};
}, [handleResize]);
const connectSocket = useCallback(
(lease, { isRetry = false } = {}) => {
if (!lease?.tunnel_id) return;
if (joinRetryRef.current) {
clearTimeout(joinRetryRef.current);
joinRetryRef.current = null;
}
if (!isRetry) {
joinAttemptsRef.current = 0;
}
disconnectSocket();
stopPolling();
setSessionState("waiting");
const socket = io(`${window.location.origin}/tunnel`, { transports: ["websocket", "polling"] });
socketRef.current = socket;
socket.on("connect_error", () => {
debugLog("socket connect_error");
setStatusSeverity("warning");
setStatusMessage("Tunnel namespace unavailable.");
setTunnel(null);
setSessionState("error");
appendStatus(setWebsocketSteps, "Relay connect error");
});
socket.on("disconnect", () => {
debugLog("socket disconnect", { tunnelId: tunnel?.tunnel_id });
stopPolling();
if (sessionState !== "closed") {
setSessionState("disconnected");
setStatusSeverity("warning");
setStatusMessage("Socket disconnected.");
setTunnel(null);
appendStatus(setWebsocketSteps, "Relay disconnected");
}
});
socket.on("connect", async () => {
debugLog("socket connect", { tunnelId: lease.tunnel_id });
setMilestones((prev) => ({ ...prev, operatorAttached: true }));
setStatusSeverity("info");
setStatusMessage("Joining tunnel...");
appendStatus(setWebsocketSteps, "Relay connected");
const joinResp = await emitAsync(socket, "join", { tunnel_id: lease.tunnel_id }, 5000);
if (joinResp?.error) {
const attempt = (joinAttemptsRef.current += 1);
const isTimeout = joinResp.error === "timeout";
if (joinResp.error === "unknown_tunnel") {
setSessionState("waiting_agent");
setStatusSeverity("info");
setStatusMessage("Waiting for agent to establish tunnel...");
appendStatus(setWebsocketSteps, "Waiting for agent");
} else if (isTimeout || joinResp.error === "attach_failed") {
setSessionState("waiting_agent");
setStatusSeverity("warning");
setStatusMessage("Tunnel join timed out. Retrying...");
appendStatus(setWebsocketSteps, `Join retry ${attempt}`);
} else {
debugLog("join error", joinResp);
setSessionState("error");
setStatusSeverity("error");
setStatusMessage(joinResp.error);
appendStatus(setWebsocketSteps, `Join failed: ${joinResp.error}`);
return;
}
if (attempt <= 5) {
joinRetryRef.current = setTimeout(() => connectSocket(lease, { isRetry: true }), 800);
} else {
setSessionState("error");
setTunnel(null);
setStatusSeverity("warning");
setStatusMessage("Operator could not attach to tunnel. Try Connect again.");
appendStatus(setWebsocketSteps, "Join failed after retries");
}
return;
}
appendStatus(setWebsocketSteps, "Relay joined");
const dims = measureTerminal();
debugLog("ps_open emit", { tunnelId: lease.tunnel_id, dims });
const openResp = await emitAsync(socket, "ps_open", dims, 5000);
if (openResp?.error && openResp.error === "ps_unsupported") {
// Suppress warming message; channel will settle once agent attaches.
}
if (!shellFlagsRef.current.openSent) {
appendStatus(setShellSteps, "Opening remote shell");
shellFlagsRef.current.openSent = true;
}
appendOutput("");
setSessionState("waiting_agent");
pollLoop(socket, lease.tunnel_id);
handleResize();
});
},
[appendOutput, appendStatus, disconnectSocket, emitAsync, handleResize, measureTerminal, pollLoop, sessionState, stopPolling]
);
const requestTunnel = useCallback(async () => {
if (tunnel && sessionState !== "closed" && sessionState !== "idle") {
setStatusSeverity("info");
setStatusMessage("");
connectSocket(tunnel);
return;
}
debugLog("requestTunnel", { agentId, connectionType });
if (!agentId) {
setStatusSeverity("warning");
setStatusMessage("Agent ID is required to request a tunnel.");
return;
}
if (connectionType !== "ps") {
setStatusSeverity("warning");
setStatusMessage("Only PowerShell is supported right now.");
return;
}
resetState();
setSessionState("requesting");
setStatusSeverity("info");
setStatusMessage("");
appendStatus(setTunnelSteps, "Requesting lease");
try {
const resp = await fetch("/api/tunnel/request", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ agent_id: agentId, protocol: "ps", domain: DOMAIN_REMOTE_SHELL }),
});
const data = await resp.json().catch(() => ({}));
if (!resp.ok) {
const err = data?.error || `HTTP ${resp.status}`;
setSessionState("error");
setStatusSeverity(err === "domain_limit" ? "warning" : "error");
setStatusMessage("");
return;
}
setMilestones((prev) => ({ ...prev, tunnelReady: true }));
setTunnel(data);
setStatusMessage("");
setSessionState("lease_issued");
appendStatus(
setTunnelSteps,
data?.tunnel_id
? `Lease issued (${data.tunnel_id.slice(0, 8)} @ Port ${data.port || "-"})`
: "Lease issued"
);
connectSocket(data);
} catch (e) {
setSessionState("error");
setStatusSeverity("error");
setStatusMessage("");
appendStatus(setTunnelSteps, "Lease request failed");
}
}, [DOMAIN_REMOTE_SHELL, agentId, appendStatus, connectSocket, connectionType, resetState]);
const handleSend = useCallback(
async (text) => {
const socket = socketRef.current;
if (!socket) return;
const payload = `${text}${text.endsWith("\n") ? "" : "\r\n"}`;
appendOutput(`\nPS> ${text}\n`);
setInput("");
const resp = await emitAsync(socket, "ps_send", { data: payload });
if (resp?.error) {
setStatusSeverity("warning");
setStatusMessage("");
}
},
[appendOutput, emitAsync]
);
const isConnected = sessionState === "connected" || (psStatus?.ack && !psStatus?.closed);
const isClosed = sessionState === "closed" || psStatus?.closed;
const isBusy =
sessionState === "requesting" ||
sessionState === "waiting" ||
sessionState === "waiting_agent" ||
sessionState === "lease_issued";
const canStart = Boolean(agentId) && !isBusy;
useEffect(() => {
const handleUnload = () => {
stopTunnel("window_unload");
};
if (tunnel?.tunnel_id) {
window.addEventListener("beforeunload", handleUnload);
return () => window.removeEventListener("beforeunload", handleUnload);
}
return undefined;
}, [stopTunnel, tunnel?.tunnel_id]);
const isConnected = sessionState === "connected";
const sessionChips = [
tunnel?.tunnel_id
? {
@@ -598,58 +292,43 @@ export default function ReverseTunnelPowershell({ device }) {
icon: <LinkIcon sx={{ fontSize: 18 }} />,
}
: null,
tunnel?.port
tunnel?.virtual_ip
? {
label: `Port ${tunnel.port}`,
label: `IP ${String(tunnel.virtual_ip).split("/")[0]}`,
color: MAGIC_UI.accentA,
icon: <PortIcon sx={{ fontSize: 18 }} />,
icon: <IpIcon sx={{ fontSize: 18 }} />,
}
: null,
].filter(Boolean);
return (
<Box sx={{ display: "flex", flexDirection: "column", gap: 1.5, flexGrow: 1, minHeight: 0 }}>
<Box>
<Stack
direction={{ xs: "column", sm: "row" }}
spacing={1.5}
alignItems={{ xs: "flex-start", sm: "center" }}
justifyContent={{ xs: "flex-start", sm: "flex-end" }}
<Stack direction={{ xs: "column", sm: "row" }} spacing={1.5} alignItems={{ xs: "flex-start", sm: "center" }}>
<Button
size="small"
startIcon={isConnected ? <StopIcon /> : <PlayIcon />}
sx={gradientButtonSx}
disabled={loading || (!isConnected && !agentId)}
onClick={isConnected ? handleDisconnect : requestTunnel}
>
<TextField
select
label="Connection Protocol"
size="small"
value={connectionType}
onChange={(e) => setConnectionType(e.target.value)}
sx={{
minWidth: 180,
"& .MuiInputBase-root": {
backgroundColor: "rgba(12,18,35,0.85)",
color: MAGIC_UI.textBright,
borderRadius: 1.5,
},
"& fieldset": { borderColor: MAGIC_UI.panelBorder },
"&:hover fieldset": { borderColor: MAGIC_UI.accentA },
}}
>
<MenuItem value="ps">PowerShell</MenuItem>
</TextField>
<Tooltip title={isConnected ? "Disconnect session" : "Connect to agent"}>
<span>
<Button
size="small"
startIcon={isConnected ? <StopIcon /> : <PlayIcon />}
sx={gradientButtonSx}
disabled={!isConnected && !canStart}
onClick={isConnected ? handleDisconnect : requestTunnel}
>
{isConnected ? "Disconnect" : "Connect"}
</Button>
</span>
</Tooltip>
{isConnected ? "Disconnect" : "Connect"}
</Button>
<Stack direction="row" spacing={1}>
{sessionChips.map((chip) => (
<Chip
key={chip.label}
icon={chip.icon}
label={chip.label}
sx={{
borderRadius: 999,
color: chip.color,
border: `1px solid ${MAGIC_UI.panelBorder}`,
backgroundColor: "rgba(8,12,24,0.65)",
}}
/>
))}
</Stack>
</Box>
</Stack>
<Box
sx={{
@@ -665,7 +344,7 @@ export default function ReverseTunnelPowershell({ device }) {
overflow: "hidden",
}}
>
{isBusy ? <LinearProgress color="info" sx={{ height: 3 }} /> : null}
{loading ? <LinearProgress color="info" sx={{ height: 3 }} /> : null}
<Box
ref={terminalRef}
sx={{
@@ -728,11 +407,7 @@ export default function ReverseTunnelPowershell({ device }) {
size="small"
value={input}
disabled={!isConnected}
placeholder={
isConnected
? "Enter PowerShell command and press Enter"
: "Connect to start sending commands"
}
placeholder={isConnected ? "Enter PowerShell command and press Enter" : "Connect to start sending commands"}
onChange={(e) => setInput(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter" && !e.shiftKey) {
@@ -753,43 +428,19 @@ export default function ReverseTunnelPowershell({ device }) {
/>
</Box>
</Box>
<Stack spacing={0.3} sx={{ mt: 1.25 }}>
<Typography
variant="body2"
sx={{
color: milestones.tunnelReady ? MAGIC_UI.accentC : MAGIC_UI.textMuted,
fontWeight: 700,
}}
>
Tunnel:{" "}
<Typography component="span" variant="body2" sx={{ color: MAGIC_UI.textMuted, fontWeight: 500 }}>
{tunnelSteps.join(" > ")}
</Typography>
<Stack spacing={0.3} sx={{ mt: 1 }}>
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted }}>
Tunnel: {sessionState === "connected" ? "Active" : sessionState}
</Typography>
<Typography
variant="body2"
sx={{
color: milestones.operatorAttached ? MAGIC_UI.accentC : MAGIC_UI.textMuted,
fontWeight: 700,
}}
>
Websocket:{" "}
<Typography component="span" variant="body2" sx={{ color: MAGIC_UI.textMuted, fontWeight: 500 }}>
{websocketSteps.join(" > ")}
</Typography>
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted }}>
Shell: {shellState === "connected" ? "Ready" : shellState}
</Typography>
<Typography
variant="body2"
sx={{
color: milestones.shellEstablished ? MAGIC_UI.accentC : MAGIC_UI.textMuted,
fontWeight: 700,
}}
>
Remote Shell:{" "}
<Typography component="span" variant="body2" sx={{ color: MAGIC_UI.textMuted, fontWeight: 500 }}>
{shellSteps.join(" > ")}
{statusMessage ? (
<Typography variant="body2" sx={{ color: "#ff7b89" }}>
{statusMessage}
</Typography>
</Typography>
) : null}
</Stack>
</Box>
);