Removed RDP in favor of VNC / Made WireGuard Tunnel Persistent

This commit is contained in:
2026-02-05 23:05:23 -07:00
parent 287d3b1cf7
commit 0d40ca6edb
35 changed files with 2207 additions and 1400 deletions

View File

@@ -33,7 +33,8 @@ from ..auth import DevModeManager
from .enrollment import routes as enrollment_routes
from .tokens import routes as token_routes
from .devices.tunnel import register_tunnel
from .devices.rdp import register_rdp
from .devices.vnc import register_vnc
from .devices.shell import register_shell
from ...server import EngineContext
from .access_management.login import register_auth
@@ -292,7 +293,8 @@ def _register_devices(app: Flask, adapters: EngineServiceAdapters) -> None:
register_admin_endpoints(app, adapters)
device_routes.register_agents(app, adapters)
register_tunnel(app, adapters)
register_rdp(app, adapters)
register_vnc(app, adapters)
register_shell(app, adapters)
def _register_filters(app: Flask, adapters: EngineServiceAdapters) -> None:
filters_management.register_filters(app, adapters)

View File

@@ -5,6 +5,7 @@
# API Endpoints (if applicable):
# - POST /api/agent/heartbeat (Device Authenticated) - Updates device last-seen metadata and inventory snapshots.
# - POST /api/agent/script/request (Device Authenticated) - Provides script execution payloads or idle signals to agents.
# - POST /api/agent/vpn/ensure (Device Authenticated) - Ensures persistent WireGuard tunnel material.
# ======================================================
"""Device-affiliated agent endpoints for the Borealis Engine runtime."""
@@ -13,12 +14,14 @@ from __future__ import annotations
import json
import sqlite3
import time
from urllib.parse import urlsplit
from typing import TYPE_CHECKING, Any, Dict, Optional
from flask import Blueprint, jsonify, request, g
from ....auth.device_auth import AGENT_CONTEXT_HEADER, require_device_auth
from ....auth.guid_utils import normalize_guid
from .tunnel import _get_tunnel_service
if TYPE_CHECKING: # pragma: no cover - typing aide
from .. import EngineServiceAdapters
@@ -42,6 +45,20 @@ def _json_or_none(value: Any) -> Optional[str]:
return None
def _infer_endpoint_host(req) -> str:
forwarded = (req.headers.get("X-Forwarded-Host") or req.headers.get("X-Original-Host") or "").strip()
host = forwarded.split(",")[0].strip() if forwarded else (req.host or "").strip()
if not host:
return ""
try:
parsed = urlsplit(f"//{host}")
if parsed.hostname:
return parsed.hostname
except Exception:
return host
return host
def register_agents(app, adapters: "EngineServiceAdapters") -> None:
"""Register agent heartbeat and script polling routes."""
@@ -218,4 +235,76 @@ def register_agents(app, adapters: "EngineServiceAdapters") -> None:
}
)
@blueprint.route("/api/agent/vpn/ensure", methods=["POST"])
@require_device_auth(auth_manager)
def vpn_ensure():
ctx = _auth_context()
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
body = request.get_json(silent=True) or {}
requested_agent = (body.get("agent_id") or "").strip()
guid = normalize_guid(ctx.guid)
conn = db_conn_factory()
resolved_agent = ""
try:
cur = conn.cursor()
cur.execute(
"SELECT agent_id FROM devices WHERE UPPER(guid) = ? ORDER BY last_seen DESC LIMIT 1",
(guid,),
)
row = cur.fetchone()
if row and row[0]:
resolved_agent = str(row[0]).strip()
finally:
conn.close()
if not resolved_agent:
log("VPN_Tunnel/tunnel", f"vpn_agent_ensure_missing_agent guid={guid}", _context_hint(ctx), level="ERROR")
return jsonify({"error": "agent_id_missing"}), 404
if requested_agent and requested_agent != resolved_agent:
log(
"VPN_Tunnel/tunnel",
"vpn_agent_ensure_agent_mismatch requested={0} resolved={1}".format(
requested_agent, resolved_agent
),
_context_hint(ctx),
level="WARNING",
)
try:
tunnel_service = _get_tunnel_service(adapters)
endpoint_host = _infer_endpoint_host(request)
log(
"VPN_Tunnel/tunnel",
"vpn_agent_ensure_request agent_id={0} endpoint_host={1}".format(
resolved_agent, endpoint_host or "-"
),
_context_hint(ctx),
)
payload = tunnel_service.connect(
agent_id=resolved_agent,
operator_id=None,
endpoint_host=endpoint_host,
)
except Exception as exc:
log(
"VPN_Tunnel/tunnel",
"vpn_agent_ensure_failed agent_id={0} error={1}".format(resolved_agent, str(exc)),
_context_hint(ctx),
level="ERROR",
)
return jsonify({"error": "tunnel_start_failed", "detail": str(exc)}), 500
log(
"VPN_Tunnel/tunnel",
"vpn_agent_ensure_response agent_id={0} tunnel_id={1}".format(
payload.get("agent_id", resolved_agent), payload.get("tunnel_id", "-")
),
_context_hint(ctx),
)
return jsonify(payload), 200
app.register_blueprint(blueprint)

View File

