mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-19 00:35:48 -07:00
Overhaul of VPN Codebase
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\WebSocket\__init__.py
|
||||
# Description: Socket.IO handlers for Engine runtime quick job updates and realtime notifications.
|
||||
# Description: Socket.IO handlers for Engine runtime quick job updates and VPN shell bridging.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
@@ -8,24 +8,20 @@
|
||||
"""WebSocket service registration for the Borealis Engine runtime."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import sqlite3
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from flask import session, request
|
||||
from flask import request
|
||||
from flask_socketio import SocketIO
|
||||
|
||||
from ...database import initialise_engine_database
|
||||
from ...security import signing
|
||||
from ...server import EngineContext
|
||||
from .Agent.reverse_tunnel_orchestrator import (
|
||||
ReverseTunnelService,
|
||||
TunnelBridge,
|
||||
decode_frame,
|
||||
TunnelFrame,
|
||||
)
|
||||
from ..VPN import VpnTunnelService
|
||||
from .vpn_shell import VpnShellBridge
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
@@ -70,20 +66,31 @@ class EngineRealtimeAdapters:
|
||||
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
"""Register Socket.IO event handlers for the Engine runtime."""
|
||||
|
||||
from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load
|
||||
|
||||
adapters = EngineRealtimeAdapters(context)
|
||||
logger = context.logger.getChild("realtime.quick_jobs")
|
||||
tunnel_service = getattr(context, "reverse_tunnel_service", None)
|
||||
if tunnel_service is None:
|
||||
tunnel_service = ReverseTunnelService(
|
||||
context,
|
||||
signer=None,
|
||||
shell_bridge = VpnShellBridge(socket_server, context)
|
||||
|
||||
def _get_tunnel_service() -> Optional[VpnTunnelService]:
|
||||
service = getattr(context, "vpn_tunnel_service", None)
|
||||
if service is not None:
|
||||
return service
|
||||
manager = getattr(context, "wireguard_server_manager", None)
|
||||
if manager is None:
|
||||
return None
|
||||
try:
|
||||
signer = signing.load_signer()
|
||||
except Exception:
|
||||
signer = None
|
||||
service = VpnTunnelService(
|
||||
context=context,
|
||||
wireguard_manager=manager,
|
||||
db_conn_factory=adapters.db_conn_factory,
|
||||
socketio=socket_server,
|
||||
service_log=adapters.service_log,
|
||||
signer=signer,
|
||||
)
|
||||
tunnel_service.start()
|
||||
setattr(context, "reverse_tunnel_service", tunnel_service)
|
||||
setattr(context, "vpn_tunnel_service", service)
|
||||
return service
|
||||
|
||||
@socket_server.on("quick_job_result")
|
||||
def _handle_quick_job_result(data: Any) -> None:
|
||||
@@ -246,252 +253,45 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
exc,
|
||||
)
|
||||
|
||||
@socket_server.on("tunnel_bridge_attach")
|
||||
def _tunnel_bridge_attach(data: Any) -> Any:
|
||||
"""Placeholder operator bridge attach handler (no data channel yet)."""
|
||||
@socket_server.on("vpn_shell_open")
|
||||
def _vpn_shell_open(data: Any) -> Dict[str, Any]:
|
||||
agent_id = ""
|
||||
if isinstance(data, dict):
|
||||
agent_id = str(data.get("agent_id") or "").strip()
|
||||
elif isinstance(data, str):
|
||||
agent_id = data.strip()
|
||||
if not agent_id:
|
||||
return {"error": "agent_id_required"}
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return {"error": "invalid_payload"}
|
||||
service = _get_tunnel_service()
|
||||
if service is None:
|
||||
return {"error": "vpn_service_unavailable"}
|
||||
if not service.status(agent_id):
|
||||
return {"error": "tunnel_down"}
|
||||
|
||||
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
||||
operator_id = str(data.get("operator_id") or "").strip() or None
|
||||
if not tunnel_id:
|
||||
return {"error": "tunnel_id_required"}
|
||||
session = shell_bridge.open_session(request.sid, agent_id)
|
||||
if session is None:
|
||||
return {"error": "shell_connect_failed"}
|
||||
service.bump_activity(agent_id)
|
||||
return {"status": "ok"}
|
||||
|
||||
try:
|
||||
tunnel_service.operator_attach(tunnel_id, operator_id)
|
||||
except ValueError as exc:
|
||||
return {"error": str(exc)}
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("tunnel_bridge_attach failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "bridge_attach_failed"}
|
||||
|
||||
return {"status": "ok", "tunnel_id": tunnel_id, "operator_id": operator_id or "-"}
|
||||
|
||||
def _encode_frame(frame: TunnelFrame) -> str:
|
||||
return base64.b64encode(frame.encode()).decode("ascii")
|
||||
|
||||
def _decode_frame_payload(raw: Any) -> TunnelFrame:
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
raw_bytes = base64.b64decode(raw)
|
||||
except Exception:
|
||||
raise ValueError("invalid_frame")
|
||||
elif isinstance(raw, (bytes, bytearray)):
|
||||
raw_bytes = bytes(raw)
|
||||
@socket_server.on("vpn_shell_send")
|
||||
def _vpn_shell_send(data: Any) -> Dict[str, Any]:
|
||||
payload = None
|
||||
if isinstance(data, dict):
|
||||
payload = data.get("data")
|
||||
else:
|
||||
raise ValueError("invalid_frame")
|
||||
return decode_frame(raw_bytes)
|
||||
|
||||
@socket_server.on("tunnel_operator_send")
|
||||
def _tunnel_operator_send(data: Any) -> Any:
|
||||
"""Operator -> agent frame enqueue (placeholder queue)."""
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return {"error": "invalid_payload"}
|
||||
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
||||
frame_raw = data.get("frame")
|
||||
if not tunnel_id or frame_raw is None:
|
||||
return {"error": "tunnel_id_and_frame_required"}
|
||||
try:
|
||||
frame = _decode_frame_payload(frame_raw)
|
||||
except Exception as exc:
|
||||
return {"error": str(exc)}
|
||||
|
||||
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
bridge.operator_to_agent(frame)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("tunnel_operator_poll")
|
||||
def _tunnel_operator_poll(data: Any) -> Any:
|
||||
"""Operator polls queued frames from agent."""
|
||||
|
||||
tunnel_id = ""
|
||||
if isinstance(data, dict):
|
||||
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
||||
if not tunnel_id:
|
||||
return {"error": "tunnel_id_required"}
|
||||
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
|
||||
frames = []
|
||||
while True:
|
||||
frame = bridge.next_for_operator()
|
||||
if frame is None:
|
||||
break
|
||||
frames.append(_encode_frame(frame))
|
||||
return {"frames": frames}
|
||||
|
||||
# WebUI operator bridge namespace for browser clients
|
||||
tunnel_namespace = "/tunnel"
|
||||
_operator_sessions: Dict[str, str] = {}
|
||||
|
||||
def _current_operator() -> Optional[str]:
|
||||
username = session.get("username")
|
||||
if username:
|
||||
return str(username)
|
||||
auth_header = (request.headers.get("Authorization") or "").strip()
|
||||
token = None
|
||||
if auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
if not token:
|
||||
token = request.cookies.get("borealis_auth")
|
||||
return token or None
|
||||
|
||||
@socket_server.on("join", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_join(data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return {"error": "invalid_payload"}
|
||||
operator_id = _current_operator()
|
||||
if not operator_id:
|
||||
return {"error": "unauthorized"}
|
||||
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
||||
if not tunnel_id:
|
||||
return {"error": "tunnel_id_required"}
|
||||
bridge = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
try:
|
||||
tunnel_service.operator_attach(tunnel_id, operator_id)
|
||||
except Exception as exc:
|
||||
logger.debug("ws_tunnel_join failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "attach_failed"}
|
||||
sid = request.sid
|
||||
_operator_sessions[sid] = tunnel_id
|
||||
return {"status": "ok", "tunnel_id": tunnel_id}
|
||||
|
||||
@socket_server.on("send", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_send(data: Any) -> Any:
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.get(sid)
|
||||
if not tunnel_id:
|
||||
return {"error": "not_joined"}
|
||||
if not isinstance(data, dict):
|
||||
return {"error": "invalid_payload"}
|
||||
frame_raw = data.get("frame")
|
||||
if frame_raw is None:
|
||||
return {"error": "frame_required"}
|
||||
try:
|
||||
frame = _decode_frame_payload(frame_raw)
|
||||
except Exception:
|
||||
return {"error": "invalid_frame"}
|
||||
bridge = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
bridge.operator_to_agent(frame)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("poll", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_poll() -> Any:
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.get(sid)
|
||||
if not tunnel_id:
|
||||
return {"error": "not_joined"}
|
||||
bridge = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
frames = []
|
||||
while True:
|
||||
frame = bridge.next_for_operator()
|
||||
if frame is None:
|
||||
break
|
||||
frames.append(_encode_frame(frame))
|
||||
return {"frames": frames}
|
||||
|
||||
def _require_ps_server():
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.get(sid)
|
||||
if not tunnel_id:
|
||||
return None, None, {"error": "not_joined"}
|
||||
server = tunnel_service.ensure_protocol_server(tunnel_id)
|
||||
if server is None or not hasattr(server, "open_channel"):
|
||||
return None, tunnel_id, {"error": "ps_unsupported"}
|
||||
return server, tunnel_id, None
|
||||
|
||||
@socket_server.on("ps_open", namespace=tunnel_namespace)
|
||||
def _ws_ps_open(data: Any) -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
cols = 120
|
||||
rows = 32
|
||||
if isinstance(data, dict):
|
||||
try:
|
||||
cols = int(data.get("cols", cols))
|
||||
rows = int(data.get("rows", rows))
|
||||
except Exception:
|
||||
pass
|
||||
cols = max(20, min(cols, 300))
|
||||
rows = max(10, min(rows, 200))
|
||||
try:
|
||||
server.open_channel(cols=cols, rows=rows)
|
||||
except Exception as exc:
|
||||
logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_open_failed"}
|
||||
return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows}
|
||||
|
||||
@socket_server.on("ps_send", namespace=tunnel_namespace)
|
||||
def _ws_ps_send(data: Any) -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
if data is None:
|
||||
payload = data
|
||||
if payload is None:
|
||||
return {"error": "payload_required"}
|
||||
text = data
|
||||
if isinstance(data, dict):
|
||||
text = data.get("data")
|
||||
if text is None:
|
||||
return {"error": "payload_required"}
|
||||
try:
|
||||
server.send_input(str(text))
|
||||
except Exception as exc:
|
||||
logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_send_failed"}
|
||||
shell_bridge.send(request.sid, str(payload))
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("ps_resize", namespace=tunnel_namespace)
|
||||
def _ws_ps_resize(data: Any) -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
cols = None
|
||||
rows = None
|
||||
if isinstance(data, dict):
|
||||
cols = data.get("cols")
|
||||
rows = data.get("rows")
|
||||
try:
|
||||
cols_int = int(cols) if cols is not None else 120
|
||||
rows_int = int(rows) if rows is not None else 32
|
||||
cols_int = max(20, min(cols_int, 300))
|
||||
rows_int = max(10, min(rows_int, 200))
|
||||
server.send_resize(cols_int, rows_int)
|
||||
return {"status": "ok", "cols": cols_int, "rows": rows_int}
|
||||
except Exception as exc:
|
||||
logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_resize_failed"}
|
||||
@socket_server.on("vpn_shell_close")
|
||||
def _vpn_shell_close() -> Dict[str, Any]:
|
||||
shell_bridge.close(request.sid)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("ps_poll", namespace=tunnel_namespace)
|
||||
def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
try:
|
||||
output = server.drain_output()
|
||||
status = server.status()
|
||||
return {"output": output, "status": status}
|
||||
except Exception as exc:
|
||||
logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_poll_failed"}
|
||||
|
||||
@socket_server.on("disconnect", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_disconnect():
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.pop(sid, None)
|
||||
if tunnel_id and tunnel_id not in _operator_sessions.values():
|
||||
try:
|
||||
tunnel_service.stop_tunnel(tunnel_id, reason="operator_socket_disconnect")
|
||||
except Exception as exc:
|
||||
logger.debug("ws_tunnel_disconnect stop_tunnel failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
@socket_server.on("disconnect")
|
||||
def _ws_disconnect() -> None:
|
||||
shell_bridge.close(request.sid)
|
||||
|
||||
Reference in New Issue
Block a user