Additional Changes to VPN Tunneling

This commit is contained in:
2026-01-11 19:02:53 -07:00
parent 6ceb59f717
commit df14a1e26a
18 changed files with 681 additions and 175 deletions

View File

@@ -47,11 +47,11 @@ def _b64decode(value: str) -> bytes:
def _resolve_shell_port() -> int:
raw = os.environ.get("BOREALIS_WIREGUARD_SHELL_PORT")
try:
value = int(raw) if raw is not None else 47001
value = int(raw) if raw is not None else 47002
except Exception:
value = 47001
value = 47002
if value < 1 or value > 65535:
return 47001
return 47002
return value

View File

@@ -22,13 +22,22 @@ import os
import subprocess
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import x25519
from signature_utils import verify_and_store_script_signature
try:
from signature_utils import verify_and_store_script_signature
except Exception: # pragma: no cover - fallback for runtime path issues
import sys
from pathlib import Path as _Path
base_dir = _Path(__file__).resolve().parents[1]
if str(base_dir) not in sys.path:
sys.path.insert(0, str(base_dir))
from signature_utils import verify_and_store_script_signature
ROLE_NAME = "WireGuardTunnel"
ROLE_CONTEXTS = ["system"]
@@ -88,18 +97,31 @@ def _generate_client_keys(root: Path) -> Dict[str, str]:
return {"private": priv, "public": pub}
@dataclass
class SessionConfig:
token: Dict[str, Any]
virtual_ip: str
allowed_ips: str
endpoint: str
server_public_key: str
allowed_ports: str
idle_seconds: int = 900
preshared_key: Optional[str] = None
client_private_key: Optional[str] = None
client_public_key: Optional[str] = None
def __init__(
self,
*,
token: Dict[str, Any],
virtual_ip: str,
allowed_ips: str,
endpoint: str,
server_public_key: str,
allowed_ports: str,
idle_seconds: int = 900,
preshared_key: Optional[str] = None,
client_private_key: Optional[str] = None,
client_public_key: Optional[str] = None,
) -> None:
self.token = token
self.virtual_ip = virtual_ip
self.allowed_ips = allowed_ips
self.endpoint = endpoint
self.server_public_key = server_public_key
self.allowed_ports = allowed_ports
self.idle_seconds = idle_seconds
self.preshared_key = preshared_key
self.client_private_key = client_private_key
self.client_public_key = client_public_key
class WireGuardClient:
@@ -150,6 +172,19 @@ class WireGuardClient:
if port < 1 or port > 65535:
raise ValueError("Invalid token port")
if not signature:
if sig_alg or signing_key:
raise ValueError("Token signature missing")
stored_key = None
if signing_client is not None and hasattr(signing_client, "load_server_signing_key"):
try:
stored_key = signing_client.load_server_signing_key()
except Exception:
stored_key = None
if isinstance(stored_key, str) and stored_key.strip():
raise ValueError("Token signature missing")
return
if signature:
if sig_alg and str(sig_alg).lower() not in ("ed25519", "eddsa"):
raise ValueError("Unsupported token signature algorithm")
@@ -292,6 +327,11 @@ class Role:
self._log("WireGuard start payload missing/invalid.", error=True)
return None
payload_agent_id = payload.get("agent_id") or payload.get("agent_guid")
if payload_agent_id:
if str(payload_agent_id).strip() != str(self.ctx.agent_id).strip():
return None
token = payload.get("token") or payload.get("orchestration_token")
if not isinstance(token, dict):
self._log("WireGuard start missing token payload.", error=True)
@@ -351,6 +391,9 @@ class Role:
async def _vpn_tunnel_stop(payload):
reason = "server_stop"
if isinstance(payload, dict):
target_agent = payload.get("agent_id")
if target_agent and str(target_agent).strip() != str(self.ctx.agent_id).strip():
return
reason = payload.get("reason") or reason
self._log(f"WireGuard stop requested (reason={reason}).")
self.client.stop_session(reason=str(reason))

View File