@@ -1,12 +1,13 @@
# ======================================================
# Data\Engine\services\API\devices\rdp.py
# Description: RDP session bootstrap for Guacamole WebSocket tunnels.
# Data\Engine\services\API\devices\shell.py
# Description: Remote PowerShell session endpoints for persistent WireGuard tunnels.
#
# API Endpoints (if applicable):
# - POST /api/rdp/session (Token Authenticated) - Issues a one-time Guacamole tunnel token for RDP.
# - POST /api/shell/establish (Token Authenticated) - Ensure shell readiness over WireGuard.
# - POST /api/shell/disconnect (Token Authenticated) - Disconnect the operator shell session.
# ======================================================
"""RDP session bootstrap endpoints for the Borealis Engine."""
"""Remote PowerShell session endpoints for the Borealis Engine."""
from __future__ import annotations
import os
@@ -16,7 +17,6 @@ from urllib.parse import urlsplit
from flask import Blueprint, jsonify, request, session
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from ...RemoteDesktop.guacamole_proxy import GUAC_WS_PATH, ensure_guacamole_proxy
from .tunnel import _get_tunnel_service
if False: # pragma: no cover - hint for type checkers
@@ -81,25 +81,18 @@ def _infer_endpoint_host(req) -> str:
return host
def _is_secure(req) -> bool:
if req.is_secure:
return True
forwarded = (req.headers.get("X-Forwarded-Proto") or "").split(",")[0].strip().lower()
return forwarded == "https"
def register_rdp(app, adapters: "EngineServiceAdapters") -> None:
blueprint = Blueprint("rdp", __name__)
logger = adapters.context.logger.getChild("rdp.api")
def register_shell(app, adapters: "EngineServiceAdapters") -> None:
blueprint = Blueprint("vpn_shell", __name__)
logger = adapters.context.logger.getChild("vpn_shell.api")
service_log = adapters.service_log
def _service_log_event(message: str, *, level: str = "INFO") -> None:
if not callable(service_log):
return
try:
service_log("RDP", message, level=level)
service_log("VPN_Tunnel/remote_shell", message, level=level)
except Exception:
logger.debug("rdp service log write failed", exc_info=True)
logger.debug("vpn_shell service log write failed", exc_info=True)
def _request_remote() -> str:
forwarded = (request.headers.get("X-Forwarded-For") or "").strip()
@@ -107,8 +100,8 @@ def register_rdp(app, adapters: "EngineServiceAdapters") -> None:
return forwarded.split(",")[0].strip()
return (request.remote_addr or "").strip()
@blueprint.route("/api/rdp/session", methods=["POST"])
def rdp_session():
@blueprint.route("/api/shell/establish", methods=["POST"])
def shell_establish():
requirement = _require_login(app)
if requirement:
payload, status = requirement
@@ -119,77 +112,82 @@ def register_rdp(app, adapters: "EngineServiceAdapters") -> None:
body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id"))
protocol = _normalize_text(body.get("protocol") or "rdp").lower()
username = _normalize_text(body.get("username"))
password = str(body.get("password") or "")
if not agent_id:
return jsonify({"error": "agent_id_required"}), 400
try:
tunnel_service = _get_tunnel_service(adapters)
endpoint_host = _infer_endpoint_host(request)
_service_log_event(
"vpn_shell_establish_request agent_id={0} operator={1} endpoint_host={2} remote={3}".format(
agent_id,
operator_id or "-",
endpoint_host or "-",
_request_remote() or "-",
)
)
payload = tunnel_service.connect(
agent_id=agent_id,
operator_id=operator_id,
endpoint_host=endpoint_host,
)
except Exception as exc:
_service_log_event(
"vpn_shell_establish_failed agent_id={0} operator={1} error={2}".format(
agent_id,
operator_id or "-",
str(exc),
),
level="ERROR",
)
return jsonify({"error": "establish_failed", "detail": str(exc)}), 500
agent_socket = False
registry = getattr(adapters.context, "agent_socket_registry", None)
if registry and hasattr(registry, "is_registered"):
try:
agent_socket = bool(registry.is_registered(agent_id))
except Exception:
agent_socket = False
response = dict(payload)
response["status"] = "ok"
response["agent_socket"] = agent_socket
_service_log_event(
"vpn_shell_establish_response agent_id={0} tunnel_id={1} agent_socket={2}".format(
agent_id,
response.get("tunnel_id", "-"),
str(agent_socket).lower(),
)
)
return jsonify(response), 200
@blueprint.route("/api/shell/disconnect", methods=["POST"])
def shell_disconnect():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id"))
reason = _normalize_text(body.get("reason") or "operator_disconnect")
operator_id = (_current_user(app) or {}).get("username") or None
if not agent_id:
return jsonify({"error": "agent_id_required"}), 400
if protocol != "rdp":
return jsonify({"error": "unsupported_protocol"}), 400
tunnel_service = _get_tunnel_service(adapters)
session_payload = tunnel_service.session_payload(agent_id, include_token=False)
if not session_payload:
return jsonify({"error": "tunnel_down"}), 409
allowed_ports = session_payload.get("allowed_ports") or []
if 3389 not in allowed_ports:
return jsonify({"error": "rdp_not_allowed"}), 403
virtual_ip = _normalize_text(session_payload.get("virtual_ip"))
host = virtual_ip.split("/")[0] if virtual_ip else ""
if not host:
return jsonify({"error": "virtual_ip_missing"}), 500
registry = ensure_guacamole_proxy(adapters.context, logger=logger)
if registry is None:
return jsonify({"error": "rdp_proxy_unavailable"}), 503
_service_log_event(
"rdp_session_request agent_id={0} operator={1} protocol={2} remote={3}".format(
"vpn_shell_disconnect_request agent_id={0} operator={1} reason={2} remote={3}".format(
agent_id,
operator_id or "-",
protocol,
reason or "-",
_request_remote() or "-",
)
)
rdp_session = registry.create(
agent_id=agent_id,
host=host,
port=3389,
username=username,
password=password,
protocol=protocol,
operator_id=operator_id,
ignore_cert=True,
)
ws_scheme = "wss" if _is_secure(request) else "ws"
ws_host = _infer_endpoint_host(request)
ws_port = int(getattr(adapters.context, "rdp_ws_port", 4823))
ws_url = f"{ws_scheme}://{ws_host}:{ws_port}{GUAC_WS_PATH}"
_service_log_event(
"rdp_session_ready agent_id={0} token={1} expires_at={2}".format(
agent_id,
rdp_session.token[:8],
int(rdp_session.expires_at),
)
)
return (
jsonify(
{
"token": rdp_session.token,
"ws_url": ws_url,
"expires_at": int(rdp_session.expires_at),
"virtual_ip": host,
"tunnel_id": session_payload.get("tunnel_id"),
}
),
200,
)
return jsonify({"status": "disconnected", "reason": reason}), 200
app.register_blueprint(blueprint)
__all__ = ["register_shell"]

View File

@@ -1,12 +1,11 @@
# ======================================================
# Data\Engine\services\API\devices\tunnel.py
# Description: WireGuard VPN tunnel API (connect/status/disconnect).
# Description: WireGuard VPN tunnel API (connect/status).
#
# API Endpoints (if applicable):
# - POST /api/tunnel/connect (Token Authenticated) - Issues VPN session material for an agent.
# - GET /api/tunnel/status (Token Authenticated) - Returns VPN status for an agent.
# - GET /api/tunnel/active (Token Authenticated) - Lists active VPN tunnel sessions.
# - DELETE /api/tunnel/disconnect (Token Authenticated) - Tears down VPN session for an agent.
# ======================================================
"""WireGuard VPN tunnel API (Engine side)."""
@@ -254,52 +253,4 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
)
return jsonify({"count": len(sessions), "tunnels": sessions}), 200
@blueprint.route("/api/tunnel/disconnect", methods=["DELETE"])
def disconnect_tunnel():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id"))
tunnel_id = _normalize_text(body.get("tunnel_id"))
reason = _normalize_text(body.get("reason") or "operator_stop")
tunnel_service = _get_tunnel_service(adapters)
_service_log_event(
"vpn_api_disconnect_request agent_id={0} tunnel_id={1} reason={2} operator={3} remote={4}".format(
agent_id or "-",
tunnel_id or "-",
reason or "-",
(_current_user(app) or {}).get("username") or "-",
_request_remote() or "-",
)
)
stopped = False
if tunnel_id:
stopped = tunnel_service.disconnect_by_tunnel(tunnel_id, reason=reason)
elif agent_id:
stopped = tunnel_service.disconnect(agent_id, reason=reason)
else:
return jsonify({"error": "agent_id_required"}), 400
if not stopped:
_service_log_event(
"vpn_api_disconnect_not_found agent_id={0} tunnel_id={1}".format(
agent_id or "-",
tunnel_id or "-",
),
level="WARNING",
)
return jsonify({"error": "not_found"}), 404
_service_log_event(
"vpn_api_disconnect_response agent_id={0} tunnel_id={1} status=stopped".format(
agent_id or "-",
tunnel_id or "-",
)
)
return jsonify({"status": "stopped", "reason": reason}), 200
app.register_blueprint(blueprint)

