# ====================================================== # Data\Engine\services\WebSocket\__init__.py # Description: Socket.IO handlers for Engine runtime quick job updates and realtime notifications. # # API Endpoints (if applicable): None # ====================================================== """WebSocket service registration for the Borealis Engine runtime.""" from __future__ import annotations import base64 import sqlite3 import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Dict, Optional from flask_socketio import SocketIO from flask import session, request from ...database import initialise_engine_database from ...server import EngineContext from .Agent.ReverseTunnel import ( ReverseTunnelService, TunnelBridge, decode_frame, TunnelFrame, ) from ..API import _make_db_conn_factory, _make_service_logger def _now_ts() -> int: return int(time.time()) def _normalize_text(value: Any) -> str: if value is None: return "" if isinstance(value, bytes): try: return value.decode("utf-8") except Exception: return value.decode("utf-8", errors="replace") return str(value) @dataclass class EngineRealtimeAdapters: context: EngineContext db_conn_factory: Callable[[], sqlite3.Connection] = field(init=False) service_log: Callable[[str, str, Optional[str]], None] = field(init=False) def __post_init__(self) -> None: initialise_engine_database(self.context.database_path, logger=self.context.logger) self.db_conn_factory = _make_db_conn_factory(self.context.database_path) log_file = str( self.context.config.get("log_file") or self.context.config.get("LOG_FILE") or "" ).strip() if log_file: base = Path(log_file).resolve().parent else: base = Path(self.context.database_path).resolve().parent self.service_log = _make_service_logger(base, self.context.logger) def register_realtime(socket_server: SocketIO, context: EngineContext) -> None: """Register Socket.IO event handlers for the Engine runtime.""" adapters = EngineRealtimeAdapters(context) logger = context.logger.getChild("realtime.quick_jobs") tunnel_service = getattr(context, "reverse_tunnel_service", None) if tunnel_service is None: tunnel_service = ReverseTunnelService( context, signer=None, db_conn_factory=adapters.db_conn_factory, socketio=socket_server, ) tunnel_service.start() setattr(context, "reverse_tunnel_service", tunnel_service) @socket_server.on("quick_job_result") def _handle_quick_job_result(data: Any) -> None: if not isinstance(data, dict): logger.debug("quick_job_result payload ignored (non-dict): %r", data) return job_id_raw = data.get("job_id") try: job_id = int(job_id_raw) except (TypeError, ValueError): logger.debug("quick_job_result missing valid job_id: %r", job_id_raw) return status = str(data.get("status") or "").strip() or "Failed" stdout = _normalize_text(data.get("stdout")) stderr = _normalize_text(data.get("stderr")) conn: Optional[sqlite3.Connection] = None cursor = None broadcast_payload: Optional[Dict[str, Any]] = None ctx_payload = data.get("context") context_info: Optional[Dict[str, Any]] = ctx_payload if isinstance(ctx_payload, dict) else None try: conn = adapters.db_conn_factory() cursor = conn.cursor() cursor.execute( "UPDATE activity_history SET status=?, stdout=?, stderr=? WHERE id=?", (status, stdout, stderr, job_id), ) if cursor.rowcount == 0: logger.debug("quick_job_result missing activity_history row for job_id=%s", job_id) conn.commit() try: cursor.execute( "SELECT run_id FROM scheduled_job_run_activity WHERE activity_id=?", (job_id,), ) link = cursor.fetchone() except sqlite3.Error: link = None run_id: Optional[int] = None scheduled_ts_ctx: Optional[int] = None if link: try: run_id = int(link[0]) except Exception: run_id = None if run_id is None and context_info: ctx_run = context_info.get("scheduled_job_run_id") or context_info.get("run_id") try: if ctx_run is not None: run_id = int(ctx_run) except (TypeError, ValueError): run_id = None try: if context_info.get("scheduled_ts") is not None: scheduled_ts_ctx = int(context_info.get("scheduled_ts")) except (TypeError, ValueError): scheduled_ts_ctx = None if run_id is not None: ts_now = _now_ts() try: if status.lower() == "running": cursor.execute( "UPDATE scheduled_job_runs SET status='Running', updated_at=? WHERE id=?", (ts_now, run_id), ) else: cursor.execute( """ UPDATE scheduled_job_runs SET status=?, finished_ts=COALESCE(finished_ts, ?), updated_at=? WHERE id=? """, (status, ts_now, ts_now, run_id), ) if scheduled_ts_ctx is not None: cursor.execute( "UPDATE scheduled_job_runs SET scheduled_ts=COALESCE(scheduled_ts, ?) WHERE id=?", (scheduled_ts_ctx, run_id), ) conn.commit() adapters.service_log( "scheduled_jobs", f"scheduled run update run_id={run_id} activity_id={job_id} status={status}", ) except Exception as exc: # pragma: no cover - defensive guard logger.debug( "quick_job_result failed to update scheduled_job_runs for job_id=%s run_id=%s: %s", job_id, run_id, exc, ) elif context_info: adapters.service_log( "scheduled_jobs", f"scheduled run update skipped (no run_id) activity_id={job_id} status={status} context={context_info}", level="WARNING", ) try: cursor.execute( "SELECT id, hostname, status FROM activity_history WHERE id=?", (job_id,), ) row = cursor.fetchone() except sqlite3.Error: row = None if row: hostname = (row[1] or "").strip() if hostname: broadcast_payload = { "activity_id": int(row[0]), "hostname": hostname, "status": row[2] or status, "change": "updated", "source": "quick_job", } adapters.service_log( "assemblies", f"quick_job_result processed job_id={job_id} status={status}", ) except Exception as exc: # pragma: no cover - defensive guard logger.warning( "quick_job_result handler error for job_id=%s: %s", job_id, exc, exc_info=True, ) finally: if cursor is not None: try: cursor.close() except Exception: pass if conn is not None: try: conn.close() except Exception: pass if broadcast_payload: try: socket_server.emit("device_activity_changed", broadcast_payload) except Exception as exc: # pragma: no cover - defensive guard logger.debug( "Failed to emit device_activity_changed for job_id=%s: %s", job_id, exc, ) @socket_server.on("tunnel_bridge_attach") def _tunnel_bridge_attach(data: Any) -> Any: """Placeholder operator bridge attach handler (no data channel yet).""" if not isinstance(data, dict): return {"error": "invalid_payload"} tunnel_id = str(data.get("tunnel_id") or "").strip() operator_id = str(data.get("operator_id") or "").strip() or None if not tunnel_id: return {"error": "tunnel_id_required"} try: tunnel_service.operator_attach(tunnel_id, operator_id) except ValueError as exc: return {"error": str(exc)} except Exception as exc: # pragma: no cover - defensive guard logger.debug("tunnel_bridge_attach failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) return {"error": "bridge_attach_failed"} return {"status": "ok", "tunnel_id": tunnel_id, "operator_id": operator_id or "-"} def _encode_frame(frame: TunnelFrame) -> str: return base64.b64encode(frame.encode()).decode("ascii") def _decode_frame_payload(raw: Any) -> TunnelFrame: if isinstance(raw, str): try: raw_bytes = base64.b64decode(raw) except Exception: raise ValueError("invalid_frame") elif isinstance(raw, (bytes, bytearray)): raw_bytes = bytes(raw) else: raise ValueError("invalid_frame") return decode_frame(raw_bytes) @socket_server.on("tunnel_operator_send") def _tunnel_operator_send(data: Any) -> Any: """Operator -> agent frame enqueue (placeholder queue).""" if not isinstance(data, dict): return {"error": "invalid_payload"} tunnel_id = str(data.get("tunnel_id") or "").strip() frame_raw = data.get("frame") if not tunnel_id or frame_raw is None: return {"error": "tunnel_id_and_frame_required"} try: frame = _decode_frame_payload(frame_raw) except Exception as exc: return {"error": str(exc)} bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id) if bridge is None: return {"error": "unknown_tunnel"} bridge.operator_to_agent(frame) return {"status": "ok"} @socket_server.on("tunnel_operator_poll") def _tunnel_operator_poll(data: Any) -> Any: """Operator polls queued frames from agent.""" tunnel_id = "" if isinstance(data, dict): tunnel_id = str(data.get("tunnel_id") or "").strip() if not tunnel_id: return {"error": "tunnel_id_required"} bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id) if bridge is None: return {"error": "unknown_tunnel"} frames = [] while True: frame = bridge.next_for_operator() if frame is None: break frames.append(_encode_frame(frame)) return {"frames": frames} # WebUI operator bridge namespace for browser clients tunnel_namespace = "/tunnel" _operator_sessions: Dict[str, str] = {} def _current_operator() -> Optional[str]: username = session.get("username") if username: return str(username) auth_header = (request.headers.get("Authorization") or "").strip() token = None if auth_header.lower().startswith("bearer "): token = auth_header.split(" ", 1)[1].strip() if not token: token = request.cookies.get("borealis_auth") return token or None @socket_server.on("join", namespace=tunnel_namespace) def _ws_tunnel_join(data: Any) -> Any: if not isinstance(data, dict): return {"error": "invalid_payload"} operator_id = _current_operator() if not operator_id: return {"error": "unauthorized"} tunnel_id = str(data.get("tunnel_id") or "").strip() if not tunnel_id: return {"error": "tunnel_id_required"} bridge = tunnel_service.get_bridge(tunnel_id) if bridge is None: return {"error": "unknown_tunnel"} try: tunnel_service.operator_attach(tunnel_id, operator_id) except Exception as exc: logger.debug("ws_tunnel_join failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) return {"error": "attach_failed"} sid = request.sid _operator_sessions[sid] = tunnel_id return {"status": "ok", "tunnel_id": tunnel_id} @socket_server.on("send", namespace=tunnel_namespace) def _ws_tunnel_send(data: Any) -> Any: sid = request.sid tunnel_id = _operator_sessions.get(sid) if not tunnel_id: return {"error": "not_joined"} if not isinstance(data, dict): return {"error": "invalid_payload"} frame_raw = data.get("frame") if frame_raw is None: return {"error": "frame_required"} try: frame = _decode_frame_payload(frame_raw) except Exception: return {"error": "invalid_frame"} bridge = tunnel_service.get_bridge(tunnel_id) if bridge is None: return {"error": "unknown_tunnel"} bridge.operator_to_agent(frame) return {"status": "ok"} @socket_server.on("poll", namespace=tunnel_namespace) def _ws_tunnel_poll() -> Any: sid = request.sid tunnel_id = _operator_sessions.get(sid) if not tunnel_id: return {"error": "not_joined"} bridge = tunnel_service.get_bridge(tunnel_id) if bridge is None: return {"error": "unknown_tunnel"} frames = [] while True: frame = bridge.next_for_operator() if frame is None: break frames.append(_encode_frame(frame)) return {"frames": frames} @socket_server.on("disconnect", namespace=tunnel_namespace) def _ws_tunnel_disconnect(): sid = request.sid _operator_sessions.pop(sid, None)