Overhaul of VPN Codebase

This commit is contained in:
2025-12-18 01:35:03 -07:00
parent 2f81061a1b
commit 6ceb59f717
56 changed files with 1786 additions and 4778 deletions

View File

@@ -9,6 +9,8 @@
# - GET /api/devices/<guid> (Token Authenticated) - Retrieves a single device record by GUID, including summary fields.
# - GET /api/device/details/<hostname> (Token Authenticated) - Returns full device details keyed by hostname.
# - POST /api/device/description/<hostname> (Token Authenticated) - Updates the human-readable description for a device.
# - GET /api/device/vpn_config/<agent_id> (Token Authenticated) - Returns per-device VPN allowed port settings.
# - PUT /api/device/vpn_config/<agent_id> (Token Authenticated) - Updates per-device VPN allowed port settings.
# - GET /api/device_list_views (Token Authenticated) - Lists saved device table view definitions.
# - GET /api/device_list_views/<int:view_id> (Token Authenticated) - Retrieves a specific saved device table view definition.
# - POST /api/device_list_views (Token Authenticated) - Creates a custom device list view for the signed-in operator.
@@ -426,6 +428,131 @@ class DeviceManagementService:
return None
return None
def _parse_ports(self, raw: Any) -> List[int]:
ports: List[int] = []
if isinstance(raw, str):
parts = [part.strip() for part in raw.split(",") if part.strip()]
elif isinstance(raw, list):
parts = raw
else:
parts = []
for part in parts:
try:
value = int(part)
except Exception:
continue
if 1 <= value <= 65535:
ports.append(value)
return list(dict.fromkeys(ports))
def _default_vpn_ports(self, os_name: Optional[str]) -> List[int]:
ports = list(self.adapters.context.wireguard_acl_allowlist_windows or [])
os_text = (os_name or "").strip().lower()
if os_text and "windows" not in os_text:
baseline = {5900, 3478}
filtered = [p for p in ports if p in baseline]
return filtered or ports
return ports
def get_vpn_config(self, agent_id: str) -> Tuple[Dict[str, Any], int]:
agent_id = (agent_id or "").strip()
if not agent_id:
return {"error": "agent_id_required"}, 400
default_ports: List[int] = []
shell_port = int(self.adapters.context.wireguard_shell_port)
try:
conn = self._db_conn()
cur = conn.cursor()
os_name = ""
try:
cur.execute(
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
(agent_id,),
)
row = cur.fetchone()
if row and row[0]:
os_name = str(row[0])
except Exception:
os_name = ""
default_ports = self._default_vpn_ports(os_name)
cur.execute(
"SELECT allowed_ports, updated_at, updated_by FROM device_vpn_config WHERE agent_id=?",
(agent_id,),
)
row = cur.fetchone()
if not row:
return {
"agent_id": agent_id,
"allowed_ports": default_ports,
"default_ports": default_ports,
"shell_port": shell_port,
"source": "default",
}, 200
raw_ports = row[0] or ""
ports = []
try:
ports = json.loads(raw_ports) if raw_ports else []
except Exception:
ports = []
return {
"agent_id": agent_id,
"allowed_ports": ports or default_ports,
"default_ports": default_ports,
"shell_port": shell_port,
"updated_at": row[1],
"updated_by": row[2],
"source": "custom" if ports else "default",
}, 200
except Exception as exc:
self.logger.debug("Failed to load vpn config", exc_info=True)
return {"error": "internal_error"}, 500
finally:
try:
conn.close()
except Exception:
pass
def set_vpn_config(self, agent_id: str, payload: Dict[str, Any]) -> Tuple[Dict[str, Any], int]:
agent_id = (agent_id or "").strip()
if not agent_id:
return {"error": "agent_id_required"}, 400
ports = self._parse_ports(payload.get("allowed_ports"))
if not ports:
return {"error": "allowed_ports_required"}, 400
user = self._current_user() or {}
updated_by = user.get("username") or ""
updated_at = datetime.now(timezone.utc).isoformat()
try:
conn = self._db_conn()
cur = conn.cursor()
cur.execute(
"""
INSERT INTO device_vpn_config(agent_id, allowed_ports, updated_at, updated_by)
VALUES (?, ?, ?, ?)
ON CONFLICT(agent_id) DO UPDATE SET
allowed_ports=excluded.allowed_ports,
updated_at=excluded.updated_at,
updated_by=excluded.updated_by
""",
(agent_id, json.dumps(ports), updated_at, updated_by),
)
conn.commit()
return {
"agent_id": agent_id,
"allowed_ports": ports,
"updated_at": updated_at,
"updated_by": updated_by,
"source": "custom",
}, 200
except Exception:
self.logger.debug("Failed to save vpn config", exc_info=True)
return {"error": "internal_error"}, 500
finally:
try:
conn.close()
except Exception:
pass
def _require_login(self) -> Optional[Tuple[Dict[str, Any], int]]:
if not self._current_user():
return {"error": "unauthorized"}, 401
@@ -1793,6 +1920,19 @@ def register_management(app, adapters: "EngineServiceAdapters") -> None:
payload, status = service.set_device_description(hostname, description)
return jsonify(payload), status
@blueprint.route("/api/device/vpn_config/<agent_id>", methods=["GET", "PUT"])
def _vpn_config(agent_id: str):
requirement = service._require_login()
if requirement:
payload, status = requirement
return jsonify(payload), status
if request.method == "GET":
payload, status = service.get_vpn_config(agent_id)
else:
body = request.get_json(silent=True) or {}
payload, status = service.set_vpn_config(agent_id, body)
return jsonify(payload), status
@blueprint.route("/api/device_list_views", methods=["GET"])
def _list_views():
requirement = service._require_login()

View File