@@ -1,5 +1,7 @@
import os
import importlib.util
import sys
from pathlib import Path
from typing import Dict, List, Optional
@@ -37,8 +39,27 @@ class RoleManager:
self.config = config
self.loop = loop
self.hooks = hooks or {}
self._log_hook = self.hooks.get('log_agent')
self.roles: Dict[str, object] = {}
# Ensure role helpers alongside Roles/ are importable (e.g., signature_utils.py).
try:
base_path = Path(self.base_dir).resolve()
parent_path = base_path.parent
for candidate in (base_path, parent_path):
if candidate and str(candidate) not in sys.path:
sys.path.insert(0, str(candidate))
except Exception:
pass
def _log(self, message: str, *, error: bool = False) -> None:
if callable(self._log_hook):
try:
target = "agent.error.log" if error else "agent.log"
self._log_hook(message, fname=target)
except Exception:
pass
def _iter_role_files(self) -> List[str]:
roles_dir = os.path.join(self.base_dir, 'Roles')
if not os.path.isdir(roles_dir):
@@ -56,7 +77,8 @@ class RoleManager:
mod = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(mod)
except Exception:
except Exception as exc:
self._log(f"Role load failed during import path={path} error={exc}", error=True)
continue
role_name = getattr(mod, 'ROLE_NAME', None)
@@ -75,10 +97,12 @@ class RoleManager:
if hasattr(role_obj, 'register_events'):
try:
role_obj.register_events()
except Exception:
pass
except Exception as exc:
self._log(f"Role register_events failed name={role_name} error={exc}", error=True)
self.roles[role_name] = role_obj
except Exception:
self._log(f"Role loaded name={role_name} context={self.context}")
except Exception as exc:
self._log(f"Role init failed name={role_name} path={path} error={exc}", error=True)
continue
def on_config(self, roles_cfg: List[dict]):
@@ -96,4 +120,3 @@ class RoleManager:
role.stop_all()
except Exception:
pass

View File

@@ -81,7 +81,7 @@ VPN_TUNNEL_LOG_FILE_PATH = LOG_ROOT / "reverse_tunnel.log"
DEFAULT_WIREGUARD_PORT = 30000
DEFAULT_WIREGUARD_ENGINE_VIRTUAL_IP = "10.255.0.1/32"
DEFAULT_WIREGUARD_PEER_NETWORK = "10.255.0.0/24"
DEFAULT_WIREGUARD_SHELL_PORT = 47001
DEFAULT_WIREGUARD_SHELL_PORT = 47002
DEFAULT_WIREGUARD_ACL_WINDOWS = (3389, 5985, 5986, 5900, 3478, DEFAULT_WIREGUARD_SHELL_PORT)
VPN_SERVER_CERT_ROOT = PROJECT_ROOT / "Engine" / "Certificates" / "VPN_Server"

View File

@@ -117,7 +117,7 @@ def _rotate_daily(path: Path) -> None:
pass
_QUIET_SERVICE_LOGS = {"scheduled_jobs"}
_QUIET_SERVICE_LOGS = {"scheduled_jobs", "device_enrollment"}
def _make_service_logger(base: Path, logger: logging.Logger) -> Callable[[str, str, Optional[str]], None]:

View File

