Overhaul of VPN Codebase

This commit is contained in:
2025-12-18 01:35:03 -07:00
parent 2f81061a1b
commit 6ceb59f717
56 changed files with 1786 additions and 4778 deletions

View File

@@ -1,6 +1,6 @@
# ======================================================
# Data\Engine\services\WebSocket\__init__.py
# Description: Socket.IO handlers for Engine runtime quick job updates and realtime notifications.
# Description: Socket.IO handlers for Engine runtime quick job updates and VPN shell bridging.
#
# API Endpoints (if applicable): None
# ======================================================
@@ -8,24 +8,20 @@
"""WebSocket service registration for the Borealis Engine runtime."""
from __future__ import annotations
import base64
import sqlite3
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, Optional
from flask import session, request
from flask import request
from flask_socketio import SocketIO
from ...database import initialise_engine_database
from ...security import signing
from ...server import EngineContext
from .Agent.reverse_tunnel_orchestrator import (
ReverseTunnelService,
TunnelBridge,
decode_frame,
TunnelFrame,
)
from ..VPN import VpnTunnelService
from .vpn_shell import VpnShellBridge
def _now_ts() -> int:
@@ -70,20 +66,31 @@ class EngineRealtimeAdapters:
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
"""Register Socket.IO event handlers for the Engine runtime."""
from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load
adapters = EngineRealtimeAdapters(context)
logger = context.logger.getChild("realtime.quick_jobs")
tunnel_service = getattr(context, "reverse_tunnel_service", None)
if tunnel_service is None:
tunnel_service = ReverseTunnelService(
context,
signer=None,
shell_bridge = VpnShellBridge(socket_server, context)
def _get_tunnel_service() -> Optional[VpnTunnelService]:
service = getattr(context, "vpn_tunnel_service", None)
if service is not None:
return service
manager = getattr(context, "wireguard_server_manager", None)
if manager is None:
return None
try:
signer = signing.load_signer()
except Exception:
signer = None
service = VpnTunnelService(
context=context,
wireguard_manager=manager,
db_conn_factory=adapters.db_conn_factory,
socketio=socket_server,
service_log=adapters.service_log,
signer=signer,
)
tunnel_service.start()
setattr(context, "reverse_tunnel_service", tunnel_service)
setattr(context, "vpn_tunnel_service", service)
return service
@socket_server.on("quick_job_result")
def _handle_quick_job_result(data: Any) -> None:
@@ -246,252 +253,45 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
exc,
)
@socket_server.on("tunnel_bridge_attach")
def _tunnel_bridge_attach(data: Any) -> Any:
"""Placeholder operator bridge attach handler (no data channel yet)."""
@socket_server.on("vpn_shell_open")
def _vpn_shell_open(data: Any) -> Dict[str, Any]:
agent_id = ""
if isinstance(data, dict):
agent_id = str(data.get("agent_id") or "").strip()
elif isinstance(data, str):
agent_id = data.strip()
if not agent_id:
return {"error": "agent_id_required"}
if not isinstance(data, dict):
return {"error": "invalid_payload"}
service = _get_tunnel_service()
if service is None:
return {"error": "vpn_service_unavailable"}
if not service.status(agent_id):
return {"error": "tunnel_down"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
operator_id = str(data.get("operator_id") or "").strip() or None
if not tunnel_id:
return {"error": "tunnel_id_required"}
session = shell_bridge.open_session(request.sid, agent_id)
if session is None:
return {"error": "shell_connect_failed"}
service.bump_activity(agent_id)
return {"status": "ok"}
try:
tunnel_service.operator_attach(tunnel_id, operator_id)
except ValueError as exc:
return {"error": str(exc)}
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("tunnel_bridge_attach failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "bridge_attach_failed"}
return {"status": "ok", "tunnel_id": tunnel_id, "operator_id": operator_id or "-"}
def _encode_frame(frame: TunnelFrame) -> str:
return base64.b64encode(frame.encode()).decode("ascii")
def _decode_frame_payload(raw: Any) -> TunnelFrame:
if isinstance(raw, str):
try:
raw_bytes = base64.b64decode(raw)
except Exception:
raise ValueError("invalid_frame")
elif isinstance(raw, (bytes, bytearray)):
raw_bytes = bytes(raw)
@socket_server.on("vpn_shell_send")
def _vpn_shell_send(data: Any) -> Dict[str, Any]:
payload = None
if isinstance(data, dict):
payload = data.get("data")
else:
raise ValueError("invalid_frame")
return decode_frame(raw_bytes)
@socket_server.on("tunnel_operator_send")
def _tunnel_operator_send(data: Any) -> Any:
"""Operator -> agent frame enqueue (placeholder queue)."""
if not isinstance(data, dict):
return {"error": "invalid_payload"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
frame_raw = data.get("frame")
if not tunnel_id or frame_raw is None:
return {"error": "tunnel_id_and_frame_required"}
try:
frame = _decode_frame_payload(frame_raw)
except Exception as exc:
return {"error": str(exc)}
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
bridge.operator_to_agent(frame)
return {"status": "ok"}
@socket_server.on("tunnel_operator_poll")
def _tunnel_operator_poll(data: Any) -> Any:
"""Operator polls queued frames from agent."""
tunnel_id = ""
if isinstance(data, dict):
tunnel_id = str(data.get("tunnel_id") or "").strip()
if not tunnel_id:
return {"error": "tunnel_id_required"}
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
frames = []
while True:
frame = bridge.next_for_operator()
if frame is None:
break
frames.append(_encode_frame(frame))
return {"frames": frames}
# WebUI operator bridge namespace for browser clients
tunnel_namespace = "/tunnel"
_operator_sessions: Dict[str, str] = {}
def _current_operator() -> Optional[str]:
username = session.get("username")
if username:
return str(username)
auth_header = (request.headers.get("Authorization") or "").strip()
token = None
if auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
if not token:
token = request.cookies.get("borealis_auth")
return token or None
@socket_server.on("join", namespace=tunnel_namespace)
def _ws_tunnel_join(data: Any) -> Any:
if not isinstance(data, dict):
return {"error": "invalid_payload"}
operator_id = _current_operator()
if not operator_id:
return {"error": "unauthorized"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
if not tunnel_id:
return {"error": "tunnel_id_required"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
try:
tunnel_service.operator_attach(tunnel_id, operator_id)
except Exception as exc:
logger.debug("ws_tunnel_join failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "attach_failed"}
sid = request.sid
_operator_sessions[sid] = tunnel_id
return {"status": "ok", "tunnel_id": tunnel_id}
@socket_server.on("send", namespace=tunnel_namespace)
def _ws_tunnel_send(data: Any) -> Any:
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return {"error": "not_joined"}
if not isinstance(data, dict):
return {"error": "invalid_payload"}
frame_raw = data.get("frame")
if frame_raw is None:
return {"error": "frame_required"}
try:
frame = _decode_frame_payload(frame_raw)
except Exception:
return {"error": "invalid_frame"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
bridge.operator_to_agent(frame)
return {"status": "ok"}
@socket_server.on("poll", namespace=tunnel_namespace)
def _ws_tunnel_poll() -> Any:
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return {"error": "not_joined"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
frames = []
while True:
frame = bridge.next_for_operator()
if frame is None:
break
frames.append(_encode_frame(frame))
return {"frames": frames}
def _require_ps_server():
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return None, None, {"error": "not_joined"}
server = tunnel_service.ensure_protocol_server(tunnel_id)
if server is None or not hasattr(server, "open_channel"):
return None, tunnel_id, {"error": "ps_unsupported"}
return server, tunnel_id, None
@socket_server.on("ps_open", namespace=tunnel_namespace)
def _ws_ps_open(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = 120
rows = 32
if isinstance(data, dict):
try:
cols = int(data.get("cols", cols))
rows = int(data.get("rows", rows))
except Exception:
pass
cols = max(20, min(cols, 300))
rows = max(10, min(rows, 200))
try:
server.open_channel(cols=cols, rows=rows)
except Exception as exc:
logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_open_failed"}
return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows}
@socket_server.on("ps_send", namespace=tunnel_namespace)
def _ws_ps_send(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
if data is None:
payload = data
if payload is None:
return {"error": "payload_required"}
text = data
if isinstance(data, dict):
text = data.get("data")
if text is None:
return {"error": "payload_required"}
try:
server.send_input(str(text))
except Exception as exc:
logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_send_failed"}
shell_bridge.send(request.sid, str(payload))
return {"status": "ok"}
@socket_server.on("ps_resize", namespace=tunnel_namespace)
def _ws_ps_resize(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = None
rows = None
if isinstance(data, dict):
cols = data.get("cols")
rows = data.get("rows")
try:
cols_int = int(cols) if cols is not None else 120
rows_int = int(rows) if rows is not None else 32
cols_int = max(20, min(cols_int, 300))
rows_int = max(10, min(rows_int, 200))
server.send_resize(cols_int, rows_int)
return {"status": "ok", "cols": cols_int, "rows": rows_int}
except Exception as exc:
logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_resize_failed"}
@socket_server.on("vpn_shell_close")
def _vpn_shell_close() -> Dict[str, Any]:
shell_bridge.close(request.sid)
return {"status": "ok"}
@socket_server.on("ps_poll", namespace=tunnel_namespace)
def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
try:
output = server.drain_output()
status = server.status()
return {"output": output, "status": status}
except Exception as exc:
logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_poll_failed"}
@socket_server.on("disconnect", namespace=tunnel_namespace)
def _ws_tunnel_disconnect():
sid = request.sid
tunnel_id = _operator_sessions.pop(sid, None)
if tunnel_id and tunnel_id not in _operator_sessions.values():
try:
tunnel_service.stop_tunnel(tunnel_id, reason="operator_socket_disconnect")
except Exception as exc:
logger.debug("ws_tunnel_disconnect stop_tunnel failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
@socket_server.on("disconnect")
def _ws_disconnect() -> None:
shell_bridge.close(request.sid)