@@ -1,12 +1,14 @@
# ======================================================
# Data\Engine\services\API\devices\tunnel.py
# Description: Negotiation endpoint for reverse tunnel leases (operator-initiated; dormant until tunnel listener is wired).
# Description: WireGuard VPN tunnel API (connect/status/disconnect).
#
# API Endpoints (if applicable):
# - POST /api/tunnel/request (Token Authenticated) - Allocates a reverse tunnel lease for the requested agent/protocol.
# - POST /api/tunnel/connect (Token Authenticated) - Issues VPN session material for an agent.
# - GET /api/tunnel/status (Token Authenticated) - Returns VPN status for an agent.
# - DELETE /api/tunnel/disconnect (Token Authenticated) - Tears down VPN session for an agent.
# ======================================================
"""Reverse tunnel negotiation API (Engine side)."""
"""WireGuard VPN tunnel API (Engine side)."""
from __future__ import annotations
import os
@@ -15,15 +17,13 @@ from typing import Any, Dict, Optional, Tuple
from flask import Blueprint, jsonify, request, session
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from ...WebSocket.Agent.reverse_tunnel_orchestrator import ReverseTunnelService
from ...VPN import VpnTunnelService
if False: # pragma: no cover - import cycle hint for type checkers
from .. import EngineServiceAdapters
def _current_user(app) -> Optional[Dict[str, str]]:
"""Resolve operator identity from session or signed token."""
username = session.get("username")
role = session.get("role") or "User"
if username:
@@ -58,18 +58,22 @@ def _require_login(app) -> Optional[Tuple[Dict[str, Any], int]]:
return None
def _get_tunnel_service(adapters: "EngineServiceAdapters") -> ReverseTunnelService:
service = getattr(adapters.context, "reverse_tunnel_service", None) or getattr(adapters, "_reverse_tunnel_service", None)
def _get_tunnel_service(adapters: "EngineServiceAdapters") -> VpnTunnelService:
service = getattr(adapters.context, "vpn_tunnel_service", None) or getattr(adapters, "_vpn_tunnel_service", None)
if service is None:
service = ReverseTunnelService(
adapters.context,
signer=getattr(adapters, "script_signer", None),
manager = getattr(adapters.context, "wireguard_server_manager", None)
if manager is None:
raise RuntimeError("wireguard_manager_unavailable")
service = VpnTunnelService(
context=adapters.context,
wireguard_manager=manager,
db_conn_factory=adapters.db_conn_factory,
socketio=getattr(adapters.context, "socketio", None),
service_log=adapters.service_log,
signer=getattr(adapters, "script_signer", None),
)
service.start()
setattr(adapters, "_reverse_tunnel_service", service)
setattr(adapters.context, "reverse_tunnel_service", service)
setattr(adapters, "_vpn_tunnel_service", service)
setattr(adapters.context, "vpn_tunnel_service", service)
return service
@@ -83,14 +87,11 @@ def _normalize_text(value: Any) -> str:
def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
"""Register reverse tunnel negotiation endpoints."""
blueprint = Blueprint("vpn_tunnel", __name__)
logger = adapters.context.logger.getChild("vpn_tunnel.api")
blueprint = Blueprint("reverse_tunnel", __name__)
service_log = adapters.service_log
logger = adapters.context.logger.getChild("tunnel.api")
@blueprint.route("/api/tunnel/request", methods=["POST"])
def request_tunnel():
@blueprint.route("/api/tunnel/connect", methods=["POST"])
def connect_tunnel():
requirement = _require_login(app)
if requirement:
payload, status = requirement
@@ -101,69 +102,67 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id"))
protocol = _normalize_text(body.get("protocol") or "ps").lower() or "ps"
domain = _normalize_text(body.get("domain") or protocol).lower() or protocol
if protocol == "ps" and domain == "ps":
domain = "remote-interactive-shell"
if not agent_id:
return jsonify({"error": "agent_id_required"}), 400
tunnel_service = _get_tunnel_service(adapters)
try:
lease = tunnel_service.request_lease(
agent_id=agent_id,
protocol=protocol,
domain=domain,
operator_id=operator_id,
)
tunnel_service = _get_tunnel_service(adapters)
payload = tunnel_service.connect(agent_id=agent_id, operator_id=operator_id)
except RuntimeError as exc:
message = str(exc)
if message.startswith("domain_limit:"):
domain_name = message.split(":", 1)[-1] if ":" in message else domain
return jsonify({"error": "domain_limit", "domain": domain_name}), 409
if message == "port_pool_exhausted":
return jsonify({"error": "port_pool_exhausted"}), 503
logger.warning("tunnel lease request failed for agent_id=%s: %s", agent_id, message)
return jsonify({"error": "lease_allocation_failed"}), 500
logger.warning("vpn connect failed for agent_id=%s: %s", agent_id, exc)
return jsonify({"error": "connect_failed"}), 500
summary = tunnel_service.lease_summary(lease)
summary["fixed_port"] = tunnel_service.fixed_port
summary["heartbeat_seconds"] = tunnel_service.heartbeat_seconds
return jsonify(payload), 200
service_log(
"reverse_tunnel",
f"lease created tunnel_id={lease.tunnel_id} agent_id={lease.agent_id} domain={lease.domain} protocol={lease.protocol}",
)
return jsonify(summary), 200
@blueprint.route("/api/tunnel/<tunnel_id>", methods=["DELETE"])
def stop_tunnel(tunnel_id: str):
@blueprint.route("/api/tunnel/status", methods=["GET"])
def tunnel_status():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
tunnel_id_norm = _normalize_text(tunnel_id)
if not tunnel_id_norm:
return jsonify({"error": "tunnel_id_required"}), 400
agent_id = _normalize_text(request.args.get("agent_id") or "")
if not agent_id:
return jsonify({"error": "agent_id_required"}), 400
tunnel_service = _get_tunnel_service(adapters)
payload = tunnel_service.status(agent_id)
if not payload:
return jsonify({"status": "down", "agent_id": agent_id}), 200
payload["status"] = "up"
bump = _normalize_text(request.args.get("bump") or "")
if bump:
tunnel_service.bump_activity(agent_id)
return jsonify(payload), 200
@blueprint.route("/api/tunnel/connect/status", methods=["GET"])
def tunnel_connect_status():
return tunnel_status()
@blueprint.route("/api/tunnel/disconnect", methods=["DELETE"])
def disconnect_tunnel():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id"))
tunnel_id = _normalize_text(body.get("tunnel_id"))
reason = _normalize_text(body.get("reason") or "operator_stop")
tunnel_service = _get_tunnel_service(adapters)
stopped = False
try:
stopped = tunnel_service.stop_tunnel(tunnel_id_norm, reason=reason)
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("stop_tunnel failed tunnel_id=%s: %s", tunnel_id_norm, exc, exc_info=True)
if tunnel_id:
stopped = tunnel_service.disconnect_by_tunnel(tunnel_id, reason=reason)
elif agent_id:
stopped = tunnel_service.disconnect(agent_id, reason=reason)
else:
return jsonify({"error": "agent_id_required"}), 400
if not stopped:
return jsonify({"error": "not_found"}), 404
service_log(
"reverse_tunnel",
f"lease stopped tunnel_id={tunnel_id_norm} reason={reason or '-'}",
)
return jsonify({"status": "stopped", "tunnel_id": tunnel_id_norm}), 200
return jsonify({"status": "stopped", "reason": reason}), 200
app.register_blueprint(blueprint)

View File

@@ -8,4 +8,4 @@
"""VPN service helpers for the Engine runtime."""
from .wireguard_server import WireGuardServerConfig, WireGuardServerManager # noqa: F401
from .vpn_tunnel_service import VpnTunnelService # noqa: F401

View File

@@ -0,0 +1,473 @@
# ======================================================
# Data\Engine\services\VPN\vpn_tunnel_service.py
# Description: WireGuard tunnel orchestration (single tunnel per agent, token issuance, idle handling).
#
# API Endpoints (if applicable): None
# ======================================================
"""WireGuard tunnel orchestration helpers for the Engine runtime."""
from __future__ import annotations
import base64
import ipaddress
import json
import threading
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from .wireguard_server import WireGuardServerManager
@dataclass
class VpnSession:
tunnel_id: str
agent_id: str
virtual_ip: str
token: Dict[str, Any]
client_public_key: str
client_private_key: str
allowed_ports: Tuple[int, ...]
created_at: float
expires_at: float
last_activity: float
operator_ids: set[str] = field(default_factory=set)
firewall_rules: List[str] = field(default_factory=list)
activity_id: Optional[int] = None
hostname: Optional[str] = None
class VpnTunnelService:
def __init__(
self,
*,
context: Any,
wireguard_manager: WireGuardServerManager,
db_conn_factory,
socketio,
service_log,
signer: Optional[Any] = None,
idle_seconds: int = 900,
) -> None:
self.context = context
self.wg = wireguard_manager
self.db_conn_factory = db_conn_factory
self.socketio = socketio
self.service_log = service_log
self.signer = signer
self.logger = context.logger.getChild("vpn_tunnel")
self.activity_logger = self.wg.logger.getChild("device_activity")
self.idle_seconds = max(60, int(idle_seconds))
self._lock = threading.Lock()
self._sessions_by_agent: Dict[str, VpnSession] = {}
self._sessions_by_tunnel: Dict[str, VpnSession] = {}
self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip)
self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False)
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
self._idle_thread.start()
def _idle_loop(self) -> None:
while True:
time.sleep(10)
now = time.time()
expired: List[VpnSession] = []
with self._lock:
for session in list(self._sessions_by_agent.values()):
if session.last_activity + self.idle_seconds <= now:
expired.append(session)
for session in expired:
self.disconnect(session.agent_id, reason="idle_timeout")
def _allocate_virtual_ip(self, agent_id: str) -> str:
existing = self._sessions_by_agent.get(agent_id)
if existing:
return existing.virtual_ip
used = {s.virtual_ip for s in self._sessions_by_agent.values()}
for host in self._peer_network.hosts():
if host == self._engine_ip.ip:
continue
candidate = f"{host}/32"
if candidate not in used:
return candidate
raise RuntimeError("vpn_ip_pool_exhausted")
def _load_allowed_ports(self, agent_id: str) -> Tuple[int, ...]:
default = tuple(self.context.wireguard_acl_allowlist_windows or ())
try:
conn = self.db_conn_factory()
cur = conn.cursor()
try:
cur.execute(
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
(agent_id,),
)
row = cur.fetchone()
os_name = str(row[0]).lower() if row and row[0] else ""
except Exception:
os_name = ""
if os_name and "windows" not in os_name:
baseline = {5900, 3478}
filtered = [p for p in default if p in baseline]
if filtered:
default = tuple(filtered)
cur.execute(
"SELECT allowed_ports FROM device_vpn_config WHERE agent_id=?",
(agent_id,),
)
row = cur.fetchone()
if not row:
return default
raw = row[0] or ""
ports = json.loads(raw) if raw else []
ports = [int(p) for p in ports if isinstance(p, (int, float, str))]
ports = [p for p in ports if 1 <= p <= 65535]
return tuple(dict.fromkeys(ports)) or default
except Exception:
return default
finally:
try:
conn.close()
except Exception:
pass
def _generate_client_keys(self) -> Tuple[str, str]:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import x25519
key = x25519.X25519PrivateKey.generate()
priv = base64.b64encode(
key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
).decode("ascii").strip()
pub = base64.b64encode(
key.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
).decode("ascii").strip()
return priv, pub
def _issue_token(self, agent_id: str, tunnel_id: str, expires_at: float) -> Dict[str, Any]:
payload = {
"agent_id": agent_id,
"tunnel_id": tunnel_id,
"port": self.context.wireguard_port,
"expires_at": expires_at,
"issued_at": time.time(),
}
if not self.signer:
return dict(payload)
token = dict(payload)
try:
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
signature = self.signer.sign(payload_bytes)
token["signature"] = base64.b64encode(signature).decode("ascii")
if hasattr(self.signer, "public_base64_spki"):
token["signing_key"] = self.signer.public_base64_spki()
token["sig_alg"] = "ed25519"
except Exception:
self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True)
return token
def _refresh_listener(self) -> None:
peers: List[Mapping[str, object]] = []
for session in self._sessions_by_agent.values():
peer = self.wg.build_peer_profile(
session.agent_id,
session.virtual_ip,
allowed_ports=session.allowed_ports,
)
peer = dict(peer)
peer["public_key"] = session.client_public_key
peers.append(peer)
if not peers:
self.wg.stop_listener()
return
self.wg.start_listener(peers)
def connect(self, *, agent_id: str, operator_id: Optional[str]) -> Mapping[str, Any]:
now = time.time()
with self._lock:
existing = self._sessions_by_agent.get(agent_id)
if existing:
if operator_id:
existing.operator_ids.add(operator_id)
existing.last_activity = now
return self._session_payload(existing)
tunnel_id = uuid.uuid4().hex
virtual_ip = self._allocate_virtual_ip(agent_id)
allowed_ports = self._load_allowed_ports(agent_id)
client_private, client_public = self._generate_client_keys()
token = self._issue_token(agent_id, tunnel_id, now + 300)
self.wg.require_orchestration_token(token)
session = VpnSession(
tunnel_id=tunnel_id,
agent_id=agent_id,
virtual_ip=virtual_ip,
token=token,
client_public_key=client_public,
client_private_key=client_private,
allowed_ports=allowed_ports,
created_at=now,
expires_at=now + 300,
last_activity=now,
)
if operator_id:
session.operator_ids.add(operator_id)
self._sessions_by_agent[agent_id] = session
self._sessions_by_tunnel[tunnel_id] = session
try:
self._refresh_listener()
peer = self.wg.build_peer_profile(
agent_id,
virtual_ip,
allowed_ports=allowed_ports,
)
rule_names = self.wg.apply_firewall_rules(peer)
session.firewall_rules = rule_names
except Exception:
with self._lock:
self._sessions_by_agent.pop(agent_id, None)
self._sessions_by_tunnel.pop(tunnel_id, None)
try:
self._refresh_listener()
except Exception:
self.logger.debug("Failed to refresh WireGuard listener after connect rollback.", exc_info=True)
raise
payload = self._session_payload(session)
self._emit_start(payload)
self._log_device_activity(session, event="start")
return payload
def status(self, agent_id: str) -> Optional[Mapping[str, Any]]:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
if not session:
return None
return self._session_payload(session, include_token=False)
def bump_activity(self, agent_id: str) -> None:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
if not session:
return
session.last_activity = time.time()
try:
if self.socketio:
self.socketio.emit("vpn_tunnel_activity", {"agent_id": agent_id}, namespace="/")
except Exception:
self.logger.debug("vpn_tunnel_activity emit failed for agent_id=%s", agent_id, exc_info=True)
def disconnect(self, agent_id: str, reason: str = "operator_stop") -> bool:
with self._lock:
session = self._sessions_by_agent.pop(agent_id, None)
if not session:
return False
self._sessions_by_tunnel.pop(session.tunnel_id, None)
try:
self.wg.remove_firewall_rules(session.firewall_rules)
except Exception:
self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True)
self._refresh_listener()
self._emit_stop(session, reason)
self._log_device_activity(session, event="stop", reason=reason)
return True
def disconnect_by_tunnel(self, tunnel_id: str, reason: str = "operator_stop") -> bool:
with self._lock:
session = self._sessions_by_tunnel.get(tunnel_id)
if not session:
return False
return self.disconnect(session.agent_id, reason=reason)
def _emit_start(self, payload: Mapping[str, Any]) -> None:
if not self.socketio:
return
try:
self.socketio.emit("vpn_tunnel_start", payload, namespace="/")
except Exception:
self.logger.debug("vpn_tunnel_start emit failed", exc_info=True)
def _emit_stop(self, session: VpnSession, reason: str) -> None:
if not self.socketio:
return
try:
self.socketio.emit(
"vpn_tunnel_stop",
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
namespace="/",
)
except Exception:
self.logger.debug("vpn_tunnel_stop emit failed", exc_info=True)
def _log_device_activity(self, session: VpnSession, *, event: str, reason: Optional[str] = None) -> None:
if self.db_conn_factory is None:
self.activity_logger.info(
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s",
event,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
reason or "-",
)
return
conn = None
try:
conn = self.db_conn_factory()
cur = conn.cursor()
hostname = session.hostname
if not hostname:
try:
cur.execute(
"SELECT hostname FROM devices WHERE agent_id = ? ORDER BY last_seen DESC LIMIT 1",
(session.agent_id,),
)
row = cur.fetchone()
if row and row[0]:
hostname = str(row[0]).strip()
session.hostname = hostname
except Exception:
hostname = None
if not hostname:
self.activity_logger.info(
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s hostname=unknown",
event,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
reason or "-",
)
return
now_ts = int(time.time())
script_name = "Reverse VPN Tunnel (WireGuard)"
if event == "start":
cur.execute(
"""
INSERT INTO activity_history(hostname, script_path, script_name, script_type, ran_at, status, stdout, stderr)
VALUES(?,?,?,?,?,?,?,?)
""",
(
hostname,
session.tunnel_id,
script_name,
"vpn_tunnel",
now_ts,
"Running",
"",
"",
),
)
session.activity_id = cur.lastrowid
conn.commit()
if self.socketio:
try:
self.socketio.emit(
"device_activity_changed",
{
"hostname": hostname,
"activity_id": session.activity_id,
"change": "created",
"source": "vpn_tunnel",
},
)
except Exception:
pass
self.activity_logger.info(
"device_activity_start hostname=%s agent_id=%s tunnel_id=%s operator=%s activity_id=%s",
hostname,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
session.activity_id or "-",
)
return
if session.activity_id:
status = "Completed" if event == "stop" else "Closed"
cur.execute(
"""
UPDATE activity_history
SET status=?,
stderr=COALESCE(stderr, '') || ?
WHERE id=?
""",
(
status,
f"\nreason: {reason}" if reason else "",
session.activity_id,
),
)
conn.commit()
if self.socketio:
try:
self.socketio.emit(
"device_activity_changed",
{
"hostname": hostname,
"activity_id": session.activity_id,
"change": "updated",
"source": "vpn_tunnel",
},
)
except Exception:
pass
self.activity_logger.info(
"device_activity event=%s hostname=%s agent_id=%s tunnel_id=%s operator=%s reason=%s activity_id=%s",
event,
hostname,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
reason or "-",
session.activity_id or "-",
)
except Exception:
self.activity_logger.debug(
"device_activity logging failed for tunnel_id=%s",
session.tunnel_id,
exc_info=True,
)
finally:
if conn is not None:
try:
conn.close()
except Exception:
pass
def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]:
payload: Dict[str, Any] = {
"tunnel_id": session.tunnel_id,
"agent_id": session.agent_id,
"virtual_ip": session.virtual_ip,
"engine_virtual_ip": str(self._engine_ip.ip),
"allowed_ips": f"{self._engine_ip.ip}/32",
"endpoint": f"{self._engine_ip.ip}:{self.context.wireguard_port}",
"server_public_key": self.wg.server_public_key,
"client_public_key": session.client_public_key,
"client_private_key": session.client_private_key,
"idle_seconds": self.idle_seconds,
"allowed_ports": list(session.allowed_ports),
"connected_operators": len([o for o in session.operator_ids if o]),
}
if include_token:
payload["token"] = session.token
return payload