View File

@@ -0,0 +1,338 @@
# ======================================================
# Data\Engine\services\API\devices\vnc.py
# Description: VNC session bootstrap for noVNC WebSocket tunnels.
#
# API Endpoints (if applicable):
# - POST /api/vnc/establish (Token Authenticated) - Establish a VNC session for noVNC.
# - POST /api/vnc/disconnect (Token Authenticated) - Disconnect the operator VNC session.
# - POST /api/vnc/session (Token Authenticated) - Legacy alias for establish.
# ======================================================
"""VNC session bootstrap endpoints for the Borealis Engine."""
from __future__ import annotations
import os
import secrets
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlsplit
from flask import Blueprint, jsonify, request, session
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from ...RemoteDesktop.vnc_proxy import VNC_WS_PATH, ensure_vnc_proxy
from .tunnel import _get_tunnel_service
if False: # pragma: no cover - hint for type checkers
from .. import EngineServiceAdapters
def _current_user(app) -> Optional[Dict[str, str]]:
username = session.get("username")
role = session.get("role") or "User"
if username:
return {"username": username, "role": role}
token = None
auth_header = request.headers.get("Authorization") or ""
if auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
if not token:
token = request.cookies.get("borealis_auth")
if not token:
return None
try:
serializer = URLSafeTimedSerializer(app.secret_key or "borealis-dev-secret", salt="borealis-auth")
token_ttl = int(os.environ.get("BOREALIS_TOKEN_TTL_SECONDS", 60 * 60 * 24 * 30))
data = serializer.loads(token, max_age=token_ttl)
username = data.get("u")
role = data.get("r") or "User"
if username:
return {"username": username, "role": role}
except (BadSignature, SignatureExpired, Exception):
return None
return None
def _require_login(app) -> Optional[Tuple[Dict[str, Any], int]]:
user = _current_user(app)
if not user:
return {"error": "unauthorized"}, 401
return None
def _normalize_text(value: Any) -> str:
if value is None:
return ""
try:
return str(value).strip()
except Exception:
return ""
def _infer_endpoint_host(req) -> str:
forwarded = (req.headers.get("X-Forwarded-Host") or req.headers.get("X-Original-Host") or "").strip()
host = forwarded.split(",")[0].strip() if forwarded else (req.host or "").strip()
if not host:
return ""
try:
parsed = urlsplit(f"//{host}")
if parsed.hostname:
return parsed.hostname
except Exception:
return host
return host
def _is_secure(req) -> bool:
if req.is_secure:
return True
forwarded = (req.headers.get("X-Forwarded-Proto") or "").split(",")[0].strip().lower()
return forwarded == "https"
def _generate_vnc_password() -> str:
# UltraVNC uses the first 8 characters for VNC auth; keep the token to 8 for compatibility.
return secrets.token_hex(4)
def _load_vnc_password(adapters: "EngineServiceAdapters", agent_id: str) -> Optional[str]:
conn = adapters.db_conn_factory()
try:
cur = conn.cursor()
cur.execute(
"SELECT agent_vnc_password FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
(agent_id,),
)
row = cur.fetchone()
if row and row[0]:
return str(row[0]).strip()
except Exception:
adapters.context.logger.debug("Failed to load agent VNC password", exc_info=True)
finally:
try:
conn.close()
except Exception:
pass
return None
def _store_vnc_password(adapters: "EngineServiceAdapters", agent_id: str, password: str) -> None:
conn = adapters.db_conn_factory()
try:
cur = conn.cursor()
cur.execute(
"UPDATE devices SET agent_vnc_password=? WHERE agent_id=?",
(password, agent_id),
)
conn.commit()
except Exception:
adapters.context.logger.debug("Failed to store agent VNC password", exc_info=True)
finally:
try:
conn.close()
except Exception:
pass
def register_vnc(app, adapters: "EngineServiceAdapters") -> None:
blueprint = Blueprint("vnc", __name__)
logger = adapters.context.logger.getChild("vnc.api")
service_log = adapters.service_log
def _service_log_event(message: str, *, level: str = "INFO") -> None:
if not callable(service_log):
return
try:
service_log("VNC", message, level=level)
except Exception:
logger.debug("vnc service log write failed", exc_info=True)
def _request_remote() -> str:
forwarded = (request.headers.get("X-Forwarded-For") or "").strip()
if forwarded:
return forwarded.split(",")[0].strip()
return (request.remote_addr or "").strip()
def _issue_session(agent_id: str, operator_id: Optional[str]) -> Tuple[Dict[str, Any], int]:
tunnel_service = _get_tunnel_service(adapters)
session_payload = tunnel_service.session_payload(agent_id, include_token=False)
if not session_payload:
try:
session_payload = tunnel_service.connect(
agent_id=agent_id,
operator_id=operator_id,
endpoint_host=_infer_endpoint_host(request),
)
except Exception:
return {"error": "tunnel_down"}, 409
vnc_port = int(getattr(adapters.context, "vnc_port", 5900))
raw_ports = session_payload.get("allowed_ports") or []
allowed_ports = []
for value in raw_ports:
try:
allowed_ports.append(int(value))
except Exception:
continue
if vnc_port not in allowed_ports:
return {"error": "vnc_not_allowed", "vnc_port": vnc_port}, 403
virtual_ip = _normalize_text(session_payload.get("virtual_ip"))
host = virtual_ip.split("/")[0] if virtual_ip else ""
if not host:
return {"error": "virtual_ip_missing"}, 500
vnc_password = _load_vnc_password(adapters, agent_id)
if not vnc_password:
vnc_password = _generate_vnc_password()
_store_vnc_password(adapters, agent_id, vnc_password)
if len(vnc_password) > 8:
vnc_password = vnc_password[:8]
_store_vnc_password(adapters, agent_id, vnc_password)
registry = ensure_vnc_proxy(adapters.context, logger=logger)
if registry is None:
return {"error": "vnc_proxy_unavailable"}, 503
_service_log_event(
"vnc_establish_request agent_id={0} operator={1} remote={2}".format(
agent_id,
operator_id or "-",
_request_remote() or "-",
)
)
vnc_session = registry.create(
agent_id=agent_id,
host=host,
port=vnc_port,
operator_id=operator_id,
)
emit_agent = getattr(adapters.context, "emit_agent_event", None)
payload = {
"agent_id": agent_id,
"port": vnc_port,
"allowed_ips": session_payload.get("allowed_ips"),
"virtual_ip": host,
"password": vnc_password,
"reason": "vnc_session_start",
}
agent_socket_ready = True
if callable(emit_agent):
agent_socket_ready = bool(emit_agent(agent_id, "vnc_start", payload))
if agent_socket_ready:
_service_log_event(
"vnc_start_emit agent_id={0} port={1} virtual_ip={2}".format(
agent_id,
vnc_port,
host or "-",
)
)
else:
_service_log_event(
"vnc_start_emit_failed agent_id={0} port={1}".format(
agent_id,
vnc_port,
),
level="WARNING",
)
if not agent_socket_ready:
return {"error": "agent_socket_missing"}, 409
ws_scheme = "wss" if _is_secure(request) else "ws"
ws_host = _infer_endpoint_host(request)
ws_port = int(getattr(adapters.context, "vnc_ws_port", 4823))
ws_url = f"{ws_scheme}://{ws_host}:{ws_port}{VNC_WS_PATH}"
_service_log_event(
"vnc_session_ready agent_id={0} token={1} expires_at={2}".format(
agent_id,
vnc_session.token[:8],
int(vnc_session.expires_at),
)
)
return (
{
"token": vnc_session.token,
"ws_url": ws_url,
"expires_at": int(vnc_session.expires_at),
"virtual_ip": host,
"tunnel_id": session_payload.get("tunnel_id"),
"engine_virtual_ip": session_payload.get("engine_virtual_ip"),
"allowed_ports": session_payload.get("allowed_ports"),
"vnc_password": vnc_password,
"vnc_port": vnc_port,
},
200,
)
@blueprint.route("/api/vnc/establish", methods=["POST"])
def vnc_establish():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
user = _current_user(app) or {}
operator_id = user.get("username") or None
body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id"))
if not agent_id:
return jsonify({"error": "agent_id_required"}), 400
payload, status = _issue_session(agent_id, operator_id)
return jsonify(payload), status
@blueprint.route("/api/vnc/session", methods=["POST"])
def vnc_session():
return vnc_establish()
@blueprint.route("/api/vnc/disconnect", methods=["POST"])
def vnc_disconnect():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
user = _current_user(app) or {}
operator_id = user.get("username") or None
body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id"))
reason = _normalize_text(body.get("reason") or "operator_disconnect")
if not agent_id:
return jsonify({"error": "agent_id_required"}), 400
registry = ensure_vnc_proxy(adapters.context, logger=logger)
revoked = 0
if registry is not None:
try:
revoked = registry.revoke_agent(agent_id)
except Exception:
revoked = 0
emit_agent = getattr(adapters.context, "emit_agent_event", None)
if callable(emit_agent):
try:
emit_agent(agent_id, "vnc_stop", {"agent_id": agent_id, "reason": reason})
except Exception:
pass
_service_log_event(
"vnc_disconnect agent_id={0} operator={1} reason={2} revoked={3}".format(
agent_id,
operator_id or "-",
reason or "-",
revoked,
)
)
return jsonify({"status": "disconnected", "revoked": revoked, "reason": reason}), 200
app.register_blueprint(blueprint)

