mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2026-02-06 20:00:31 -07:00
Removed RDP in favor of VNC / Made WireGuard Tunnel Persistent
This commit is contained in:
@@ -33,7 +33,8 @@ from ..auth import DevModeManager
|
||||
from .enrollment import routes as enrollment_routes
|
||||
from .tokens import routes as token_routes
|
||||
from .devices.tunnel import register_tunnel
|
||||
from .devices.rdp import register_rdp
|
||||
from .devices.vnc import register_vnc
|
||||
from .devices.shell import register_shell
|
||||
|
||||
from ...server import EngineContext
|
||||
from .access_management.login import register_auth
|
||||
@@ -292,7 +293,8 @@ def _register_devices(app: Flask, adapters: EngineServiceAdapters) -> None:
|
||||
register_admin_endpoints(app, adapters)
|
||||
device_routes.register_agents(app, adapters)
|
||||
register_tunnel(app, adapters)
|
||||
register_rdp(app, adapters)
|
||||
register_vnc(app, adapters)
|
||||
register_shell(app, adapters)
|
||||
|
||||
def _register_filters(app: Flask, adapters: EngineServiceAdapters) -> None:
|
||||
filters_management.register_filters(app, adapters)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
# API Endpoints (if applicable):
|
||||
# - POST /api/agent/heartbeat (Device Authenticated) - Updates device last-seen metadata and inventory snapshots.
|
||||
# - POST /api/agent/script/request (Device Authenticated) - Provides script execution payloads or idle signals to agents.
|
||||
# - POST /api/agent/vpn/ensure (Device Authenticated) - Ensures persistent WireGuard tunnel material.
|
||||
# ======================================================
|
||||
|
||||
"""Device-affiliated agent endpoints for the Borealis Engine runtime."""
|
||||
@@ -13,12 +14,14 @@ from __future__ import annotations
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from urllib.parse import urlsplit
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from flask import Blueprint, jsonify, request, g
|
||||
|
||||
from ....auth.device_auth import AGENT_CONTEXT_HEADER, require_device_auth
|
||||
from ....auth.guid_utils import normalize_guid
|
||||
from .tunnel import _get_tunnel_service
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - typing aide
|
||||
from .. import EngineServiceAdapters
|
||||
@@ -42,6 +45,20 @@ def _json_or_none(value: Any) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _infer_endpoint_host(req) -> str:
|
||||
forwarded = (req.headers.get("X-Forwarded-Host") or req.headers.get("X-Original-Host") or "").strip()
|
||||
host = forwarded.split(",")[0].strip() if forwarded else (req.host or "").strip()
|
||||
if not host:
|
||||
return ""
|
||||
try:
|
||||
parsed = urlsplit(f"//{host}")
|
||||
if parsed.hostname:
|
||||
return parsed.hostname
|
||||
except Exception:
|
||||
return host
|
||||
return host
|
||||
|
||||
|
||||
def register_agents(app, adapters: "EngineServiceAdapters") -> None:
|
||||
"""Register agent heartbeat and script polling routes."""
|
||||
|
||||
@@ -218,4 +235,76 @@ def register_agents(app, adapters: "EngineServiceAdapters") -> None:
|
||||
}
|
||||
)
|
||||
|
||||
@blueprint.route("/api/agent/vpn/ensure", methods=["POST"])
|
||||
@require_device_auth(auth_manager)
|
||||
def vpn_ensure():
|
||||
ctx = _auth_context()
|
||||
if ctx is None:
|
||||
return jsonify({"error": "auth_context_missing"}), 500
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
requested_agent = (body.get("agent_id") or "").strip()
|
||||
guid = normalize_guid(ctx.guid)
|
||||
|
||||
conn = db_conn_factory()
|
||||
resolved_agent = ""
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT agent_id FROM devices WHERE UPPER(guid) = ? ORDER BY last_seen DESC LIMIT 1",
|
||||
(guid,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row[0]:
|
||||
resolved_agent = str(row[0]).strip()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if not resolved_agent:
|
||||
log("VPN_Tunnel/tunnel", f"vpn_agent_ensure_missing_agent guid={guid}", _context_hint(ctx), level="ERROR")
|
||||
return jsonify({"error": "agent_id_missing"}), 404
|
||||
|
||||
if requested_agent and requested_agent != resolved_agent:
|
||||
log(
|
||||
"VPN_Tunnel/tunnel",
|
||||
"vpn_agent_ensure_agent_mismatch requested={0} resolved={1}".format(
|
||||
requested_agent, resolved_agent
|
||||
),
|
||||
_context_hint(ctx),
|
||||
level="WARNING",
|
||||
)
|
||||
|
||||
try:
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
endpoint_host = _infer_endpoint_host(request)
|
||||
log(
|
||||
"VPN_Tunnel/tunnel",
|
||||
"vpn_agent_ensure_request agent_id={0} endpoint_host={1}".format(
|
||||
resolved_agent, endpoint_host or "-"
|
||||
),
|
||||
_context_hint(ctx),
|
||||
)
|
||||
payload = tunnel_service.connect(
|
||||
agent_id=resolved_agent,
|
||||
operator_id=None,
|
||||
endpoint_host=endpoint_host,
|
||||
)
|
||||
except Exception as exc:
|
||||
log(
|
||||
"VPN_Tunnel/tunnel",
|
||||
"vpn_agent_ensure_failed agent_id={0} error={1}".format(resolved_agent, str(exc)),
|
||||
_context_hint(ctx),
|
||||
level="ERROR",
|
||||
)
|
||||
return jsonify({"error": "tunnel_start_failed", "detail": str(exc)}), 500
|
||||
|
||||
log(
|
||||
"VPN_Tunnel/tunnel",
|
||||
"vpn_agent_ensure_response agent_id={0} tunnel_id={1}".format(
|
||||
payload.get("agent_id", resolved_agent), payload.get("tunnel_id", "-")
|
||||
),
|
||||
_context_hint(ctx),
|
||||
)
|
||||
return jsonify(payload), 200
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\API\devices\rdp.py
|
||||
# Description: RDP session bootstrap for Guacamole WebSocket tunnels.
|
||||
# Data\Engine\services\API\devices\shell.py
|
||||
# Description: Remote PowerShell session endpoints for persistent WireGuard tunnels.
|
||||
#
|
||||
# API Endpoints (if applicable):
|
||||
# - POST /api/rdp/session (Token Authenticated) - Issues a one-time Guacamole tunnel token for RDP.
|
||||
# - POST /api/shell/establish (Token Authenticated) - Ensure shell readiness over WireGuard.
|
||||
# - POST /api/shell/disconnect (Token Authenticated) - Disconnect the operator shell session.
|
||||
# ======================================================
|
||||
|
||||
"""RDP session bootstrap endpoints for the Borealis Engine."""
|
||||
"""Remote PowerShell session endpoints for the Borealis Engine."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
@@ -16,7 +17,6 @@ from urllib.parse import urlsplit
|
||||
from flask import Blueprint, jsonify, request, session
|
||||
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
|
||||
|
||||
from ...RemoteDesktop.guacamole_proxy import GUAC_WS_PATH, ensure_guacamole_proxy
|
||||
from .tunnel import _get_tunnel_service
|
||||
|
||||
if False: # pragma: no cover - hint for type checkers
|
||||
@@ -81,25 +81,18 @@ def _infer_endpoint_host(req) -> str:
|
||||
return host
|
||||
|
||||
|
||||
def _is_secure(req) -> bool:
|
||||
if req.is_secure:
|
||||
return True
|
||||
forwarded = (req.headers.get("X-Forwarded-Proto") or "").split(",")[0].strip().lower()
|
||||
return forwarded == "https"
|
||||
|
||||
|
||||
def register_rdp(app, adapters: "EngineServiceAdapters") -> None:
|
||||
blueprint = Blueprint("rdp", __name__)
|
||||
logger = adapters.context.logger.getChild("rdp.api")
|
||||
def register_shell(app, adapters: "EngineServiceAdapters") -> None:
|
||||
blueprint = Blueprint("vpn_shell", __name__)
|
||||
logger = adapters.context.logger.getChild("vpn_shell.api")
|
||||
service_log = adapters.service_log
|
||||
|
||||
def _service_log_event(message: str, *, level: str = "INFO") -> None:
|
||||
if not callable(service_log):
|
||||
return
|
||||
try:
|
||||
service_log("RDP", message, level=level)
|
||||
service_log("VPN_Tunnel/remote_shell", message, level=level)
|
||||
except Exception:
|
||||
logger.debug("rdp service log write failed", exc_info=True)
|
||||
logger.debug("vpn_shell service log write failed", exc_info=True)
|
||||
|
||||
def _request_remote() -> str:
|
||||
forwarded = (request.headers.get("X-Forwarded-For") or "").strip()
|
||||
@@ -107,8 +100,8 @@ def register_rdp(app, adapters: "EngineServiceAdapters") -> None:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return (request.remote_addr or "").strip()
|
||||
|
||||
@blueprint.route("/api/rdp/session", methods=["POST"])
|
||||
def rdp_session():
|
||||
@blueprint.route("/api/shell/establish", methods=["POST"])
|
||||
def shell_establish():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
@@ -119,77 +112,82 @@ def register_rdp(app, adapters: "EngineServiceAdapters") -> None:
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
agent_id = _normalize_text(body.get("agent_id"))
|
||||
protocol = _normalize_text(body.get("protocol") or "rdp").lower()
|
||||
username = _normalize_text(body.get("username"))
|
||||
password = str(body.get("password") or "")
|
||||
if not agent_id:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
try:
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
endpoint_host = _infer_endpoint_host(request)
|
||||
_service_log_event(
|
||||
"vpn_shell_establish_request agent_id={0} operator={1} endpoint_host={2} remote={3}".format(
|
||||
agent_id,
|
||||
operator_id or "-",
|
||||
endpoint_host or "-",
|
||||
_request_remote() or "-",
|
||||
)
|
||||
)
|
||||
payload = tunnel_service.connect(
|
||||
agent_id=agent_id,
|
||||
operator_id=operator_id,
|
||||
endpoint_host=endpoint_host,
|
||||
)
|
||||
except Exception as exc:
|
||||
_service_log_event(
|
||||
"vpn_shell_establish_failed agent_id={0} operator={1} error={2}".format(
|
||||
agent_id,
|
||||
operator_id or "-",
|
||||
str(exc),
|
||||
),
|
||||
level="ERROR",
|
||||
)
|
||||
return jsonify({"error": "establish_failed", "detail": str(exc)}), 500
|
||||
|
||||
agent_socket = False
|
||||
registry = getattr(adapters.context, "agent_socket_registry", None)
|
||||
if registry and hasattr(registry, "is_registered"):
|
||||
try:
|
||||
agent_socket = bool(registry.is_registered(agent_id))
|
||||
except Exception:
|
||||
agent_socket = False
|
||||
|
||||
response = dict(payload)
|
||||
response["status"] = "ok"
|
||||
response["agent_socket"] = agent_socket
|
||||
_service_log_event(
|
||||
"vpn_shell_establish_response agent_id={0} tunnel_id={1} agent_socket={2}".format(
|
||||
agent_id,
|
||||
response.get("tunnel_id", "-"),
|
||||
str(agent_socket).lower(),
|
||||
)
|
||||
)
|
||||
return jsonify(response), 200
|
||||
|
||||
@blueprint.route("/api/shell/disconnect", methods=["POST"])
|
||||
def shell_disconnect():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
agent_id = _normalize_text(body.get("agent_id"))
|
||||
reason = _normalize_text(body.get("reason") or "operator_disconnect")
|
||||
operator_id = (_current_user(app) or {}).get("username") or None
|
||||
|
||||
if not agent_id:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
if protocol != "rdp":
|
||||
return jsonify({"error": "unsupported_protocol"}), 400
|
||||
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
session_payload = tunnel_service.session_payload(agent_id, include_token=False)
|
||||
if not session_payload:
|
||||
return jsonify({"error": "tunnel_down"}), 409
|
||||
|
||||
allowed_ports = session_payload.get("allowed_ports") or []
|
||||
if 3389 not in allowed_ports:
|
||||
return jsonify({"error": "rdp_not_allowed"}), 403
|
||||
|
||||
virtual_ip = _normalize_text(session_payload.get("virtual_ip"))
|
||||
host = virtual_ip.split("/")[0] if virtual_ip else ""
|
||||
if not host:
|
||||
return jsonify({"error": "virtual_ip_missing"}), 500
|
||||
|
||||
registry = ensure_guacamole_proxy(adapters.context, logger=logger)
|
||||
if registry is None:
|
||||
return jsonify({"error": "rdp_proxy_unavailable"}), 503
|
||||
|
||||
_service_log_event(
|
||||
"rdp_session_request agent_id={0} operator={1} protocol={2} remote={3}".format(
|
||||
"vpn_shell_disconnect_request agent_id={0} operator={1} reason={2} remote={3}".format(
|
||||
agent_id,
|
||||
operator_id or "-",
|
||||
protocol,
|
||||
reason or "-",
|
||||
_request_remote() or "-",
|
||||
)
|
||||
)
|
||||
|
||||
rdp_session = registry.create(
|
||||
agent_id=agent_id,
|
||||
host=host,
|
||||
port=3389,
|
||||
username=username,
|
||||
password=password,
|
||||
protocol=protocol,
|
||||
operator_id=operator_id,
|
||||
ignore_cert=True,
|
||||
)
|
||||
|
||||
ws_scheme = "wss" if _is_secure(request) else "ws"
|
||||
ws_host = _infer_endpoint_host(request)
|
||||
ws_port = int(getattr(adapters.context, "rdp_ws_port", 4823))
|
||||
ws_url = f"{ws_scheme}://{ws_host}:{ws_port}{GUAC_WS_PATH}"
|
||||
|
||||
_service_log_event(
|
||||
"rdp_session_ready agent_id={0} token={1} expires_at={2}".format(
|
||||
agent_id,
|
||||
rdp_session.token[:8],
|
||||
int(rdp_session.expires_at),
|
||||
)
|
||||
)
|
||||
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"token": rdp_session.token,
|
||||
"ws_url": ws_url,
|
||||
"expires_at": int(rdp_session.expires_at),
|
||||
"virtual_ip": host,
|
||||
"tunnel_id": session_payload.get("tunnel_id"),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
return jsonify({"status": "disconnected", "reason": reason}), 200
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
__all__ = ["register_shell"]
|
||||
@@ -1,12 +1,11 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\API\devices\tunnel.py
|
||||
# Description: WireGuard VPN tunnel API (connect/status/disconnect).
|
||||
# Description: WireGuard VPN tunnel API (connect/status).
|
||||
#
|
||||
# API Endpoints (if applicable):
|
||||
# - POST /api/tunnel/connect (Token Authenticated) - Issues VPN session material for an agent.
|
||||
# - GET /api/tunnel/status (Token Authenticated) - Returns VPN status for an agent.
|
||||
# - GET /api/tunnel/active (Token Authenticated) - Lists active VPN tunnel sessions.
|
||||
# - DELETE /api/tunnel/disconnect (Token Authenticated) - Tears down VPN session for an agent.
|
||||
# ======================================================
|
||||
|
||||
"""WireGuard VPN tunnel API (Engine side)."""
|
||||
@@ -254,52 +253,4 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
|
||||
)
|
||||
return jsonify({"count": len(sessions), "tunnels": sessions}), 200
|
||||
|
||||
@blueprint.route("/api/tunnel/disconnect", methods=["DELETE"])
|
||||
def disconnect_tunnel():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
agent_id = _normalize_text(body.get("agent_id"))
|
||||
tunnel_id = _normalize_text(body.get("tunnel_id"))
|
||||
reason = _normalize_text(body.get("reason") or "operator_stop")
|
||||
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
_service_log_event(
|
||||
"vpn_api_disconnect_request agent_id={0} tunnel_id={1} reason={2} operator={3} remote={4}".format(
|
||||
agent_id or "-",
|
||||
tunnel_id or "-",
|
||||
reason or "-",
|
||||
(_current_user(app) or {}).get("username") or "-",
|
||||
_request_remote() or "-",
|
||||
)
|
||||
)
|
||||
stopped = False
|
||||
if tunnel_id:
|
||||
stopped = tunnel_service.disconnect_by_tunnel(tunnel_id, reason=reason)
|
||||
elif agent_id:
|
||||
stopped = tunnel_service.disconnect(agent_id, reason=reason)
|
||||
else:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
if not stopped:
|
||||
_service_log_event(
|
||||
"vpn_api_disconnect_not_found agent_id={0} tunnel_id={1}".format(
|
||||
agent_id or "-",
|
||||
tunnel_id or "-",
|
||||
),
|
||||
level="WARNING",
|
||||
)
|
||||
return jsonify({"error": "not_found"}), 404
|
||||
|
||||
_service_log_event(
|
||||
"vpn_api_disconnect_response agent_id={0} tunnel_id={1} status=stopped".format(
|
||||
agent_id or "-",
|
||||
tunnel_id or "-",
|
||||
)
|
||||
)
|
||||
return jsonify({"status": "stopped", "reason": reason}), 200
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
338
Data/Engine/services/API/devices/vnc.py
Normal file
338
Data/Engine/services/API/devices/vnc.py
Normal file
@@ -0,0 +1,338 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\API\devices\vnc.py
|
||||
# Description: VNC session bootstrap for noVNC WebSocket tunnels.
|
||||
#
|
||||
# API Endpoints (if applicable):
|
||||
# - POST /api/vnc/establish (Token Authenticated) - Establish a VNC session for noVNC.
|
||||
# - POST /api/vnc/disconnect (Token Authenticated) - Disconnect the operator VNC session.
|
||||
# - POST /api/vnc/session (Token Authenticated) - Legacy alias for establish.
|
||||
# ======================================================
|
||||
|
||||
"""VNC session bootstrap endpoints for the Borealis Engine."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from flask import Blueprint, jsonify, request, session
|
||||
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
|
||||
|
||||
from ...RemoteDesktop.vnc_proxy import VNC_WS_PATH, ensure_vnc_proxy
|
||||
from .tunnel import _get_tunnel_service
|
||||
|
||||
if False: # pragma: no cover - hint for type checkers
|
||||
from .. import EngineServiceAdapters
|
||||
|
||||
|
||||
def _current_user(app) -> Optional[Dict[str, str]]:
|
||||
username = session.get("username")
|
||||
role = session.get("role") or "User"
|
||||
if username:
|
||||
return {"username": username, "role": role}
|
||||
|
||||
token = None
|
||||
auth_header = request.headers.get("Authorization") or ""
|
||||
if auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
if not token:
|
||||
token = request.cookies.get("borealis_auth")
|
||||
if not token:
|
||||
return None
|
||||
|
||||
try:
|
||||
serializer = URLSafeTimedSerializer(app.secret_key or "borealis-dev-secret", salt="borealis-auth")
|
||||
token_ttl = int(os.environ.get("BOREALIS_TOKEN_TTL_SECONDS", 60 * 60 * 24 * 30))
|
||||
data = serializer.loads(token, max_age=token_ttl)
|
||||
username = data.get("u")
|
||||
role = data.get("r") or "User"
|
||||
if username:
|
||||
return {"username": username, "role": role}
|
||||
except (BadSignature, SignatureExpired, Exception):
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _require_login(app) -> Optional[Tuple[Dict[str, Any], int]]:
|
||||
user = _current_user(app)
|
||||
if not user:
|
||||
return {"error": "unauthorized"}, 401
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
try:
|
||||
return str(value).strip()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _infer_endpoint_host(req) -> str:
|
||||
forwarded = (req.headers.get("X-Forwarded-Host") or req.headers.get("X-Original-Host") or "").strip()
|
||||
host = forwarded.split(",")[0].strip() if forwarded else (req.host or "").strip()
|
||||
if not host:
|
||||
return ""
|
||||
try:
|
||||
parsed = urlsplit(f"//{host}")
|
||||
if parsed.hostname:
|
||||
return parsed.hostname
|
||||
except Exception:
|
||||
return host
|
||||
return host
|
||||
|
||||
|
||||
def _is_secure(req) -> bool:
|
||||
if req.is_secure:
|
||||
return True
|
||||
forwarded = (req.headers.get("X-Forwarded-Proto") or "").split(",")[0].strip().lower()
|
||||
return forwarded == "https"
|
||||
|
||||
|
||||
def _generate_vnc_password() -> str:
|
||||
# UltraVNC uses the first 8 characters for VNC auth; keep the token to 8 for compatibility.
|
||||
return secrets.token_hex(4)
|
||||
|
||||
|
||||
def _load_vnc_password(adapters: "EngineServiceAdapters", agent_id: str) -> Optional[str]:
|
||||
conn = adapters.db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT agent_vnc_password FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
|
||||
(agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row[0]:
|
||||
return str(row[0]).strip()
|
||||
except Exception:
|
||||
adapters.context.logger.debug("Failed to load agent VNC password", exc_info=True)
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _store_vnc_password(adapters: "EngineServiceAdapters", agent_id: str, password: str) -> None:
|
||||
conn = adapters.db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"UPDATE devices SET agent_vnc_password=? WHERE agent_id=?",
|
||||
(password, agent_id),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
adapters.context.logger.debug("Failed to store agent VNC password", exc_info=True)
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def register_vnc(app, adapters: "EngineServiceAdapters") -> None:
|
||||
blueprint = Blueprint("vnc", __name__)
|
||||
logger = adapters.context.logger.getChild("vnc.api")
|
||||
service_log = adapters.service_log
|
||||
|
||||
def _service_log_event(message: str, *, level: str = "INFO") -> None:
|
||||
if not callable(service_log):
|
||||
return
|
||||
try:
|
||||
service_log("VNC", message, level=level)
|
||||
except Exception:
|
||||
logger.debug("vnc service log write failed", exc_info=True)
|
||||
|
||||
def _request_remote() -> str:
|
||||
forwarded = (request.headers.get("X-Forwarded-For") or "").strip()
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return (request.remote_addr or "").strip()
|
||||
|
||||
def _issue_session(agent_id: str, operator_id: Optional[str]) -> Tuple[Dict[str, Any], int]:
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
session_payload = tunnel_service.session_payload(agent_id, include_token=False)
|
||||
if not session_payload:
|
||||
try:
|
||||
session_payload = tunnel_service.connect(
|
||||
agent_id=agent_id,
|
||||
operator_id=operator_id,
|
||||
endpoint_host=_infer_endpoint_host(request),
|
||||
)
|
||||
except Exception:
|
||||
return {"error": "tunnel_down"}, 409
|
||||
|
||||
vnc_port = int(getattr(adapters.context, "vnc_port", 5900))
|
||||
raw_ports = session_payload.get("allowed_ports") or []
|
||||
allowed_ports = []
|
||||
for value in raw_ports:
|
||||
try:
|
||||
allowed_ports.append(int(value))
|
||||
except Exception:
|
||||
continue
|
||||
if vnc_port not in allowed_ports:
|
||||
return {"error": "vnc_not_allowed", "vnc_port": vnc_port}, 403
|
||||
|
||||
virtual_ip = _normalize_text(session_payload.get("virtual_ip"))
|
||||
host = virtual_ip.split("/")[0] if virtual_ip else ""
|
||||
if not host:
|
||||
return {"error": "virtual_ip_missing"}, 500
|
||||
|
||||
vnc_password = _load_vnc_password(adapters, agent_id)
|
||||
if not vnc_password:
|
||||
vnc_password = _generate_vnc_password()
|
||||
_store_vnc_password(adapters, agent_id, vnc_password)
|
||||
if len(vnc_password) > 8:
|
||||
vnc_password = vnc_password[:8]
|
||||
_store_vnc_password(adapters, agent_id, vnc_password)
|
||||
|
||||
registry = ensure_vnc_proxy(adapters.context, logger=logger)
|
||||
if registry is None:
|
||||
return {"error": "vnc_proxy_unavailable"}, 503
|
||||
|
||||
_service_log_event(
|
||||
"vnc_establish_request agent_id={0} operator={1} remote={2}".format(
|
||||
agent_id,
|
||||
operator_id or "-",
|
||||
_request_remote() or "-",
|
||||
)
|
||||
)
|
||||
|
||||
vnc_session = registry.create(
|
||||
agent_id=agent_id,
|
||||
host=host,
|
||||
port=vnc_port,
|
||||
operator_id=operator_id,
|
||||
)
|
||||
|
||||
emit_agent = getattr(adapters.context, "emit_agent_event", None)
|
||||
payload = {
|
||||
"agent_id": agent_id,
|
||||
"port": vnc_port,
|
||||
"allowed_ips": session_payload.get("allowed_ips"),
|
||||
"virtual_ip": host,
|
||||
"password": vnc_password,
|
||||
"reason": "vnc_session_start",
|
||||
}
|
||||
agent_socket_ready = True
|
||||
if callable(emit_agent):
|
||||
agent_socket_ready = bool(emit_agent(agent_id, "vnc_start", payload))
|
||||
if agent_socket_ready:
|
||||
_service_log_event(
|
||||
"vnc_start_emit agent_id={0} port={1} virtual_ip={2}".format(
|
||||
agent_id,
|
||||
vnc_port,
|
||||
host or "-",
|
||||
)
|
||||
)
|
||||
else:
|
||||
_service_log_event(
|
||||
"vnc_start_emit_failed agent_id={0} port={1}".format(
|
||||
agent_id,
|
||||
vnc_port,
|
||||
),
|
||||
level="WARNING",
|
||||
)
|
||||
if not agent_socket_ready:
|
||||
return {"error": "agent_socket_missing"}, 409
|
||||
|
||||
ws_scheme = "wss" if _is_secure(request) else "ws"
|
||||
ws_host = _infer_endpoint_host(request)
|
||||
ws_port = int(getattr(adapters.context, "vnc_ws_port", 4823))
|
||||
ws_url = f"{ws_scheme}://{ws_host}:{ws_port}{VNC_WS_PATH}"
|
||||
|
||||
_service_log_event(
|
||||
"vnc_session_ready agent_id={0} token={1} expires_at={2}".format(
|
||||
agent_id,
|
||||
vnc_session.token[:8],
|
||||
int(vnc_session.expires_at),
|
||||
)
|
||||
)
|
||||
|
||||
return (
|
||||
{
|
||||
"token": vnc_session.token,
|
||||
"ws_url": ws_url,
|
||||
"expires_at": int(vnc_session.expires_at),
|
||||
"virtual_ip": host,
|
||||
"tunnel_id": session_payload.get("tunnel_id"),
|
||||
"engine_virtual_ip": session_payload.get("engine_virtual_ip"),
|
||||
"allowed_ports": session_payload.get("allowed_ports"),
|
||||
"vnc_password": vnc_password,
|
||||
"vnc_port": vnc_port,
|
||||
},
|
||||
200,
|
||||
)
|
||||
|
||||
@blueprint.route("/api/vnc/establish", methods=["POST"])
|
||||
def vnc_establish():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
|
||||
user = _current_user(app) or {}
|
||||
operator_id = user.get("username") or None
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
agent_id = _normalize_text(body.get("agent_id"))
|
||||
|
||||
if not agent_id:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
payload, status = _issue_session(agent_id, operator_id)
|
||||
return jsonify(payload), status
|
||||
|
||||
@blueprint.route("/api/vnc/session", methods=["POST"])
|
||||
def vnc_session():
|
||||
return vnc_establish()
|
||||
|
||||
@blueprint.route("/api/vnc/disconnect", methods=["POST"])
|
||||
def vnc_disconnect():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
|
||||
user = _current_user(app) or {}
|
||||
operator_id = user.get("username") or None
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
agent_id = _normalize_text(body.get("agent_id"))
|
||||
reason = _normalize_text(body.get("reason") or "operator_disconnect")
|
||||
|
||||
if not agent_id:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
registry = ensure_vnc_proxy(adapters.context, logger=logger)
|
||||
revoked = 0
|
||||
if registry is not None:
|
||||
try:
|
||||
revoked = registry.revoke_agent(agent_id)
|
||||
except Exception:
|
||||
revoked = 0
|
||||
|
||||
emit_agent = getattr(adapters.context, "emit_agent_event", None)
|
||||
if callable(emit_agent):
|
||||
try:
|
||||
emit_agent(agent_id, "vnc_stop", {"agent_id": agent_id, "reason": reason})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_service_log_event(
|
||||
"vnc_disconnect agent_id={0} operator={1} reason={2} revoked={3}".format(
|
||||
agent_id,
|
||||
operator_id or "-",
|
||||
reason or "-",
|
||||
revoked,
|
||||
)
|
||||
)
|
||||
|
||||
return jsonify({"status": "disconnected", "revoked": revoked, "reason": reason}), 200
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
@@ -1,9 +1,8 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\RemoteDesktop\__init__.py
|
||||
# Description: Remote desktop services (Guacamole proxy + session management).
|
||||
# Description: Remote desktop services (VNC proxy + session management).
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Remote desktop service helpers for the Borealis Engine runtime."""
|
||||
|
||||
|
||||
@@ -1,369 +0,0 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\RemoteDesktop\guacamole_proxy.py
|
||||
# Description: Guacamole tunnel proxy (WebSocket -> guacd) for RDP sessions.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Guacamole WebSocket proxy that bridges browser tunnels to guacd."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import ssl
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from urllib.parse import parse_qs, urlsplit
|
||||
|
||||
import websockets
|
||||
|
||||
GUAC_WS_PATH = "/guacamole"
|
||||
_MAX_MESSAGE_SIZE = 100_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class RdpSession:
|
||||
token: str
|
||||
agent_id: str
|
||||
host: str
|
||||
port: int
|
||||
protocol: str
|
||||
username: str
|
||||
password: str
|
||||
ignore_cert: bool
|
||||
created_at: float
|
||||
expires_at: float
|
||||
operator_id: Optional[str] = None
|
||||
domain: Optional[str] = None
|
||||
security: Optional[str] = None
|
||||
|
||||
|
||||
class RdpSessionRegistry:
|
||||
def __init__(self, ttl_seconds: int, logger: logging.Logger) -> None:
|
||||
self.ttl_seconds = max(30, int(ttl_seconds))
|
||||
self.logger = logger
|
||||
self._lock = threading.Lock()
|
||||
self._sessions: Dict[str, RdpSession] = {}
|
||||
|
||||
def _cleanup(self, now: Optional[float] = None) -> None:
|
||||
current = now if now is not None else time.time()
|
||||
expired = [token for token, session in self._sessions.items() if session.expires_at <= current]
|
||||
for token in expired:
|
||||
self._sessions.pop(token, None)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
agent_id: str,
|
||||
host: str,
|
||||
port: int,
|
||||
username: str,
|
||||
password: str,
|
||||
protocol: str = "rdp",
|
||||
ignore_cert: bool = True,
|
||||
operator_id: Optional[str] = None,
|
||||
domain: Optional[str] = None,
|
||||
security: Optional[str] = None,
|
||||
) -> RdpSession:
|
||||
token = uuid.uuid4().hex
|
||||
now = time.time()
|
||||
expires_at = now + self.ttl_seconds
|
||||
session = RdpSession(
|
||||
token=token,
|
||||
agent_id=agent_id,
|
||||
host=host,
|
||||
port=port,
|
||||
protocol=protocol,
|
||||
username=username,
|
||||
password=password,
|
||||
ignore_cert=ignore_cert,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
operator_id=operator_id,
|
||||
domain=domain,
|
||||
security=security,
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup(now)
|
||||
self._sessions[token] = session
|
||||
return session
|
||||
|
||||
def consume(self, token: str) -> Optional[RdpSession]:
|
||||
if not token:
|
||||
return None
|
||||
with self._lock:
|
||||
self._cleanup()
|
||||
session = self._sessions.pop(token, None)
|
||||
return session
|
||||
|
||||
|
||||
class GuacamoleProxyServer:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host: str,
|
||||
port: int,
|
||||
guacd_host: str,
|
||||
guacd_port: int,
|
||||
registry: RdpSessionRegistry,
|
||||
logger: logging.Logger,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
) -> None:
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.guacd_host = guacd_host
|
||||
self.guacd_port = guacd_port
|
||||
self.registry = registry
|
||||
self.logger = logger
|
||||
self.ssl_context = ssl_context
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._ready = threading.Event()
|
||||
self._failed = threading.Event()
|
||||
|
||||
def ensure_started(self, timeout: float = 3.0) -> bool:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return not self._failed.is_set()
|
||||
self._failed.clear()
|
||||
self._ready.clear()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
self._ready.wait(timeout)
|
||||
return not self._failed.is_set()
|
||||
|
||||
def _run(self) -> None:
|
||||
try:
|
||||
asyncio.run(self._serve())
|
||||
except Exception as exc:
|
||||
self._failed.set()
|
||||
self.logger.error("Guacamole proxy server failed: %s", exc)
|
||||
self._ready.set()
|
||||
|
||||
async def _serve(self) -> None:
|
||||
self.logger.info(
|
||||
"Starting Guacamole proxy on %s:%s (guacd %s:%s)",
|
||||
self.host,
|
||||
self.port,
|
||||
self.guacd_host,
|
||||
self.guacd_port,
|
||||
)
|
||||
try:
|
||||
server = await websockets.serve(
|
||||
self._handle_client,
|
||||
self.host,
|
||||
self.port,
|
||||
ssl=self.ssl_context,
|
||||
max_size=_MAX_MESSAGE_SIZE,
|
||||
ping_interval=20,
|
||||
ping_timeout=20,
|
||||
)
|
||||
except Exception:
|
||||
self._failed.set()
|
||||
self._ready.set()
|
||||
raise
|
||||
self._ready.set()
|
||||
await server.wait_closed()
|
||||
|
||||
async def _handle_client(self, websocket, path: str) -> None:
|
||||
parsed = urlsplit(path)
|
||||
if parsed.path != GUAC_WS_PATH:
|
||||
await websocket.close(code=1008, reason="invalid_path")
|
||||
return
|
||||
query = parse_qs(parsed.query or "")
|
||||
token = (query.get("token") or [""])[0]
|
||||
session = self.registry.consume(token)
|
||||
if not session:
|
||||
await websocket.close(code=1008, reason="invalid_session")
|
||||
return
|
||||
|
||||
logger = self.logger.getChild("session")
|
||||
logger.info("Guacamole session start agent_id=%s protocol=%s", session.agent_id, session.protocol)
|
||||
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(self.guacd_host, self.guacd_port)
|
||||
except Exception as exc:
|
||||
logger.warning("guacd connect failed: %s", exc)
|
||||
await websocket.close(code=1011, reason="guacd_unavailable")
|
||||
return
|
||||
|
||||
try:
|
||||
await self._perform_handshake(reader, writer, session)
|
||||
except Exception as exc:
|
||||
logger.warning("guacd handshake failed: %s", exc)
|
||||
try:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
await websocket.close(code=1011, reason="handshake_failed")
|
||||
return
|
||||
|
||||
async def _ws_to_guacd() -> None:
|
||||
try:
|
||||
async for message in websocket:
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, str):
|
||||
data = message.encode("utf-8")
|
||||
else:
|
||||
data = bytes(message)
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
finally:
|
||||
try:
|
||||
writer.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _guacd_to_ws() -> None:
|
||||
try:
|
||||
while True:
|
||||
data = await reader.read(8192)
|
||||
if not data:
|
||||
break
|
||||
await websocket.send(data.decode("utf-8", errors="ignore"))
|
||||
finally:
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.wait(
|
||||
[asyncio.create_task(_ws_to_guacd()), asyncio.create_task(_guacd_to_ws())],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
logger.info("Guacamole session ended agent_id=%s", session.agent_id)
|
||||
|
||||
async def _perform_handshake(self, reader, writer, session: RdpSession) -> None:
|
||||
writer.write(_encode_instruction("select", session.protocol))
|
||||
await writer.drain()
|
||||
|
||||
buffer = b""
|
||||
args = None
|
||||
deadline = time.time() + 8
|
||||
|
||||
while time.time() < deadline:
|
||||
parts, buffer = await _read_instruction(reader, buffer)
|
||||
if not parts:
|
||||
continue
|
||||
op = parts[0]
|
||||
if op == "args":
|
||||
args = parts[1:]
|
||||
break
|
||||
if op == "error":
|
||||
raise RuntimeError("guacd_error:" + " ".join(parts[1:]))
|
||||
if not args:
|
||||
raise RuntimeError("guacd_args_timeout")
|
||||
|
||||
params = {
|
||||
"hostname": session.host,
|
||||
"port": str(session.port),
|
||||
"username": session.username or "",
|
||||
"password": session.password or "",
|
||||
}
|
||||
if session.domain:
|
||||
params["domain"] = session.domain
|
||||
if session.security:
|
||||
params["security"] = session.security
|
||||
if session.ignore_cert:
|
||||
params["ignore-cert"] = "true"
|
||||
|
||||
values = [params.get(name, "") for name in args]
|
||||
writer.write(_encode_instruction("connect", *values))
|
||||
await writer.drain()
|
||||
|
||||
|
||||
def _encode_instruction(*elements: str) -> bytes:
|
||||
parts = []
|
||||
for element in elements:
|
||||
text = "" if element is None else str(element)
|
||||
parts.append(f"{len(text)}.{text}".encode("utf-8"))
|
||||
return b",".join(parts) + b";"
|
||||
|
||||
|
||||
def _parse_instruction(raw: bytes) -> Tuple[str, ...]:
|
||||
parts = []
|
||||
idx = 0
|
||||
length = len(raw)
|
||||
while idx < length:
|
||||
dot = raw.find(b".", idx)
|
||||
if dot < 0:
|
||||
break
|
||||
try:
|
||||
element_len = int(raw[idx:dot].decode("ascii") or "0")
|
||||
except Exception:
|
||||
break
|
||||
start = dot + 1
|
||||
end = start + element_len
|
||||
if end > length:
|
||||
break
|
||||
parts.append(raw[start:end].decode("utf-8", errors="ignore"))
|
||||
idx = end
|
||||
if idx < length and raw[idx:idx + 1] == b",":
|
||||
idx += 1
|
||||
return tuple(parts)
|
||||
|
||||
|
||||
async def _read_instruction(reader, buffer: bytes) -> Tuple[Tuple[str, ...], bytes]:
|
||||
while b";" not in buffer:
|
||||
chunk = await reader.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
buffer += chunk
|
||||
if b";" not in buffer:
|
||||
return tuple(), buffer
|
||||
instruction, remainder = buffer.split(b";", 1)
|
||||
if not instruction:
|
||||
return tuple(), remainder
|
||||
return _parse_instruction(instruction), remainder
|
||||
|
||||
|
||||
def _build_ssl_context(cert_path: Optional[str], key_path: Optional[str]) -> Optional[ssl.SSLContext]:
|
||||
if not cert_path or not key_path:
|
||||
return None
|
||||
try:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.load_cert_chain(certfile=cert_path, keyfile=key_path)
|
||||
return context
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def ensure_guacamole_proxy(context: Any, *, logger: Optional[logging.Logger] = None) -> Optional[RdpSessionRegistry]:
|
||||
if logger is None:
|
||||
logger = context.logger if hasattr(context, "logger") else logging.getLogger("borealis.engine.rdp")
|
||||
|
||||
registry = getattr(context, "rdp_registry", None)
|
||||
if registry is None:
|
||||
ttl = int(getattr(context, "rdp_session_ttl_seconds", 120))
|
||||
registry = RdpSessionRegistry(ttl_seconds=ttl, logger=logger)
|
||||
setattr(context, "rdp_registry", registry)
|
||||
|
||||
proxy = getattr(context, "rdp_proxy", None)
|
||||
if proxy is None:
|
||||
cert_path = getattr(context, "tls_bundle_path", None) or getattr(context, "tls_cert_path", None)
|
||||
ssl_context = _build_ssl_context(
|
||||
cert_path,
|
||||
getattr(context, "tls_key_path", None),
|
||||
)
|
||||
proxy = GuacamoleProxyServer(
|
||||
host=str(getattr(context, "rdp_ws_host", "0.0.0.0")),
|
||||
port=int(getattr(context, "rdp_ws_port", 4823)),
|
||||
guacd_host=str(getattr(context, "guacd_host", "127.0.0.1")),
|
||||
guacd_port=int(getattr(context, "guacd_port", 4822)),
|
||||
registry=registry,
|
||||
logger=logger.getChild("guacamole_proxy"),
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
setattr(context, "rdp_proxy", proxy)
|
||||
|
||||
if not proxy.ensure_started():
|
||||
logger.error("Guacamole proxy failed to start; RDP sessions unavailable.")
|
||||
return None
|
||||
return registry
|
||||
|
||||
|
||||
__all__ = ["GUAC_WS_PATH", "RdpSessionRegistry", "GuacamoleProxyServer", "ensure_guacamole_proxy"]
|
||||
285
Data/Engine/services/RemoteDesktop/vnc_proxy.py
Normal file
285
Data/Engine/services/RemoteDesktop/vnc_proxy.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\RemoteDesktop\vnc_proxy.py
|
||||
# Description: VNC tunnel proxy (WebSocket -> TCP) for noVNC sessions.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""VNC WebSocket proxy that bridges browser sessions to agent VNC servers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import ssl
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from urllib.parse import parse_qs, urlsplit
|
||||
|
||||
import websockets
|
||||
|
||||
VNC_WS_PATH = "/vnc"
|
||||
_MAX_MESSAGE_SIZE = 100_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class VncSession:
|
||||
token: str
|
||||
agent_id: str
|
||||
host: str
|
||||
port: int
|
||||
created_at: float
|
||||
expires_at: float
|
||||
operator_id: Optional[str] = None
|
||||
|
||||
|
||||
class VncSessionRegistry:
|
||||
def __init__(self, ttl_seconds: int, logger: logging.Logger) -> None:
|
||||
self.ttl_seconds = max(30, int(ttl_seconds))
|
||||
self.logger = logger
|
||||
self._lock = threading.Lock()
|
||||
self._sessions: Dict[str, VncSession] = {}
|
||||
|
||||
def _cleanup(self, now: Optional[float] = None) -> None:
|
||||
current = now if now is not None else time.time()
|
||||
expired = [token for token, session in self._sessions.items() if session.expires_at <= current]
|
||||
for token in expired:
|
||||
self._sessions.pop(token, None)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
agent_id: str,
|
||||
host: str,
|
||||
port: int,
|
||||
operator_id: Optional[str] = None,
|
||||
) -> VncSession:
|
||||
token = uuid.uuid4().hex
|
||||
now = time.time()
|
||||
expires_at = now + self.ttl_seconds
|
||||
session = VncSession(
|
||||
token=token,
|
||||
agent_id=agent_id,
|
||||
host=host,
|
||||
port=port,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
operator_id=operator_id,
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup(now)
|
||||
self._sessions[token] = session
|
||||
return session
|
||||
|
||||
def consume(self, token: str) -> Optional[VncSession]:
|
||||
if not token:
|
||||
return None
|
||||
with self._lock:
|
||||
self._cleanup()
|
||||
session = self._sessions.pop(token, None)
|
||||
return session
|
||||
|
||||
def revoke_agent(self, agent_id: str) -> int:
|
||||
if not agent_id:
|
||||
return 0
|
||||
removed = 0
|
||||
with self._lock:
|
||||
self._cleanup()
|
||||
tokens = [token for token, session in self._sessions.items() if session.agent_id == agent_id]
|
||||
for token in tokens:
|
||||
if self._sessions.pop(token, None):
|
||||
removed += 1
|
||||
return removed
|
||||
|
||||
|
||||
class VncProxyServer:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host: str,
|
||||
port: int,
|
||||
registry: VncSessionRegistry,
|
||||
logger: logging.Logger,
|
||||
emit_agent_event: Optional[Callable[[str, str, Any], bool]] = None,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
) -> None:
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.registry = registry
|
||||
self.logger = logger
|
||||
self._emit_agent_event = emit_agent_event
|
||||
self.ssl_context = ssl_context
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._ready = threading.Event()
|
||||
self._failed = threading.Event()
|
||||
|
||||
def ensure_started(self, timeout: float = 3.0) -> bool:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return not self._failed.is_set()
|
||||
self._failed.clear()
|
||||
self._ready.clear()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
self._ready.wait(timeout)
|
||||
return not self._failed.is_set()
|
||||
|
||||
def _run(self) -> None:
|
||||
try:
|
||||
asyncio.run(self._serve())
|
||||
except Exception as exc:
|
||||
self._failed.set()
|
||||
self.logger.error("VNC proxy server failed: %s", exc)
|
||||
self._ready.set()
|
||||
|
||||
async def _serve(self) -> None:
|
||||
self.logger.info("Starting VNC proxy on %s:%s", self.host, self.port)
|
||||
try:
|
||||
server = await websockets.serve(
|
||||
self._handle_client,
|
||||
self.host,
|
||||
self.port,
|
||||
ssl=self.ssl_context,
|
||||
max_size=_MAX_MESSAGE_SIZE,
|
||||
ping_interval=20,
|
||||
ping_timeout=20,
|
||||
)
|
||||
except Exception:
|
||||
self._failed.set()
|
||||
self._ready.set()
|
||||
raise
|
||||
self._ready.set()
|
||||
await server.wait_closed()
|
||||
|
||||
async def _handle_client(self, websocket, path: str) -> None:
|
||||
parsed = urlsplit(path)
|
||||
if parsed.path != VNC_WS_PATH:
|
||||
await websocket.close(code=1008, reason="invalid_path")
|
||||
return
|
||||
query = parse_qs(parsed.query or "")
|
||||
token = (query.get("token") or [""])[0]
|
||||
session = self.registry.consume(token)
|
||||
if not session:
|
||||
await websocket.close(code=1008, reason="invalid_session")
|
||||
return
|
||||
|
||||
logger = self.logger.getChild("session")
|
||||
logger.info("VNC session start agent_id=%s", session.agent_id)
|
||||
|
||||
try:
|
||||
try:
|
||||
reader, writer = await self._connect_vnc(session.host, session.port)
|
||||
except Exception as exc:
|
||||
logger.warning("VNC connect failed: %s", exc)
|
||||
await websocket.close(code=1011, reason="vnc_unavailable")
|
||||
return
|
||||
|
||||
async def _ws_to_tcp() -> None:
|
||||
try:
|
||||
async for message in websocket:
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, str):
|
||||
data = message.encode("utf-8")
|
||||
else:
|
||||
data = bytes(message)
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
finally:
|
||||
try:
|
||||
writer.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _tcp_to_ws() -> None:
|
||||
try:
|
||||
while True:
|
||||
data = await reader.read(8192)
|
||||
if not data:
|
||||
break
|
||||
await websocket.send(data)
|
||||
finally:
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.wait(
|
||||
[asyncio.create_task(_ws_to_tcp()), asyncio.create_task(_tcp_to_ws())],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
finally:
|
||||
logger.info("VNC session ended agent_id=%s", session.agent_id)
|
||||
self._notify_agent_session_end(session, reason="vnc_session_end")
|
||||
|
||||
async def _connect_vnc(self, host: str, port: int) -> Tuple[Any, Any]:
|
||||
attempts = 5
|
||||
delay = 0.5
|
||||
last_exc: Optional[Exception] = None
|
||||
for attempt in range(attempts):
|
||||
try:
|
||||
return await asyncio.open_connection(host, port)
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if attempt < attempts - 1:
|
||||
await asyncio.sleep(delay)
|
||||
if last_exc:
|
||||
raise last_exc
|
||||
raise RuntimeError("vnc_connect_failed")
|
||||
|
||||
def _notify_agent_session_end(self, session: VncSession, reason: str) -> None:
|
||||
if not self._emit_agent_event:
|
||||
return
|
||||
payload = {"agent_id": session.agent_id, "reason": reason}
|
||||
try:
|
||||
self._emit_agent_event(session.agent_id, "vnc_stop", payload)
|
||||
except Exception:
|
||||
self.logger.debug("Failed to emit vnc_stop for agent_id=%s", session.agent_id, exc_info=True)
|
||||
|
||||
|
||||
def _build_ssl_context(cert_path: Optional[str], key_path: Optional[str]) -> Optional[ssl.SSLContext]:
|
||||
if not cert_path or not key_path:
|
||||
return None
|
||||
try:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.load_cert_chain(certfile=cert_path, keyfile=key_path)
|
||||
return context
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def ensure_vnc_proxy(context: Any, *, logger: Optional[logging.Logger] = None) -> Optional[VncSessionRegistry]:
|
||||
if logger is None:
|
||||
logger = context.logger if hasattr(context, "logger") else logging.getLogger("borealis.engine.vnc")
|
||||
|
||||
registry = getattr(context, "vnc_registry", None)
|
||||
if registry is None:
|
||||
ttl = int(getattr(context, "vnc_session_ttl_seconds", 120))
|
||||
registry = VncSessionRegistry(ttl_seconds=ttl, logger=logger)
|
||||
setattr(context, "vnc_registry", registry)
|
||||
|
||||
proxy = getattr(context, "vnc_proxy", None)
|
||||
if proxy is None:
|
||||
cert_path = getattr(context, "tls_bundle_path", None) or getattr(context, "tls_cert_path", None)
|
||||
ssl_context = _build_ssl_context(
|
||||
cert_path,
|
||||
getattr(context, "tls_key_path", None),
|
||||
)
|
||||
proxy = VncProxyServer(
|
||||
host=str(getattr(context, "vnc_ws_host", "0.0.0.0")),
|
||||
port=int(getattr(context, "vnc_ws_port", 4823)),
|
||||
registry=registry,
|
||||
logger=logger.getChild("vnc_proxy"),
|
||||
emit_agent_event=getattr(context, "emit_agent_event", None),
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
setattr(context, "vnc_proxy", proxy)
|
||||
|
||||
if not proxy.ensure_started():
|
||||
logger.error("VNC proxy failed to start; VNC sessions unavailable.")
|
||||
return None
|
||||
return registry
|
||||
|
||||
|
||||
__all__ = ["VNC_WS_PATH", "VncSessionRegistry", "VncProxyServer", "ensure_vnc_proxy"]
|
||||
@@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
import base64
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
@@ -22,6 +23,13 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||
from .wireguard_server import WireGuardServerManager
|
||||
|
||||
|
||||
def _env_flag(name: str, *, default: bool) -> bool:
|
||||
value = os.environ.get(name)
|
||||
if value is None:
|
||||
return default
|
||||
return str(value).strip().lower() not in ("0", "false", "no", "off")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VpnSession:
|
||||
tunnel_id: str
|
||||
@@ -62,14 +70,17 @@ class VpnTunnelService:
|
||||
self.logger = context.logger.getChild("vpn_tunnel")
|
||||
self.activity_logger = self.wg.logger.getChild("device_activity")
|
||||
self.idle_seconds = max(60, int(idle_seconds))
|
||||
self.persistent = _env_flag("BOREALIS_WIREGUARD_PERSISTENT", default=True)
|
||||
self._lock = threading.Lock()
|
||||
self._sessions_by_agent: Dict[str, VpnSession] = {}
|
||||
self._sessions_by_tunnel: Dict[str, VpnSession] = {}
|
||||
self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip)
|
||||
self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False)
|
||||
self._cleanup_listener()
|
||||
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
|
||||
self._idle_thread.start()
|
||||
self._idle_thread: Optional[threading.Thread] = None
|
||||
if not self.persistent:
|
||||
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
|
||||
self._idle_thread.start()
|
||||
|
||||
def _idle_loop(self) -> None:
|
||||
while True:
|
||||
@@ -90,7 +101,7 @@ class VpnTunnelService:
|
||||
self.idle_seconds,
|
||||
)
|
||||
)
|
||||
self.disconnect(session.agent_id, reason="idle_timeout")
|
||||
self.disconnect(session.agent_id, reason="idle_timeout", force=True)
|
||||
|
||||
def _allocate_virtual_ip(self, agent_id: str) -> str:
|
||||
existing = self._sessions_by_agent.get(agent_id)
|
||||
@@ -226,12 +237,12 @@ class VpnTunnelService:
|
||||
self.logger.debug("Failed to write vpn_tunnel service log entry", exc_info=True)
|
||||
|
||||
def _cleanup_listener(self) -> None:
|
||||
try:
|
||||
self.wg.stop_listener(ignore_missing=True)
|
||||
self._service_log_event("vpn_listener_cleanup reason=startup")
|
||||
except Exception:
|
||||
self.logger.debug("Failed to clean up WireGuard listener on startup.", exc_info=True)
|
||||
self._service_log_event("vpn_listener_cleanup_failed reason=startup", level="WARNING")
|
||||
self._service_log_event("vpn_listener_cleanup_skipped reason=startup")
|
||||
|
||||
def _is_soft_disconnect(self, reason: Optional[str]) -> bool:
|
||||
if not reason:
|
||||
return False
|
||||
return str(reason).lower() in ("operator_disconnect", "component_unmount")
|
||||
|
||||
def _refresh_listener(self) -> None:
|
||||
peers: List[Mapping[str, object]] = []
|
||||
@@ -432,15 +443,60 @@ class VpnTunnelService:
|
||||
except Exception:
|
||||
self.logger.debug("vpn_tunnel_activity emit failed for agent_id=%s", agent_id, exc_info=True)
|
||||
|
||||
def disconnect(self, agent_id: str, reason: str = "operator_stop") -> bool:
|
||||
def disconnect(
|
||||
self,
|
||||
agent_id: str,
|
||||
reason: str = "operator_stop",
|
||||
*,
|
||||
operator_id: Optional[str] = None,
|
||||
force: bool = False,
|
||||
) -> bool:
|
||||
with self._lock:
|
||||
session = self._sessions_by_agent.pop(agent_id, None)
|
||||
session = self._sessions_by_agent.get(agent_id)
|
||||
if not session:
|
||||
self._service_log_event(
|
||||
"vpn_tunnel_disconnect_missing agent_id={0} reason={1}".format(agent_id or "-", reason or "-")
|
||||
)
|
||||
return False
|
||||
self._sessions_by_tunnel.pop(session.tunnel_id, None)
|
||||
if self.persistent and not force:
|
||||
if operator_id:
|
||||
try:
|
||||
session.operator_ids.discard(operator_id)
|
||||
except Exception:
|
||||
pass
|
||||
session.last_activity = time.time()
|
||||
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
|
||||
self._service_log_event(
|
||||
"vpn_tunnel_keepalive agent_id={0} tunnel_id={1} reason={2} operators={3}".format(
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
reason or "-",
|
||||
operator_text,
|
||||
)
|
||||
)
|
||||
return True
|
||||
if not force and self._is_soft_disconnect(reason):
|
||||
if operator_id:
|
||||
try:
|
||||
session.operator_ids.discard(operator_id)
|
||||
except Exception:
|
||||
pass
|
||||
session.last_activity = time.time()
|
||||
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
|
||||
self._service_log_event(
|
||||
"vpn_tunnel_keepalive agent_id={0} tunnel_id={1} reason={2} operators={3}".format(
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
reason or "-",
|
||||
operator_text,
|
||||
)
|
||||
)
|
||||
return True
|
||||
session = self._sessions_by_agent.pop(agent_id, None)
|
||||
if session:
|
||||
self._sessions_by_tunnel.pop(session.tunnel_id, None)
|
||||
else:
|
||||
return False
|
||||
|
||||
try:
|
||||
self.wg.remove_firewall_rules(session.firewall_rules)
|
||||
@@ -461,7 +517,14 @@ class VpnTunnelService:
|
||||
self._log_device_activity(session, event="stop", reason=reason)
|
||||
return True
|
||||
|
||||
def disconnect_by_tunnel(self, tunnel_id: str, reason: str = "operator_stop") -> bool:
|
||||
def disconnect_by_tunnel(
|
||||
self,
|
||||
tunnel_id: str,
|
||||
reason: str = "operator_stop",
|
||||
*,
|
||||
operator_id: Optional[str] = None,
|
||||
force: bool = False,
|
||||
) -> bool:
|
||||
with self._lock:
|
||||
session = self._sessions_by_tunnel.get(tunnel_id)
|
||||
if not session:
|
||||
@@ -469,7 +532,7 @@ class VpnTunnelService:
|
||||
"vpn_tunnel_disconnect_missing tunnel_id={0} reason={1}".format(tunnel_id or "-", reason or "-")
|
||||
)
|
||||
return False
|
||||
return self.disconnect(session.agent_id, reason=reason)
|
||||
return self.disconnect(session.agent_id, reason=reason, operator_id=operator_id, force=force)
|
||||
|
||||
def _emit_start(self, payload: Mapping[str, Any]) -> None:
|
||||
if not self.socketio:
|
||||
@@ -704,7 +767,7 @@ class VpnTunnelService:
|
||||
"server_public_key": self.wg.server_public_key,
|
||||
"client_public_key": session.client_public_key,
|
||||
"client_private_key": session.client_private_key,
|
||||
"idle_seconds": self.idle_seconds,
|
||||
"idle_seconds": 0 if self.persistent else self.idle_seconds,
|
||||
"allowed_ports": list(session.allowed_ports),
|
||||
"connected_operators": len([o for o in session.operator_ids if o]),
|
||||
}
|
||||
@@ -729,6 +792,6 @@ class VpnTunnelService:
|
||||
"last_activity_iso": self._ts_to_iso(session.last_activity),
|
||||
"expires_at": int(session.expires_at),
|
||||
"expires_at_iso": self._ts_to_iso(session.expires_at),
|
||||
"idle_seconds": self.idle_seconds,
|
||||
"idle_seconds": 0 if self.persistent else self.idle_seconds,
|
||||
"status": "up",
|
||||
}
|
||||
|
||||
@@ -164,6 +164,25 @@ class WireGuardServerManager:
|
||||
return match.group(1).upper()
|
||||
return None
|
||||
|
||||
def _service_exists(self) -> bool:
|
||||
code, _, _ = self._run_command(["sc.exe", "query", self._service_id()])
|
||||
return code == 0
|
||||
|
||||
def _stop_service(self, *, timeout: int = 20) -> bool:
|
||||
service_id = self._service_id()
|
||||
state = self._query_service_state()
|
||||
if not state:
|
||||
return False
|
||||
if state == "STOPPED":
|
||||
return True
|
||||
self._run_command(["sc.exe", "stop", service_id])
|
||||
for _ in range(max(1, timeout)):
|
||||
time.sleep(1)
|
||||
state = self._query_service_state()
|
||||
if state == "STOPPED":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _ensure_service_display_name(self) -> None:
|
||||
if not self._service_display_name:
|
||||
return
|
||||
@@ -172,9 +191,9 @@ class WireGuardServerManager:
|
||||
if code != 0 and err:
|
||||
self.logger.warning("Failed to set WireGuard service display name: %s", err)
|
||||
|
||||
def _ensure_service_running(self) -> None:
|
||||
def _ensure_service_running(self, *, timeout: int = 20) -> None:
|
||||
service_id = self._service_id()
|
||||
for _ in range(6):
|
||||
for _ in range(max(1, timeout)):
|
||||
state = self._query_service_state()
|
||||
if state == "RUNNING":
|
||||
return
|
||||
@@ -183,8 +202,20 @@ class WireGuardServerManager:
|
||||
if code != 0:
|
||||
self.logger.error("Failed to start WireGuard tunnel service %s err=%s", service_id, err)
|
||||
break
|
||||
if state in ("START_PENDING", "STOP_PENDING"):
|
||||
time.sleep(1)
|
||||
continue
|
||||
time.sleep(1)
|
||||
state = self._query_service_state()
|
||||
if state == "START_PENDING":
|
||||
self.logger.warning("WireGuard tunnel service still START_PENDING; attempting restart.")
|
||||
self._stop_service(timeout=10)
|
||||
self._run_command(["sc.exe", "start", service_id])
|
||||
for _ in range(10):
|
||||
time.sleep(1)
|
||||
if self._query_service_state() == "RUNNING":
|
||||
return
|
||||
state = self._query_service_state()
|
||||
raise RuntimeError(f"WireGuard tunnel service {service_id} failed to start (state={state})")
|
||||
|
||||
def _normalise_allowed_ports(
|
||||
@@ -329,6 +360,7 @@ class WireGuardServerManager:
|
||||
for idx, rule in enumerate(rules):
|
||||
name = f"Borealis-WG-Agent-{peer.get('agent_id','')}-{idx}"
|
||||
protocol = str(rule.get("protocol") or "TCP").upper()
|
||||
self._run_command(["netsh", "advfirewall", "firewall", "delete", "rule", f"name={name}"])
|
||||
args = [
|
||||
"netsh",
|
||||
"advfirewall",
|
||||
@@ -374,8 +406,12 @@ class WireGuardServerManager:
|
||||
config_path.write_text(rendered, encoding="utf-8")
|
||||
self.logger.info("Rendered WireGuard config to %s", config_path)
|
||||
|
||||
# Ensure old service is removed before re-installing.
|
||||
self.stop_listener()
|
||||
if self._service_exists():
|
||||
if not self._stop_service(timeout=20):
|
||||
self.logger.warning("WireGuard tunnel service did not stop cleanly before restart.")
|
||||
self._ensure_service_display_name()
|
||||
self._ensure_service_running(timeout=25)
|
||||
return
|
||||
|
||||
args = [self._wireguard_exe, "/installtunnelservice", str(config_path)]
|
||||
code, out, err = self._run_command(args)
|
||||
@@ -384,21 +420,22 @@ class WireGuardServerManager:
|
||||
raise RuntimeError(f"WireGuard installtunnelservice failed: {err}")
|
||||
self.logger.info("WireGuard listener installed (service=%s)", config_path.stem)
|
||||
self._ensure_service_display_name()
|
||||
self._ensure_service_running()
|
||||
self._ensure_service_running(timeout=25)
|
||||
|
||||
def stop_listener(self, *, ignore_missing: bool = False) -> None:
|
||||
"""Stop and remove the WireGuard tunnel service."""
|
||||
"""Stop the WireGuard tunnel service (leave installed for reuse)."""
|
||||
|
||||
args = [self._wireguard_exe, "/uninstalltunnelservice", self._service_name]
|
||||
code, out, err = self._run_command(args)
|
||||
if code != 0:
|
||||
err_text = " ".join([out or "", err or ""]).strip().lower()
|
||||
if ignore_missing and ("does not exist" in err_text or "not exist" in err_text):
|
||||
if not self._service_exists():
|
||||
if ignore_missing:
|
||||
self.logger.info("WireGuard tunnel service already absent")
|
||||
return
|
||||
self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err)
|
||||
else:
|
||||
self.logger.info("WireGuard tunnel service removed")
|
||||
self.logger.warning("WireGuard tunnel service not found during stop.")
|
||||
return
|
||||
|
||||
if not self._stop_service(timeout=20):
|
||||
self.logger.warning("WireGuard tunnel service did not stop cleanly.")
|
||||
return
|
||||
self.logger.info("WireGuard tunnel service stopped")
|
||||
|
||||
def build_firewall_rules(
|
||||
self,
|
||||
|
||||
@@ -187,7 +187,10 @@ class VpnShellBridge:
|
||||
existing.close()
|
||||
status = service.status(agent_id)
|
||||
if not status:
|
||||
return None
|
||||
try:
|
||||
status = service.connect(agent_id=agent_id, operator_id=None, endpoint_host=None)
|
||||
except Exception:
|
||||
return None
|
||||
host = str(status.get("virtual_ip") or "").split("/")[0]
|
||||
port = int(self.context.wireguard_shell_port)
|
||||
tcp = None
|
||||
|
||||
Reference in New Issue
Block a user