mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-18 10:35:48 -07:00
Overhaul of VPN Codebase
This commit is contained in:
@@ -9,6 +9,8 @@
|
||||
# - GET /api/devices/<guid> (Token Authenticated) - Retrieves a single device record by GUID, including summary fields.
|
||||
# - GET /api/device/details/<hostname> (Token Authenticated) - Returns full device details keyed by hostname.
|
||||
# - POST /api/device/description/<hostname> (Token Authenticated) - Updates the human-readable description for a device.
|
||||
# - GET /api/device/vpn_config/<agent_id> (Token Authenticated) - Returns per-device VPN allowed port settings.
|
||||
# - PUT /api/device/vpn_config/<agent_id> (Token Authenticated) - Updates per-device VPN allowed port settings.
|
||||
# - GET /api/device_list_views (Token Authenticated) - Lists saved device table view definitions.
|
||||
# - GET /api/device_list_views/<int:view_id> (Token Authenticated) - Retrieves a specific saved device table view definition.
|
||||
# - POST /api/device_list_views (Token Authenticated) - Creates a custom device list view for the signed-in operator.
|
||||
@@ -426,6 +428,131 @@ class DeviceManagementService:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _parse_ports(self, raw: Any) -> List[int]:
|
||||
ports: List[int] = []
|
||||
if isinstance(raw, str):
|
||||
parts = [part.strip() for part in raw.split(",") if part.strip()]
|
||||
elif isinstance(raw, list):
|
||||
parts = raw
|
||||
else:
|
||||
parts = []
|
||||
for part in parts:
|
||||
try:
|
||||
value = int(part)
|
||||
except Exception:
|
||||
continue
|
||||
if 1 <= value <= 65535:
|
||||
ports.append(value)
|
||||
return list(dict.fromkeys(ports))
|
||||
|
||||
def _default_vpn_ports(self, os_name: Optional[str]) -> List[int]:
|
||||
ports = list(self.adapters.context.wireguard_acl_allowlist_windows or [])
|
||||
os_text = (os_name or "").strip().lower()
|
||||
if os_text and "windows" not in os_text:
|
||||
baseline = {5900, 3478}
|
||||
filtered = [p for p in ports if p in baseline]
|
||||
return filtered or ports
|
||||
return ports
|
||||
|
||||
def get_vpn_config(self, agent_id: str) -> Tuple[Dict[str, Any], int]:
|
||||
agent_id = (agent_id or "").strip()
|
||||
if not agent_id:
|
||||
return {"error": "agent_id_required"}, 400
|
||||
default_ports: List[int] = []
|
||||
shell_port = int(self.adapters.context.wireguard_shell_port)
|
||||
try:
|
||||
conn = self._db_conn()
|
||||
cur = conn.cursor()
|
||||
os_name = ""
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
|
||||
(agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row[0]:
|
||||
os_name = str(row[0])
|
||||
except Exception:
|
||||
os_name = ""
|
||||
default_ports = self._default_vpn_ports(os_name)
|
||||
cur.execute(
|
||||
"SELECT allowed_ports, updated_at, updated_by FROM device_vpn_config WHERE agent_id=?",
|
||||
(agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"allowed_ports": default_ports,
|
||||
"default_ports": default_ports,
|
||||
"shell_port": shell_port,
|
||||
"source": "default",
|
||||
}, 200
|
||||
raw_ports = row[0] or ""
|
||||
ports = []
|
||||
try:
|
||||
ports = json.loads(raw_ports) if raw_ports else []
|
||||
except Exception:
|
||||
ports = []
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"allowed_ports": ports or default_ports,
|
||||
"default_ports": default_ports,
|
||||
"shell_port": shell_port,
|
||||
"updated_at": row[1],
|
||||
"updated_by": row[2],
|
||||
"source": "custom" if ports else "default",
|
||||
}, 200
|
||||
except Exception as exc:
|
||||
self.logger.debug("Failed to load vpn config", exc_info=True)
|
||||
return {"error": "internal_error"}, 500
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def set_vpn_config(self, agent_id: str, payload: Dict[str, Any]) -> Tuple[Dict[str, Any], int]:
|
||||
agent_id = (agent_id or "").strip()
|
||||
if not agent_id:
|
||||
return {"error": "agent_id_required"}, 400
|
||||
ports = self._parse_ports(payload.get("allowed_ports"))
|
||||
if not ports:
|
||||
return {"error": "allowed_ports_required"}, 400
|
||||
user = self._current_user() or {}
|
||||
updated_by = user.get("username") or ""
|
||||
updated_at = datetime.now(timezone.utc).isoformat()
|
||||
try:
|
||||
conn = self._db_conn()
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO device_vpn_config(agent_id, allowed_ports, updated_at, updated_by)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(agent_id) DO UPDATE SET
|
||||
allowed_ports=excluded.allowed_ports,
|
||||
updated_at=excluded.updated_at,
|
||||
updated_by=excluded.updated_by
|
||||
""",
|
||||
(agent_id, json.dumps(ports), updated_at, updated_by),
|
||||
)
|
||||
conn.commit()
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"allowed_ports": ports,
|
||||
"updated_at": updated_at,
|
||||
"updated_by": updated_by,
|
||||
"source": "custom",
|
||||
}, 200
|
||||
except Exception:
|
||||
self.logger.debug("Failed to save vpn config", exc_info=True)
|
||||
return {"error": "internal_error"}, 500
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _require_login(self) -> Optional[Tuple[Dict[str, Any], int]]:
|
||||
if not self._current_user():
|
||||
return {"error": "unauthorized"}, 401
|
||||
@@ -1793,6 +1920,19 @@ def register_management(app, adapters: "EngineServiceAdapters") -> None:
|
||||
payload, status = service.set_device_description(hostname, description)
|
||||
return jsonify(payload), status
|
||||
|
||||
@blueprint.route("/api/device/vpn_config/<agent_id>", methods=["GET", "PUT"])
|
||||
def _vpn_config(agent_id: str):
|
||||
requirement = service._require_login()
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
if request.method == "GET":
|
||||
payload, status = service.get_vpn_config(agent_id)
|
||||
else:
|
||||
body = request.get_json(silent=True) or {}
|
||||
payload, status = service.set_vpn_config(agent_id, body)
|
||||
return jsonify(payload), status
|
||||
|
||||
@blueprint.route("/api/device_list_views", methods=["GET"])
|
||||
def _list_views():
|
||||
requirement = service._require_login()
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\API\devices\tunnel.py
|
||||
# Description: Negotiation endpoint for reverse tunnel leases (operator-initiated; dormant until tunnel listener is wired).
|
||||
# Description: WireGuard VPN tunnel API (connect/status/disconnect).
|
||||
#
|
||||
# API Endpoints (if applicable):
|
||||
# - POST /api/tunnel/request (Token Authenticated) - Allocates a reverse tunnel lease for the requested agent/protocol.
|
||||
# - POST /api/tunnel/connect (Token Authenticated) - Issues VPN session material for an agent.
|
||||
# - GET /api/tunnel/status (Token Authenticated) - Returns VPN status for an agent.
|
||||
# - DELETE /api/tunnel/disconnect (Token Authenticated) - Tears down VPN session for an agent.
|
||||
# ======================================================
|
||||
|
||||
"""Reverse tunnel negotiation API (Engine side)."""
|
||||
"""WireGuard VPN tunnel API (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
@@ -15,15 +17,13 @@ from typing import Any, Dict, Optional, Tuple
|
||||
from flask import Blueprint, jsonify, request, session
|
||||
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
|
||||
|
||||
from ...WebSocket.Agent.reverse_tunnel_orchestrator import ReverseTunnelService
|
||||
from ...VPN import VpnTunnelService
|
||||
|
||||
if False: # pragma: no cover - import cycle hint for type checkers
|
||||
from .. import EngineServiceAdapters
|
||||
|
||||
|
||||
def _current_user(app) -> Optional[Dict[str, str]]:
|
||||
"""Resolve operator identity from session or signed token."""
|
||||
|
||||
username = session.get("username")
|
||||
role = session.get("role") or "User"
|
||||
if username:
|
||||
@@ -58,18 +58,22 @@ def _require_login(app) -> Optional[Tuple[Dict[str, Any], int]]:
|
||||
return None
|
||||
|
||||
|
||||
def _get_tunnel_service(adapters: "EngineServiceAdapters") -> ReverseTunnelService:
|
||||
service = getattr(adapters.context, "reverse_tunnel_service", None) or getattr(adapters, "_reverse_tunnel_service", None)
|
||||
def _get_tunnel_service(adapters: "EngineServiceAdapters") -> VpnTunnelService:
|
||||
service = getattr(adapters.context, "vpn_tunnel_service", None) or getattr(adapters, "_vpn_tunnel_service", None)
|
||||
if service is None:
|
||||
service = ReverseTunnelService(
|
||||
adapters.context,
|
||||
signer=getattr(adapters, "script_signer", None),
|
||||
manager = getattr(adapters.context, "wireguard_server_manager", None)
|
||||
if manager is None:
|
||||
raise RuntimeError("wireguard_manager_unavailable")
|
||||
service = VpnTunnelService(
|
||||
context=adapters.context,
|
||||
wireguard_manager=manager,
|
||||
db_conn_factory=adapters.db_conn_factory,
|
||||
socketio=getattr(adapters.context, "socketio", None),
|
||||
service_log=adapters.service_log,
|
||||
signer=getattr(adapters, "script_signer", None),
|
||||
)
|
||||
service.start()
|
||||
setattr(adapters, "_reverse_tunnel_service", service)
|
||||
setattr(adapters.context, "reverse_tunnel_service", service)
|
||||
setattr(adapters, "_vpn_tunnel_service", service)
|
||||
setattr(adapters.context, "vpn_tunnel_service", service)
|
||||
return service
|
||||
|
||||
|
||||
@@ -83,14 +87,11 @@ def _normalize_text(value: Any) -> str:
|
||||
|
||||
|
||||
def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
|
||||
"""Register reverse tunnel negotiation endpoints."""
|
||||
blueprint = Blueprint("vpn_tunnel", __name__)
|
||||
logger = adapters.context.logger.getChild("vpn_tunnel.api")
|
||||
|
||||
blueprint = Blueprint("reverse_tunnel", __name__)
|
||||
service_log = adapters.service_log
|
||||
logger = adapters.context.logger.getChild("tunnel.api")
|
||||
|
||||
@blueprint.route("/api/tunnel/request", methods=["POST"])
|
||||
def request_tunnel():
|
||||
@blueprint.route("/api/tunnel/connect", methods=["POST"])
|
||||
def connect_tunnel():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
@@ -101,69 +102,67 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
agent_id = _normalize_text(body.get("agent_id"))
|
||||
protocol = _normalize_text(body.get("protocol") or "ps").lower() or "ps"
|
||||
domain = _normalize_text(body.get("domain") or protocol).lower() or protocol
|
||||
if protocol == "ps" and domain == "ps":
|
||||
domain = "remote-interactive-shell"
|
||||
|
||||
if not agent_id:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
try:
|
||||
lease = tunnel_service.request_lease(
|
||||
agent_id=agent_id,
|
||||
protocol=protocol,
|
||||
domain=domain,
|
||||
operator_id=operator_id,
|
||||
)
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
payload = tunnel_service.connect(agent_id=agent_id, operator_id=operator_id)
|
||||
except RuntimeError as exc:
|
||||
message = str(exc)
|
||||
if message.startswith("domain_limit:"):
|
||||
domain_name = message.split(":", 1)[-1] if ":" in message else domain
|
||||
return jsonify({"error": "domain_limit", "domain": domain_name}), 409
|
||||
if message == "port_pool_exhausted":
|
||||
return jsonify({"error": "port_pool_exhausted"}), 503
|
||||
logger.warning("tunnel lease request failed for agent_id=%s: %s", agent_id, message)
|
||||
return jsonify({"error": "lease_allocation_failed"}), 500
|
||||
logger.warning("vpn connect failed for agent_id=%s: %s", agent_id, exc)
|
||||
return jsonify({"error": "connect_failed"}), 500
|
||||
|
||||
summary = tunnel_service.lease_summary(lease)
|
||||
summary["fixed_port"] = tunnel_service.fixed_port
|
||||
summary["heartbeat_seconds"] = tunnel_service.heartbeat_seconds
|
||||
return jsonify(payload), 200
|
||||
|
||||
service_log(
|
||||
"reverse_tunnel",
|
||||
f"lease created tunnel_id={lease.tunnel_id} agent_id={lease.agent_id} domain={lease.domain} protocol={lease.protocol}",
|
||||
)
|
||||
return jsonify(summary), 200
|
||||
|
||||
@blueprint.route("/api/tunnel/<tunnel_id>", methods=["DELETE"])
|
||||
def stop_tunnel(tunnel_id: str):
|
||||
@blueprint.route("/api/tunnel/status", methods=["GET"])
|
||||
def tunnel_status():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
|
||||
tunnel_id_norm = _normalize_text(tunnel_id)
|
||||
if not tunnel_id_norm:
|
||||
return jsonify({"error": "tunnel_id_required"}), 400
|
||||
agent_id = _normalize_text(request.args.get("agent_id") or "")
|
||||
if not agent_id:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
payload = tunnel_service.status(agent_id)
|
||||
if not payload:
|
||||
return jsonify({"status": "down", "agent_id": agent_id}), 200
|
||||
payload["status"] = "up"
|
||||
bump = _normalize_text(request.args.get("bump") or "")
|
||||
if bump:
|
||||
tunnel_service.bump_activity(agent_id)
|
||||
return jsonify(payload), 200
|
||||
|
||||
@blueprint.route("/api/tunnel/connect/status", methods=["GET"])
|
||||
def tunnel_connect_status():
|
||||
return tunnel_status()
|
||||
|
||||
@blueprint.route("/api/tunnel/disconnect", methods=["DELETE"])
|
||||
def disconnect_tunnel():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
agent_id = _normalize_text(body.get("agent_id"))
|
||||
tunnel_id = _normalize_text(body.get("tunnel_id"))
|
||||
reason = _normalize_text(body.get("reason") or "operator_stop")
|
||||
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
stopped = False
|
||||
try:
|
||||
stopped = tunnel_service.stop_tunnel(tunnel_id_norm, reason=reason)
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("stop_tunnel failed tunnel_id=%s: %s", tunnel_id_norm, exc, exc_info=True)
|
||||
if tunnel_id:
|
||||
stopped = tunnel_service.disconnect_by_tunnel(tunnel_id, reason=reason)
|
||||
elif agent_id:
|
||||
stopped = tunnel_service.disconnect(agent_id, reason=reason)
|
||||
else:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
if not stopped:
|
||||
return jsonify({"error": "not_found"}), 404
|
||||
|
||||
service_log(
|
||||
"reverse_tunnel",
|
||||
f"lease stopped tunnel_id={tunnel_id_norm} reason={reason or '-'}",
|
||||
)
|
||||
return jsonify({"status": "stopped", "tunnel_id": tunnel_id_norm}), 200
|
||||
return jsonify({"status": "stopped", "reason": reason}), 200
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
Reference in New Issue
Block a user