View File

@@ -1,9 +1,8 @@
# ======================================================
# Data\Engine\services\RemoteDesktop\__init__.py
# Description: Remote desktop services (Guacamole proxy + session management).
# Description: Remote desktop services (VNC proxy + session management).
#
# API Endpoints (if applicable): None
# ======================================================
"""Remote desktop service helpers for the Borealis Engine runtime."""

View File

@@ -1,369 +0,0 @@
# ======================================================
# Data\Engine\services\RemoteDesktop\guacamole_proxy.py
# Description: Guacamole tunnel proxy (WebSocket -> guacd) for RDP sessions.
#
# API Endpoints (if applicable): None
# ======================================================
"""Guacamole WebSocket proxy that bridges browser tunnels to guacd."""
from __future__ import annotations
import asyncio
import logging
import ssl
import threading
import time
import uuid
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
from urllib.parse import parse_qs, urlsplit
import websockets
GUAC_WS_PATH = "/guacamole"
_MAX_MESSAGE_SIZE = 100_000_000
@dataclass
class RdpSession:
token: str
agent_id: str
host: str
port: int
protocol: str
username: str
password: str
ignore_cert: bool
created_at: float
expires_at: float
operator_id: Optional[str] = None
domain: Optional[str] = None
security: Optional[str] = None
class RdpSessionRegistry:
def __init__(self, ttl_seconds: int, logger: logging.Logger) -> None:
self.ttl_seconds = max(30, int(ttl_seconds))
self.logger = logger
self._lock = threading.Lock()
self._sessions: Dict[str, RdpSession] = {}
def _cleanup(self, now: Optional[float] = None) -> None:
current = now if now is not None else time.time()
expired = [token for token, session in self._sessions.items() if session.expires_at <= current]
for token in expired:
self._sessions.pop(token, None)
def create(
self,
*,
agent_id: str,
host: str,
port: int,
username: str,
password: str,
protocol: str = "rdp",
ignore_cert: bool = True,
operator_id: Optional[str] = None,
domain: Optional[str] = None,
security: Optional[str] = None,
) -> RdpSession:
token = uuid.uuid4().hex
now = time.time()
expires_at = now + self.ttl_seconds
session = RdpSession(
token=token,
agent_id=agent_id,
host=host,
port=port,
protocol=protocol,
username=username,
password=password,
ignore_cert=ignore_cert,
created_at=now,
expires_at=expires_at,
operator_id=operator_id,
domain=domain,
security=security,
)
with self._lock:
self._cleanup(now)
self._sessions[token] = session
return session
def consume(self, token: str) -> Optional[RdpSession]:
if not token:
return None
with self._lock:
self._cleanup()
session = self._sessions.pop(token, None)
return session
class GuacamoleProxyServer:
def __init__(
self,
*,
host: str,
port: int,
guacd_host: str,
guacd_port: int,
registry: RdpSessionRegistry,
logger: logging.Logger,
ssl_context: Optional[ssl.SSLContext] = None,
) -> None:
self.host = host
self.port = port
self.guacd_host = guacd_host
self.guacd_port = guacd_port
self.registry = registry
self.logger = logger
self.ssl_context = ssl_context
self._thread: Optional[threading.Thread] = None
self._ready = threading.Event()
self._failed = threading.Event()
def ensure_started(self, timeout: float = 3.0) -> bool:
if self._thread and self._thread.is_alive():
return not self._failed.is_set()
self._failed.clear()
self._ready.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
self._ready.wait(timeout)
return not self._failed.is_set()
def _run(self) -> None:
try:
asyncio.run(self._serve())
except Exception as exc:
self._failed.set()
self.logger.error("Guacamole proxy server failed: %s", exc)
self._ready.set()
async def _serve(self) -> None:
self.logger.info(
"Starting Guacamole proxy on %s:%s (guacd %s:%s)",
self.host,
self.port,
self.guacd_host,
self.guacd_port,
)
try:
server = await websockets.serve(
self._handle_client,
self.host,
self.port,
ssl=self.ssl_context,
max_size=_MAX_MESSAGE_SIZE,
ping_interval=20,
ping_timeout=20,
)
except Exception:
self._failed.set()
self._ready.set()
raise
self._ready.set()
await server.wait_closed()
async def _handle_client(self, websocket, path: str) -> None:
parsed = urlsplit(path)
if parsed.path != GUAC_WS_PATH:
await websocket.close(code=1008, reason="invalid_path")
return
query = parse_qs(parsed.query or "")
token = (query.get("token") or [""])[0]
session = self.registry.consume(token)
if not session:
await websocket.close(code=1008, reason="invalid_session")
return
logger = self.logger.getChild("session")
logger.info("Guacamole session start agent_id=%s protocol=%s", session.agent_id, session.protocol)
try:
reader, writer = await asyncio.open_connection(self.guacd_host, self.guacd_port)
except Exception as exc:
logger.warning("guacd connect failed: %s", exc)
await websocket.close(code=1011, reason="guacd_unavailable")
return
try:
await self._perform_handshake(reader, writer, session)
except Exception as exc:
logger.warning("guacd handshake failed: %s", exc)
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
await websocket.close(code=1011, reason="handshake_failed")
return
async def _ws_to_guacd() -> None:
try:
async for message in websocket:
if message is None:
break
if isinstance(message, str):
data = message.encode("utf-8")
else:
data = bytes(message)
writer.write(data)
await writer.drain()
finally:
try:
writer.close()
except Exception:
pass
async def _guacd_to_ws() -> None:
try:
while True:
data = await reader.read(8192)
if not data:
break
await websocket.send(data.decode("utf-8", errors="ignore"))
finally:
try:
await websocket.close()
except Exception:
pass
await asyncio.wait(
[asyncio.create_task(_ws_to_guacd()), asyncio.create_task(_guacd_to_ws())],
return_when=asyncio.FIRST_COMPLETED,
)
logger.info("Guacamole session ended agent_id=%s", session.agent_id)
async def _perform_handshake(self, reader, writer, session: RdpSession) -> None:
writer.write(_encode_instruction("select", session.protocol))
await writer.drain()
buffer = b""
args = None
deadline = time.time() + 8
while time.time() < deadline:
parts, buffer = await _read_instruction(reader, buffer)
if not parts:
continue
op = parts[0]
if op == "args":
args = parts[1:]
break
if op == "error":
raise RuntimeError("guacd_error:" + " ".join(parts[1:]))
if not args:
raise RuntimeError("guacd_args_timeout")
params = {
"hostname": session.host,
"port": str(session.port),
"username": session.username or "",
"password": session.password or "",
}
if session.domain:
params["domain"] = session.domain
if session.security:
params["security"] = session.security
if session.ignore_cert:
params["ignore-cert"] = "true"
values = [params.get(name, "") for name in args]
writer.write(_encode_instruction("connect", *values))
await writer.drain()
def _encode_instruction(*elements: str) -> bytes:
parts = []
for element in elements:
text = "" if element is None else str(element)
parts.append(f"{len(text)}.{text}".encode("utf-8"))
return b",".join(parts) + b";"
def _parse_instruction(raw: bytes) -> Tuple[str, ...]:
parts = []
idx = 0
length = len(raw)
while idx < length:
dot = raw.find(b".", idx)
if dot < 0:
break
try:
element_len = int(raw[idx:dot].decode("ascii") or "0")
except Exception:
break
start = dot + 1
end = start + element_len
if end > length:
break
parts.append(raw[start:end].decode("utf-8", errors="ignore"))
idx = end
if idx < length and raw[idx:idx + 1] == b",":
idx += 1
return tuple(parts)
async def _read_instruction(reader, buffer: bytes) -> Tuple[Tuple[str, ...], bytes]:
while b";" not in buffer:
chunk = await reader.read(4096)
if not chunk:
break
buffer += chunk
if b";" not in buffer:
return tuple(), buffer
instruction, remainder = buffer.split(b";", 1)
if not instruction:
return tuple(), remainder
return _parse_instruction(instruction), remainder
def _build_ssl_context(cert_path: Optional[str], key_path: Optional[str]) -> Optional[ssl.SSLContext]:
if not cert_path or not key_path:
return None
try:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(certfile=cert_path, keyfile=key_path)
return context
except Exception:
return None
def ensure_guacamole_proxy(context: Any, *, logger: Optional[logging.Logger] = None) -> Optional[RdpSessionRegistry]:
if logger is None:
logger = context.logger if hasattr(context, "logger") else logging.getLogger("borealis.engine.rdp")
registry = getattr(context, "rdp_registry", None)
if registry is None:
ttl = int(getattr(context, "rdp_session_ttl_seconds", 120))
registry = RdpSessionRegistry(ttl_seconds=ttl, logger=logger)
setattr(context, "rdp_registry", registry)
proxy = getattr(context, "rdp_proxy", None)
if proxy is None:
cert_path = getattr(context, "tls_bundle_path", None) or getattr(context, "tls_cert_path", None)
ssl_context = _build_ssl_context(
cert_path,
getattr(context, "tls_key_path", None),
)
proxy = GuacamoleProxyServer(
host=str(getattr(context, "rdp_ws_host", "0.0.0.0")),
port=int(getattr(context, "rdp_ws_port", 4823)),
guacd_host=str(getattr(context, "guacd_host", "127.0.0.1")),
guacd_port=int(getattr(context, "guacd_port", 4822)),
registry=registry,
logger=logger.getChild("guacamole_proxy"),
ssl_context=ssl_context,
)
setattr(context, "rdp_proxy", proxy)
if not proxy.ensure_started():
logger.error("Guacamole proxy failed to start; RDP sessions unavailable.")
return None
return registry
__all__ = ["GUAC_WS_PATH", "RdpSessionRegistry", "GuacamoleProxyServer", "ensure_guacamole_proxy"]

