mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-19 00:35:48 -07:00
Overhaul of VPN Codebase
This commit is contained in:
@@ -9,6 +9,8 @@
|
||||
# - GET /api/devices/<guid> (Token Authenticated) - Retrieves a single device record by GUID, including summary fields.
|
||||
# - GET /api/device/details/<hostname> (Token Authenticated) - Returns full device details keyed by hostname.
|
||||
# - POST /api/device/description/<hostname> (Token Authenticated) - Updates the human-readable description for a device.
|
||||
# - GET /api/device/vpn_config/<agent_id> (Token Authenticated) - Returns per-device VPN allowed port settings.
|
||||
# - PUT /api/device/vpn_config/<agent_id> (Token Authenticated) - Updates per-device VPN allowed port settings.
|
||||
# - GET /api/device_list_views (Token Authenticated) - Lists saved device table view definitions.
|
||||
# - GET /api/device_list_views/<int:view_id> (Token Authenticated) - Retrieves a specific saved device table view definition.
|
||||
# - POST /api/device_list_views (Token Authenticated) - Creates a custom device list view for the signed-in operator.
|
||||
@@ -426,6 +428,131 @@ class DeviceManagementService:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _parse_ports(self, raw: Any) -> List[int]:
|
||||
ports: List[int] = []
|
||||
if isinstance(raw, str):
|
||||
parts = [part.strip() for part in raw.split(",") if part.strip()]
|
||||
elif isinstance(raw, list):
|
||||
parts = raw
|
||||
else:
|
||||
parts = []
|
||||
for part in parts:
|
||||
try:
|
||||
value = int(part)
|
||||
except Exception:
|
||||
continue
|
||||
if 1 <= value <= 65535:
|
||||
ports.append(value)
|
||||
return list(dict.fromkeys(ports))
|
||||
|
||||
def _default_vpn_ports(self, os_name: Optional[str]) -> List[int]:
|
||||
ports = list(self.adapters.context.wireguard_acl_allowlist_windows or [])
|
||||
os_text = (os_name or "").strip().lower()
|
||||
if os_text and "windows" not in os_text:
|
||||
baseline = {5900, 3478}
|
||||
filtered = [p for p in ports if p in baseline]
|
||||
return filtered or ports
|
||||
return ports
|
||||
|
||||
def get_vpn_config(self, agent_id: str) -> Tuple[Dict[str, Any], int]:
|
||||
agent_id = (agent_id or "").strip()
|
||||
if not agent_id:
|
||||
return {"error": "agent_id_required"}, 400
|
||||
default_ports: List[int] = []
|
||||
shell_port = int(self.adapters.context.wireguard_shell_port)
|
||||
try:
|
||||
conn = self._db_conn()
|
||||
cur = conn.cursor()
|
||||
os_name = ""
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
|
||||
(agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row[0]:
|
||||
os_name = str(row[0])
|
||||
except Exception:
|
||||
os_name = ""
|
||||
default_ports = self._default_vpn_ports(os_name)
|
||||
cur.execute(
|
||||
"SELECT allowed_ports, updated_at, updated_by FROM device_vpn_config WHERE agent_id=?",
|
||||
(agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"allowed_ports": default_ports,
|
||||
"default_ports": default_ports,
|
||||
"shell_port": shell_port,
|
||||
"source": "default",
|
||||
}, 200
|
||||
raw_ports = row[0] or ""
|
||||
ports = []
|
||||
try:
|
||||
ports = json.loads(raw_ports) if raw_ports else []
|
||||
except Exception:
|
||||
ports = []
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"allowed_ports": ports or default_ports,
|
||||
"default_ports": default_ports,
|
||||
"shell_port": shell_port,
|
||||
"updated_at": row[1],
|
||||
"updated_by": row[2],
|
||||
"source": "custom" if ports else "default",
|
||||
}, 200
|
||||
except Exception as exc:
|
||||
self.logger.debug("Failed to load vpn config", exc_info=True)
|
||||
return {"error": "internal_error"}, 500
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def set_vpn_config(self, agent_id: str, payload: Dict[str, Any]) -> Tuple[Dict[str, Any], int]:
|
||||
agent_id = (agent_id or "").strip()
|
||||
if not agent_id:
|
||||
return {"error": "agent_id_required"}, 400
|
||||
ports = self._parse_ports(payload.get("allowed_ports"))
|
||||
if not ports:
|
||||
return {"error": "allowed_ports_required"}, 400
|
||||
user = self._current_user() or {}
|
||||
updated_by = user.get("username") or ""
|
||||
updated_at = datetime.now(timezone.utc).isoformat()
|
||||
try:
|
||||
conn = self._db_conn()
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO device_vpn_config(agent_id, allowed_ports, updated_at, updated_by)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(agent_id) DO UPDATE SET
|
||||
allowed_ports=excluded.allowed_ports,
|
||||
updated_at=excluded.updated_at,
|
||||
updated_by=excluded.updated_by
|
||||
""",
|
||||
(agent_id, json.dumps(ports), updated_at, updated_by),
|
||||
)
|
||||
conn.commit()
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"allowed_ports": ports,
|
||||
"updated_at": updated_at,
|
||||
"updated_by": updated_by,
|
||||
"source": "custom",
|
||||
}, 200
|
||||
except Exception:
|
||||
self.logger.debug("Failed to save vpn config", exc_info=True)
|
||||
return {"error": "internal_error"}, 500
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _require_login(self) -> Optional[Tuple[Dict[str, Any], int]]:
|
||||
if not self._current_user():
|
||||
return {"error": "unauthorized"}, 401
|
||||
@@ -1793,6 +1920,19 @@ def register_management(app, adapters: "EngineServiceAdapters") -> None:
|
||||
payload, status = service.set_device_description(hostname, description)
|
||||
return jsonify(payload), status
|
||||
|
||||
@blueprint.route("/api/device/vpn_config/<agent_id>", methods=["GET", "PUT"])
|
||||
def _vpn_config(agent_id: str):
|
||||
requirement = service._require_login()
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
if request.method == "GET":
|
||||
payload, status = service.get_vpn_config(agent_id)
|
||||
else:
|
||||
body = request.get_json(silent=True) or {}
|
||||
payload, status = service.set_vpn_config(agent_id, body)
|
||||
return jsonify(payload), status
|
||||
|
||||
@blueprint.route("/api/device_list_views", methods=["GET"])
|
||||
def _list_views():
|
||||
requirement = service._require_login()
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\API\devices\tunnel.py
|
||||
# Description: Negotiation endpoint for reverse tunnel leases (operator-initiated; dormant until tunnel listener is wired).
|
||||
# Description: WireGuard VPN tunnel API (connect/status/disconnect).
|
||||
#
|
||||
# API Endpoints (if applicable):
|
||||
# - POST /api/tunnel/request (Token Authenticated) - Allocates a reverse tunnel lease for the requested agent/protocol.
|
||||
# - POST /api/tunnel/connect (Token Authenticated) - Issues VPN session material for an agent.
|
||||
# - GET /api/tunnel/status (Token Authenticated) - Returns VPN status for an agent.
|
||||
# - DELETE /api/tunnel/disconnect (Token Authenticated) - Tears down VPN session for an agent.
|
||||
# ======================================================
|
||||
|
||||
"""Reverse tunnel negotiation API (Engine side)."""
|
||||
"""WireGuard VPN tunnel API (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
@@ -15,15 +17,13 @@ from typing import Any, Dict, Optional, Tuple
|
||||
from flask import Blueprint, jsonify, request, session
|
||||
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
|
||||
|
||||
from ...WebSocket.Agent.reverse_tunnel_orchestrator import ReverseTunnelService
|
||||
from ...VPN import VpnTunnelService
|
||||
|
||||
if False: # pragma: no cover - import cycle hint for type checkers
|
||||
from .. import EngineServiceAdapters
|
||||
|
||||
|
||||
def _current_user(app) -> Optional[Dict[str, str]]:
|
||||
"""Resolve operator identity from session or signed token."""
|
||||
|
||||
username = session.get("username")
|
||||
role = session.get("role") or "User"
|
||||
if username:
|
||||
@@ -58,18 +58,22 @@ def _require_login(app) -> Optional[Tuple[Dict[str, Any], int]]:
|
||||
return None
|
||||
|
||||
|
||||
def _get_tunnel_service(adapters: "EngineServiceAdapters") -> ReverseTunnelService:
|
||||
service = getattr(adapters.context, "reverse_tunnel_service", None) or getattr(adapters, "_reverse_tunnel_service", None)
|
||||
def _get_tunnel_service(adapters: "EngineServiceAdapters") -> VpnTunnelService:
|
||||
service = getattr(adapters.context, "vpn_tunnel_service", None) or getattr(adapters, "_vpn_tunnel_service", None)
|
||||
if service is None:
|
||||
service = ReverseTunnelService(
|
||||
adapters.context,
|
||||
signer=getattr(adapters, "script_signer", None),
|
||||
manager = getattr(adapters.context, "wireguard_server_manager", None)
|
||||
if manager is None:
|
||||
raise RuntimeError("wireguard_manager_unavailable")
|
||||
service = VpnTunnelService(
|
||||
context=adapters.context,
|
||||
wireguard_manager=manager,
|
||||
db_conn_factory=adapters.db_conn_factory,
|
||||
socketio=getattr(adapters.context, "socketio", None),
|
||||
service_log=adapters.service_log,
|
||||
signer=getattr(adapters, "script_signer", None),
|
||||
)
|
||||
service.start()
|
||||
setattr(adapters, "_reverse_tunnel_service", service)
|
||||
setattr(adapters.context, "reverse_tunnel_service", service)
|
||||
setattr(adapters, "_vpn_tunnel_service", service)
|
||||
setattr(adapters.context, "vpn_tunnel_service", service)
|
||||
return service
|
||||
|
||||
|
||||
@@ -83,14 +87,11 @@ def _normalize_text(value: Any) -> str:
|
||||
|
||||
|
||||
def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
|
||||
"""Register reverse tunnel negotiation endpoints."""
|
||||
blueprint = Blueprint("vpn_tunnel", __name__)
|
||||
logger = adapters.context.logger.getChild("vpn_tunnel.api")
|
||||
|
||||
blueprint = Blueprint("reverse_tunnel", __name__)
|
||||
service_log = adapters.service_log
|
||||
logger = adapters.context.logger.getChild("tunnel.api")
|
||||
|
||||
@blueprint.route("/api/tunnel/request", methods=["POST"])
|
||||
def request_tunnel():
|
||||
@blueprint.route("/api/tunnel/connect", methods=["POST"])
|
||||
def connect_tunnel():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
@@ -101,69 +102,67 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
agent_id = _normalize_text(body.get("agent_id"))
|
||||
protocol = _normalize_text(body.get("protocol") or "ps").lower() or "ps"
|
||||
domain = _normalize_text(body.get("domain") or protocol).lower() or protocol
|
||||
if protocol == "ps" and domain == "ps":
|
||||
domain = "remote-interactive-shell"
|
||||
|
||||
if not agent_id:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
try:
|
||||
lease = tunnel_service.request_lease(
|
||||
agent_id=agent_id,
|
||||
protocol=protocol,
|
||||
domain=domain,
|
||||
operator_id=operator_id,
|
||||
)
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
payload = tunnel_service.connect(agent_id=agent_id, operator_id=operator_id)
|
||||
except RuntimeError as exc:
|
||||
message = str(exc)
|
||||
if message.startswith("domain_limit:"):
|
||||
domain_name = message.split(":", 1)[-1] if ":" in message else domain
|
||||
return jsonify({"error": "domain_limit", "domain": domain_name}), 409
|
||||
if message == "port_pool_exhausted":
|
||||
return jsonify({"error": "port_pool_exhausted"}), 503
|
||||
logger.warning("tunnel lease request failed for agent_id=%s: %s", agent_id, message)
|
||||
return jsonify({"error": "lease_allocation_failed"}), 500
|
||||
logger.warning("vpn connect failed for agent_id=%s: %s", agent_id, exc)
|
||||
return jsonify({"error": "connect_failed"}), 500
|
||||
|
||||
summary = tunnel_service.lease_summary(lease)
|
||||
summary["fixed_port"] = tunnel_service.fixed_port
|
||||
summary["heartbeat_seconds"] = tunnel_service.heartbeat_seconds
|
||||
return jsonify(payload), 200
|
||||
|
||||
service_log(
|
||||
"reverse_tunnel",
|
||||
f"lease created tunnel_id={lease.tunnel_id} agent_id={lease.agent_id} domain={lease.domain} protocol={lease.protocol}",
|
||||
)
|
||||
return jsonify(summary), 200
|
||||
|
||||
@blueprint.route("/api/tunnel/<tunnel_id>", methods=["DELETE"])
|
||||
def stop_tunnel(tunnel_id: str):
|
||||
@blueprint.route("/api/tunnel/status", methods=["GET"])
|
||||
def tunnel_status():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
|
||||
tunnel_id_norm = _normalize_text(tunnel_id)
|
||||
if not tunnel_id_norm:
|
||||
return jsonify({"error": "tunnel_id_required"}), 400
|
||||
agent_id = _normalize_text(request.args.get("agent_id") or "")
|
||||
if not agent_id:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
payload = tunnel_service.status(agent_id)
|
||||
if not payload:
|
||||
return jsonify({"status": "down", "agent_id": agent_id}), 200
|
||||
payload["status"] = "up"
|
||||
bump = _normalize_text(request.args.get("bump") or "")
|
||||
if bump:
|
||||
tunnel_service.bump_activity(agent_id)
|
||||
return jsonify(payload), 200
|
||||
|
||||
@blueprint.route("/api/tunnel/connect/status", methods=["GET"])
|
||||
def tunnel_connect_status():
|
||||
return tunnel_status()
|
||||
|
||||
@blueprint.route("/api/tunnel/disconnect", methods=["DELETE"])
|
||||
def disconnect_tunnel():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
agent_id = _normalize_text(body.get("agent_id"))
|
||||
tunnel_id = _normalize_text(body.get("tunnel_id"))
|
||||
reason = _normalize_text(body.get("reason") or "operator_stop")
|
||||
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
stopped = False
|
||||
try:
|
||||
stopped = tunnel_service.stop_tunnel(tunnel_id_norm, reason=reason)
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("stop_tunnel failed tunnel_id=%s: %s", tunnel_id_norm, exc, exc_info=True)
|
||||
if tunnel_id:
|
||||
stopped = tunnel_service.disconnect_by_tunnel(tunnel_id, reason=reason)
|
||||
elif agent_id:
|
||||
stopped = tunnel_service.disconnect(agent_id, reason=reason)
|
||||
else:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
if not stopped:
|
||||
return jsonify({"error": "not_found"}), 404
|
||||
|
||||
service_log(
|
||||
"reverse_tunnel",
|
||||
f"lease stopped tunnel_id={tunnel_id_norm} reason={reason or '-'}",
|
||||
)
|
||||
return jsonify({"status": "stopped", "tunnel_id": tunnel_id_norm}), 200
|
||||
return jsonify({"status": "stopped", "reason": reason}), 200
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
@@ -8,4 +8,4 @@
|
||||
"""VPN service helpers for the Engine runtime."""
|
||||
|
||||
from .wireguard_server import WireGuardServerConfig, WireGuardServerManager # noqa: F401
|
||||
|
||||
from .vpn_tunnel_service import VpnTunnelService # noqa: F401
|
||||
|
||||
473
Data/Engine/services/VPN/vpn_tunnel_service.py
Normal file
473
Data/Engine/services/VPN/vpn_tunnel_service.py
Normal file
@@ -0,0 +1,473 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\VPN\vpn_tunnel_service.py
|
||||
# Description: WireGuard tunnel orchestration (single tunnel per agent, token issuance, idle handling).
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""WireGuard tunnel orchestration helpers for the Engine runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import ipaddress
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from .wireguard_server import WireGuardServerManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class VpnSession:
|
||||
tunnel_id: str
|
||||
agent_id: str
|
||||
virtual_ip: str
|
||||
token: Dict[str, Any]
|
||||
client_public_key: str
|
||||
client_private_key: str
|
||||
allowed_ports: Tuple[int, ...]
|
||||
created_at: float
|
||||
expires_at: float
|
||||
last_activity: float
|
||||
operator_ids: set[str] = field(default_factory=set)
|
||||
firewall_rules: List[str] = field(default_factory=list)
|
||||
activity_id: Optional[int] = None
|
||||
hostname: Optional[str] = None
|
||||
|
||||
|
||||
class VpnTunnelService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
context: Any,
|
||||
wireguard_manager: WireGuardServerManager,
|
||||
db_conn_factory,
|
||||
socketio,
|
||||
service_log,
|
||||
signer: Optional[Any] = None,
|
||||
idle_seconds: int = 900,
|
||||
) -> None:
|
||||
self.context = context
|
||||
self.wg = wireguard_manager
|
||||
self.db_conn_factory = db_conn_factory
|
||||
self.socketio = socketio
|
||||
self.service_log = service_log
|
||||
self.signer = signer
|
||||
self.logger = context.logger.getChild("vpn_tunnel")
|
||||
self.activity_logger = self.wg.logger.getChild("device_activity")
|
||||
self.idle_seconds = max(60, int(idle_seconds))
|
||||
self._lock = threading.Lock()
|
||||
self._sessions_by_agent: Dict[str, VpnSession] = {}
|
||||
self._sessions_by_tunnel: Dict[str, VpnSession] = {}
|
||||
self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip)
|
||||
self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False)
|
||||
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
|
||||
self._idle_thread.start()
|
||||
|
||||
def _idle_loop(self) -> None:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
now = time.time()
|
||||
expired: List[VpnSession] = []
|
||||
with self._lock:
|
||||
for session in list(self._sessions_by_agent.values()):
|
||||
if session.last_activity + self.idle_seconds <= now:
|
||||
expired.append(session)
|
||||
for session in expired:
|
||||
self.disconnect(session.agent_id, reason="idle_timeout")
|
||||
|
||||
def _allocate_virtual_ip(self, agent_id: str) -> str:
|
||||
existing = self._sessions_by_agent.get(agent_id)
|
||||
if existing:
|
||||
return existing.virtual_ip
|
||||
|
||||
used = {s.virtual_ip for s in self._sessions_by_agent.values()}
|
||||
for host in self._peer_network.hosts():
|
||||
if host == self._engine_ip.ip:
|
||||
continue
|
||||
candidate = f"{host}/32"
|
||||
if candidate not in used:
|
||||
return candidate
|
||||
raise RuntimeError("vpn_ip_pool_exhausted")
|
||||
|
||||
def _load_allowed_ports(self, agent_id: str) -> Tuple[int, ...]:
|
||||
default = tuple(self.context.wireguard_acl_allowlist_windows or ())
|
||||
try:
|
||||
conn = self.db_conn_factory()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
|
||||
(agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
os_name = str(row[0]).lower() if row and row[0] else ""
|
||||
except Exception:
|
||||
os_name = ""
|
||||
if os_name and "windows" not in os_name:
|
||||
baseline = {5900, 3478}
|
||||
filtered = [p for p in default if p in baseline]
|
||||
if filtered:
|
||||
default = tuple(filtered)
|
||||
cur.execute(
|
||||
"SELECT allowed_ports FROM device_vpn_config WHERE agent_id=?",
|
||||
(agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return default
|
||||
raw = row[0] or ""
|
||||
ports = json.loads(raw) if raw else []
|
||||
ports = [int(p) for p in ports if isinstance(p, (int, float, str))]
|
||||
ports = [p for p in ports if 1 <= p <= 65535]
|
||||
return tuple(dict.fromkeys(ports)) or default
|
||||
except Exception:
|
||||
return default
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _generate_client_keys(self) -> Tuple[str, str]:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import x25519
|
||||
|
||||
key = x25519.X25519PrivateKey.generate()
|
||||
priv = base64.b64encode(
|
||||
key.private_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PrivateFormat.Raw,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
).decode("ascii").strip()
|
||||
pub = base64.b64encode(
|
||||
key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
).decode("ascii").strip()
|
||||
return priv, pub
|
||||
|
||||
def _issue_token(self, agent_id: str, tunnel_id: str, expires_at: float) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"agent_id": agent_id,
|
||||
"tunnel_id": tunnel_id,
|
||||
"port": self.context.wireguard_port,
|
||||
"expires_at": expires_at,
|
||||
"issued_at": time.time(),
|
||||
}
|
||||
if not self.signer:
|
||||
return dict(payload)
|
||||
|
||||
token = dict(payload)
|
||||
try:
|
||||
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
|
||||
signature = self.signer.sign(payload_bytes)
|
||||
token["signature"] = base64.b64encode(signature).decode("ascii")
|
||||
if hasattr(self.signer, "public_base64_spki"):
|
||||
token["signing_key"] = self.signer.public_base64_spki()
|
||||
token["sig_alg"] = "ed25519"
|
||||
except Exception:
|
||||
self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True)
|
||||
return token
|
||||
|
||||
def _refresh_listener(self) -> None:
|
||||
peers: List[Mapping[str, object]] = []
|
||||
for session in self._sessions_by_agent.values():
|
||||
peer = self.wg.build_peer_profile(
|
||||
session.agent_id,
|
||||
session.virtual_ip,
|
||||
allowed_ports=session.allowed_ports,
|
||||
)
|
||||
peer = dict(peer)
|
||||
peer["public_key"] = session.client_public_key
|
||||
peers.append(peer)
|
||||
if not peers:
|
||||
self.wg.stop_listener()
|
||||
return
|
||||
self.wg.start_listener(peers)
|
||||
|
||||
def connect(self, *, agent_id: str, operator_id: Optional[str]) -> Mapping[str, Any]:
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
existing = self._sessions_by_agent.get(agent_id)
|
||||
if existing:
|
||||
if operator_id:
|
||||
existing.operator_ids.add(operator_id)
|
||||
existing.last_activity = now
|
||||
return self._session_payload(existing)
|
||||
|
||||
tunnel_id = uuid.uuid4().hex
|
||||
virtual_ip = self._allocate_virtual_ip(agent_id)
|
||||
allowed_ports = self._load_allowed_ports(agent_id)
|
||||
client_private, client_public = self._generate_client_keys()
|
||||
token = self._issue_token(agent_id, tunnel_id, now + 300)
|
||||
self.wg.require_orchestration_token(token)
|
||||
|
||||
session = VpnSession(
|
||||
tunnel_id=tunnel_id,
|
||||
agent_id=agent_id,
|
||||
virtual_ip=virtual_ip,
|
||||
token=token,
|
||||
client_public_key=client_public,
|
||||
client_private_key=client_private,
|
||||
allowed_ports=allowed_ports,
|
||||
created_at=now,
|
||||
expires_at=now + 300,
|
||||
last_activity=now,
|
||||
)
|
||||
if operator_id:
|
||||
session.operator_ids.add(operator_id)
|
||||
self._sessions_by_agent[agent_id] = session
|
||||
self._sessions_by_tunnel[tunnel_id] = session
|
||||
|
||||
try:
|
||||
self._refresh_listener()
|
||||
|
||||
peer = self.wg.build_peer_profile(
|
||||
agent_id,
|
||||
virtual_ip,
|
||||
allowed_ports=allowed_ports,
|
||||
)
|
||||
rule_names = self.wg.apply_firewall_rules(peer)
|
||||
session.firewall_rules = rule_names
|
||||
except Exception:
|
||||
with self._lock:
|
||||
self._sessions_by_agent.pop(agent_id, None)
|
||||
self._sessions_by_tunnel.pop(tunnel_id, None)
|
||||
try:
|
||||
self._refresh_listener()
|
||||
except Exception:
|
||||
self.logger.debug("Failed to refresh WireGuard listener after connect rollback.", exc_info=True)
|
||||
raise
|
||||
|
||||
payload = self._session_payload(session)
|
||||
self._emit_start(payload)
|
||||
self._log_device_activity(session, event="start")
|
||||
return payload
|
||||
|
||||
def status(self, agent_id: str) -> Optional[Mapping[str, Any]]:
|
||||
with self._lock:
|
||||
session = self._sessions_by_agent.get(agent_id)
|
||||
if not session:
|
||||
return None
|
||||
return self._session_payload(session, include_token=False)
|
||||
|
||||
def bump_activity(self, agent_id: str) -> None:
|
||||
with self._lock:
|
||||
session = self._sessions_by_agent.get(agent_id)
|
||||
if not session:
|
||||
return
|
||||
session.last_activity = time.time()
|
||||
try:
|
||||
if self.socketio:
|
||||
self.socketio.emit("vpn_tunnel_activity", {"agent_id": agent_id}, namespace="/")
|
||||
except Exception:
|
||||
self.logger.debug("vpn_tunnel_activity emit failed for agent_id=%s", agent_id, exc_info=True)
|
||||
|
||||
def disconnect(self, agent_id: str, reason: str = "operator_stop") -> bool:
|
||||
with self._lock:
|
||||
session = self._sessions_by_agent.pop(agent_id, None)
|
||||
if not session:
|
||||
return False
|
||||
self._sessions_by_tunnel.pop(session.tunnel_id, None)
|
||||
|
||||
try:
|
||||
self.wg.remove_firewall_rules(session.firewall_rules)
|
||||
except Exception:
|
||||
self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True)
|
||||
|
||||
self._refresh_listener()
|
||||
self._emit_stop(session, reason)
|
||||
self._log_device_activity(session, event="stop", reason=reason)
|
||||
return True
|
||||
|
||||
def disconnect_by_tunnel(self, tunnel_id: str, reason: str = "operator_stop") -> bool:
|
||||
with self._lock:
|
||||
session = self._sessions_by_tunnel.get(tunnel_id)
|
||||
if not session:
|
||||
return False
|
||||
return self.disconnect(session.agent_id, reason=reason)
|
||||
|
||||
def _emit_start(self, payload: Mapping[str, Any]) -> None:
|
||||
if not self.socketio:
|
||||
return
|
||||
try:
|
||||
self.socketio.emit("vpn_tunnel_start", payload, namespace="/")
|
||||
except Exception:
|
||||
self.logger.debug("vpn_tunnel_start emit failed", exc_info=True)
|
||||
|
||||
def _emit_stop(self, session: VpnSession, reason: str) -> None:
|
||||
if not self.socketio:
|
||||
return
|
||||
try:
|
||||
self.socketio.emit(
|
||||
"vpn_tunnel_stop",
|
||||
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
|
||||
namespace="/",
|
||||
)
|
||||
except Exception:
|
||||
self.logger.debug("vpn_tunnel_stop emit failed", exc_info=True)
|
||||
|
||||
def _log_device_activity(self, session: VpnSession, *, event: str, reason: Optional[str] = None) -> None:
|
||||
if self.db_conn_factory is None:
|
||||
self.activity_logger.info(
|
||||
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s",
|
||||
event,
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
",".join(sorted(filter(None, session.operator_ids))) or "-",
|
||||
reason or "-",
|
||||
)
|
||||
return
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = self.db_conn_factory()
|
||||
cur = conn.cursor()
|
||||
hostname = session.hostname
|
||||
if not hostname:
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT hostname FROM devices WHERE agent_id = ? ORDER BY last_seen DESC LIMIT 1",
|
||||
(session.agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row[0]:
|
||||
hostname = str(row[0]).strip()
|
||||
session.hostname = hostname
|
||||
except Exception:
|
||||
hostname = None
|
||||
|
||||
if not hostname:
|
||||
self.activity_logger.info(
|
||||
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s hostname=unknown",
|
||||
event,
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
",".join(sorted(filter(None, session.operator_ids))) or "-",
|
||||
reason or "-",
|
||||
)
|
||||
return
|
||||
|
||||
now_ts = int(time.time())
|
||||
script_name = "Reverse VPN Tunnel (WireGuard)"
|
||||
|
||||
if event == "start":
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO activity_history(hostname, script_path, script_name, script_type, ran_at, status, stdout, stderr)
|
||||
VALUES(?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
(
|
||||
hostname,
|
||||
session.tunnel_id,
|
||||
script_name,
|
||||
"vpn_tunnel",
|
||||
now_ts,
|
||||
"Running",
|
||||
"",
|
||||
"",
|
||||
),
|
||||
)
|
||||
session.activity_id = cur.lastrowid
|
||||
conn.commit()
|
||||
if self.socketio:
|
||||
try:
|
||||
self.socketio.emit(
|
||||
"device_activity_changed",
|
||||
{
|
||||
"hostname": hostname,
|
||||
"activity_id": session.activity_id,
|
||||
"change": "created",
|
||||
"source": "vpn_tunnel",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self.activity_logger.info(
|
||||
"device_activity_start hostname=%s agent_id=%s tunnel_id=%s operator=%s activity_id=%s",
|
||||
hostname,
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
",".join(sorted(filter(None, session.operator_ids))) or "-",
|
||||
session.activity_id or "-",
|
||||
)
|
||||
return
|
||||
|
||||
if session.activity_id:
|
||||
status = "Completed" if event == "stop" else "Closed"
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE activity_history
|
||||
SET status=?,
|
||||
stderr=COALESCE(stderr, '') || ?
|
||||
WHERE id=?
|
||||
""",
|
||||
(
|
||||
status,
|
||||
f"\nreason: {reason}" if reason else "",
|
||||
session.activity_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
if self.socketio:
|
||||
try:
|
||||
self.socketio.emit(
|
||||
"device_activity_changed",
|
||||
{
|
||||
"hostname": hostname,
|
||||
"activity_id": session.activity_id,
|
||||
"change": "updated",
|
||||
"source": "vpn_tunnel",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.activity_logger.info(
|
||||
"device_activity event=%s hostname=%s agent_id=%s tunnel_id=%s operator=%s reason=%s activity_id=%s",
|
||||
event,
|
||||
hostname,
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
",".join(sorted(filter(None, session.operator_ids))) or "-",
|
||||
reason or "-",
|
||||
session.activity_id or "-",
|
||||
)
|
||||
except Exception:
|
||||
self.activity_logger.debug(
|
||||
"device_activity logging failed for tunnel_id=%s",
|
||||
session.tunnel_id,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]:
|
||||
payload: Dict[str, Any] = {
|
||||
"tunnel_id": session.tunnel_id,
|
||||
"agent_id": session.agent_id,
|
||||
"virtual_ip": session.virtual_ip,
|
||||
"engine_virtual_ip": str(self._engine_ip.ip),
|
||||
"allowed_ips": f"{self._engine_ip.ip}/32",
|
||||
"endpoint": f"{self._engine_ip.ip}:{self.context.wireguard_port}",
|
||||
"server_public_key": self.wg.server_public_key,
|
||||
"client_public_key": session.client_public_key,
|
||||
"client_private_key": session.client_private_key,
|
||||
"idle_seconds": self.idle_seconds,
|
||||
"allowed_ports": list(session.allowed_ports),
|
||||
"connected_operators": len([o for o in session.operator_ids if o]),
|
||||
}
|
||||
if include_token:
|
||||
payload["token"] = session.token
|
||||
return payload
|
||||
@@ -70,7 +70,7 @@ class WireGuardServerManager:
|
||||
self.logger = _build_logger(config.log_path)
|
||||
self._ensure_cert_dir()
|
||||
self.server_private_key, self.server_public_key = self._ensure_server_keys()
|
||||
self._service_name = "BorealisWireGuard"
|
||||
self._service_name = "borealis-wg"
|
||||
self._temp_dir = Path(tempfile.gettempdir()) / "borealis-wg-engine"
|
||||
|
||||
def _ensure_cert_dir(self) -> None:
|
||||
@@ -157,7 +157,7 @@ class WireGuardServerManager:
|
||||
if not token:
|
||||
raise ValueError("Missing orchestration token for WireGuard peer")
|
||||
|
||||
required_fields = ("agent_id", "tunnel_id", "expires_at")
|
||||
required_fields = ("agent_id", "tunnel_id", "expires_at", "port")
|
||||
missing = [field for field in required_fields if field not in token or token[field] in (None, "")]
|
||||
if missing:
|
||||
raise ValueError(f"Invalid orchestration token; missing {', '.join(missing)}")
|
||||
@@ -167,6 +167,13 @@ class WireGuardServerManager:
|
||||
except Exception:
|
||||
raise ValueError("Invalid orchestration token expiry")
|
||||
|
||||
try:
|
||||
port = int(token["port"])
|
||||
except Exception:
|
||||
raise ValueError("Invalid orchestration token port")
|
||||
if port != int(self.config.port):
|
||||
raise ValueError("Orchestration token port mismatch")
|
||||
|
||||
now = time.time()
|
||||
if expires_at <= now:
|
||||
raise ValueError("Orchestration token expired")
|
||||
@@ -253,12 +260,14 @@ class WireGuardServerManager:
|
||||
"host_only": True,
|
||||
}
|
||||
|
||||
def apply_firewall_rules(self, peer: Mapping[str, object]) -> None:
|
||||
def apply_firewall_rules(self, peer: Mapping[str, object]) -> List[str]:
|
||||
"""Apply outbound firewall allow rules for the agent's virtual IP/ports (Windows netsh)."""
|
||||
|
||||
rules = self.build_firewall_rules(peer)
|
||||
rule_names: List[str] = []
|
||||
for idx, rule in enumerate(rules):
|
||||
name = f"Borealis-WG-Agent-{peer.get('agent_id','')}-{idx}"
|
||||
protocol = str(rule.get("protocol") or "TCP").upper()
|
||||
args = [
|
||||
"netsh",
|
||||
"advfirewall",
|
||||
@@ -269,7 +278,7 @@ class WireGuardServerManager:
|
||||
"dir=out",
|
||||
"action=allow",
|
||||
f"remoteip={rule.get('remote_address','')}",
|
||||
f"protocol=TCP",
|
||||
f"protocol={protocol}",
|
||||
f"localport={rule.get('local_port','')}",
|
||||
]
|
||||
code, out, err = self._run_command(args)
|
||||
@@ -277,6 +286,19 @@ class WireGuardServerManager:
|
||||
self.logger.warning("Failed to apply firewall rule %s code=%s err=%s", name, code, err)
|
||||
else:
|
||||
self.logger.info("Applied firewall rule %s", name)
|
||||
rule_names.append(name)
|
||||
return rule_names
|
||||
|
||||
def remove_firewall_rules(self, rule_names: Sequence[str]) -> None:
|
||||
for name in rule_names:
|
||||
if not name:
|
||||
continue
|
||||
args = ["netsh", "advfirewall", "firewall", "delete", "rule", f"name={name}"]
|
||||
code, out, err = self._run_command(args)
|
||||
if code != 0:
|
||||
self.logger.warning("Failed to remove firewall rule %s code=%s err=%s", name, code, err)
|
||||
else:
|
||||
self.logger.info("Removed firewall rule %s", name)
|
||||
|
||||
def start_listener(self, peers: Sequence[Mapping[str, object]]) -> None:
|
||||
"""Render a temporary WireGuard config and start the service."""
|
||||
@@ -291,6 +313,9 @@ class WireGuardServerManager:
|
||||
config_path.write_text(rendered, encoding="utf-8")
|
||||
self.logger.info("Rendered WireGuard config to %s", config_path)
|
||||
|
||||
# Ensure old service is removed before re-installing.
|
||||
self.stop_listener()
|
||||
|
||||
args = ["wireguard.exe", "/installtunnelservice", str(config_path)]
|
||||
code, out, err = self._run_command(args)
|
||||
if code != 0:
|
||||
@@ -301,7 +326,7 @@ class WireGuardServerManager:
|
||||
def stop_listener(self) -> None:
|
||||
"""Stop and remove the WireGuard tunnel service."""
|
||||
|
||||
args = ["wireguard.exe", "/uninstalltunnelservice", "borealis-wg"]
|
||||
args = ["wireguard.exe", "/uninstalltunnelservice", self._service_name]
|
||||
code, out, err = self._run_command(args)
|
||||
if code != 0:
|
||||
self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err)
|
||||
@@ -323,15 +348,17 @@ class WireGuardServerManager:
|
||||
port_list = []
|
||||
|
||||
for port in port_list:
|
||||
rules.append(
|
||||
{
|
||||
"direction": "outbound",
|
||||
"remote_address": ip,
|
||||
"local_port": port,
|
||||
"action": "allow",
|
||||
"description": f"WireGuard engine->agent allow port {port}",
|
||||
}
|
||||
)
|
||||
for protocol in ("TCP", "UDP"):
|
||||
rules.append(
|
||||
{
|
||||
"direction": "outbound",
|
||||
"remote_address": ip,
|
||||
"local_port": port,
|
||||
"protocol": protocol,
|
||||
"action": "allow",
|
||||
"description": f"WireGuard engine->agent allow port {port}/{protocol}",
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"Prepared firewall rule plan for agent=%s rules=%s",
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Namespace package for reverse tunnel domain handlers (Engine side)."""
|
||||
|
||||
__all__ = ["remote_interactive_shell", "remote_management", "remote_video"]
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Placeholder Bash channel server (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
class BashChannelServer:
|
||||
"""Stub Bash handler until the agent-side channel is implemented."""
|
||||
|
||||
protocol_name = "bash"
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"bash.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output = deque()
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
# No-op placeholder; output collection for future Bash support.
|
||||
try:
|
||||
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
|
||||
self._ack_received = True
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"bash channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
# Placeholder: no agent-side Bash yet.
|
||||
self.logger.info("bash placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
# Placeholder: not implemented.
|
||||
return
|
||||
|
||||
def close(self, code: int = 6, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
try:
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def drain_output(self):
|
||||
items = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": None,
|
||||
"close_code": None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["BashChannelServer"]
|
||||
@@ -1,139 +0,0 @@
|
||||
"""Engine-side PowerShell tunnel channel helper (remote interactive shell domain)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Optional
|
||||
|
||||
# Mirror framing constants to avoid circular imports.
|
||||
MSG_CHANNEL_OPEN = 0x03
|
||||
MSG_CHANNEL_ACK = 0x04
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_OK = 0
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class PowershellChannelServer:
|
||||
"""Coordinate PowerShell channel frames over a TunnelBridge."""
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"ps.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output: Deque[str] = deque()
|
||||
self._close_reason: Optional[str] = None
|
||||
self._close_code: Optional[int] = None
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
# ------------------------------------------------------------------ Agent frame handling
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
if frame.channel_id != self.channel_id:
|
||||
return
|
||||
if frame.msg_type == MSG_CHANNEL_ACK:
|
||||
self._ack_received = True
|
||||
self.logger.info("ps channel acked tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
return
|
||||
if frame.msg_type == MSG_DATA:
|
||||
try:
|
||||
text = frame.payload.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
text = ""
|
||||
if text:
|
||||
self._append_output(text)
|
||||
return
|
||||
if frame.msg_type == MSG_CLOSE:
|
||||
try:
|
||||
payload = json.loads(frame.payload.decode("utf-8"))
|
||||
except Exception:
|
||||
payload = {}
|
||||
self._closed = True
|
||||
self._close_code = payload.get("code") if isinstance(payload, dict) else None
|
||||
self._close_reason = payload.get("reason") if isinstance(payload, dict) else None
|
||||
self.logger.info(
|
||||
"ps channel closed tunnel_id=%s code=%s reason=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self._close_code,
|
||||
self._close_reason or "-",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ Operator actions
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
if self._open_sent:
|
||||
return
|
||||
payload = json.dumps(
|
||||
{"protocol": "ps", "metadata": {"cols": cols, "rows": rows}},
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
frame = self._frame_cls(msg_type=MSG_CHANNEL_OPEN, channel_id=self.channel_id, payload=payload)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"ps channel open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
payload = data.encode("utf-8", errors="replace")
|
||||
frame = self._frame_cls(msg_type=MSG_DATA, channel_id=self.channel_id, payload=payload)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
payload = json.dumps({"cols": cols, "rows": rows}, separators=(",", ":")).encode("utf-8")
|
||||
frame = self._frame_cls(msg_type=MSG_CONTROL, channel_id=self.channel_id, payload=payload)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
|
||||
def close(self, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
else:
|
||||
frame = self._frame_cls(
|
||||
msg_type=MSG_CLOSE,
|
||||
channel_id=self.channel_id,
|
||||
payload=json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8"),
|
||||
)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
|
||||
# ------------------------------------------------------------------ Output polling
|
||||
def drain_output(self) -> List[str]:
|
||||
items: List[str] = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def _append_output(self, text: str) -> None:
|
||||
self._output.append(text)
|
||||
# Cap buffer to avoid unbounded memory growth.
|
||||
while len(self._output) > 500:
|
||||
self._output.popleft()
|
||||
|
||||
# ------------------------------------------------------------------ Status helpers
|
||||
def status(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": self._close_reason,
|
||||
"close_code": self._close_code,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["PowershellChannelServer"]
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Protocol handlers for remote interactive shell tunnels (Engine side)."""
|
||||
|
||||
from .Powershell import PowershellChannelServer
|
||||
from .Bash import BashChannelServer
|
||||
|
||||
__all__ = ["PowershellChannelServer", "BashChannelServer"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Domain handlers for remote interactive shells (PowerShell/Bash)."""
|
||||
|
||||
__all__ = ["Protocols"]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Placeholder SSH channel server (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
class SSHChannelServer:
|
||||
protocol_name = "ssh"
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"ssh.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output = deque()
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
try:
|
||||
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
|
||||
self._ack_received = True
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"ssh channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
self.logger.info("ssh placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
return
|
||||
|
||||
def close(self, code: int = 6, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
try:
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def drain_output(self):
|
||||
items = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": None,
|
||||
"close_code": None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["SSHChannelServer"]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Placeholder WinRM channel server (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
class WinRMChannelServer:
|
||||
protocol_name = "winrm"
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"winrm.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output = deque()
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
try:
|
||||
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
|
||||
self._ack_received = True
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"winrm channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
self.logger.info("winrm placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
return
|
||||
|
||||
def close(self, code: int = 6, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
try:
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def drain_output(self):
|
||||
items = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": None,
|
||||
"close_code": None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["WinRMChannelServer"]
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Protocol handlers for remote management tunnels (Engine side)."""
|
||||
|
||||
from .SSH import SSHChannelServer
|
||||
from .WinRM import WinRMChannelServer
|
||||
|
||||
__all__ = ["SSHChannelServer", "WinRMChannelServer"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Domain handlers for remote management tunnels (SSH/WinRM)."""
|
||||
|
||||
__all__ = ["Protocols"]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Placeholder RDP channel server (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
class RDPChannelServer:
|
||||
protocol_name = "rdp"
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"rdp.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output = deque()
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
try:
|
||||
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
|
||||
self._ack_received = True
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"rdp channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
self.logger.info("rdp placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
return
|
||||
|
||||
def close(self, code: int = 6, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
try:
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def drain_output(self):
|
||||
items = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": None,
|
||||
"close_code": None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["RDPChannelServer"]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Placeholder VNC channel server (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
class VNCChannelServer:
|
||||
protocol_name = "vnc"
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"vnc.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output = deque()
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
try:
|
||||
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
|
||||
self._ack_received = True
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"vnc channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
self.logger.info("vnc placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
return
|
||||
|
||||
def close(self, code: int = 6, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
try:
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def drain_output(self):
|
||||
items = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": None,
|
||||
"close_code": None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["VNCChannelServer"]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Placeholder WebRTC channel server (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
class WebRTCChannelServer:
|
||||
protocol_name = "webrtc"
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"webrtc.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output = deque()
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
try:
|
||||
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
|
||||
self._ack_received = True
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"webrtc channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
self.logger.info("webrtc placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
return
|
||||
|
||||
def close(self, code: int = 6, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
try:
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def drain_output(self):
|
||||
items = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": None,
|
||||
"close_code": None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["WebRTCChannelServer"]
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Protocol handlers for remote video tunnels (Engine side)."""
|
||||
|
||||
from .WebRTC import WebRTCChannelServer
|
||||
from .RDP import RDPChannelServer
|
||||
from .VNC import VNCChannelServer
|
||||
|
||||
__all__ = ["WebRTCChannelServer", "RDPChannelServer", "VNCChannelServer"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Domain handlers for remote video/desktop tunnels (RDP/VNC/WebRTC)."""
|
||||
|
||||
__all__ = ["Protocols"]
|
||||
@@ -1,10 +0,0 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\WebSocket\Agent\__init__.py
|
||||
# Description: Package marker for Agent-facing WebSocket services (reverse tunnel scaffolding).
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Agent-facing WebSocket services for the Engine runtime."""
|
||||
|
||||
__all__ = []
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\WebSocket\__init__.py
|
||||
# Description: Socket.IO handlers for Engine runtime quick job updates and realtime notifications.
|
||||
# Description: Socket.IO handlers for Engine runtime quick job updates and VPN shell bridging.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
@@ -8,24 +8,20 @@
|
||||
"""WebSocket service registration for the Borealis Engine runtime."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import sqlite3
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from flask import session, request
|
||||
from flask import request
|
||||
from flask_socketio import SocketIO
|
||||
|
||||
from ...database import initialise_engine_database
|
||||
from ...security import signing
|
||||
from ...server import EngineContext
|
||||
from .Agent.reverse_tunnel_orchestrator import (
|
||||
ReverseTunnelService,
|
||||
TunnelBridge,
|
||||
decode_frame,
|
||||
TunnelFrame,
|
||||
)
|
||||
from ..VPN import VpnTunnelService
|
||||
from .vpn_shell import VpnShellBridge
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
@@ -70,20 +66,31 @@ class EngineRealtimeAdapters:
|
||||
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
"""Register Socket.IO event handlers for the Engine runtime."""
|
||||
|
||||
from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load
|
||||
|
||||
adapters = EngineRealtimeAdapters(context)
|
||||
logger = context.logger.getChild("realtime.quick_jobs")
|
||||
tunnel_service = getattr(context, "reverse_tunnel_service", None)
|
||||
if tunnel_service is None:
|
||||
tunnel_service = ReverseTunnelService(
|
||||
context,
|
||||
signer=None,
|
||||
shell_bridge = VpnShellBridge(socket_server, context)
|
||||
|
||||
def _get_tunnel_service() -> Optional[VpnTunnelService]:
|
||||
service = getattr(context, "vpn_tunnel_service", None)
|
||||
if service is not None:
|
||||
return service
|
||||
manager = getattr(context, "wireguard_server_manager", None)
|
||||
if manager is None:
|
||||
return None
|
||||
try:
|
||||
signer = signing.load_signer()
|
||||
except Exception:
|
||||
signer = None
|
||||
service = VpnTunnelService(
|
||||
context=context,
|
||||
wireguard_manager=manager,
|
||||
db_conn_factory=adapters.db_conn_factory,
|
||||
socketio=socket_server,
|
||||
service_log=adapters.service_log,
|
||||
signer=signer,
|
||||
)
|
||||
tunnel_service.start()
|
||||
setattr(context, "reverse_tunnel_service", tunnel_service)
|
||||
setattr(context, "vpn_tunnel_service", service)
|
||||
return service
|
||||
|
||||
@socket_server.on("quick_job_result")
|
||||
def _handle_quick_job_result(data: Any) -> None:
|
||||
@@ -246,252 +253,45 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
exc,
|
||||
)
|
||||
|
||||
@socket_server.on("tunnel_bridge_attach")
|
||||
def _tunnel_bridge_attach(data: Any) -> Any:
|
||||
"""Placeholder operator bridge attach handler (no data channel yet)."""
|
||||
@socket_server.on("vpn_shell_open")
|
||||
def _vpn_shell_open(data: Any) -> Dict[str, Any]:
|
||||
agent_id = ""
|
||||
if isinstance(data, dict):
|
||||
agent_id = str(data.get("agent_id") or "").strip()
|
||||
elif isinstance(data, str):
|
||||
agent_id = data.strip()
|
||||
if not agent_id:
|
||||
return {"error": "agent_id_required"}
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return {"error": "invalid_payload"}
|
||||
service = _get_tunnel_service()
|
||||
if service is None:
|
||||
return {"error": "vpn_service_unavailable"}
|
||||
if not service.status(agent_id):
|
||||
return {"error": "tunnel_down"}
|
||||
|
||||
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
||||
operator_id = str(data.get("operator_id") or "").strip() or None
|
||||
if not tunnel_id:
|
||||
return {"error": "tunnel_id_required"}
|
||||
session = shell_bridge.open_session(request.sid, agent_id)
|
||||
if session is None:
|
||||
return {"error": "shell_connect_failed"}
|
||||
service.bump_activity(agent_id)
|
||||
return {"status": "ok"}
|
||||
|
||||
try:
|
||||
tunnel_service.operator_attach(tunnel_id, operator_id)
|
||||
except ValueError as exc:
|
||||
return {"error": str(exc)}
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("tunnel_bridge_attach failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "bridge_attach_failed"}
|
||||
|
||||
return {"status": "ok", "tunnel_id": tunnel_id, "operator_id": operator_id or "-"}
|
||||
|
||||
def _encode_frame(frame: TunnelFrame) -> str:
|
||||
return base64.b64encode(frame.encode()).decode("ascii")
|
||||
|
||||
def _decode_frame_payload(raw: Any) -> TunnelFrame:
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
raw_bytes = base64.b64decode(raw)
|
||||
except Exception:
|
||||
raise ValueError("invalid_frame")
|
||||
elif isinstance(raw, (bytes, bytearray)):
|
||||
raw_bytes = bytes(raw)
|
||||
@socket_server.on("vpn_shell_send")
|
||||
def _vpn_shell_send(data: Any) -> Dict[str, Any]:
|
||||
payload = None
|
||||
if isinstance(data, dict):
|
||||
payload = data.get("data")
|
||||
else:
|
||||
raise ValueError("invalid_frame")
|
||||
return decode_frame(raw_bytes)
|
||||
|
||||
@socket_server.on("tunnel_operator_send")
|
||||
def _tunnel_operator_send(data: Any) -> Any:
|
||||
"""Operator -> agent frame enqueue (placeholder queue)."""
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return {"error": "invalid_payload"}
|
||||
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
||||
frame_raw = data.get("frame")
|
||||
if not tunnel_id or frame_raw is None:
|
||||
return {"error": "tunnel_id_and_frame_required"}
|
||||
try:
|
||||
frame = _decode_frame_payload(frame_raw)
|
||||
except Exception as exc:
|
||||
return {"error": str(exc)}
|
||||
|
||||
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
bridge.operator_to_agent(frame)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("tunnel_operator_poll")
|
||||
def _tunnel_operator_poll(data: Any) -> Any:
|
||||
"""Operator polls queued frames from agent."""
|
||||
|
||||
tunnel_id = ""
|
||||
if isinstance(data, dict):
|
||||
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
||||
if not tunnel_id:
|
||||
return {"error": "tunnel_id_required"}
|
||||
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
|
||||
frames = []
|
||||
while True:
|
||||
frame = bridge.next_for_operator()
|
||||
if frame is None:
|
||||
break
|
||||
frames.append(_encode_frame(frame))
|
||||
return {"frames": frames}
|
||||
|
||||
# WebUI operator bridge namespace for browser clients
|
||||
tunnel_namespace = "/tunnel"
|
||||
_operator_sessions: Dict[str, str] = {}
|
||||
|
||||
def _current_operator() -> Optional[str]:
|
||||
username = session.get("username")
|
||||
if username:
|
||||
return str(username)
|
||||
auth_header = (request.headers.get("Authorization") or "").strip()
|
||||
token = None
|
||||
if auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
if not token:
|
||||
token = request.cookies.get("borealis_auth")
|
||||
return token or None
|
||||
|
||||
@socket_server.on("join", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_join(data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return {"error": "invalid_payload"}
|
||||
operator_id = _current_operator()
|
||||
if not operator_id:
|
||||
return {"error": "unauthorized"}
|
||||
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
||||
if not tunnel_id:
|
||||
return {"error": "tunnel_id_required"}
|
||||
bridge = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
try:
|
||||
tunnel_service.operator_attach(tunnel_id, operator_id)
|
||||
except Exception as exc:
|
||||
logger.debug("ws_tunnel_join failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "attach_failed"}
|
||||
sid = request.sid
|
||||
_operator_sessions[sid] = tunnel_id
|
||||
return {"status": "ok", "tunnel_id": tunnel_id}
|
||||
|
||||
@socket_server.on("send", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_send(data: Any) -> Any:
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.get(sid)
|
||||
if not tunnel_id:
|
||||
return {"error": "not_joined"}
|
||||
if not isinstance(data, dict):
|
||||
return {"error": "invalid_payload"}
|
||||
frame_raw = data.get("frame")
|
||||
if frame_raw is None:
|
||||
return {"error": "frame_required"}
|
||||
try:
|
||||
frame = _decode_frame_payload(frame_raw)
|
||||
except Exception:
|
||||
return {"error": "invalid_frame"}
|
||||
bridge = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
bridge.operator_to_agent(frame)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("poll", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_poll() -> Any:
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.get(sid)
|
||||
if not tunnel_id:
|
||||
return {"error": "not_joined"}
|
||||
bridge = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
frames = []
|
||||
while True:
|
||||
frame = bridge.next_for_operator()
|
||||
if frame is None:
|
||||
break
|
||||
frames.append(_encode_frame(frame))
|
||||
return {"frames": frames}
|
||||
|
||||
def _require_ps_server():
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.get(sid)
|
||||
if not tunnel_id:
|
||||
return None, None, {"error": "not_joined"}
|
||||
server = tunnel_service.ensure_protocol_server(tunnel_id)
|
||||
if server is None or not hasattr(server, "open_channel"):
|
||||
return None, tunnel_id, {"error": "ps_unsupported"}
|
||||
return server, tunnel_id, None
|
||||
|
||||
@socket_server.on("ps_open", namespace=tunnel_namespace)
|
||||
def _ws_ps_open(data: Any) -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
cols = 120
|
||||
rows = 32
|
||||
if isinstance(data, dict):
|
||||
try:
|
||||
cols = int(data.get("cols", cols))
|
||||
rows = int(data.get("rows", rows))
|
||||
except Exception:
|
||||
pass
|
||||
cols = max(20, min(cols, 300))
|
||||
rows = max(10, min(rows, 200))
|
||||
try:
|
||||
server.open_channel(cols=cols, rows=rows)
|
||||
except Exception as exc:
|
||||
logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_open_failed"}
|
||||
return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows}
|
||||
|
||||
@socket_server.on("ps_send", namespace=tunnel_namespace)
|
||||
def _ws_ps_send(data: Any) -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
if data is None:
|
||||
payload = data
|
||||
if payload is None:
|
||||
return {"error": "payload_required"}
|
||||
text = data
|
||||
if isinstance(data, dict):
|
||||
text = data.get("data")
|
||||
if text is None:
|
||||
return {"error": "payload_required"}
|
||||
try:
|
||||
server.send_input(str(text))
|
||||
except Exception as exc:
|
||||
logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_send_failed"}
|
||||
shell_bridge.send(request.sid, str(payload))
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("ps_resize", namespace=tunnel_namespace)
|
||||
def _ws_ps_resize(data: Any) -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
cols = None
|
||||
rows = None
|
||||
if isinstance(data, dict):
|
||||
cols = data.get("cols")
|
||||
rows = data.get("rows")
|
||||
try:
|
||||
cols_int = int(cols) if cols is not None else 120
|
||||
rows_int = int(rows) if rows is not None else 32
|
||||
cols_int = max(20, min(cols_int, 300))
|
||||
rows_int = max(10, min(rows_int, 200))
|
||||
server.send_resize(cols_int, rows_int)
|
||||
return {"status": "ok", "cols": cols_int, "rows": rows_int}
|
||||
except Exception as exc:
|
||||
logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_resize_failed"}
|
||||
@socket_server.on("vpn_shell_close")
|
||||
def _vpn_shell_close() -> Dict[str, Any]:
|
||||
shell_bridge.close(request.sid)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("ps_poll", namespace=tunnel_namespace)
|
||||
def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
try:
|
||||
output = server.drain_output()
|
||||
status = server.status()
|
||||
return {"output": output, "status": status}
|
||||
except Exception as exc:
|
||||
logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_poll_failed"}
|
||||
|
||||
@socket_server.on("disconnect", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_disconnect():
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.pop(sid, None)
|
||||
if tunnel_id and tunnel_id not in _operator_sessions.values():
|
||||
try:
|
||||
tunnel_service.stop_tunnel(tunnel_id, reason="operator_socket_disconnect")
|
||||
except Exception as exc:
|
||||
logger.debug("ws_tunnel_disconnect stop_tunnel failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
@socket_server.on("disconnect")
|
||||
def _ws_disconnect() -> None:
|
||||
shell_bridge.close(request.sid)
|
||||
|
||||
127
Data/Engine/services/WebSocket/vpn_shell.py
Normal file
127
Data/Engine/services/WebSocket/vpn_shell.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\WebSocket\vpn_shell.py
|
||||
# Description: Socket.IO handlers bridging UI shell to agent TCP server over WireGuard.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""WireGuard VPN PowerShell bridge (Engine side)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
def _b64encode(data: bytes) -> str:
|
||||
return base64.b64encode(data).decode("ascii").strip()
|
||||
|
||||
|
||||
def _b64decode(value: str) -> bytes:
|
||||
return base64.b64decode(value.encode("ascii"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShellSession:
|
||||
sid: str
|
||||
agent_id: str
|
||||
socketio: Any
|
||||
tcp: socket.socket
|
||||
_reader: Optional[threading.Thread] = None
|
||||
|
||||
def start_reader(self) -> None:
|
||||
t = threading.Thread(target=self._read_loop, daemon=True)
|
||||
t.start()
|
||||
self._reader = t
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
buffer = b""
|
||||
try:
|
||||
while True:
|
||||
data = self.tcp.recv(4096)
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(line.decode("utf-8"))
|
||||
except Exception:
|
||||
continue
|
||||
if msg.get("type") == "stdout":
|
||||
payload = msg.get("data") or ""
|
||||
try:
|
||||
decoded = _b64decode(str(payload)).decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
decoded = ""
|
||||
self.socketio.emit("vpn_shell_output", {"data": decoded}, to=self.sid)
|
||||
finally:
|
||||
self.socketio.emit("vpn_shell_closed", {"agent_id": self.agent_id}, to=self.sid)
|
||||
try:
|
||||
self.tcp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def send(self, payload: str) -> None:
|
||||
data = json.dumps({"type": "stdin", "data": _b64encode(payload.encode("utf-8"))})
|
||||
self.tcp.sendall(data.encode("utf-8") + b"\n")
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
data = json.dumps({"type": "close"})
|
||||
self.tcp.sendall(data.encode("utf-8") + b"\n")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.tcp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class VpnShellBridge:
|
||||
def __init__(self, socketio, context) -> None:
|
||||
self.socketio = socketio
|
||||
self.context = context
|
||||
self._sessions: Dict[str, ShellSession] = {}
|
||||
self.logger = context.logger.getChild("vpn_shell")
|
||||
|
||||
def open_session(self, sid: str, agent_id: str) -> Optional[ShellSession]:
|
||||
service = getattr(self.context, "vpn_tunnel_service", None)
|
||||
if service is None:
|
||||
return None
|
||||
status = service.status(agent_id)
|
||||
if not status:
|
||||
return None
|
||||
host = str(status.get("virtual_ip") or "").split("/")[0]
|
||||
port = int(self.context.wireguard_shell_port)
|
||||
try:
|
||||
tcp = socket.create_connection((host, port), timeout=5)
|
||||
except Exception:
|
||||
self.logger.debug("Failed to connect vpn shell to %s:%s", host, port, exc_info=True)
|
||||
return None
|
||||
session = ShellSession(sid=sid, agent_id=agent_id, socketio=self.socketio, tcp=tcp)
|
||||
self._sessions[sid] = session
|
||||
session.start_reader()
|
||||
return session
|
||||
|
||||
def send(self, sid: str, payload: str) -> None:
|
||||
session = self._sessions.get(sid)
|
||||
if not session:
|
||||
return
|
||||
session.send(payload)
|
||||
service = getattr(self.context, "vpn_tunnel_service", None)
|
||||
if service:
|
||||
service.bump_activity(session.agent_id)
|
||||
|
||||
def close(self, sid: str) -> None:
|
||||
session = self._sessions.pop(sid, None)
|
||||
if not session:
|
||||
return
|
||||
session.close()
|
||||
|
||||
Reference in New Issue
Block a user