Files
Borealis-Github-Replica/Data/Engine/services/API/devices/tunnel.py

306 lines
12 KiB
Python

# ======================================================
# Data\Engine\services\API\devices\tunnel.py
# Description: WireGuard VPN tunnel API (connect/status/disconnect).
#
# API Endpoints (if applicable):
# - POST /api/tunnel/connect (Token Authenticated) - Issues VPN session material for an agent.
# - GET /api/tunnel/status (Token Authenticated) - Returns VPN status for an agent.
# - GET /api/tunnel/active (Token Authenticated) - Lists active VPN tunnel sessions.
# - DELETE /api/tunnel/disconnect (Token Authenticated) - Tears down VPN session for an agent.
# ======================================================
"""WireGuard VPN tunnel API (Engine side)."""
from __future__ import annotations
import os
from urllib.parse import urlsplit
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from flask import Blueprint, jsonify, request, session
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from ...VPN import WireGuardServerConfig, WireGuardServerManager, VpnTunnelService
if False: # pragma: no cover - import cycle hint for type checkers
from .. import EngineServiceAdapters
def _current_user(app) -> Optional[Dict[str, str]]:
username = session.get("username")
role = session.get("role") or "User"
if username:
return {"username": username, "role": role}
token = None
auth_header = request.headers.get("Authorization") or ""
if auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
if not token:
token = request.cookies.get("borealis_auth")
if not token:
return None
try:
serializer = URLSafeTimedSerializer(app.secret_key or "borealis-dev-secret", salt="borealis-auth")
token_ttl = int(os.environ.get("BOREALIS_TOKEN_TTL_SECONDS", 60 * 60 * 24 * 30))
data = serializer.loads(token, max_age=token_ttl)
username = data.get("u")
role = data.get("r") or "User"
if username:
return {"username": username, "role": role}
except (BadSignature, SignatureExpired, Exception):
return None
return None
def _require_login(app) -> Optional[Tuple[Dict[str, Any], int]]:
user = _current_user(app)
if not user:
return {"error": "unauthorized"}, 401
return None
def _get_tunnel_service(adapters: "EngineServiceAdapters") -> VpnTunnelService:
service = getattr(adapters.context, "vpn_tunnel_service", None) or getattr(adapters, "_vpn_tunnel_service", None)
if service is None:
manager = getattr(adapters.context, "wireguard_server_manager", None)
if manager is None:
try:
manager = WireGuardServerManager(
WireGuardServerConfig(
port=adapters.context.wireguard_port,
engine_virtual_ip=adapters.context.wireguard_engine_virtual_ip,
peer_network=adapters.context.wireguard_peer_network,
private_key_path=Path(adapters.context.wireguard_server_private_key_path),
public_key_path=Path(adapters.context.wireguard_server_public_key_path),
acl_allowlist_windows=tuple(adapters.context.wireguard_acl_allowlist_windows),
log_path=Path(adapters.context.vpn_tunnel_log_path),
)
)
adapters.context.wireguard_server_manager = manager
except Exception as exc:
adapters.context.logger.error("Failed to initialize WireGuard server manager on demand.", exc_info=True)
raise RuntimeError("wireguard_manager_unavailable") from exc
service = VpnTunnelService(
context=adapters.context,
wireguard_manager=manager,
db_conn_factory=adapters.db_conn_factory,
socketio=getattr(adapters.context, "socketio", None),
service_log=adapters.service_log,
signer=getattr(adapters, "script_signer", None),
)
setattr(adapters, "_vpn_tunnel_service", service)
setattr(adapters.context, "vpn_tunnel_service", service)
return service
def _normalize_text(value: Any) -> str:
if value is None:
return ""
try:
return str(value).strip()
except Exception:
return ""
def _infer_endpoint_host(req) -> str:
forwarded = (req.headers.get("X-Forwarded-Host") or req.headers.get("X-Original-Host") or "").strip()
host = forwarded.split(",")[0].strip() if forwarded else (req.host or "").strip()
if not host:
return ""
try:
parsed = urlsplit(f"//{host}")
if parsed.hostname:
return parsed.hostname
except Exception:
return host
return host
def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
blueprint = Blueprint("vpn_tunnel", __name__)
logger = adapters.context.logger.getChild("vpn_tunnel.api")
service_log = adapters.service_log
def _service_log_event(message: str, *, level: str = "INFO") -> None:
if not callable(service_log):
return
try:
service_log("VPN_Tunnel/tunnel", message, level=level)
except Exception:
logger.debug("vpn_tunnel service log write failed", exc_info=True)
def _request_remote() -> str:
forwarded = (request.headers.get("X-Forwarded-For") or "").strip()
if forwarded:
return forwarded.split(",")[0].strip()
return (request.remote_addr or "").strip()
@blueprint.route("/api/tunnel/connect", methods=["POST"])
def connect_tunnel():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
user = _current_user(app) or {}
operator_id = user.get("username") or None
body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id"))
if not agent_id:
return jsonify({"error": "agent_id_required"}), 400
try:
tunnel_service = _get_tunnel_service(adapters)
endpoint_host = _infer_endpoint_host(request)
_service_log_event(
"vpn_api_connect_request agent_id={0} operator={1} endpoint_host={2} remote={3}".format(
agent_id,
operator_id or "-",
endpoint_host or "-",
_request_remote() or "-",
)
)
payload = tunnel_service.connect(
agent_id=agent_id,
operator_id=operator_id,
endpoint_host=endpoint_host,
)
except Exception as exc:
_service_log_event(
"vpn_api_connect_failed agent_id={0} operator={1} error={2}".format(
agent_id,
operator_id or "-",
str(exc),
),
level="ERROR",
)
logger.warning("vpn connect failed for agent_id=%s: %s", agent_id, exc)
return jsonify({"error": "connect_failed", "detail": str(exc)}), 500
_service_log_event(
"vpn_api_connect_response agent_id={0} tunnel_id={1} status=ok".format(
payload.get("agent_id", agent_id),
payload.get("tunnel_id", "-"),
)
)
return jsonify(payload), 200
@blueprint.route("/api/tunnel/status", methods=["GET"])
def tunnel_status():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
agent_id = _normalize_text(request.args.get("agent_id") or "")
if not agent_id:
return jsonify({"error": "agent_id_required"}), 400
tunnel_service = _get_tunnel_service(adapters)
payload = tunnel_service.status(agent_id)
agent_socket = False
registry = getattr(adapters.context, "agent_socket_registry", None)
if registry and hasattr(registry, "is_registered"):
try:
agent_socket = bool(registry.is_registered(agent_id))
except Exception:
agent_socket = False
bump = _normalize_text(request.args.get("bump") or "")
_service_log_event(
"vpn_api_status_request agent_id={0} bump={1} remote={2}".format(
agent_id,
"true" if bump else "false",
_request_remote() or "-",
)
)
if not payload:
_service_log_event(
"vpn_api_status_response agent_id={0} status=down".format(agent_id)
)
return jsonify({"status": "down", "agent_id": agent_id, "agent_socket": agent_socket}), 200
payload["status"] = "up"
payload["agent_socket"] = agent_socket
if bump:
tunnel_service.bump_activity(agent_id)
_service_log_event(
"vpn_api_status_response agent_id={0} status=up tunnel_id={1}".format(
agent_id,
payload.get("tunnel_id", "-"),
)
)
return jsonify(payload), 200
@blueprint.route("/api/tunnel/connect/status", methods=["GET"])
def tunnel_connect_status():
return tunnel_status()
@blueprint.route("/api/tunnel/active", methods=["GET"])
def tunnel_active():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
tunnel_service = _get_tunnel_service(adapters)
sessions = list(tunnel_service.list_sessions())
_service_log_event(
"vpn_api_active_response count={0} remote={1}".format(
len(sessions),
_request_remote() or "-",
)
)
return jsonify({"count": len(sessions), "tunnels": sessions}), 200
@blueprint.route("/api/tunnel/disconnect", methods=["DELETE"])
def disconnect_tunnel():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id"))
tunnel_id = _normalize_text(body.get("tunnel_id"))
reason = _normalize_text(body.get("reason") or "operator_stop")
tunnel_service = _get_tunnel_service(adapters)
_service_log_event(
"vpn_api_disconnect_request agent_id={0} tunnel_id={1} reason={2} operator={3} remote={4}".format(
agent_id or "-",
tunnel_id or "-",
reason or "-",
(_current_user(app) or {}).get("username") or "-",
_request_remote() or "-",
)
)
stopped = False
if tunnel_id:
stopped = tunnel_service.disconnect_by_tunnel(tunnel_id, reason=reason)
elif agent_id:
stopped = tunnel_service.disconnect(agent_id, reason=reason)
else:
return jsonify({"error": "agent_id_required"}), 400
if not stopped:
_service_log_event(
"vpn_api_disconnect_not_found agent_id={0} tunnel_id={1}".format(
agent_id or "-",
tunnel_id or "-",
),
level="WARNING",
)
return jsonify({"error": "not_found"}), 404
_service_log_event(
"vpn_api_disconnect_response agent_id={0} tunnel_id={1} status=stopped".format(
agent_id or "-",
tunnel_id or "-",
)
)
return jsonify({"status": "stopped", "reason": reason}), 200
app.register_blueprint(blueprint)