View File

@@ -0,0 +1,285 @@
# ======================================================
# Data\Engine\services\RemoteDesktop\vnc_proxy.py
# Description: VNC tunnel proxy (WebSocket -> TCP) for noVNC sessions.
#
# API Endpoints (if applicable): None
# ======================================================
"""VNC WebSocket proxy that bridges browser sessions to agent VNC servers."""
from __future__ import annotations
import asyncio
import logging
import ssl
import threading
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Tuple
from urllib.parse import parse_qs, urlsplit
import websockets
VNC_WS_PATH = "/vnc"
_MAX_MESSAGE_SIZE = 100_000_000
@dataclass
class VncSession:
token: str
agent_id: str
host: str
port: int
created_at: float
expires_at: float
operator_id: Optional[str] = None
class VncSessionRegistry:
def __init__(self, ttl_seconds: int, logger: logging.Logger) -> None:
self.ttl_seconds = max(30, int(ttl_seconds))
self.logger = logger
self._lock = threading.Lock()
self._sessions: Dict[str, VncSession] = {}
def _cleanup(self, now: Optional[float] = None) -> None:
current = now if now is not None else time.time()
expired = [token for token, session in self._sessions.items() if session.expires_at <= current]
for token in expired:
self._sessions.pop(token, None)
def create(
self,
*,
agent_id: str,
host: str,
port: int,
operator_id: Optional[str] = None,
) -> VncSession:
token = uuid.uuid4().hex
now = time.time()
expires_at = now + self.ttl_seconds
session = VncSession(
token=token,
agent_id=agent_id,
host=host,
port=port,
created_at=now,
expires_at=expires_at,
operator_id=operator_id,
)
with self._lock:
self._cleanup(now)
self._sessions[token] = session
return session
def consume(self, token: str) -> Optional[VncSession]:
if not token:
return None
with self._lock:
self._cleanup()
session = self._sessions.pop(token, None)
return session
def revoke_agent(self, agent_id: str) -> int:
if not agent_id:
return 0
removed = 0
with self._lock:
self._cleanup()
tokens = [token for token, session in self._sessions.items() if session.agent_id == agent_id]
for token in tokens:
if self._sessions.pop(token, None):
removed += 1
return removed
class VncProxyServer:
def __init__(
self,
*,
host: str,
port: int,
registry: VncSessionRegistry,
logger: logging.Logger,
emit_agent_event: Optional[Callable[[str, str, Any], bool]] = None,
ssl_context: Optional[ssl.SSLContext] = None,
) -> None:
self.host = host
self.port = port
self.registry = registry
self.logger = logger
self._emit_agent_event = emit_agent_event
self.ssl_context = ssl_context
self._thread: Optional[threading.Thread] = None
self._ready = threading.Event()
self._failed = threading.Event()
def ensure_started(self, timeout: float = 3.0) -> bool:
if self._thread and self._thread.is_alive():
return not self._failed.is_set()
self._failed.clear()
self._ready.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
self._ready.wait(timeout)
return not self._failed.is_set()
def _run(self) -> None:
try:
asyncio.run(self._serve())
except Exception as exc:
self._failed.set()
self.logger.error("VNC proxy server failed: %s", exc)
self._ready.set()
async def _serve(self) -> None:
self.logger.info("Starting VNC proxy on %s:%s", self.host, self.port)
try:
server = await websockets.serve(
self._handle_client,
self.host,
self.port,
ssl=self.ssl_context,
max_size=_MAX_MESSAGE_SIZE,
ping_interval=20,
ping_timeout=20,
)
except Exception:
self._failed.set()
self._ready.set()
raise
self._ready.set()
await server.wait_closed()
async def _handle_client(self, websocket, path: str) -> None:
parsed = urlsplit(path)
if parsed.path != VNC_WS_PATH:
await websocket.close(code=1008, reason="invalid_path")
return
query = parse_qs(parsed.query or "")
token = (query.get("token") or [""])[0]
session = self.registry.consume(token)
if not session:
await websocket.close(code=1008, reason="invalid_session")
return
logger = self.logger.getChild("session")
logger.info("VNC session start agent_id=%s", session.agent_id)
try:
try:
reader, writer = await self._connect_vnc(session.host, session.port)
except Exception as exc:
logger.warning("VNC connect failed: %s", exc)
await websocket.close(code=1011, reason="vnc_unavailable")
return
async def _ws_to_tcp() -> None:
try:
async for message in websocket:
if message is None:
break
if isinstance(message, str):
data = message.encode("utf-8")
else:
data = bytes(message)
writer.write(data)
await writer.drain()
finally:
try:
writer.close()
except Exception:
pass
async def _tcp_to_ws() -> None:
try:
while True:
data = await reader.read(8192)
if not data:
break
await websocket.send(data)
finally:
try:
await websocket.close()
except Exception:
pass
await asyncio.wait(
[asyncio.create_task(_ws_to_tcp()), asyncio.create_task(_tcp_to_ws())],
return_when=asyncio.FIRST_COMPLETED,
)
finally:
logger.info("VNC session ended agent_id=%s", session.agent_id)
self._notify_agent_session_end(session, reason="vnc_session_end")
async def _connect_vnc(self, host: str, port: int) -> Tuple[Any, Any]:
attempts = 5
delay = 0.5
last_exc: Optional[Exception] = None
for attempt in range(attempts):
try:
return await asyncio.open_connection(host, port)
except Exception as exc:
last_exc = exc
if attempt < attempts - 1:
await asyncio.sleep(delay)
if last_exc:
raise last_exc
raise RuntimeError("vnc_connect_failed")
def _notify_agent_session_end(self, session: VncSession, reason: str) -> None:
if not self._emit_agent_event:
return
payload = {"agent_id": session.agent_id, "reason": reason}
try:
self._emit_agent_event(session.agent_id, "vnc_stop", payload)
except Exception:
self.logger.debug("Failed to emit vnc_stop for agent_id=%s", session.agent_id, exc_info=True)
def _build_ssl_context(cert_path: Optional[str], key_path: Optional[str]) -> Optional[ssl.SSLContext]:
if not cert_path or not key_path:
return None
try:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(certfile=cert_path, keyfile=key_path)
return context
except Exception:
return None
def ensure_vnc_proxy(context: Any, *, logger: Optional[logging.Logger] = None) -> Optional[VncSessionRegistry]:
if logger is None:
logger = context.logger if hasattr(context, "logger") else logging.getLogger("borealis.engine.vnc")
registry = getattr(context, "vnc_registry", None)
if registry is None:
ttl = int(getattr(context, "vnc_session_ttl_seconds", 120))
registry = VncSessionRegistry(ttl_seconds=ttl, logger=logger)
setattr(context, "vnc_registry", registry)
proxy = getattr(context, "vnc_proxy", None)
if proxy is None:
cert_path = getattr(context, "tls_bundle_path", None) or getattr(context, "tls_cert_path", None)
ssl_context = _build_ssl_context(
cert_path,
getattr(context, "tls_key_path", None),
)
proxy = VncProxyServer(
host=str(getattr(context, "vnc_ws_host", "0.0.0.0")),
port=int(getattr(context, "vnc_ws_port", 4823)),
registry=registry,
logger=logger.getChild("vnc_proxy"),
emit_agent_event=getattr(context, "emit_agent_event", None),
ssl_context=ssl_context,
)
setattr(context, "vnc_proxy", proxy)
if not proxy.ensure_started():
logger.error("VNC proxy failed to start; VNC sessions unavailable.")
return None
return registry
__all__ = ["VNC_WS_PATH", "VncSessionRegistry", "VncProxyServer", "ensure_vnc_proxy"]