@@ -12,12 +12,14 @@
from __future__ import annotations
import os
from urllib.parse import urlsplit
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from flask import Blueprint, jsonify, request, session
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from ...VPN import VpnTunnelService
from ...VPN import WireGuardServerConfig, WireGuardServerManager, VpnTunnelService
if False: # pragma: no cover - import cycle hint for type checkers
from .. import EngineServiceAdapters
@@ -63,7 +65,22 @@ def _get_tunnel_service(adapters: "EngineServiceAdapters") -> VpnTunnelService:
if service is None:
manager = getattr(adapters.context, "wireguard_server_manager", None)
if manager is None:
raise RuntimeError("wireguard_manager_unavailable")
try:
manager = WireGuardServerManager(
WireGuardServerConfig(
port=adapters.context.wireguard_port,
engine_virtual_ip=adapters.context.wireguard_engine_virtual_ip,
peer_network=adapters.context.wireguard_peer_network,
private_key_path=Path(adapters.context.wireguard_server_private_key_path),
public_key_path=Path(adapters.context.wireguard_server_public_key_path),
acl_allowlist_windows=tuple(adapters.context.wireguard_acl_allowlist_windows),
log_path=Path(adapters.context.vpn_tunnel_log_path),
)
)
adapters.context.wireguard_server_manager = manager
except Exception as exc:
adapters.context.logger.error("Failed to initialize WireGuard server manager on demand.", exc_info=True)
raise RuntimeError("wireguard_manager_unavailable") from exc
service = VpnTunnelService(
context=adapters.context,
wireguard_manager=manager,
@@ -86,6 +103,20 @@ def _normalize_text(value: Any) -> str:
return ""
def _infer_endpoint_host(req) -> str:
forwarded = (req.headers.get("X-Forwarded-Host") or req.headers.get("X-Original-Host") or "").strip()
host = forwarded.split(",")[0].strip() if forwarded else (req.host or "").strip()
if not host:
return ""
try:
parsed = urlsplit(f"//{host}")
if parsed.hostname:
return parsed.hostname
except Exception:
return host
return host
def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
blueprint = Blueprint("vpn_tunnel", __name__)
logger = adapters.context.logger.getChild("vpn_tunnel.api")
@@ -107,10 +138,15 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
try:
tunnel_service = _get_tunnel_service(adapters)
payload = tunnel_service.connect(agent_id=agent_id, operator_id=operator_id)
except RuntimeError as exc:
endpoint_host = _infer_endpoint_host(request)
payload = tunnel_service.connect(
agent_id=agent_id,
operator_id=operator_id,
endpoint_host=endpoint_host,
)
except Exception as exc:
logger.warning("vpn connect failed for agent_id=%s: %s", agent_id, exc)
return jsonify({"error": "connect_failed"}), 500
return jsonify({"error": "connect_failed", "detail": str(exc)}), 500
return jsonify(payload), 200

View File

@@ -73,6 +73,9 @@ def register(
except Exception:
return ""
def _enrollment_log(message: str, context_hint: Optional[str] = None) -> None:
log("device_enrollment", message, context_hint)
def _rate_limited(
key: str,
limiter: SlidingWindowRateLimiter,
@@ -82,8 +85,7 @@ def register(
):
decision = limiter.check(key, limit, window_s)
if not decision.allowed:
log(
"server",
_enrollment_log(
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
context_hint,
)
@@ -93,6 +95,9 @@ def register(
return response
return None
def _poll_log(message: str, context_hint: Optional[str] = None) -> None:
_enrollment_log(message, context_hint)
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
cur.execute(
"""
@@ -347,8 +352,7 @@ def register(
agent_pubkey_b64 = payload.get("agent_pubkey")
client_nonce_b64 = payload.get("client_nonce")
log(
"server",
_enrollment_log(
"enrollment request received "
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
@@ -356,35 +360,35 @@ def register(
)
if not hostname:
log("server", f"enrollment rejected missing_hostname ip={remote}", context_hint)
_enrollment_log(f"enrollment rejected missing_hostname ip={remote}", context_hint)
return jsonify({"error": "hostname_required"}), 400
if not enrollment_code:
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
return jsonify({"error": "enrollment_code_required"}), 400
if not isinstance(agent_pubkey_b64, str):
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "agent_pubkey_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400
try:
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
except Exception:
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400
if len(agent_pubkey_der) < 10:
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
if len(client_nonce_bytes) < 16:
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
@@ -398,8 +402,7 @@ def register(
install_code = _load_install_code(cur, enrollment_code)
site_id = install_code.get("site_id") if install_code else None
if site_id is None:
log(
"server",
_enrollment_log(
"enrollment request rejected missing_site_binding "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
@@ -407,8 +410,7 @@ def register(
return jsonify({"error": "invalid_enrollment_code"}), 400
cur.execute("SELECT 1 FROM sites WHERE id = ?", (site_id,))
if cur.fetchone() is None:
log(
"server",
_enrollment_log(
"enrollment request rejected missing_site_owner "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
@@ -416,8 +418,7 @@ def register(
return jsonify({"error": "invalid_enrollment_code"}), 400
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
if not valid_code:
log(
"server",
_enrollment_log(
"enrollment request invalid_code "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
@@ -509,8 +510,7 @@ def register(
"server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(),
}
log(
"server",
_enrollment_log(
f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}",
context_hint,
)
@@ -524,8 +524,7 @@ def register(
proof_sig_b64 = payload.get("proof_sig")
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
log(
"server",
_poll_log(
"enrollment poll received "
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
f" proof_sig_len={len(proof_sig_b64 or '')}",
@@ -533,25 +532,25 @@ def register(
)
if not isinstance(approval_reference, str) or not approval_reference:
log("server", "enrollment poll rejected missing_reference", context_hint)
_poll_log("enrollment poll rejected missing_reference", context_hint)
return jsonify({"error": "approval_reference_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400
if not isinstance(proof_sig_b64, str):
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
return jsonify({"error": "proof_sig_required"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
try:
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof_sig"}), 400
conn = db_conn_factory()
@@ -569,7 +568,7 @@ def register(
)
row = cur.fetchone()
if not row:
log("server", f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
return jsonify({"status": "unknown"}), 404
(
@@ -589,13 +588,13 @@ def register(
) = row
if client_nonce_stored != client_nonce_b64:
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
return jsonify({"error": "nonce_mismatch"}), 400
try:
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "server_nonce_invalid"}), 400
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
@@ -603,52 +602,47 @@ def register(
try:
public_key = serialization.load_der_public_key(agent_pubkey_der)
except Exception:
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
public_key = None
if public_key is None:
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
return jsonify({"error": "agent_pubkey_invalid"}), 400
try:
public_key.verify(proof_sig, message)
except Exception:
log("server", f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof"}), 400
if status == "pending":
log(
"server",
_poll_log(
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
f" fingerprint={fingerprint[:12]}",
context_hint,
)
return jsonify({"status": "pending", "poll_after_ms": 5000})
if status == "denied":
log(
"server",
_poll_log(
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "denied", "reason": "operator_denied"})
if status == "expired":
log(
"server",
_poll_log(
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "expired"})
if status == "completed":
log(
"server",
_poll_log(
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "approved", "detail": "finalized"})
if status != "approved":
log(
"server",
_poll_log(
f"enrollment poll unexpected_status={status} ref={approval_reference}",
context_hint,
)
@@ -656,8 +650,7 @@ def register(
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
if not nonce_cache.consume(nonce_key):
log(
"server",
_poll_log(
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
context_hint,
)
@@ -769,8 +762,7 @@ def register(
finally:
conn.close()
log(
"server",
_poll_log(
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
context_hint,
)

View File

@@ -37,6 +37,7 @@ class VpnSession:
firewall_rules: List[str] = field(default_factory=list)
activity_id: Optional[int] = None
hostname: Optional[str] = None
endpoint_host: Optional[str] = None
class VpnTunnelService:
@@ -176,6 +177,37 @@ class VpnTunnelService:
self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True)
return token
def _ensure_token(self, session: VpnSession, *, now: Optional[float] = None) -> None:
if not session:
return
current = now if now is not None else time.time()
if session.expires_at > current + 30:
return
session.expires_at = current + 300
session.token = self._issue_token(session.agent_id, session.tunnel_id, session.expires_at)
def _normalize_endpoint_host(self, host: Optional[str]) -> Optional[str]:
if not host:
return None
try:
text = str(host).strip()
except Exception:
return None
return text or None
def _format_endpoint_host(self, host: str) -> str:
if ":" in host and not host.startswith("["):
return f"[{host}]"
return host
def _service_log_event(self, message: str, *, level: str = "INFO") -> None:
if not callable(self.service_log):
return
try:
self.service_log("reverse_tunnel", message, level=level)
except Exception:
self.logger.debug("Failed to write reverse_tunnel service log entry", exc_info=True)
def _refresh_listener(self) -> None:
peers: List[Mapping[str, object]] = []
for session in self._sessions_by_agent.values():
@@ -192,14 +224,24 @@ class VpnTunnelService:
return
self.wg.start_listener(peers)
def connect(self, *, agent_id: str, operator_id: Optional[str]) -> Mapping[str, Any]:
def connect(
self,
*,
agent_id: str,
operator_id: Optional[str],
endpoint_host: Optional[str] = None,
) -> Mapping[str, Any]:
now = time.time()
normalized_host = self._normalize_endpoint_host(endpoint_host)
with self._lock:
existing = self._sessions_by_agent.get(agent_id)
if existing:
if operator_id:
existing.operator_ids.add(operator_id)
if normalized_host and not existing.endpoint_host:
existing.endpoint_host = normalized_host
existing.last_activity = now
self._ensure_token(existing, now=now)
return self._session_payload(existing)
tunnel_id = uuid.uuid4().hex
@@ -220,6 +262,7 @@ class VpnTunnelService:
created_at=now,
expires_at=now + 300,
last_activity=now,
endpoint_host=normalized_host,
)
if operator_id:
session.operator_ids.add(operator_id)
@@ -247,6 +290,17 @@ class VpnTunnelService:
raise
payload = self._session_payload(session)
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
self._service_log_event(
"vpn_tunnel_start agent_id={0} tunnel_id={1} virtual_ip={2} endpoint={3} allowed_ports={4} operators={5}".format(
session.agent_id,
session.tunnel_id,
session.virtual_ip,
payload.get("endpoint", ""),
",".join(str(p) for p in session.allowed_ports),
operator_text,
)
)
self._emit_start(payload)
self._log_device_activity(session, event="start")
return payload
@@ -258,6 +312,22 @@ class VpnTunnelService:
return None
return self._session_payload(session, include_token=False)
def session_payload(self, agent_id: str, *, include_token: bool = True) -> Optional[Mapping[str, Any]]:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
if not session:
return None
if include_token:
self._ensure_token(session)
return self._session_payload(session, include_token=include_token)
def request_agent_start(self, agent_id: str) -> Optional[Mapping[str, Any]]:
payload = self.session_payload(agent_id, include_token=True)
if not payload:
return None
self._emit_start(payload)
return payload
def bump_activity(self, agent_id: str) -> None:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
@@ -283,6 +353,15 @@ class VpnTunnelService:
self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True)
self._refresh_listener()
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
self._service_log_event(
"vpn_tunnel_stop agent_id={0} tunnel_id={1} reason={2} operators={3}".format(
session.agent_id,
session.tunnel_id,
reason,
operator_text,
)
)
self._emit_stop(session, reason)
self._log_device_activity(session, event="stop", reason=reason)
return True
@@ -297,6 +376,16 @@ class VpnTunnelService:
def _emit_start(self, payload: Mapping[str, Any]) -> None:
if not self.socketio:
return
agent_id = None
if isinstance(payload, Mapping):
agent_id = payload.get("agent_id")
emit_agent = getattr(self.context, "emit_agent_event", None)
if agent_id and callable(emit_agent):
try:
if emit_agent(agent_id, "vpn_tunnel_start", payload):
return
except Exception:
self.logger.debug("emit_agent_event failed for vpn_tunnel_start", exc_info=True)
try:
self.socketio.emit("vpn_tunnel_start", payload, namespace="/")
except Exception:
@@ -305,6 +394,17 @@ class VpnTunnelService:
def _emit_stop(self, session: VpnSession, reason: str) -> None:
if not self.socketio:
return
emit_agent = getattr(self.context, "emit_agent_event", None)
if callable(emit_agent):
try:
if emit_agent(
session.agent_id,
"vpn_tunnel_stop",
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
):
return
except Exception:
self.logger.debug("emit_agent_event failed for vpn_tunnel_stop", exc_info=True)
try:
self.socketio.emit(
"vpn_tunnel_stop",
@@ -454,13 +554,15 @@ class VpnTunnelService:
pass
def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]:
endpoint_host = session.endpoint_host or str(self._engine_ip.ip)
endpoint_host = self._format_endpoint_host(endpoint_host)
payload: Dict[str, Any] = {
"tunnel_id": session.tunnel_id,
"agent_id": session.agent_id,
"virtual_ip": session.virtual_ip,
"engine_virtual_ip": str(self._engine_ip.ip),
"allowed_ips": f"{self._engine_ip.ip}/32",
"endpoint": f"{self._engine_ip.ip}:{self.context.wireguard_port}",
"endpoint": f"{endpoint_host}:{self.context.wireguard_port}",
"server_public_key": self.wg.server_public_key,
"client_public_key": session.client_public_key,
"client_private_key": session.client_private_key,

View File

@@ -18,6 +18,7 @@ from __future__ import annotations
import base64
import ipaddress
import logging
import os
import subprocess
import tempfile
import time
@@ -72,6 +73,18 @@ class WireGuardServerManager:
self.server_private_key, self.server_public_key = self._ensure_server_keys()
self._service_name = "borealis-wg"
self._temp_dir = Path(tempfile.gettempdir()) / "borealis-wg-engine"
self._wireguard_exe = self._resolve_wireguard_exe()
def _resolve_wireguard_exe(self) -> str:
candidates = [
str(Path(os.environ.get("ProgramFiles", "C:\\Program Files")) / "WireGuard" / "wireguard.exe"),
str(Path(os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)")) / "WireGuard" / "wireguard.exe"),
"wireguard.exe",
]
for candidate in candidates:
if Path(candidate).is_file():
return candidate
return "wireguard.exe"
def _ensure_cert_dir(self) -> None:
try:
@@ -316,7 +329,7 @@ class WireGuardServerManager:
# Ensure old service is removed before re-installing.
self.stop_listener()
args = ["wireguard.exe", "/installtunnelservice", str(config_path)]
args = [self._wireguard_exe, "/installtunnelservice", str(config_path)]
code, out, err = self._run_command(args)
if code != 0:
self.logger.error("Failed to install WireGuard tunnel service code=%s err=%s", code, err)
@@ -326,7 +339,7 @@ class WireGuardServerManager:
def stop_listener(self) -> None:
"""Stop and remove the WireGuard tunnel service."""
args = ["wireguard.exe", "/uninstalltunnelservice", self._service_name]
args = [self._wireguard_exe, "/uninstalltunnelservice", self._service_name]
code, out, err = self._run_command(args)
if code != 0:
self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err)

View File

@@ -20,7 +20,7 @@ from flask_socketio import SocketIO
from ...database import initialise_engine_database
from ...security import signing
from ...server import EngineContext
from ..VPN import VpnTunnelService
from ..VPN import WireGuardServerConfig, WireGuardServerManager, VpnTunnelService
from .vpn_shell import VpnShellBridge
@@ -63,12 +63,53 @@ class EngineRealtimeAdapters:
self.service_log = _make_service_logger(base, self.context.logger)
class AgentSocketRegistry:
def __init__(self, socketio: SocketIO, logger) -> None:
self.socketio = socketio
self.logger = logger
self._sid_by_agent: Dict[str, str] = {}
self._agent_by_sid: Dict[str, str] = {}
def register(self, agent_id: str, sid: str) -> None:
if not agent_id or not sid:
return
previous = self._sid_by_agent.get(agent_id)
if previous and previous != sid:
self._agent_by_sid.pop(previous, None)
self._sid_by_agent[agent_id] = sid
self._agent_by_sid[sid] = agent_id
def unregister(self, sid: str) -> Optional[str]:
agent_id = self._agent_by_sid.pop(sid, None)
if agent_id and self._sid_by_agent.get(agent_id) == sid:
self._sid_by_agent.pop(agent_id, None)
return agent_id
def emit(self, agent_id: str, event: str, payload: Any) -> bool:
sid = self._sid_by_agent.get(agent_id)
if not sid:
return False
try:
self.socketio.emit(event, payload, to=sid)
return True
except Exception:
self.logger.debug("Failed to emit %s to agent_id=%s", event, agent_id, exc_info=True)
return False
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
"""Register Socket.IO event handlers for the Engine runtime."""
adapters = EngineRealtimeAdapters(context)
logger = context.logger.getChild("realtime.quick_jobs")
agent_logger = context.logger.getChild("realtime.agents")
shell_bridge = VpnShellBridge(socket_server, context)
agent_registry = AgentSocketRegistry(socket_server, agent_logger)
def _emit_agent_event(agent_id: str, event: str, payload: Any) -> bool:
return agent_registry.emit(agent_id, event, payload)
setattr(context, "emit_agent_event", _emit_agent_event)
def _get_tunnel_service() -> Optional[VpnTunnelService]:
service = getattr(context, "vpn_tunnel_service", None)
@@ -76,7 +117,22 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
return service
manager = getattr(context, "wireguard_server_manager", None)
if manager is None:
return None
try:
manager = WireGuardServerManager(
WireGuardServerConfig(
port=context.wireguard_port,
engine_virtual_ip=context.wireguard_engine_virtual_ip,
peer_network=context.wireguard_peer_network,
private_key_path=Path(context.wireguard_server_private_key_path),
public_key_path=Path(context.wireguard_server_public_key_path),
acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows),
log_path=Path(context.vpn_tunnel_log_path),
)
)
setattr(context, "wireguard_server_manager", manager)
except Exception:
context.logger.error("Failed to initialize WireGuard server manager on demand.", exc_info=True)
return None
try:
signer = signing.load_signer()
except Exception:
@@ -275,6 +331,29 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
service.bump_activity(agent_id)
return {"status": "ok"}
@socket_server.on("connect_agent")
def _connect_agent(data: Any) -> Dict[str, Any]:
agent_id = ""
service_mode = ""
if isinstance(data, dict):
agent_id = str(data.get("agent_id") or "").strip()
service_mode = str(data.get("service_mode") or "").strip().lower()
elif isinstance(data, str):
agent_id = data.strip()
if not agent_id:
return {"error": "agent_id_required"}
agent_registry.register(agent_id, request.sid)
agent_logger.info("Agent socket registered agent_id=%s service_mode=%s sid=%s", agent_id, service_mode, request.sid)
service = _get_tunnel_service()
if service:
payload = service.session_payload(agent_id, include_token=True)
if payload:
agent_registry.emit(agent_id, "vpn_tunnel_start", payload)
return {"status": "ok"}
@socket_server.on("vpn_shell_send")
def _vpn_shell_send(data: Any) -> Dict[str, Any]:
payload = None
@@ -288,10 +367,13 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
return {"status": "ok"}
@socket_server.on("vpn_shell_close")
def _vpn_shell_close() -> Dict[str, Any]:
def _vpn_shell_close(data: Any = None) -> Dict[str, Any]:
shell_bridge.close(request.sid)
return {"status": "ok"}
@socket_server.on("disconnect")
def _ws_disconnect() -> None:
agent_id = agent_registry.unregister(request.sid)
if agent_id:
agent_logger.info("Agent socket disconnected agent_id=%s sid=%s", agent_id, request.sid)
shell_bridge.close(request.sid)

View File

@@ -13,6 +13,7 @@ import base64
import json
import socket
import threading
import time
from dataclasses import dataclass
from typing import Any, Dict, Optional
@@ -42,7 +43,10 @@ class ShellSession:
buffer = b""
try:
while True:
data = self.tcp.recv(4096)
try:
data = self.tcp.recv(4096)
except (socket.timeout, TimeoutError):
break
if not data:
break
buffer += data
@@ -100,12 +104,28 @@ class VpnShellBridge:
return None
host = str(status.get("virtual_ip") or "").split("/")[0]
port = int(self.context.wireguard_shell_port)
try:
tcp = socket.create_connection((host, port), timeout=5)
except Exception:
self.logger.debug("Failed to connect vpn shell to %s:%s", host, port, exc_info=True)
tcp = None
last_error: Optional[Exception] = None
for attempt in range(3):
try:
tcp = socket.create_connection((host, port), timeout=5)
break
except Exception as exc:
last_error = exc
if attempt == 0:
try:
service.request_agent_start(agent_id)
except Exception:
self.logger.debug("Failed to re-emit vpn_tunnel_start for agent=%s", agent_id, exc_info=True)
time.sleep(1)
if tcp is None:
self.logger.warning("Failed to connect vpn shell to %s:%s", host, port, exc_info=last_error)
return None
session = ShellSession(sid=sid, agent_id=agent_id, socketio=self.socketio, tcp=tcp)
try:
session.tcp.settimeout(15)
except Exception:
pass
self._sessions[sid] = session
session.start_reader()
return session
@@ -124,4 +144,3 @@ class VpnShellBridge:
if not session:
return
session.close()

View File

@@ -7,6 +7,15 @@
"""Service registration hooks for the Borealis Engine runtime."""
from . import API, WebSocket, WebUI
from __future__ import annotations
import importlib
from typing import Any
__all__ = ["API", "WebSocket", "WebUI"]
def __getattr__(name: str) -> Any:
if name in __all__:
return importlib.import_module(f"{__name__}.{name}")
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -75,7 +75,7 @@ const SECTION_HEIGHTS = {
};
const buildVpnGroups = (shellPort) => {
const normalizedShell = Number(shellPort) || 47001;
const normalizedShell = Number(shellPort) || 47002;
return [
{
key: "shell",
@@ -335,7 +335,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
const [vpnToggles, setVpnToggles] = useState({});
const [vpnCustomPorts, setVpnCustomPorts] = useState([]);
const [vpnDefaultPorts, setVpnDefaultPorts] = useState([]);
const [vpnShellPort, setVpnShellPort] = useState(47001);
const [vpnShellPort, setVpnShellPort] = useState(47002);
const [vpnLoadedFor, setVpnLoadedFor] = useState("");
// Snapshotted status for the lifetime of this page
const [lockedStatus, setLockedStatus] = useState(() => {
@@ -347,6 +347,31 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
const now = Date.now() / 1000;
return now - tsSec <= 300 ? "Online" : "Offline";
});
const summary = details.summary || {};
const vpnAgentId = useMemo(() => {
return (
meta.agentId ||
summary.agent_id ||
agent?.agent_id ||
agent?.id ||
device?.agent_id ||
device?.agent_guid ||
device?.id ||
""
);
}, [agent?.agent_id, agent?.id, device?.agent_guid, device?.agent_id, device?.id, meta.agentId, summary.agent_id]);
const vpnPortGroups = useMemo(() => buildVpnGroups(vpnShellPort), [vpnShellPort]);
const tunnelDevice = useMemo(
() => ({
...(device || {}),
...(agent || {}),
summary,
hostname: meta.hostname || summary.hostname || device?.hostname || agent?.hostname,
agent_id: meta.agentId || summary.agent_id || agent?.agent_id || agent?.id || device?.agent_id || device?.agent_guid,
agent_guid: meta.agentGuid || summary.agent_guid || device?.agent_guid || device?.guid || agent?.agent_guid || agent?.guid,
}),
[agent, device, meta.agentGuid, meta.agentId, meta.hostname, summary]
);
const quickJobTargets = useMemo(() => {
const values = [];
const push = (value) => {
@@ -715,7 +740,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
const numericDefaults = normalizedDefaults
.map((p) => Number(p))
.filter((p) => Number.isFinite(p) && p > 0);
const effectiveShell = Number(shellPort) || 47001;
const effectiveShell = Number(shellPort) || 47002;
const groups = buildVpnGroups(effectiveShell);
const knownPorts = new Set(groups.flatMap((group) => group.ports));
const allowedSet = new Set(numericPorts);
@@ -887,31 +912,6 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
[]
);
const summary = details.summary || {};
const vpnAgentId = useMemo(() => {
return (
meta.agentId ||
summary.agent_id ||
agent?.agent_id ||
agent?.id ||
device?.agent_id ||
device?.agent_guid ||
device?.id ||
""
);
}, [agent?.agent_id, agent?.id, device?.agent_guid, device?.agent_id, device?.id, meta.agentId, summary.agent_id]);
const vpnPortGroups = useMemo(() => buildVpnGroups(vpnShellPort), [vpnShellPort]);
const tunnelDevice = useMemo(
() => ({
...(device || {}),
...(agent || {}),
summary,
hostname: meta.hostname || summary.hostname || device?.hostname || agent?.hostname,
agent_id: meta.agentId || summary.agent_id || agent?.agent_id || agent?.id || device?.agent_id || device?.agent_guid,
agent_guid: meta.agentGuid || summary.agent_guid || device?.agent_guid || device?.guid || agent?.agent_guid || agent?.guid,
}),
[agent, device, meta.agentGuid, meta.agentId, meta.hostname, summary]
);
// Build a best-effort CPU display from summary fields
const cpuInfo = useMemo(() => {
const cpu = details.cpu || summary.cpu || {};

View File

@@ -93,6 +93,8 @@ export default function ReverseTunnelPowershell({ device }) {
const socketRef = useRef(null);
const localSocketRef = useRef(false);
const terminalRef = useRef(null);
const agentIdRef = useRef("");
const tunnelIdRef = useRef("");
const agentId = useMemo(() => {
return (
@@ -107,6 +109,14 @@ export default function ReverseTunnelPowershell({ device }) {
);
}, [device]);
useEffect(() => {
agentIdRef.current = agentId;
}, [agentId]);
useEffect(() => {
tunnelIdRef.current = tunnel?.tunnel_id || "";
}, [tunnel?.tunnel_id]);
const ensureSocket = useCallback(() => {
if (socketRef.current) return socketRef.current;
const existing = typeof window !== "undefined" ? window.BorealisSocket : null;
@@ -142,21 +152,20 @@ export default function ReverseTunnelPowershell({ device }) {
scrollToBottom();
}, [output, scrollToBottom]);
const stopTunnel = useCallback(
async (reason = "operator_disconnect") => {
if (!agentId) return;
try {
await fetch("/api/tunnel/disconnect", {
method: "DELETE",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ agent_id: agentId, tunnel_id: tunnel?.tunnel_id, reason }),
});
} catch {
// best-effort
}
},
[agentId, tunnel?.tunnel_id]
);
const stopTunnel = useCallback(async (reason = "operator_disconnect") => {
const currentAgentId = agentIdRef.current;
if (!currentAgentId) return;
const currentTunnelId = tunnelIdRef.current;
try {
await fetch("/api/tunnel/disconnect", {
method: "DELETE",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ agent_id: currentAgentId, tunnel_id: currentTunnelId, reason }),
});
} catch {
// best-effort
}
}, []);
const closeShell = useCallback(async () => {
const socket = ensureSocket();
@@ -232,7 +241,10 @@ export default function ReverseTunnelPowershell({ device }) {
body: JSON.stringify({ agent_id: agentId }),
});
const data = await resp.json().catch(() => ({}));
if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`);
if (!resp.ok) {
const detail = data?.detail ? `: ${data.detail}` : "";
throw new Error(`${data?.error || `HTTP ${resp.status}`}${detail}`);
}
const statusResp = await fetch(
`/api/tunnel/connect/status?agent_id=${encodeURIComponent(agentId)}&bump=1`
);