mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2026-02-04 11:30:30 -07:00
Additional Changes to VPN Tunneling
This commit is contained in:
@@ -47,11 +47,11 @@ def _b64decode(value: str) -> bytes:
|
||||
def _resolve_shell_port() -> int:
|
||||
raw = os.environ.get("BOREALIS_WIREGUARD_SHELL_PORT")
|
||||
try:
|
||||
value = int(raw) if raw is not None else 47001
|
||||
value = int(raw) if raw is not None else 47002
|
||||
except Exception:
|
||||
value = 47001
|
||||
value = 47002
|
||||
if value < 1 or value > 65535:
|
||||
return 47001
|
||||
return 47002
|
||||
return value
|
||||
|
||||
|
||||
|
||||
@@ -22,13 +22,22 @@ import os
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import x25519
|
||||
from signature_utils import verify_and_store_script_signature
|
||||
|
||||
try:
|
||||
from signature_utils import verify_and_store_script_signature
|
||||
except Exception: # pragma: no cover - fallback for runtime path issues
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
base_dir = _Path(__file__).resolve().parents[1]
|
||||
if str(base_dir) not in sys.path:
|
||||
sys.path.insert(0, str(base_dir))
|
||||
from signature_utils import verify_and_store_script_signature
|
||||
|
||||
ROLE_NAME = "WireGuardTunnel"
|
||||
ROLE_CONTEXTS = ["system"]
|
||||
@@ -88,18 +97,31 @@ def _generate_client_keys(root: Path) -> Dict[str, str]:
|
||||
return {"private": priv, "public": pub}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionConfig:
|
||||
token: Dict[str, Any]
|
||||
virtual_ip: str
|
||||
allowed_ips: str
|
||||
endpoint: str
|
||||
server_public_key: str
|
||||
allowed_ports: str
|
||||
idle_seconds: int = 900
|
||||
preshared_key: Optional[str] = None
|
||||
client_private_key: Optional[str] = None
|
||||
client_public_key: Optional[str] = None
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
token: Dict[str, Any],
|
||||
virtual_ip: str,
|
||||
allowed_ips: str,
|
||||
endpoint: str,
|
||||
server_public_key: str,
|
||||
allowed_ports: str,
|
||||
idle_seconds: int = 900,
|
||||
preshared_key: Optional[str] = None,
|
||||
client_private_key: Optional[str] = None,
|
||||
client_public_key: Optional[str] = None,
|
||||
) -> None:
|
||||
self.token = token
|
||||
self.virtual_ip = virtual_ip
|
||||
self.allowed_ips = allowed_ips
|
||||
self.endpoint = endpoint
|
||||
self.server_public_key = server_public_key
|
||||
self.allowed_ports = allowed_ports
|
||||
self.idle_seconds = idle_seconds
|
||||
self.preshared_key = preshared_key
|
||||
self.client_private_key = client_private_key
|
||||
self.client_public_key = client_public_key
|
||||
|
||||
|
||||
class WireGuardClient:
|
||||
@@ -150,6 +172,19 @@ class WireGuardClient:
|
||||
if port < 1 or port > 65535:
|
||||
raise ValueError("Invalid token port")
|
||||
|
||||
if not signature:
|
||||
if sig_alg or signing_key:
|
||||
raise ValueError("Token signature missing")
|
||||
stored_key = None
|
||||
if signing_client is not None and hasattr(signing_client, "load_server_signing_key"):
|
||||
try:
|
||||
stored_key = signing_client.load_server_signing_key()
|
||||
except Exception:
|
||||
stored_key = None
|
||||
if isinstance(stored_key, str) and stored_key.strip():
|
||||
raise ValueError("Token signature missing")
|
||||
return
|
||||
|
||||
if signature:
|
||||
if sig_alg and str(sig_alg).lower() not in ("ed25519", "eddsa"):
|
||||
raise ValueError("Unsupported token signature algorithm")
|
||||
@@ -292,6 +327,11 @@ class Role:
|
||||
self._log("WireGuard start payload missing/invalid.", error=True)
|
||||
return None
|
||||
|
||||
payload_agent_id = payload.get("agent_id") or payload.get("agent_guid")
|
||||
if payload_agent_id:
|
||||
if str(payload_agent_id).strip() != str(self.ctx.agent_id).strip():
|
||||
return None
|
||||
|
||||
token = payload.get("token") or payload.get("orchestration_token")
|
||||
if not isinstance(token, dict):
|
||||
self._log("WireGuard start missing token payload.", error=True)
|
||||
@@ -351,6 +391,9 @@ class Role:
|
||||
async def _vpn_tunnel_stop(payload):
|
||||
reason = "server_stop"
|
||||
if isinstance(payload, dict):
|
||||
target_agent = payload.get("agent_id")
|
||||
if target_agent and str(target_agent).strip() != str(self.ctx.agent_id).strip():
|
||||
return
|
||||
reason = payload.get("reason") or reason
|
||||
self._log(f"WireGuard stop requested (reason={reason}).")
|
||||
self.client.stop_session(reason=str(reason))
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
@@ -37,8 +39,27 @@ class RoleManager:
|
||||
self.config = config
|
||||
self.loop = loop
|
||||
self.hooks = hooks or {}
|
||||
self._log_hook = self.hooks.get('log_agent')
|
||||
self.roles: Dict[str, object] = {}
|
||||
|
||||
# Ensure role helpers alongside Roles/ are importable (e.g., signature_utils.py).
|
||||
try:
|
||||
base_path = Path(self.base_dir).resolve()
|
||||
parent_path = base_path.parent
|
||||
for candidate in (base_path, parent_path):
|
||||
if candidate and str(candidate) not in sys.path:
|
||||
sys.path.insert(0, str(candidate))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _log(self, message: str, *, error: bool = False) -> None:
|
||||
if callable(self._log_hook):
|
||||
try:
|
||||
target = "agent.error.log" if error else "agent.log"
|
||||
self._log_hook(message, fname=target)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _iter_role_files(self) -> List[str]:
|
||||
roles_dir = os.path.join(self.base_dir, 'Roles')
|
||||
if not os.path.isdir(roles_dir):
|
||||
@@ -56,7 +77,8 @@ class RoleManager:
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(mod)
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
self._log(f"Role load failed during import path={path} error={exc}", error=True)
|
||||
continue
|
||||
|
||||
role_name = getattr(mod, 'ROLE_NAME', None)
|
||||
@@ -75,10 +97,12 @@ class RoleManager:
|
||||
if hasattr(role_obj, 'register_events'):
|
||||
try:
|
||||
role_obj.register_events()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
self._log(f"Role register_events failed name={role_name} error={exc}", error=True)
|
||||
self.roles[role_name] = role_obj
|
||||
except Exception:
|
||||
self._log(f"Role loaded name={role_name} context={self.context}")
|
||||
except Exception as exc:
|
||||
self._log(f"Role init failed name={role_name} path={path} error={exc}", error=True)
|
||||
continue
|
||||
|
||||
def on_config(self, roles_cfg: List[dict]):
|
||||
@@ -96,4 +120,3 @@ class RoleManager:
|
||||
role.stop_all()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ VPN_TUNNEL_LOG_FILE_PATH = LOG_ROOT / "reverse_tunnel.log"
|
||||
DEFAULT_WIREGUARD_PORT = 30000
|
||||
DEFAULT_WIREGUARD_ENGINE_VIRTUAL_IP = "10.255.0.1/32"
|
||||
DEFAULT_WIREGUARD_PEER_NETWORK = "10.255.0.0/24"
|
||||
DEFAULT_WIREGUARD_SHELL_PORT = 47001
|
||||
DEFAULT_WIREGUARD_SHELL_PORT = 47002
|
||||
DEFAULT_WIREGUARD_ACL_WINDOWS = (3389, 5985, 5986, 5900, 3478, DEFAULT_WIREGUARD_SHELL_PORT)
|
||||
VPN_SERVER_CERT_ROOT = PROJECT_ROOT / "Engine" / "Certificates" / "VPN_Server"
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ def _rotate_daily(path: Path) -> None:
|
||||
pass
|
||||
|
||||
|
||||
_QUIET_SERVICE_LOGS = {"scheduled_jobs"}
|
||||
_QUIET_SERVICE_LOGS = {"scheduled_jobs", "device_enrollment"}
|
||||
|
||||
|
||||
def _make_service_logger(base: Path, logger: logging.Logger) -> Callable[[str, str, Optional[str]], None]:
|
||||
|
||||
@@ -12,12 +12,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from urllib.parse import urlsplit
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from flask import Blueprint, jsonify, request, session
|
||||
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
|
||||
|
||||
from ...VPN import VpnTunnelService
|
||||
from ...VPN import WireGuardServerConfig, WireGuardServerManager, VpnTunnelService
|
||||
|
||||
if False: # pragma: no cover - import cycle hint for type checkers
|
||||
from .. import EngineServiceAdapters
|
||||
@@ -63,7 +65,22 @@ def _get_tunnel_service(adapters: "EngineServiceAdapters") -> VpnTunnelService:
|
||||
if service is None:
|
||||
manager = getattr(adapters.context, "wireguard_server_manager", None)
|
||||
if manager is None:
|
||||
raise RuntimeError("wireguard_manager_unavailable")
|
||||
try:
|
||||
manager = WireGuardServerManager(
|
||||
WireGuardServerConfig(
|
||||
port=adapters.context.wireguard_port,
|
||||
engine_virtual_ip=adapters.context.wireguard_engine_virtual_ip,
|
||||
peer_network=adapters.context.wireguard_peer_network,
|
||||
private_key_path=Path(adapters.context.wireguard_server_private_key_path),
|
||||
public_key_path=Path(adapters.context.wireguard_server_public_key_path),
|
||||
acl_allowlist_windows=tuple(adapters.context.wireguard_acl_allowlist_windows),
|
||||
log_path=Path(adapters.context.vpn_tunnel_log_path),
|
||||
)
|
||||
)
|
||||
adapters.context.wireguard_server_manager = manager
|
||||
except Exception as exc:
|
||||
adapters.context.logger.error("Failed to initialize WireGuard server manager on demand.", exc_info=True)
|
||||
raise RuntimeError("wireguard_manager_unavailable") from exc
|
||||
service = VpnTunnelService(
|
||||
context=adapters.context,
|
||||
wireguard_manager=manager,
|
||||
@@ -86,6 +103,20 @@ def _normalize_text(value: Any) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _infer_endpoint_host(req) -> str:
|
||||
forwarded = (req.headers.get("X-Forwarded-Host") or req.headers.get("X-Original-Host") or "").strip()
|
||||
host = forwarded.split(",")[0].strip() if forwarded else (req.host or "").strip()
|
||||
if not host:
|
||||
return ""
|
||||
try:
|
||||
parsed = urlsplit(f"//{host}")
|
||||
if parsed.hostname:
|
||||
return parsed.hostname
|
||||
except Exception:
|
||||
return host
|
||||
return host
|
||||
|
||||
|
||||
def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
|
||||
blueprint = Blueprint("vpn_tunnel", __name__)
|
||||
logger = adapters.context.logger.getChild("vpn_tunnel.api")
|
||||
@@ -107,10 +138,15 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
|
||||
|
||||
try:
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
payload = tunnel_service.connect(agent_id=agent_id, operator_id=operator_id)
|
||||
except RuntimeError as exc:
|
||||
endpoint_host = _infer_endpoint_host(request)
|
||||
payload = tunnel_service.connect(
|
||||
agent_id=agent_id,
|
||||
operator_id=operator_id,
|
||||
endpoint_host=endpoint_host,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("vpn connect failed for agent_id=%s: %s", agent_id, exc)
|
||||
return jsonify({"error": "connect_failed"}), 500
|
||||
return jsonify({"error": "connect_failed", "detail": str(exc)}), 500
|
||||
|
||||
return jsonify(payload), 200
|
||||
|
||||
|
||||
@@ -73,6 +73,9 @@ def register(
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _enrollment_log(message: str, context_hint: Optional[str] = None) -> None:
|
||||
log("device_enrollment", message, context_hint)
|
||||
|
||||
def _rate_limited(
|
||||
key: str,
|
||||
limiter: SlidingWindowRateLimiter,
|
||||
@@ -82,8 +85,7 @@ def register(
|
||||
):
|
||||
decision = limiter.check(key, limit, window_s)
|
||||
if not decision.allowed:
|
||||
log(
|
||||
"server",
|
||||
_enrollment_log(
|
||||
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
|
||||
context_hint,
|
||||
)
|
||||
@@ -93,6 +95,9 @@ def register(
|
||||
return response
|
||||
return None
|
||||
|
||||
def _poll_log(message: str, context_hint: Optional[str] = None) -> None:
|
||||
_enrollment_log(message, context_hint)
|
||||
|
||||
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
|
||||
cur.execute(
|
||||
"""
|
||||
@@ -347,8 +352,7 @@ def register(
|
||||
agent_pubkey_b64 = payload.get("agent_pubkey")
|
||||
client_nonce_b64 = payload.get("client_nonce")
|
||||
|
||||
log(
|
||||
"server",
|
||||
_enrollment_log(
|
||||
"enrollment request received "
|
||||
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
|
||||
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
|
||||
@@ -356,35 +360,35 @@ def register(
|
||||
)
|
||||
|
||||
if not hostname:
|
||||
log("server", f"enrollment rejected missing_hostname ip={remote}", context_hint)
|
||||
_enrollment_log(f"enrollment rejected missing_hostname ip={remote}", context_hint)
|
||||
return jsonify({"error": "hostname_required"}), 400
|
||||
if not enrollment_code:
|
||||
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
|
||||
_enrollment_log(f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "enrollment_code_required"}), 400
|
||||
if not isinstance(agent_pubkey_b64, str):
|
||||
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
|
||||
_enrollment_log(f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "agent_pubkey_required"}), 400
|
||||
if not isinstance(client_nonce_b64, str):
|
||||
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
|
||||
_enrollment_log(f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "client_nonce_required"}), 400
|
||||
|
||||
try:
|
||||
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
|
||||
except Exception:
|
||||
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
|
||||
_enrollment_log(f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||
|
||||
if len(agent_pubkey_der) < 10:
|
||||
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
|
||||
_enrollment_log(f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
|
||||
_enrollment_log(f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
if len(client_nonce_bytes) < 16:
|
||||
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
|
||||
_enrollment_log(f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
|
||||
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
||||
@@ -398,8 +402,7 @@ def register(
|
||||
install_code = _load_install_code(cur, enrollment_code)
|
||||
site_id = install_code.get("site_id") if install_code else None
|
||||
if site_id is None:
|
||||
log(
|
||||
"server",
|
||||
_enrollment_log(
|
||||
"enrollment request rejected missing_site_binding "
|
||||
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
|
||||
context_hint,
|
||||
@@ -407,8 +410,7 @@ def register(
|
||||
return jsonify({"error": "invalid_enrollment_code"}), 400
|
||||
cur.execute("SELECT 1 FROM sites WHERE id = ?", (site_id,))
|
||||
if cur.fetchone() is None:
|
||||
log(
|
||||
"server",
|
||||
_enrollment_log(
|
||||
"enrollment request rejected missing_site_owner "
|
||||
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
|
||||
context_hint,
|
||||
@@ -416,8 +418,7 @@ def register(
|
||||
return jsonify({"error": "invalid_enrollment_code"}), 400
|
||||
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
|
||||
if not valid_code:
|
||||
log(
|
||||
"server",
|
||||
_enrollment_log(
|
||||
"enrollment request invalid_code "
|
||||
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
|
||||
context_hint,
|
||||
@@ -509,8 +510,7 @@ def register(
|
||||
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
||||
"signing_key": _signing_key_b64(),
|
||||
}
|
||||
log(
|
||||
"server",
|
||||
_enrollment_log(
|
||||
f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}",
|
||||
context_hint,
|
||||
)
|
||||
@@ -524,8 +524,7 @@ def register(
|
||||
proof_sig_b64 = payload.get("proof_sig")
|
||||
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||
|
||||
log(
|
||||
"server",
|
||||
_poll_log(
|
||||
"enrollment poll received "
|
||||
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
|
||||
f" proof_sig_len={len(proof_sig_b64 or '')}",
|
||||
@@ -533,25 +532,25 @@ def register(
|
||||
)
|
||||
|
||||
if not isinstance(approval_reference, str) or not approval_reference:
|
||||
log("server", "enrollment poll rejected missing_reference", context_hint)
|
||||
_poll_log("enrollment poll rejected missing_reference", context_hint)
|
||||
return jsonify({"error": "approval_reference_required"}), 400
|
||||
if not isinstance(client_nonce_b64, str):
|
||||
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
|
||||
_poll_log(f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "client_nonce_required"}), 400
|
||||
if not isinstance(proof_sig_b64, str):
|
||||
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
|
||||
_poll_log(f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "proof_sig_required"}), 400
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
|
||||
_poll_log(f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
|
||||
try:
|
||||
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
|
||||
_poll_log(f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "invalid_proof_sig"}), 400
|
||||
|
||||
conn = db_conn_factory()
|
||||
@@ -569,7 +568,7 @@ def register(
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
log("server", f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
|
||||
_poll_log(f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
|
||||
return jsonify({"status": "unknown"}), 404
|
||||
|
||||
(
|
||||
@@ -589,13 +588,13 @@ def register(
|
||||
) = row
|
||||
|
||||
if client_nonce_stored != client_nonce_b64:
|
||||
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
|
||||
_poll_log(f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "nonce_mismatch"}), 400
|
||||
|
||||
try:
|
||||
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
|
||||
_poll_log(f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "server_nonce_invalid"}), 400
|
||||
|
||||
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
|
||||
@@ -603,52 +602,47 @@ def register(
|
||||
try:
|
||||
public_key = serialization.load_der_public_key(agent_pubkey_der)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
|
||||
_poll_log(f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
|
||||
public_key = None
|
||||
|
||||
if public_key is None:
|
||||
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
|
||||
_poll_log(f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "agent_pubkey_invalid"}), 400
|
||||
|
||||
try:
|
||||
public_key.verify(proof_sig, message)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
|
||||
_poll_log(f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "invalid_proof"}), 400
|
||||
|
||||
if status == "pending":
|
||||
log(
|
||||
"server",
|
||||
_poll_log(
|
||||
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
|
||||
f" fingerprint={fingerprint[:12]}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": "pending", "poll_after_ms": 5000})
|
||||
if status == "denied":
|
||||
log(
|
||||
"server",
|
||||
_poll_log(
|
||||
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": "denied", "reason": "operator_denied"})
|
||||
if status == "expired":
|
||||
log(
|
||||
"server",
|
||||
_poll_log(
|
||||
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": "expired"})
|
||||
if status == "completed":
|
||||
log(
|
||||
"server",
|
||||
_poll_log(
|
||||
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": "approved", "detail": "finalized"})
|
||||
|
||||
if status != "approved":
|
||||
log(
|
||||
"server",
|
||||
_poll_log(
|
||||
f"enrollment poll unexpected_status={status} ref={approval_reference}",
|
||||
context_hint,
|
||||
)
|
||||
@@ -656,8 +650,7 @@ def register(
|
||||
|
||||
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
|
||||
if not nonce_cache.consume(nonce_key):
|
||||
log(
|
||||
"server",
|
||||
_poll_log(
|
||||
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
|
||||
context_hint,
|
||||
)
|
||||
@@ -769,8 +762,7 @@ def register(
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
log(
|
||||
"server",
|
||||
_poll_log(
|
||||
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
|
||||
context_hint,
|
||||
)
|
||||
|
||||
@@ -37,6 +37,7 @@ class VpnSession:
|
||||
firewall_rules: List[str] = field(default_factory=list)
|
||||
activity_id: Optional[int] = None
|
||||
hostname: Optional[str] = None
|
||||
endpoint_host: Optional[str] = None
|
||||
|
||||
|
||||
class VpnTunnelService:
|
||||
@@ -176,6 +177,37 @@ class VpnTunnelService:
|
||||
self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True)
|
||||
return token
|
||||
|
||||
def _ensure_token(self, session: VpnSession, *, now: Optional[float] = None) -> None:
|
||||
if not session:
|
||||
return
|
||||
current = now if now is not None else time.time()
|
||||
if session.expires_at > current + 30:
|
||||
return
|
||||
session.expires_at = current + 300
|
||||
session.token = self._issue_token(session.agent_id, session.tunnel_id, session.expires_at)
|
||||
|
||||
def _normalize_endpoint_host(self, host: Optional[str]) -> Optional[str]:
|
||||
if not host:
|
||||
return None
|
||||
try:
|
||||
text = str(host).strip()
|
||||
except Exception:
|
||||
return None
|
||||
return text or None
|
||||
|
||||
def _format_endpoint_host(self, host: str) -> str:
|
||||
if ":" in host and not host.startswith("["):
|
||||
return f"[{host}]"
|
||||
return host
|
||||
|
||||
def _service_log_event(self, message: str, *, level: str = "INFO") -> None:
|
||||
if not callable(self.service_log):
|
||||
return
|
||||
try:
|
||||
self.service_log("reverse_tunnel", message, level=level)
|
||||
except Exception:
|
||||
self.logger.debug("Failed to write reverse_tunnel service log entry", exc_info=True)
|
||||
|
||||
def _refresh_listener(self) -> None:
|
||||
peers: List[Mapping[str, object]] = []
|
||||
for session in self._sessions_by_agent.values():
|
||||
@@ -192,14 +224,24 @@ class VpnTunnelService:
|
||||
return
|
||||
self.wg.start_listener(peers)
|
||||
|
||||
def connect(self, *, agent_id: str, operator_id: Optional[str]) -> Mapping[str, Any]:
|
||||
def connect(
|
||||
self,
|
||||
*,
|
||||
agent_id: str,
|
||||
operator_id: Optional[str],
|
||||
endpoint_host: Optional[str] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
now = time.time()
|
||||
normalized_host = self._normalize_endpoint_host(endpoint_host)
|
||||
with self._lock:
|
||||
existing = self._sessions_by_agent.get(agent_id)
|
||||
if existing:
|
||||
if operator_id:
|
||||
existing.operator_ids.add(operator_id)
|
||||
if normalized_host and not existing.endpoint_host:
|
||||
existing.endpoint_host = normalized_host
|
||||
existing.last_activity = now
|
||||
self._ensure_token(existing, now=now)
|
||||
return self._session_payload(existing)
|
||||
|
||||
tunnel_id = uuid.uuid4().hex
|
||||
@@ -220,6 +262,7 @@ class VpnTunnelService:
|
||||
created_at=now,
|
||||
expires_at=now + 300,
|
||||
last_activity=now,
|
||||
endpoint_host=normalized_host,
|
||||
)
|
||||
if operator_id:
|
||||
session.operator_ids.add(operator_id)
|
||||
@@ -247,6 +290,17 @@ class VpnTunnelService:
|
||||
raise
|
||||
|
||||
payload = self._session_payload(session)
|
||||
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
|
||||
self._service_log_event(
|
||||
"vpn_tunnel_start agent_id={0} tunnel_id={1} virtual_ip={2} endpoint={3} allowed_ports={4} operators={5}".format(
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
session.virtual_ip,
|
||||
payload.get("endpoint", ""),
|
||||
",".join(str(p) for p in session.allowed_ports),
|
||||
operator_text,
|
||||
)
|
||||
)
|
||||
self._emit_start(payload)
|
||||
self._log_device_activity(session, event="start")
|
||||
return payload
|
||||
@@ -258,6 +312,22 @@ class VpnTunnelService:
|
||||
return None
|
||||
return self._session_payload(session, include_token=False)
|
||||
|
||||
def session_payload(self, agent_id: str, *, include_token: bool = True) -> Optional[Mapping[str, Any]]:
|
||||
with self._lock:
|
||||
session = self._sessions_by_agent.get(agent_id)
|
||||
if not session:
|
||||
return None
|
||||
if include_token:
|
||||
self._ensure_token(session)
|
||||
return self._session_payload(session, include_token=include_token)
|
||||
|
||||
def request_agent_start(self, agent_id: str) -> Optional[Mapping[str, Any]]:
|
||||
payload = self.session_payload(agent_id, include_token=True)
|
||||
if not payload:
|
||||
return None
|
||||
self._emit_start(payload)
|
||||
return payload
|
||||
|
||||
def bump_activity(self, agent_id: str) -> None:
|
||||
with self._lock:
|
||||
session = self._sessions_by_agent.get(agent_id)
|
||||
@@ -283,6 +353,15 @@ class VpnTunnelService:
|
||||
self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True)
|
||||
|
||||
self._refresh_listener()
|
||||
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
|
||||
self._service_log_event(
|
||||
"vpn_tunnel_stop agent_id={0} tunnel_id={1} reason={2} operators={3}".format(
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
reason,
|
||||
operator_text,
|
||||
)
|
||||
)
|
||||
self._emit_stop(session, reason)
|
||||
self._log_device_activity(session, event="stop", reason=reason)
|
||||
return True
|
||||
@@ -297,6 +376,16 @@ class VpnTunnelService:
|
||||
def _emit_start(self, payload: Mapping[str, Any]) -> None:
|
||||
if not self.socketio:
|
||||
return
|
||||
agent_id = None
|
||||
if isinstance(payload, Mapping):
|
||||
agent_id = payload.get("agent_id")
|
||||
emit_agent = getattr(self.context, "emit_agent_event", None)
|
||||
if agent_id and callable(emit_agent):
|
||||
try:
|
||||
if emit_agent(agent_id, "vpn_tunnel_start", payload):
|
||||
return
|
||||
except Exception:
|
||||
self.logger.debug("emit_agent_event failed for vpn_tunnel_start", exc_info=True)
|
||||
try:
|
||||
self.socketio.emit("vpn_tunnel_start", payload, namespace="/")
|
||||
except Exception:
|
||||
@@ -305,6 +394,17 @@ class VpnTunnelService:
|
||||
def _emit_stop(self, session: VpnSession, reason: str) -> None:
|
||||
if not self.socketio:
|
||||
return
|
||||
emit_agent = getattr(self.context, "emit_agent_event", None)
|
||||
if callable(emit_agent):
|
||||
try:
|
||||
if emit_agent(
|
||||
session.agent_id,
|
||||
"vpn_tunnel_stop",
|
||||
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
|
||||
):
|
||||
return
|
||||
except Exception:
|
||||
self.logger.debug("emit_agent_event failed for vpn_tunnel_stop", exc_info=True)
|
||||
try:
|
||||
self.socketio.emit(
|
||||
"vpn_tunnel_stop",
|
||||
@@ -454,13 +554,15 @@ class VpnTunnelService:
|
||||
pass
|
||||
|
||||
def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]:
|
||||
endpoint_host = session.endpoint_host or str(self._engine_ip.ip)
|
||||
endpoint_host = self._format_endpoint_host(endpoint_host)
|
||||
payload: Dict[str, Any] = {
|
||||
"tunnel_id": session.tunnel_id,
|
||||
"agent_id": session.agent_id,
|
||||
"virtual_ip": session.virtual_ip,
|
||||
"engine_virtual_ip": str(self._engine_ip.ip),
|
||||
"allowed_ips": f"{self._engine_ip.ip}/32",
|
||||
"endpoint": f"{self._engine_ip.ip}:{self.context.wireguard_port}",
|
||||
"endpoint": f"{endpoint_host}:{self.context.wireguard_port}",
|
||||
"server_public_key": self.wg.server_public_key,
|
||||
"client_public_key": session.client_public_key,
|
||||
"client_private_key": session.client_private_key,
|
||||
|
||||
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
import base64
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
@@ -72,6 +73,18 @@ class WireGuardServerManager:
|
||||
self.server_private_key, self.server_public_key = self._ensure_server_keys()
|
||||
self._service_name = "borealis-wg"
|
||||
self._temp_dir = Path(tempfile.gettempdir()) / "borealis-wg-engine"
|
||||
self._wireguard_exe = self._resolve_wireguard_exe()
|
||||
|
||||
def _resolve_wireguard_exe(self) -> str:
|
||||
candidates = [
|
||||
str(Path(os.environ.get("ProgramFiles", "C:\\Program Files")) / "WireGuard" / "wireguard.exe"),
|
||||
str(Path(os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)")) / "WireGuard" / "wireguard.exe"),
|
||||
"wireguard.exe",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if Path(candidate).is_file():
|
||||
return candidate
|
||||
return "wireguard.exe"
|
||||
|
||||
def _ensure_cert_dir(self) -> None:
|
||||
try:
|
||||
@@ -316,7 +329,7 @@ class WireGuardServerManager:
|
||||
# Ensure old service is removed before re-installing.
|
||||
self.stop_listener()
|
||||
|
||||
args = ["wireguard.exe", "/installtunnelservice", str(config_path)]
|
||||
args = [self._wireguard_exe, "/installtunnelservice", str(config_path)]
|
||||
code, out, err = self._run_command(args)
|
||||
if code != 0:
|
||||
self.logger.error("Failed to install WireGuard tunnel service code=%s err=%s", code, err)
|
||||
@@ -326,7 +339,7 @@ class WireGuardServerManager:
|
||||
def stop_listener(self) -> None:
|
||||
"""Stop and remove the WireGuard tunnel service."""
|
||||
|
||||
args = ["wireguard.exe", "/uninstalltunnelservice", self._service_name]
|
||||
args = [self._wireguard_exe, "/uninstalltunnelservice", self._service_name]
|
||||
code, out, err = self._run_command(args)
|
||||
if code != 0:
|
||||
self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err)
|
||||
|
||||
@@ -20,7 +20,7 @@ from flask_socketio import SocketIO
|
||||
from ...database import initialise_engine_database
|
||||
from ...security import signing
|
||||
from ...server import EngineContext
|
||||
from ..VPN import VpnTunnelService
|
||||
from ..VPN import WireGuardServerConfig, WireGuardServerManager, VpnTunnelService
|
||||
from .vpn_shell import VpnShellBridge
|
||||
|
||||
|
||||
@@ -63,12 +63,53 @@ class EngineRealtimeAdapters:
|
||||
self.service_log = _make_service_logger(base, self.context.logger)
|
||||
|
||||
|
||||
class AgentSocketRegistry:
|
||||
def __init__(self, socketio: SocketIO, logger) -> None:
|
||||
self.socketio = socketio
|
||||
self.logger = logger
|
||||
self._sid_by_agent: Dict[str, str] = {}
|
||||
self._agent_by_sid: Dict[str, str] = {}
|
||||
|
||||
def register(self, agent_id: str, sid: str) -> None:
|
||||
if not agent_id or not sid:
|
||||
return
|
||||
previous = self._sid_by_agent.get(agent_id)
|
||||
if previous and previous != sid:
|
||||
self._agent_by_sid.pop(previous, None)
|
||||
self._sid_by_agent[agent_id] = sid
|
||||
self._agent_by_sid[sid] = agent_id
|
||||
|
||||
def unregister(self, sid: str) -> Optional[str]:
|
||||
agent_id = self._agent_by_sid.pop(sid, None)
|
||||
if agent_id and self._sid_by_agent.get(agent_id) == sid:
|
||||
self._sid_by_agent.pop(agent_id, None)
|
||||
return agent_id
|
||||
|
||||
def emit(self, agent_id: str, event: str, payload: Any) -> bool:
|
||||
sid = self._sid_by_agent.get(agent_id)
|
||||
if not sid:
|
||||
return False
|
||||
try:
|
||||
self.socketio.emit(event, payload, to=sid)
|
||||
return True
|
||||
except Exception:
|
||||
self.logger.debug("Failed to emit %s to agent_id=%s", event, agent_id, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
"""Register Socket.IO event handlers for the Engine runtime."""
|
||||
|
||||
adapters = EngineRealtimeAdapters(context)
|
||||
logger = context.logger.getChild("realtime.quick_jobs")
|
||||
agent_logger = context.logger.getChild("realtime.agents")
|
||||
shell_bridge = VpnShellBridge(socket_server, context)
|
||||
agent_registry = AgentSocketRegistry(socket_server, agent_logger)
|
||||
|
||||
def _emit_agent_event(agent_id: str, event: str, payload: Any) -> bool:
|
||||
return agent_registry.emit(agent_id, event, payload)
|
||||
|
||||
setattr(context, "emit_agent_event", _emit_agent_event)
|
||||
|
||||
def _get_tunnel_service() -> Optional[VpnTunnelService]:
|
||||
service = getattr(context, "vpn_tunnel_service", None)
|
||||
@@ -76,7 +117,22 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
return service
|
||||
manager = getattr(context, "wireguard_server_manager", None)
|
||||
if manager is None:
|
||||
return None
|
||||
try:
|
||||
manager = WireGuardServerManager(
|
||||
WireGuardServerConfig(
|
||||
port=context.wireguard_port,
|
||||
engine_virtual_ip=context.wireguard_engine_virtual_ip,
|
||||
peer_network=context.wireguard_peer_network,
|
||||
private_key_path=Path(context.wireguard_server_private_key_path),
|
||||
public_key_path=Path(context.wireguard_server_public_key_path),
|
||||
acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows),
|
||||
log_path=Path(context.vpn_tunnel_log_path),
|
||||
)
|
||||
)
|
||||
setattr(context, "wireguard_server_manager", manager)
|
||||
except Exception:
|
||||
context.logger.error("Failed to initialize WireGuard server manager on demand.", exc_info=True)
|
||||
return None
|
||||
try:
|
||||
signer = signing.load_signer()
|
||||
except Exception:
|
||||
@@ -275,6 +331,29 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
service.bump_activity(agent_id)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("connect_agent")
|
||||
def _connect_agent(data: Any) -> Dict[str, Any]:
|
||||
agent_id = ""
|
||||
service_mode = ""
|
||||
if isinstance(data, dict):
|
||||
agent_id = str(data.get("agent_id") or "").strip()
|
||||
service_mode = str(data.get("service_mode") or "").strip().lower()
|
||||
elif isinstance(data, str):
|
||||
agent_id = data.strip()
|
||||
if not agent_id:
|
||||
return {"error": "agent_id_required"}
|
||||
|
||||
agent_registry.register(agent_id, request.sid)
|
||||
agent_logger.info("Agent socket registered agent_id=%s service_mode=%s sid=%s", agent_id, service_mode, request.sid)
|
||||
|
||||
service = _get_tunnel_service()
|
||||
if service:
|
||||
payload = service.session_payload(agent_id, include_token=True)
|
||||
if payload:
|
||||
agent_registry.emit(agent_id, "vpn_tunnel_start", payload)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("vpn_shell_send")
|
||||
def _vpn_shell_send(data: Any) -> Dict[str, Any]:
|
||||
payload = None
|
||||
@@ -288,10 +367,13 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("vpn_shell_close")
|
||||
def _vpn_shell_close() -> Dict[str, Any]:
|
||||
def _vpn_shell_close(data: Any = None) -> Dict[str, Any]:
|
||||
shell_bridge.close(request.sid)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("disconnect")
|
||||
def _ws_disconnect() -> None:
|
||||
agent_id = agent_registry.unregister(request.sid)
|
||||
if agent_id:
|
||||
agent_logger.info("Agent socket disconnected agent_id=%s sid=%s", agent_id, request.sid)
|
||||
shell_bridge.close(request.sid)
|
||||
|
||||
@@ -13,6 +13,7 @@ import base64
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -42,7 +43,10 @@ class ShellSession:
|
||||
buffer = b""
|
||||
try:
|
||||
while True:
|
||||
data = self.tcp.recv(4096)
|
||||
try:
|
||||
data = self.tcp.recv(4096)
|
||||
except (socket.timeout, TimeoutError):
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
@@ -100,12 +104,28 @@ class VpnShellBridge:
|
||||
return None
|
||||
host = str(status.get("virtual_ip") or "").split("/")[0]
|
||||
port = int(self.context.wireguard_shell_port)
|
||||
try:
|
||||
tcp = socket.create_connection((host, port), timeout=5)
|
||||
except Exception:
|
||||
self.logger.debug("Failed to connect vpn shell to %s:%s", host, port, exc_info=True)
|
||||
tcp = None
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in range(3):
|
||||
try:
|
||||
tcp = socket.create_connection((host, port), timeout=5)
|
||||
break
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
if attempt == 0:
|
||||
try:
|
||||
service.request_agent_start(agent_id)
|
||||
except Exception:
|
||||
self.logger.debug("Failed to re-emit vpn_tunnel_start for agent=%s", agent_id, exc_info=True)
|
||||
time.sleep(1)
|
||||
if tcp is None:
|
||||
self.logger.warning("Failed to connect vpn shell to %s:%s", host, port, exc_info=last_error)
|
||||
return None
|
||||
session = ShellSession(sid=sid, agent_id=agent_id, socketio=self.socketio, tcp=tcp)
|
||||
try:
|
||||
session.tcp.settimeout(15)
|
||||
except Exception:
|
||||
pass
|
||||
self._sessions[sid] = session
|
||||
session.start_reader()
|
||||
return session
|
||||
@@ -124,4 +144,3 @@ class VpnShellBridge:
|
||||
if not session:
|
||||
return
|
||||
session.close()
|
||||
|
||||
|
||||
@@ -7,6 +7,15 @@
|
||||
|
||||
"""Service registration hooks for the Borealis Engine runtime."""
|
||||
|
||||
from . import API, WebSocket, WebUI
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["API", "WebSocket", "WebUI"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in __all__:
|
||||
return importlib.import_module(f"{__name__}.{name}")
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -75,7 +75,7 @@ const SECTION_HEIGHTS = {
|
||||
};
|
||||
|
||||
const buildVpnGroups = (shellPort) => {
|
||||
const normalizedShell = Number(shellPort) || 47001;
|
||||
const normalizedShell = Number(shellPort) || 47002;
|
||||
return [
|
||||
{
|
||||
key: "shell",
|
||||
@@ -335,7 +335,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
|
||||
const [vpnToggles, setVpnToggles] = useState({});
|
||||
const [vpnCustomPorts, setVpnCustomPorts] = useState([]);
|
||||
const [vpnDefaultPorts, setVpnDefaultPorts] = useState([]);
|
||||
const [vpnShellPort, setVpnShellPort] = useState(47001);
|
||||
const [vpnShellPort, setVpnShellPort] = useState(47002);
|
||||
const [vpnLoadedFor, setVpnLoadedFor] = useState("");
|
||||
// Snapshotted status for the lifetime of this page
|
||||
const [lockedStatus, setLockedStatus] = useState(() => {
|
||||
@@ -347,6 +347,31 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
|
||||
const now = Date.now() / 1000;
|
||||
return now - tsSec <= 300 ? "Online" : "Offline";
|
||||
});
|
||||
const summary = details.summary || {};
|
||||
const vpnAgentId = useMemo(() => {
|
||||
return (
|
||||
meta.agentId ||
|
||||
summary.agent_id ||
|
||||
agent?.agent_id ||
|
||||
agent?.id ||
|
||||
device?.agent_id ||
|
||||
device?.agent_guid ||
|
||||
device?.id ||
|
||||
""
|
||||
);
|
||||
}, [agent?.agent_id, agent?.id, device?.agent_guid, device?.agent_id, device?.id, meta.agentId, summary.agent_id]);
|
||||
const vpnPortGroups = useMemo(() => buildVpnGroups(vpnShellPort), [vpnShellPort]);
|
||||
const tunnelDevice = useMemo(
|
||||
() => ({
|
||||
...(device || {}),
|
||||
...(agent || {}),
|
||||
summary,
|
||||
hostname: meta.hostname || summary.hostname || device?.hostname || agent?.hostname,
|
||||
agent_id: meta.agentId || summary.agent_id || agent?.agent_id || agent?.id || device?.agent_id || device?.agent_guid,
|
||||
agent_guid: meta.agentGuid || summary.agent_guid || device?.agent_guid || device?.guid || agent?.agent_guid || agent?.guid,
|
||||
}),
|
||||
[agent, device, meta.agentGuid, meta.agentId, meta.hostname, summary]
|
||||
);
|
||||
const quickJobTargets = useMemo(() => {
|
||||
const values = [];
|
||||
const push = (value) => {
|
||||
@@ -715,7 +740,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
|
||||
const numericDefaults = normalizedDefaults
|
||||
.map((p) => Number(p))
|
||||
.filter((p) => Number.isFinite(p) && p > 0);
|
||||
const effectiveShell = Number(shellPort) || 47001;
|
||||
const effectiveShell = Number(shellPort) || 47002;
|
||||
const groups = buildVpnGroups(effectiveShell);
|
||||
const knownPorts = new Set(groups.flatMap((group) => group.ports));
|
||||
const allowedSet = new Set(numericPorts);
|
||||
@@ -887,31 +912,6 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
|
||||
[]
|
||||
);
|
||||
|
||||
const summary = details.summary || {};
|
||||
const vpnAgentId = useMemo(() => {
|
||||
return (
|
||||
meta.agentId ||
|
||||
summary.agent_id ||
|
||||
agent?.agent_id ||
|
||||
agent?.id ||
|
||||
device?.agent_id ||
|
||||
device?.agent_guid ||
|
||||
device?.id ||
|
||||
""
|
||||
);
|
||||
}, [agent?.agent_id, agent?.id, device?.agent_guid, device?.agent_id, device?.id, meta.agentId, summary.agent_id]);
|
||||
const vpnPortGroups = useMemo(() => buildVpnGroups(vpnShellPort), [vpnShellPort]);
|
||||
const tunnelDevice = useMemo(
|
||||
() => ({
|
||||
...(device || {}),
|
||||
...(agent || {}),
|
||||
summary,
|
||||
hostname: meta.hostname || summary.hostname || device?.hostname || agent?.hostname,
|
||||
agent_id: meta.agentId || summary.agent_id || agent?.agent_id || agent?.id || device?.agent_id || device?.agent_guid,
|
||||
agent_guid: meta.agentGuid || summary.agent_guid || device?.agent_guid || device?.guid || agent?.agent_guid || agent?.guid,
|
||||
}),
|
||||
[agent, device, meta.agentGuid, meta.agentId, meta.hostname, summary]
|
||||
);
|
||||
// Build a best-effort CPU display from summary fields
|
||||
const cpuInfo = useMemo(() => {
|
||||
const cpu = details.cpu || summary.cpu || {};
|
||||
|
||||
@@ -93,6 +93,8 @@ export default function ReverseTunnelPowershell({ device }) {
|
||||
const socketRef = useRef(null);
|
||||
const localSocketRef = useRef(false);
|
||||
const terminalRef = useRef(null);
|
||||
const agentIdRef = useRef("");
|
||||
const tunnelIdRef = useRef("");
|
||||
|
||||
const agentId = useMemo(() => {
|
||||
return (
|
||||
@@ -107,6 +109,14 @@ export default function ReverseTunnelPowershell({ device }) {
|
||||
);
|
||||
}, [device]);
|
||||
|
||||
useEffect(() => {
|
||||
agentIdRef.current = agentId;
|
||||
}, [agentId]);
|
||||
|
||||
useEffect(() => {
|
||||
tunnelIdRef.current = tunnel?.tunnel_id || "";
|
||||
}, [tunnel?.tunnel_id]);
|
||||
|
||||
const ensureSocket = useCallback(() => {
|
||||
if (socketRef.current) return socketRef.current;
|
||||
const existing = typeof window !== "undefined" ? window.BorealisSocket : null;
|
||||
@@ -142,21 +152,20 @@ export default function ReverseTunnelPowershell({ device }) {
|
||||
scrollToBottom();
|
||||
}, [output, scrollToBottom]);
|
||||
|
||||
const stopTunnel = useCallback(
|
||||
async (reason = "operator_disconnect") => {
|
||||
if (!agentId) return;
|
||||
try {
|
||||
await fetch("/api/tunnel/disconnect", {
|
||||
method: "DELETE",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ agent_id: agentId, tunnel_id: tunnel?.tunnel_id, reason }),
|
||||
});
|
||||
} catch {
|
||||
// best-effort
|
||||
}
|
||||
},
|
||||
[agentId, tunnel?.tunnel_id]
|
||||
);
|
||||
const stopTunnel = useCallback(async (reason = "operator_disconnect") => {
|
||||
const currentAgentId = agentIdRef.current;
|
||||
if (!currentAgentId) return;
|
||||
const currentTunnelId = tunnelIdRef.current;
|
||||
try {
|
||||
await fetch("/api/tunnel/disconnect", {
|
||||
method: "DELETE",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ agent_id: currentAgentId, tunnel_id: currentTunnelId, reason }),
|
||||
});
|
||||
} catch {
|
||||
// best-effort
|
||||
}
|
||||
}, []);
|
||||
|
||||
const closeShell = useCallback(async () => {
|
||||
const socket = ensureSocket();
|
||||
@@ -232,7 +241,10 @@ export default function ReverseTunnelPowershell({ device }) {
|
||||
body: JSON.stringify({ agent_id: agentId }),
|
||||
});
|
||||
const data = await resp.json().catch(() => ({}));
|
||||
if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`);
|
||||
if (!resp.ok) {
|
||||
const detail = data?.detail ? `: ${data.detail}` : "";
|
||||
throw new Error(`${data?.error || `HTTP ${resp.status}`}${detail}`);
|
||||
}
|
||||
const statusResp = await fetch(
|
||||
`/api/tunnel/connect/status?agent_id=${encodeURIComponent(agentId)}&bump=1`
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user