View File

@@ -12,6 +12,7 @@ from __future__ import annotations
import base64
import ipaddress
import json
import os
import threading
import time
import uuid
@@ -22,6 +23,13 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from .wireguard_server import WireGuardServerManager
def _env_flag(name: str, *, default: bool) -> bool:
value = os.environ.get(name)
if value is None:
return default
return str(value).strip().lower() not in ("0", "false", "no", "off")
@dataclass
class VpnSession:
tunnel_id: str
@@ -62,14 +70,17 @@ class VpnTunnelService:
self.logger = context.logger.getChild("vpn_tunnel")
self.activity_logger = self.wg.logger.getChild("device_activity")
self.idle_seconds = max(60, int(idle_seconds))
self.persistent = _env_flag("BOREALIS_WIREGUARD_PERSISTENT", default=True)
self._lock = threading.Lock()
self._sessions_by_agent: Dict[str, VpnSession] = {}
self._sessions_by_tunnel: Dict[str, VpnSession] = {}
self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip)
self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False)
self._cleanup_listener()
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
self._idle_thread.start()
self._idle_thread: Optional[threading.Thread] = None
if not self.persistent:
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
self._idle_thread.start()
def _idle_loop(self) -> None:
while True:
@@ -90,7 +101,7 @@ class VpnTunnelService:
self.idle_seconds,
)
)
self.disconnect(session.agent_id, reason="idle_timeout")
self.disconnect(session.agent_id, reason="idle_timeout", force=True)
def _allocate_virtual_ip(self, agent_id: str) -> str:
existing = self._sessions_by_agent.get(agent_id)
@@ -226,12 +237,12 @@ class VpnTunnelService:
self.logger.debug("Failed to write vpn_tunnel service log entry", exc_info=True)
def _cleanup_listener(self) -> None:
try:
self.wg.stop_listener(ignore_missing=True)
self._service_log_event("vpn_listener_cleanup reason=startup")
except Exception:
self.logger.debug("Failed to clean up WireGuard listener on startup.", exc_info=True)
self._service_log_event("vpn_listener_cleanup_failed reason=startup", level="WARNING")
self._service_log_event("vpn_listener_cleanup_skipped reason=startup")
def _is_soft_disconnect(self, reason: Optional[str]) -> bool:
if not reason:
return False
return str(reason).lower() in ("operator_disconnect", "component_unmount")
def _refresh_listener(self) -> None:
peers: List[Mapping[str, object]] = []
@@ -432,15 +443,60 @@ class VpnTunnelService:
except Exception:
self.logger.debug("vpn_tunnel_activity emit failed for agent_id=%s", agent_id, exc_info=True)
def disconnect(self, agent_id: str, reason: str = "operator_stop") -> bool:
def disconnect(
self,
agent_id: str,
reason: str = "operator_stop",
*,
operator_id: Optional[str] = None,
force: bool = False,
) -> bool:
with self._lock:
session = self._sessions_by_agent.pop(agent_id, None)
session = self._sessions_by_agent.get(agent_id)
if not session:
self._service_log_event(
"vpn_tunnel_disconnect_missing agent_id={0} reason={1}".format(agent_id or "-", reason or "-")
)
return False
self._sessions_by_tunnel.pop(session.tunnel_id, None)
if self.persistent and not force:
if operator_id:
try:
session.operator_ids.discard(operator_id)
except Exception:
pass
session.last_activity = time.time()
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
self._service_log_event(
"vpn_tunnel_keepalive agent_id={0} tunnel_id={1} reason={2} operators={3}".format(
session.agent_id,
session.tunnel_id,
reason or "-",
operator_text,
)
)
return True
if not force and self._is_soft_disconnect(reason):
if operator_id:
try:
session.operator_ids.discard(operator_id)
except Exception:
pass
session.last_activity = time.time()
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
self._service_log_event(
"vpn_tunnel_keepalive agent_id={0} tunnel_id={1} reason={2} operators={3}".format(
session.agent_id,
session.tunnel_id,
reason or "-",
operator_text,
)
)
return True
session = self._sessions_by_agent.pop(agent_id, None)
if session:
self._sessions_by_tunnel.pop(session.tunnel_id, None)
else:
return False
try:
self.wg.remove_firewall_rules(session.firewall_rules)
@@ -461,7 +517,14 @@ class VpnTunnelService:
self._log_device_activity(session, event="stop", reason=reason)
return True
def disconnect_by_tunnel(self, tunnel_id: str, reason: str = "operator_stop") -> bool:
def disconnect_by_tunnel(
self,
tunnel_id: str,
reason: str = "operator_stop",
*,
operator_id: Optional[str] = None,
force: bool = False,
) -> bool:
with self._lock:
session = self._sessions_by_tunnel.get(tunnel_id)
if not session:
@@ -469,7 +532,7 @@ class VpnTunnelService:
"vpn_tunnel_disconnect_missing tunnel_id={0} reason={1}".format(tunnel_id or "-", reason or "-")
)
return False
return self.disconnect(session.agent_id, reason=reason)
return self.disconnect(session.agent_id, reason=reason, operator_id=operator_id, force=force)
def _emit_start(self, payload: Mapping[str, Any]) -> None:
if not self.socketio:
@@ -704,7 +767,7 @@ class VpnTunnelService:
"server_public_key": self.wg.server_public_key,
"client_public_key": session.client_public_key,
"client_private_key": session.client_private_key,
"idle_seconds": self.idle_seconds,
"idle_seconds": 0 if self.persistent else self.idle_seconds,
"allowed_ports": list(session.allowed_ports),
"connected_operators": len([o for o in session.operator_ids if o]),
}
@@ -729,6 +792,6 @@ class VpnTunnelService:
"last_activity_iso": self._ts_to_iso(session.last_activity),
"expires_at": int(session.expires_at),
"expires_at_iso": self._ts_to_iso(session.expires_at),
"idle_seconds": self.idle_seconds,
"idle_seconds": 0 if self.persistent else self.idle_seconds,
"status": "up",
}

