mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2026-02-04 12:40:32 -07:00
510 lines
19 KiB
Python
510 lines
19 KiB
Python
# ======================================================
|
|
# Data\Engine\services\WebSocket\__init__.py
|
|
# Description: Socket.IO handlers for Engine runtime quick job updates and VPN shell bridging.
|
|
#
|
|
# API Endpoints (if applicable): None
|
|
# ======================================================
|
|
|
|
"""WebSocket service registration for the Borealis Engine runtime."""
|
|
from __future__ import annotations
|
|
|
|
import sqlite3
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
from flask import request
|
|
from flask_socketio import SocketIO
|
|
|
|
from ...database import initialise_engine_database
|
|
from ...security import signing
|
|
from ...server import EngineContext
|
|
from ..VPN import WireGuardServerConfig, WireGuardServerManager, VpnTunnelService
|
|
from .vpn_shell import VpnShellBridge
|
|
|
|
|
|
def _now_ts() -> int:
|
|
return int(time.time())
|
|
|
|
|
|
def _normalize_text(value: Any) -> str:
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, bytes):
|
|
try:
|
|
return value.decode("utf-8")
|
|
except Exception:
|
|
return value.decode("utf-8", errors="replace")
|
|
return str(value)
|
|
|
|
|
|
@dataclass
|
|
class EngineRealtimeAdapters:
|
|
context: EngineContext
|
|
db_conn_factory: Callable[[], sqlite3.Connection] = field(init=False)
|
|
service_log: Callable[[str, str, Optional[str]], None] = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load
|
|
|
|
initialise_engine_database(self.context.database_path, logger=self.context.logger)
|
|
self.db_conn_factory = _make_db_conn_factory(self.context.database_path)
|
|
|
|
log_file = str(
|
|
self.context.config.get("log_file")
|
|
or self.context.config.get("LOG_FILE")
|
|
or ""
|
|
).strip()
|
|
if log_file:
|
|
base = Path(log_file).resolve().parent
|
|
else:
|
|
base = Path(self.context.database_path).resolve().parent
|
|
self.service_log = _make_service_logger(base, self.context.logger)
|
|
|
|
|
|
class AgentSocketRegistry:
|
|
def __init__(self, socketio: SocketIO, logger) -> None:
|
|
self.socketio = socketio
|
|
self.logger = logger
|
|
self._sid_by_agent: Dict[str, str] = {}
|
|
self._agent_by_sid: Dict[str, str] = {}
|
|
|
|
def register(self, agent_id: str, sid: str) -> None:
|
|
if not agent_id or not sid:
|
|
return
|
|
previous = self._sid_by_agent.get(agent_id)
|
|
if previous and previous != sid:
|
|
self._agent_by_sid.pop(previous, None)
|
|
self._sid_by_agent[agent_id] = sid
|
|
self._agent_by_sid[sid] = agent_id
|
|
|
|
def unregister(self, sid: str) -> Optional[str]:
|
|
agent_id = self._agent_by_sid.pop(sid, None)
|
|
if agent_id and self._sid_by_agent.get(agent_id) == sid:
|
|
self._sid_by_agent.pop(agent_id, None)
|
|
return agent_id
|
|
|
|
def is_registered(self, agent_id: str) -> bool:
|
|
return bool(self._sid_by_agent.get(agent_id))
|
|
|
|
def emit(self, agent_id: str, event: str, payload: Any) -> bool:
|
|
sid = self._sid_by_agent.get(agent_id)
|
|
if not sid:
|
|
return False
|
|
try:
|
|
self.socketio.emit(event, payload, to=sid)
|
|
return True
|
|
except Exception:
|
|
self.logger.debug("Failed to emit %s to agent_id=%s", event, agent_id, exc_info=True)
|
|
return False
|
|
|
|
|
|
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
|
"""Register Socket.IO event handlers for the Engine runtime."""
|
|
|
|
adapters = EngineRealtimeAdapters(context)
|
|
logger = context.logger.getChild("realtime.quick_jobs")
|
|
agent_logger = context.logger.getChild("realtime.agents")
|
|
shell_bridge = VpnShellBridge(socket_server, context, adapters.service_log)
|
|
agent_registry = AgentSocketRegistry(socket_server, agent_logger)
|
|
setattr(context, "agent_socket_registry", agent_registry)
|
|
|
|
def _emit_agent_event(agent_id: str, event: str, payload: Any) -> bool:
|
|
return agent_registry.emit(agent_id, event, payload)
|
|
|
|
setattr(context, "emit_agent_event", _emit_agent_event)
|
|
|
|
def _get_tunnel_service() -> Optional[VpnTunnelService]:
|
|
service = getattr(context, "vpn_tunnel_service", None)
|
|
if service is not None:
|
|
return service
|
|
manager = getattr(context, "wireguard_server_manager", None)
|
|
if manager is None:
|
|
try:
|
|
manager = WireGuardServerManager(
|
|
WireGuardServerConfig(
|
|
port=context.wireguard_port,
|
|
engine_virtual_ip=context.wireguard_engine_virtual_ip,
|
|
peer_network=context.wireguard_peer_network,
|
|
private_key_path=Path(context.wireguard_server_private_key_path),
|
|
public_key_path=Path(context.wireguard_server_public_key_path),
|
|
acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows),
|
|
log_path=Path(context.vpn_tunnel_log_path),
|
|
)
|
|
)
|
|
setattr(context, "wireguard_server_manager", manager)
|
|
except Exception:
|
|
context.logger.error("Failed to initialize WireGuard server manager on demand.", exc_info=True)
|
|
return None
|
|
try:
|
|
signer = signing.load_signer()
|
|
except Exception:
|
|
signer = None
|
|
service = VpnTunnelService(
|
|
context=context,
|
|
wireguard_manager=manager,
|
|
db_conn_factory=adapters.db_conn_factory,
|
|
socketio=socket_server,
|
|
service_log=adapters.service_log,
|
|
signer=signer,
|
|
)
|
|
setattr(context, "vpn_tunnel_service", service)
|
|
return service
|
|
|
|
def _tunnel_log(message: str, *, level: str = "INFO") -> None:
|
|
try:
|
|
adapters.service_log("VPN_Tunnel/tunnel", message, level=level)
|
|
except Exception:
|
|
agent_logger.debug("vpn_tunnel service log write failed", exc_info=True)
|
|
|
|
def _shell_log(message: str, *, level: str = "INFO") -> None:
|
|
try:
|
|
adapters.service_log("VPN_Tunnel/remote_shell", message, level=level)
|
|
except Exception:
|
|
agent_logger.debug("vpn_shell service log write failed", exc_info=True)
|
|
|
|
def _remote_addr() -> str:
|
|
forwarded = (request.headers.get("X-Forwarded-For") or "").strip()
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
return (request.remote_addr or "").strip()
|
|
|
|
@socket_server.on("quick_job_result")
|
|
def _handle_quick_job_result(data: Any) -> None:
|
|
if not isinstance(data, dict):
|
|
logger.debug("quick_job_result payload ignored (non-dict): %r", data)
|
|
return
|
|
|
|
job_id_raw = data.get("job_id")
|
|
try:
|
|
job_id = int(job_id_raw)
|
|
except (TypeError, ValueError):
|
|
logger.debug("quick_job_result missing valid job_id: %r", job_id_raw)
|
|
return
|
|
|
|
status = str(data.get("status") or "").strip() or "Failed"
|
|
stdout = _normalize_text(data.get("stdout"))
|
|
stderr = _normalize_text(data.get("stderr"))
|
|
|
|
conn: Optional[sqlite3.Connection] = None
|
|
cursor = None
|
|
broadcast_payload: Optional[Dict[str, Any]] = None
|
|
|
|
ctx_payload = data.get("context")
|
|
context_info: Optional[Dict[str, Any]] = ctx_payload if isinstance(ctx_payload, dict) else None
|
|
|
|
try:
|
|
conn = adapters.db_conn_factory()
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE activity_history SET status=?, stdout=?, stderr=? WHERE id=?",
|
|
(status, stdout, stderr, job_id),
|
|
)
|
|
if cursor.rowcount == 0:
|
|
logger.debug("quick_job_result missing activity_history row for job_id=%s", job_id)
|
|
conn.commit()
|
|
|
|
try:
|
|
cursor.execute(
|
|
"SELECT run_id FROM scheduled_job_run_activity WHERE activity_id=?",
|
|
(job_id,),
|
|
)
|
|
link = cursor.fetchone()
|
|
except sqlite3.Error:
|
|
link = None
|
|
|
|
run_id: Optional[int] = None
|
|
scheduled_ts_ctx: Optional[int] = None
|
|
if link:
|
|
try:
|
|
run_id = int(link[0])
|
|
except Exception:
|
|
run_id = None
|
|
|
|
if run_id is None and context_info:
|
|
ctx_run = context_info.get("scheduled_job_run_id") or context_info.get("run_id")
|
|
try:
|
|
if ctx_run is not None:
|
|
run_id = int(ctx_run)
|
|
except (TypeError, ValueError):
|
|
run_id = None
|
|
try:
|
|
if context_info.get("scheduled_ts") is not None:
|
|
scheduled_ts_ctx = int(context_info.get("scheduled_ts"))
|
|
except (TypeError, ValueError):
|
|
scheduled_ts_ctx = None
|
|
|
|
if run_id is not None:
|
|
ts_now = _now_ts()
|
|
try:
|
|
if status.lower() == "running":
|
|
cursor.execute(
|
|
"UPDATE scheduled_job_runs SET status='Running', updated_at=? WHERE id=?",
|
|
(ts_now, run_id),
|
|
)
|
|
else:
|
|
cursor.execute(
|
|
"""
|
|
UPDATE scheduled_job_runs
|
|
SET status=?,
|
|
finished_ts=COALESCE(finished_ts, ?),
|
|
updated_at=?
|
|
WHERE id=?
|
|
""",
|
|
(status, ts_now, ts_now, run_id),
|
|
)
|
|
if scheduled_ts_ctx is not None:
|
|
cursor.execute(
|
|
"UPDATE scheduled_job_runs SET scheduled_ts=COALESCE(scheduled_ts, ?) WHERE id=?",
|
|
(scheduled_ts_ctx, run_id),
|
|
)
|
|
conn.commit()
|
|
adapters.service_log(
|
|
"scheduled_jobs",
|
|
f"scheduled run update run_id={run_id} activity_id={job_id} status={status}",
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive guard
|
|
logger.debug(
|
|
"quick_job_result failed to update scheduled_job_runs for job_id=%s run_id=%s: %s",
|
|
job_id,
|
|
run_id,
|
|
exc,
|
|
)
|
|
elif context_info:
|
|
adapters.service_log(
|
|
"scheduled_jobs",
|
|
f"scheduled run update skipped (no run_id) activity_id={job_id} status={status} context={context_info}",
|
|
level="WARNING",
|
|
)
|
|
|
|
try:
|
|
cursor.execute(
|
|
"SELECT id, hostname, status FROM activity_history WHERE id=?",
|
|
(job_id,),
|
|
)
|
|
row = cursor.fetchone()
|
|
except sqlite3.Error:
|
|
row = None
|
|
|
|
if row:
|
|
hostname = (row[1] or "").strip()
|
|
if hostname:
|
|
broadcast_payload = {
|
|
"activity_id": int(row[0]),
|
|
"hostname": hostname,
|
|
"status": row[2] or status,
|
|
"change": "updated",
|
|
"source": "quick_job",
|
|
}
|
|
|
|
adapters.service_log(
|
|
"assemblies",
|
|
f"quick_job_result processed job_id={job_id} status={status}",
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive guard
|
|
logger.warning(
|
|
"quick_job_result handler error for job_id=%s: %s",
|
|
job_id,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
finally:
|
|
if cursor is not None:
|
|
try:
|
|
cursor.close()
|
|
except Exception:
|
|
pass
|
|
if conn is not None:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
if broadcast_payload:
|
|
try:
|
|
socket_server.emit("device_activity_changed", broadcast_payload)
|
|
except Exception as exc: # pragma: no cover - defensive guard
|
|
logger.debug(
|
|
"Failed to emit device_activity_changed for job_id=%s: %s",
|
|
job_id,
|
|
exc,
|
|
)
|
|
|
|
@socket_server.on("vpn_shell_open")
|
|
def _vpn_shell_open(data: Any) -> Dict[str, Any]:
|
|
agent_id = ""
|
|
if isinstance(data, dict):
|
|
agent_id = str(data.get("agent_id") or "").strip()
|
|
elif isinstance(data, str):
|
|
agent_id = data.strip()
|
|
if not agent_id:
|
|
_shell_log(
|
|
"vpn_shell_open_missing sid={0} remote={1}".format(
|
|
request.sid,
|
|
_remote_addr() or "-",
|
|
),
|
|
level="WARNING",
|
|
)
|
|
return {"error": "agent_id_required"}
|
|
|
|
_shell_log(
|
|
"vpn_shell_open_request agent_id={0} sid={1} remote={2}".format(
|
|
agent_id,
|
|
request.sid,
|
|
_remote_addr() or "-",
|
|
)
|
|
)
|
|
service = _get_tunnel_service()
|
|
if service is None:
|
|
_shell_log(
|
|
"vpn_shell_open_failed agent_id={0} sid={1} reason=vpn_service_unavailable".format(
|
|
agent_id,
|
|
request.sid,
|
|
),
|
|
level="WARNING",
|
|
)
|
|
return {"error": "vpn_service_unavailable"}
|
|
if not service.status(agent_id):
|
|
_shell_log(
|
|
"vpn_shell_open_failed agent_id={0} sid={1} reason=tunnel_down".format(
|
|
agent_id,
|
|
request.sid,
|
|
),
|
|
level="WARNING",
|
|
)
|
|
return {"error": "tunnel_down"}
|
|
registry = getattr(context, "agent_socket_registry", None)
|
|
if registry and hasattr(registry, "is_registered"):
|
|
try:
|
|
if not registry.is_registered(agent_id):
|
|
_shell_log(
|
|
"vpn_shell_open_failed agent_id={0} sid={1} reason=agent_socket_missing".format(
|
|
agent_id,
|
|
request.sid,
|
|
),
|
|
level="WARNING",
|
|
)
|
|
return {"error": "agent_socket_missing"}
|
|
except Exception:
|
|
agent_logger.debug("agent_socket_registry lookup failed for agent_id=%s", agent_id, exc_info=True)
|
|
|
|
session = shell_bridge.open_session(request.sid, agent_id)
|
|
if session is None:
|
|
_shell_log(
|
|
"vpn_shell_open_failed agent_id={0} sid={1} reason=shell_connect_failed".format(
|
|
agent_id,
|
|
request.sid,
|
|
),
|
|
level="WARNING",
|
|
)
|
|
return {"error": "shell_connect_failed"}
|
|
service.bump_activity(agent_id)
|
|
_shell_log(
|
|
"vpn_shell_open_success agent_id={0} sid={1}".format(
|
|
agent_id,
|
|
request.sid,
|
|
)
|
|
)
|
|
return {"status": "ok"}
|
|
|
|
@socket_server.on("connect_agent")
|
|
def _connect_agent(data: Any) -> Dict[str, Any]:
|
|
agent_id = ""
|
|
service_mode = ""
|
|
if isinstance(data, dict):
|
|
agent_id = str(data.get("agent_id") or "").strip()
|
|
service_mode = str(data.get("service_mode") or "").strip().lower()
|
|
elif isinstance(data, str):
|
|
agent_id = data.strip()
|
|
if not agent_id:
|
|
_tunnel_log(
|
|
"vpn_agent_socket_missing sid={0} remote={1}".format(
|
|
request.sid,
|
|
_remote_addr() or "-",
|
|
),
|
|
level="WARNING",
|
|
)
|
|
return {"error": "agent_id_required"}
|
|
|
|
agent_registry.register(agent_id, request.sid)
|
|
agent_logger.info("Agent socket registered agent_id=%s service_mode=%s sid=%s", agent_id, service_mode, request.sid)
|
|
_tunnel_log(
|
|
"vpn_agent_socket_register agent_id={0} service_mode={1} sid={2} remote={3}".format(
|
|
agent_id,
|
|
service_mode or "-",
|
|
request.sid,
|
|
_remote_addr() or "-",
|
|
)
|
|
)
|
|
|
|
service = _get_tunnel_service()
|
|
if service:
|
|
payload = service.session_payload(agent_id, include_token=True)
|
|
if payload:
|
|
if agent_registry.emit(agent_id, "vpn_tunnel_start", payload):
|
|
_tunnel_log(
|
|
"vpn_agent_socket_emit_start agent_id={0} tunnel_id={1} sid={2}".format(
|
|
agent_id,
|
|
payload.get("tunnel_id", "-"),
|
|
request.sid,
|
|
)
|
|
)
|
|
|
|
return {"status": "ok"}
|
|
|
|
@socket_server.on("vpn_shell_send")
|
|
def _vpn_shell_send(data: Any) -> Dict[str, Any]:
|
|
payload = None
|
|
if isinstance(data, dict):
|
|
payload = data.get("data")
|
|
else:
|
|
payload = data
|
|
if payload is None:
|
|
return {"error": "payload_required"}
|
|
try:
|
|
payload_len = len(str(payload))
|
|
except Exception:
|
|
payload_len = 0
|
|
_shell_log(
|
|
"vpn_shell_send_request sid={0} bytes={1} remote={2}".format(
|
|
request.sid,
|
|
payload_len,
|
|
_remote_addr() or "-",
|
|
)
|
|
)
|
|
shell_bridge.send(request.sid, str(payload))
|
|
return {"status": "ok"}
|
|
|
|
@socket_server.on("vpn_shell_close")
|
|
def _vpn_shell_close(data: Any = None) -> Dict[str, Any]:
|
|
_shell_log(
|
|
"vpn_shell_close_request sid={0} remote={1}".format(
|
|
request.sid,
|
|
_remote_addr() or "-",
|
|
)
|
|
)
|
|
shell_bridge.close(request.sid)
|
|
return {"status": "ok"}
|
|
|
|
@socket_server.on("disconnect")
|
|
def _ws_disconnect() -> None:
|
|
agent_id = agent_registry.unregister(request.sid)
|
|
if agent_id:
|
|
agent_logger.info("Agent socket disconnected agent_id=%s sid=%s", agent_id, request.sid)
|
|
_tunnel_log(
|
|
"vpn_agent_socket_disconnect agent_id={0} sid={1}".format(
|
|
agent_id,
|
|
request.sid,
|
|
)
|
|
)
|
|
else:
|
|
_shell_log(
|
|
"vpn_shell_client_disconnect sid={0} remote={1}".format(
|
|
request.sid,
|
|
_remote_addr() or "-",
|
|
),
|
|
level="WARNING",
|
|
)
|
|
shell_bridge.close(request.sid)
|