View File

@@ -70,7 +70,7 @@ class WireGuardServerManager:
self.logger = _build_logger(config.log_path)
self._ensure_cert_dir()
self.server_private_key, self.server_public_key = self._ensure_server_keys()
self._service_name = "BorealisWireGuard"
self._service_name = "borealis-wg"
self._temp_dir = Path(tempfile.gettempdir()) / "borealis-wg-engine"
def _ensure_cert_dir(self) -> None:
@@ -157,7 +157,7 @@ class WireGuardServerManager:
if not token:
raise ValueError("Missing orchestration token for WireGuard peer")
required_fields = ("agent_id", "tunnel_id", "expires_at")
required_fields = ("agent_id", "tunnel_id", "expires_at", "port")
missing = [field for field in required_fields if field not in token or token[field] in (None, "")]
if missing:
raise ValueError(f"Invalid orchestration token; missing {', '.join(missing)}")
@@ -167,6 +167,13 @@ class WireGuardServerManager:
except Exception:
raise ValueError("Invalid orchestration token expiry")
try:
port = int(token["port"])
except Exception:
raise ValueError("Invalid orchestration token port")
if port != int(self.config.port):
raise ValueError("Orchestration token port mismatch")
now = time.time()
if expires_at <= now:
raise ValueError("Orchestration token expired")
@@ -253,12 +260,14 @@ class WireGuardServerManager:
"host_only": True,
}
def apply_firewall_rules(self, peer: Mapping[str, object]) -> None:
def apply_firewall_rules(self, peer: Mapping[str, object]) -> List[str]:
"""Apply outbound firewall allow rules for the agent's virtual IP/ports (Windows netsh)."""
rules = self.build_firewall_rules(peer)
rule_names: List[str] = []
for idx, rule in enumerate(rules):
name = f"Borealis-WG-Agent-{peer.get('agent_id','')}-{idx}"
protocol = str(rule.get("protocol") or "TCP").upper()
args = [
"netsh",
"advfirewall",
@@ -269,7 +278,7 @@ class WireGuardServerManager:
"dir=out",
"action=allow",
f"remoteip={rule.get('remote_address','')}",
f"protocol=TCP",
f"protocol={protocol}",
f"localport={rule.get('local_port','')}",
]
code, out, err = self._run_command(args)
@@ -277,6 +286,19 @@ class WireGuardServerManager:
self.logger.warning("Failed to apply firewall rule %s code=%s err=%s", name, code, err)
else:
self.logger.info("Applied firewall rule %s", name)
rule_names.append(name)
return rule_names
def remove_firewall_rules(self, rule_names: Sequence[str]) -> None:
for name in rule_names:
if not name:
continue
args = ["netsh", "advfirewall", "firewall", "delete", "rule", f"name={name}"]
code, out, err = self._run_command(args)
if code != 0:
self.logger.warning("Failed to remove firewall rule %s code=%s err=%s", name, code, err)
else:
self.logger.info("Removed firewall rule %s", name)
def start_listener(self, peers: Sequence[Mapping[str, object]]) -> None:
"""Render a temporary WireGuard config and start the service."""
@@ -291,6 +313,9 @@ class WireGuardServerManager:
config_path.write_text(rendered, encoding="utf-8")
self.logger.info("Rendered WireGuard config to %s", config_path)
# Ensure old service is removed before re-installing.
self.stop_listener()
args = ["wireguard.exe", "/installtunnelservice", str(config_path)]
code, out, err = self._run_command(args)
if code != 0:
@@ -301,7 +326,7 @@ class WireGuardServerManager:
def stop_listener(self) -> None:
"""Stop and remove the WireGuard tunnel service."""
args = ["wireguard.exe", "/uninstalltunnelservice", "borealis-wg"]
args = ["wireguard.exe", "/uninstalltunnelservice", self._service_name]
code, out, err = self._run_command(args)
if code != 0:
self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err)
@@ -323,15 +348,17 @@ class WireGuardServerManager:
port_list = []
for port in port_list:
rules.append(
{
"direction": "outbound",
"remote_address": ip,
"local_port": port,
"action": "allow",
"description": f"WireGuard engine->agent allow port {port}",
}
)
for protocol in ("TCP", "UDP"):
rules.append(
{
"direction": "outbound",
"remote_address": ip,
"local_port": port,
"protocol": protocol,
"action": "allow",
"description": f"WireGuard engine->agent allow port {port}/{protocol}",
}
)
self.logger.info(
"Prepared firewall rule plan for agent=%s rules=%s",

View File

@@ -1,3 +0,0 @@
"""Namespace package for reverse tunnel domain handlers (Engine side)."""
__all__ = ["remote_interactive_shell", "remote_management", "remote_video"]

View File

@@ -1,78 +0,0 @@
"""Placeholder Bash channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class BashChannelServer:
"""Stub Bash handler until the agent-side channel is implemented."""
protocol_name = "bash"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"bash.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
# No-op placeholder; output collection for future Bash support.
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"bash channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
# Placeholder: no agent-side Bash yet.
self.logger.info("bash placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
# Placeholder: not implemented.
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["BashChannelServer"]

View File

@@ -1,139 +0,0 @@
"""Engine-side PowerShell tunnel channel helper (remote interactive shell domain)."""
from __future__ import annotations
import json
from collections import deque
from typing import Any, Deque, Dict, List, Optional
# Mirror framing constants to avoid circular imports.
MSG_CHANNEL_OPEN = 0x03
MSG_CHANNEL_ACK = 0x04
MSG_DATA = 0x05
MSG_CONTROL = 0x09
MSG_CLOSE = 0x08
CLOSE_OK = 0
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AGENT_SHUTDOWN = 6
class PowershellChannelServer:
"""Coordinate PowerShell channel frames over a TunnelBridge."""
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"ps.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output: Deque[str] = deque()
self._close_reason: Optional[str] = None
self._close_code: Optional[int] = None
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
# ------------------------------------------------------------------ Agent frame handling
def handle_agent_frame(self, frame) -> None:
if frame.channel_id != self.channel_id:
return
if frame.msg_type == MSG_CHANNEL_ACK:
self._ack_received = True
self.logger.info("ps channel acked tunnel_id=%s", self.bridge.lease.tunnel_id)
return
if frame.msg_type == MSG_DATA:
try:
text = frame.payload.decode("utf-8", errors="replace")
except Exception:
text = ""
if text:
self._append_output(text)
return
if frame.msg_type == MSG_CLOSE:
try:
payload = json.loads(frame.payload.decode("utf-8"))
except Exception:
payload = {}
self._closed = True
self._close_code = payload.get("code") if isinstance(payload, dict) else None
self._close_reason = payload.get("reason") if isinstance(payload, dict) else None
self.logger.info(
"ps channel closed tunnel_id=%s code=%s reason=%s",
self.bridge.lease.tunnel_id,
self._close_code,
self._close_reason or "-",
)
# ------------------------------------------------------------------ Operator actions
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
if self._open_sent:
return
payload = json.dumps(
{"protocol": "ps", "metadata": {"cols": cols, "rows": rows}},
separators=(",", ":"),
).encode("utf-8")
frame = self._frame_cls(msg_type=MSG_CHANNEL_OPEN, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
self._open_sent = True
self.logger.info(
"ps channel open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
if self._closed:
return
payload = data.encode("utf-8", errors="replace")
frame = self._frame_cls(msg_type=MSG_DATA, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
def send_resize(self, cols: int, rows: int) -> None:
if self._closed:
return
payload = json.dumps({"cols": cols, "rows": rows}, separators=(",", ":")).encode("utf-8")
frame = self._frame_cls(msg_type=MSG_CONTROL, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
def close(self, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
frame = self._close_frame_fn(self.channel_id, code, reason)
else:
frame = self._frame_cls(
msg_type=MSG_CLOSE,
channel_id=self.channel_id,
payload=json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8"),
)
self.bridge.operator_to_agent(frame)
# ------------------------------------------------------------------ Output polling
def drain_output(self) -> List[str]:
items: List[str] = []
while self._output:
items.append(self._output.popleft())
return items
def _append_output(self, text: str) -> None:
self._output.append(text)
# Cap buffer to avoid unbounded memory growth.
while len(self._output) > 500:
self._output.popleft()
# ------------------------------------------------------------------ Status helpers
def status(self) -> Dict[str, Any]:
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": self._close_reason,
"close_code": self._close_code,
}
__all__ = ["PowershellChannelServer"]

View File

@@ -1,6 +0,0 @@
"""Protocol handlers for remote interactive shell tunnels (Engine side)."""
from .Powershell import PowershellChannelServer
from .Bash import BashChannelServer
__all__ = ["PowershellChannelServer", "BashChannelServer"]

View File

@@ -1,3 +0,0 @@
"""Domain handlers for remote interactive shells (PowerShell/Bash)."""
__all__ = ["Protocols"]

View File

@@ -1,73 +0,0 @@
"""Placeholder SSH channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class SSHChannelServer:
protocol_name = "ssh"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"ssh.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"ssh channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("ssh placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["SSHChannelServer"]

View File

@@ -1,73 +0,0 @@
"""Placeholder WinRM channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class WinRMChannelServer:
protocol_name = "winrm"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"winrm.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"winrm channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("winrm placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["WinRMChannelServer"]

View File

@@ -1,6 +0,0 @@
"""Protocol handlers for remote management tunnels (Engine side)."""
from .SSH import SSHChannelServer
from .WinRM import WinRMChannelServer
__all__ = ["SSHChannelServer", "WinRMChannelServer"]

View File

@@ -1,3 +0,0 @@
"""Domain handlers for remote management tunnels (SSH/WinRM)."""
__all__ = ["Protocols"]

View File

@@ -1,73 +0,0 @@
"""Placeholder RDP channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class RDPChannelServer:
protocol_name = "rdp"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"rdp.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"rdp channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("rdp placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["RDPChannelServer"]

View File

@@ -1,73 +0,0 @@
"""Placeholder VNC channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class VNCChannelServer:
protocol_name = "vnc"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"vnc.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"vnc channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("vnc placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["VNCChannelServer"]

View File

@@ -1,73 +0,0 @@
"""Placeholder WebRTC channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class WebRTCChannelServer:
protocol_name = "webrtc"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"webrtc.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"webrtc channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("webrtc placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["WebRTCChannelServer"]

View File

@@ -1,7 +0,0 @@
"""Protocol handlers for remote video tunnels (Engine side)."""
from .WebRTC import WebRTCChannelServer
from .RDP import RDPChannelServer
from .VNC import VNCChannelServer
__all__ = ["WebRTCChannelServer", "RDPChannelServer", "VNCChannelServer"]

View File

@@ -1,3 +0,0 @@
"""Domain handlers for remote video/desktop tunnels (RDP/VNC/WebRTC)."""
__all__ = ["Protocols"]

View File

@@ -1,10 +0,0 @@
# ======================================================
# Data\Engine\services\WebSocket\Agent\__init__.py
# Description: Package marker for Agent-facing WebSocket services (reverse tunnel scaffolding).
#
# API Endpoints (if applicable): None
# ======================================================
"""Agent-facing WebSocket services for the Engine runtime."""
__all__ = []

View File

@@ -1,6 +1,6 @@
# ======================================================
# Data\Engine\services\WebSocket\__init__.py
# Description: Socket.IO handlers for Engine runtime quick job updates and realtime notifications.
# Description: Socket.IO handlers for Engine runtime quick job updates and VPN shell bridging.
#
# API Endpoints (if applicable): None
# ======================================================
@@ -8,24 +8,20 @@
"""WebSocket service registration for the Borealis Engine runtime."""
from __future__ import annotations
import base64
import sqlite3
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, Optional
from flask import session, request
from flask import request
from flask_socketio import SocketIO
from ...database import initialise_engine_database
from ...security import signing
from ...server import EngineContext
from .Agent.reverse_tunnel_orchestrator import (
ReverseTunnelService,
TunnelBridge,
decode_frame,
TunnelFrame,
)
from ..VPN import VpnTunnelService
from .vpn_shell import VpnShellBridge
def _now_ts() -> int:
@@ -70,20 +66,31 @@ class EngineRealtimeAdapters:
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
"""Register Socket.IO event handlers for the Engine runtime."""
from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load
adapters = EngineRealtimeAdapters(context)
logger = context.logger.getChild("realtime.quick_jobs")
tunnel_service = getattr(context, "reverse_tunnel_service", None)
if tunnel_service is None:
tunnel_service = ReverseTunnelService(
context,
signer=None,
shell_bridge = VpnShellBridge(socket_server, context)
def _get_tunnel_service() -> Optional[VpnTunnelService]:
service = getattr(context, "vpn_tunnel_service", None)
if service is not None:
return service
manager = getattr(context, "wireguard_server_manager", None)
if manager is None:
return None
try:
signer = signing.load_signer()
except Exception:
signer = None
service = VpnTunnelService(
context=context,
wireguard_manager=manager,
db_conn_factory=adapters.db_conn_factory,
socketio=socket_server,
service_log=adapters.service_log,
signer=signer,
)
tunnel_service.start()
setattr(context, "reverse_tunnel_service", tunnel_service)
setattr(context, "vpn_tunnel_service", service)
return service
@socket_server.on("quick_job_result")
def _handle_quick_job_result(data: Any) -> None:
@@ -246,252 +253,45 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
exc,
)
@socket_server.on("tunnel_bridge_attach")
def _tunnel_bridge_attach(data: Any) -> Any:
"""Placeholder operator bridge attach handler (no data channel yet)."""
@socket_server.on("vpn_shell_open")
def _vpn_shell_open(data: Any) -> Dict[str, Any]:
agent_id = ""
if isinstance(data, dict):
agent_id = str(data.get("agent_id") or "").strip()
elif isinstance(data, str):
agent_id = data.strip()
if not agent_id:
return {"error": "agent_id_required"}
if not isinstance(data, dict):
return {"error": "invalid_payload"}
service = _get_tunnel_service()
if service is None:
return {"error": "vpn_service_unavailable"}
if not service.status(agent_id):
return {"error": "tunnel_down"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
operator_id = str(data.get("operator_id") or "").strip() or None
if not tunnel_id:
return {"error": "tunnel_id_required"}
session = shell_bridge.open_session(request.sid, agent_id)
if session is None:
return {"error": "shell_connect_failed"}
service.bump_activity(agent_id)
return {"status": "ok"}
try:
tunnel_service.operator_attach(tunnel_id, operator_id)
except ValueError as exc:
return {"error": str(exc)}
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("tunnel_bridge_attach failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "bridge_attach_failed"}
return {"status": "ok", "tunnel_id": tunnel_id, "operator_id": operator_id or "-"}
def _encode_frame(frame: TunnelFrame) -> str:
return base64.b64encode(frame.encode()).decode("ascii")
def _decode_frame_payload(raw: Any) -> TunnelFrame:
if isinstance(raw, str):
try:
raw_bytes = base64.b64decode(raw)
except Exception:
raise ValueError("invalid_frame")
elif isinstance(raw, (bytes, bytearray)):
raw_bytes = bytes(raw)
@socket_server.on("vpn_shell_send")
def _vpn_shell_send(data: Any) -> Dict[str, Any]:
payload = None
if isinstance(data, dict):
payload = data.get("data")
else:
raise ValueError("invalid_frame")
return decode_frame(raw_bytes)
@socket_server.on("tunnel_operator_send")
def _tunnel_operator_send(data: Any) -> Any:
"""Operator -> agent frame enqueue (placeholder queue)."""
if not isinstance(data, dict):
return {"error": "invalid_payload"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
frame_raw = data.get("frame")
if not tunnel_id or frame_raw is None:
return {"error": "tunnel_id_and_frame_required"}
try:
frame = _decode_frame_payload(frame_raw)
except Exception as exc:
return {"error": str(exc)}
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
bridge.operator_to_agent(frame)
return {"status": "ok"}
@socket_server.on("tunnel_operator_poll")
def _tunnel_operator_poll(data: Any) -> Any:
"""Operator polls queued frames from agent."""
tunnel_id = ""
if isinstance(data, dict):
tunnel_id = str(data.get("tunnel_id") or "").strip()
if not tunnel_id:
return {"error": "tunnel_id_required"}
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
frames = []
while True:
frame = bridge.next_for_operator()
if frame is None:
break
frames.append(_encode_frame(frame))
return {"frames": frames}
# WebUI operator bridge namespace for browser clients
tunnel_namespace = "/tunnel"
_operator_sessions: Dict[str, str] = {}
def _current_operator() -> Optional[str]:
username = session.get("username")
if username:
return str(username)
auth_header = (request.headers.get("Authorization") or "").strip()
token = None
if auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
if not token:
token = request.cookies.get("borealis_auth")
return token or None
@socket_server.on("join", namespace=tunnel_namespace)
def _ws_tunnel_join(data: Any) -> Any:
if not isinstance(data, dict):
return {"error": "invalid_payload"}
operator_id = _current_operator()
if not operator_id:
return {"error": "unauthorized"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
if not tunnel_id:
return {"error": "tunnel_id_required"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
try:
tunnel_service.operator_attach(tunnel_id, operator_id)
except Exception as exc:
logger.debug("ws_tunnel_join failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "attach_failed"}
sid = request.sid
_operator_sessions[sid] = tunnel_id
return {"status": "ok", "tunnel_id": tunnel_id}
@socket_server.on("send", namespace=tunnel_namespace)
def _ws_tunnel_send(data: Any) -> Any:
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return {"error": "not_joined"}
if not isinstance(data, dict):
return {"error": "invalid_payload"}
frame_raw = data.get("frame")
if frame_raw is None:
return {"error": "frame_required"}
try:
frame = _decode_frame_payload(frame_raw)
except Exception:
return {"error": "invalid_frame"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
bridge.operator_to_agent(frame)
return {"status": "ok"}
@socket_server.on("poll", namespace=tunnel_namespace)
def _ws_tunnel_poll() -> Any:
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return {"error": "not_joined"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
frames = []
while True:
frame = bridge.next_for_operator()
if frame is None:
break
frames.append(_encode_frame(frame))
return {"frames": frames}
def _require_ps_server():
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return None, None, {"error": "not_joined"}
server = tunnel_service.ensure_protocol_server(tunnel_id)
if server is None or not hasattr(server, "open_channel"):
return None, tunnel_id, {"error": "ps_unsupported"}
return server, tunnel_id, None
@socket_server.on("ps_open", namespace=tunnel_namespace)
def _ws_ps_open(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = 120
rows = 32
if isinstance(data, dict):
try:
cols = int(data.get("cols", cols))
rows = int(data.get("rows", rows))
except Exception:
pass
cols = max(20, min(cols, 300))
rows = max(10, min(rows, 200))
try:
server.open_channel(cols=cols, rows=rows)
except Exception as exc:
logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_open_failed"}
return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows}
@socket_server.on("ps_send", namespace=tunnel_namespace)
def _ws_ps_send(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
if data is None:
payload = data
if payload is None:
return {"error": "payload_required"}
text = data
if isinstance(data, dict):
text = data.get("data")
if text is None:
return {"error": "payload_required"}
try:
server.send_input(str(text))
except Exception as exc:
logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_send_failed"}
shell_bridge.send(request.sid, str(payload))
return {"status": "ok"}
@socket_server.on("ps_resize", namespace=tunnel_namespace)
def _ws_ps_resize(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = None
rows = None
if isinstance(data, dict):
cols = data.get("cols")
rows = data.get("rows")
try:
cols_int = int(cols) if cols is not None else 120
rows_int = int(rows) if rows is not None else 32
cols_int = max(20, min(cols_int, 300))
rows_int = max(10, min(rows_int, 200))
server.send_resize(cols_int, rows_int)
return {"status": "ok", "cols": cols_int, "rows": rows_int}
except Exception as exc:
logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_resize_failed"}
@socket_server.on("vpn_shell_close")
def _vpn_shell_close() -> Dict[str, Any]:
shell_bridge.close(request.sid)
return {"status": "ok"}
@socket_server.on("ps_poll", namespace=tunnel_namespace)
def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
try:
output = server.drain_output()
status = server.status()
return {"output": output, "status": status}
except Exception as exc:
logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_poll_failed"}
@socket_server.on("disconnect", namespace=tunnel_namespace)
def _ws_tunnel_disconnect():
sid = request.sid
tunnel_id = _operator_sessions.pop(sid, None)
if tunnel_id and tunnel_id not in _operator_sessions.values():
try:
tunnel_service.stop_tunnel(tunnel_id, reason="operator_socket_disconnect")
except Exception as exc:
logger.debug("ws_tunnel_disconnect stop_tunnel failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
@socket_server.on("disconnect")
def _ws_disconnect() -> None:
shell_bridge.close(request.sid)

View File

@@ -0,0 +1,127 @@
# ======================================================
# Data\Engine\services\WebSocket\vpn_shell.py
# Description: Socket.IO handlers bridging UI shell to agent TCP server over WireGuard.
#
# API Endpoints (if applicable): None
# ======================================================
"""WireGuard VPN PowerShell bridge (Engine side)."""
from __future__ import annotations
import base64
import json
import socket
import threading
from dataclasses import dataclass
from typing import Any, Dict, Optional
def _b64encode(data: bytes) -> str:
return base64.b64encode(data).decode("ascii").strip()
def _b64decode(value: str) -> bytes:
return base64.b64decode(value.encode("ascii"))
@dataclass
class ShellSession:
sid: str
agent_id: str
socketio: Any
tcp: socket.socket
_reader: Optional[threading.Thread] = None
def start_reader(self) -> None:
t = threading.Thread(target=self._read_loop, daemon=True)
t.start()
self._reader = t
def _read_loop(self) -> None:
buffer = b""
try:
while True:
data = self.tcp.recv(4096)
if not data:
break
buffer += data
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
if not line:
continue
try:
msg = json.loads(line.decode("utf-8"))
except Exception:
continue
if msg.get("type") == "stdout":
payload = msg.get("data") or ""
try:
decoded = _b64decode(str(payload)).decode("utf-8", errors="replace")
except Exception:
decoded = ""
self.socketio.emit("vpn_shell_output", {"data": decoded}, to=self.sid)
finally:
self.socketio.emit("vpn_shell_closed", {"agent_id": self.agent_id}, to=self.sid)
try:
self.tcp.close()
except Exception:
pass
def send(self, payload: str) -> None:
data = json.dumps({"type": "stdin", "data": _b64encode(payload.encode("utf-8"))})
self.tcp.sendall(data.encode("utf-8") + b"\n")
def close(self) -> None:
try:
data = json.dumps({"type": "close"})
self.tcp.sendall(data.encode("utf-8") + b"\n")
except Exception:
pass
try:
self.tcp.close()
except Exception:
pass
class VpnShellBridge:
def __init__(self, socketio, context) -> None:
self.socketio = socketio
self.context = context
self._sessions: Dict[str, ShellSession] = {}
self.logger = context.logger.getChild("vpn_shell")
def open_session(self, sid: str, agent_id: str) -> Optional[ShellSession]:
service = getattr(self.context, "vpn_tunnel_service", None)
if service is None:
return None
status = service.status(agent_id)
if not status:
return None
host = str(status.get("virtual_ip") or "").split("/")[0]
port = int(self.context.wireguard_shell_port)
try:
tcp = socket.create_connection((host, port), timeout=5)
except Exception:
self.logger.debug("Failed to connect vpn shell to %s:%s", host, port, exc_info=True)
return None
session = ShellSession(sid=sid, agent_id=agent_id, socketio=self.socketio, tcp=tcp)
self._sessions[sid] = session
session.start_reader()
return session
def send(self, sid: str, payload: str) -> None:
session = self._sessions.get(sid)
if not session:
return
session.send(payload)
service = getattr(self.context, "vpn_tunnel_service", None)
if service:
service.bump_activity(session.agent_id)
def close(self, sid: str) -> None:
session = self._sessions.pop(sid, None)
if not session:
return
session.close()