View File

@@ -164,6 +164,25 @@ class WireGuardServerManager:
return match.group(1).upper()
return None
def _service_exists(self) -> bool:
code, _, _ = self._run_command(["sc.exe", "query", self._service_id()])
return code == 0
def _stop_service(self, *, timeout: int = 20) -> bool:
service_id = self._service_id()
state = self._query_service_state()
if not state:
return False
if state == "STOPPED":
return True
self._run_command(["sc.exe", "stop", service_id])
for _ in range(max(1, timeout)):
time.sleep(1)
state = self._query_service_state()
if state == "STOPPED":
return True
return False
def _ensure_service_display_name(self) -> None:
if not self._service_display_name:
return
@@ -172,9 +191,9 @@ class WireGuardServerManager:
if code != 0 and err:
self.logger.warning("Failed to set WireGuard service display name: %s", err)
def _ensure_service_running(self) -> None:
def _ensure_service_running(self, *, timeout: int = 20) -> None:
service_id = self._service_id()
for _ in range(6):
for _ in range(max(1, timeout)):
state = self._query_service_state()
if state == "RUNNING":
return
@@ -183,8 +202,20 @@ class WireGuardServerManager:
if code != 0:
self.logger.error("Failed to start WireGuard tunnel service %s err=%s", service_id, err)
break
if state in ("START_PENDING", "STOP_PENDING"):
time.sleep(1)
continue
time.sleep(1)
state = self._query_service_state()
if state == "START_PENDING":
self.logger.warning("WireGuard tunnel service still START_PENDING; attempting restart.")
self._stop_service(timeout=10)
self._run_command(["sc.exe", "start", service_id])
for _ in range(10):
time.sleep(1)
if self._query_service_state() == "RUNNING":
return
state = self._query_service_state()
raise RuntimeError(f"WireGuard tunnel service {service_id} failed to start (state={state})")
def _normalise_allowed_ports(
@@ -329,6 +360,7 @@ class WireGuardServerManager:
for idx, rule in enumerate(rules):
name = f"Borealis-WG-Agent-{peer.get('agent_id','')}-{idx}"
protocol = str(rule.get("protocol") or "TCP").upper()
self._run_command(["netsh", "advfirewall", "firewall", "delete", "rule", f"name={name}"])
args = [
"netsh",
"advfirewall",
@@ -374,8 +406,12 @@ class WireGuardServerManager:
config_path.write_text(rendered, encoding="utf-8")
self.logger.info("Rendered WireGuard config to %s", config_path)
# Ensure old service is removed before re-installing.
self.stop_listener()
if self._service_exists():
if not self._stop_service(timeout=20):
self.logger.warning("WireGuard tunnel service did not stop cleanly before restart.")
self._ensure_service_display_name()
self._ensure_service_running(timeout=25)
return
args = [self._wireguard_exe, "/installtunnelservice", str(config_path)]
code, out, err = self._run_command(args)
@@ -384,21 +420,22 @@ class WireGuardServerManager:
raise RuntimeError(f"WireGuard installtunnelservice failed: {err}")
self.logger.info("WireGuard listener installed (service=%s)", config_path.stem)
self._ensure_service_display_name()
self._ensure_service_running()
self._ensure_service_running(timeout=25)
def stop_listener(self, *, ignore_missing: bool = False) -> None:
"""Stop and remove the WireGuard tunnel service."""
"""Stop the WireGuard tunnel service (leave installed for reuse)."""
args = [self._wireguard_exe, "/uninstalltunnelservice", self._service_name]
code, out, err = self._run_command(args)
if code != 0:
err_text = " ".join([out or "", err or ""]).strip().lower()
if ignore_missing and ("does not exist" in err_text or "not exist" in err_text):
if not self._service_exists():
if ignore_missing:
self.logger.info("WireGuard tunnel service already absent")
return
self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err)
else:
self.logger.info("WireGuard tunnel service removed")
self.logger.warning("WireGuard tunnel service not found during stop.")
return
if not self._stop_service(timeout=20):
self.logger.warning("WireGuard tunnel service did not stop cleanly.")
return
self.logger.info("WireGuard tunnel service stopped")
def build_firewall_rules(
self,

View File

@@ -187,7 +187,10 @@ class VpnShellBridge:
existing.close()
status = service.status(agent_id)
if not status:
return None
try:
status = service.connect(agent_id=agent_id, operator_id=None, endpoint_host=None)
except Exception:
return None
host = str(status.get("virtual_ip") or "").split("/")[0]
port = int(self.context.wireguard_shell_port)
tcp = None