# ====================================================== # Data\Engine\services\WebSocket\__init__.py # Description: Socket.IO handlers for Engine runtime quick job updates and realtime notifications. # # API Endpoints (if applicable): None # ====================================================== """WebSocket service registration for the Borealis Engine runtime.""" from __future__ import annotations import base64 import sqlite3 import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Dict, Optional from flask import session, request from flask_socketio import SocketIO from ...database import initialise_engine_database from ...server import EngineContext from .Agent.reverse_tunnel_orchestrator import ( ReverseTunnelService, TunnelBridge, decode_frame, TunnelFrame, ) def _now_ts() -> int: return int(time.time()) def _normalize_text(value: Any) -> str: if value is None: return "" if isinstance(value, bytes): try: return value.decode("utf-8") except Exception: return value.decode("utf-8", errors="replace") return str(value) @dataclass class EngineRealtimeAdapters: context: EngineContext db_conn_factory: Callable[[], sqlite3.Connection] = field(init=False) service_log: Callable[[str, str, Optional[str]], None] = field(init=False) def __post_init__(self) -> None: from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load initialise_engine_database(self.context.database_path, logger=self.context.logger) self.db_conn_factory = _make_db_conn_factory(self.context.database_path) log_file = str( self.context.config.get("log_file") or self.context.config.get("LOG_FILE") or "" ).strip() if log_file: base = Path(log_file).resolve().parent else: base = Path(self.context.database_path).resolve().parent self.service_log = _make_service_logger(base, self.context.logger) def register_realtime(socket_server: SocketIO, context: EngineContext) -> None: """Register Socket.IO event handlers for the Engine runtime.""" from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load adapters = EngineRealtimeAdapters(context) logger = context.logger.getChild("realtime.quick_jobs") tunnel_service = getattr(context, "reverse_tunnel_service", None) if tunnel_service is None: tunnel_service = ReverseTunnelService( context, signer=None, db_conn_factory=adapters.db_conn_factory, socketio=socket_server, ) tunnel_service.start() setattr(context, "reverse_tunnel_service", tunnel_service) @socket_server.on("quick_job_result") def _handle_quick_job_result(data: Any) -> None: if not isinstance(data, dict): logger.debug("quick_job_result payload ignored (non-dict): %r", data) return job_id_raw = data.get("job_id") try: job_id = int(job_id_raw) except (TypeError, ValueError): logger.debug("quick_job_result missing valid job_id: %r", job_id_raw) return status = str(data.get("status") or "").strip() or "Failed" stdout = _normalize_text(data.get("stdout")) stderr = _normalize_text(data.get("stderr")) conn: Optional[sqlite3.Connection] = None cursor = None broadcast_payload: Optional[Dict[str, Any]] = None ctx_payload = data.get("context") context_info: Optional[Dict[str, Any]] = ctx_payload if isinstance(ctx_payload, dict) else None try: conn = adapters.db_conn_factory() cursor = conn.cursor() cursor.execute( "UPDATE activity_history SET status=?, stdout=?, stderr=? WHERE id=?", (status, stdout, stderr, job_id), ) if cursor.rowcount == 0: logger.debug("quick_job_result missing activity_history row for job_id=%s", job_id) conn.commit() try: cursor.execute( "SELECT run_id FROM scheduled_job_run_activity WHERE activity_id=?", (job_id,), ) link = cursor.fetchone() except sqlite3.Error: link = None run_id: Optional[int] = None scheduled_ts_ctx: Optional[int] = None if link: try: run_id = int(link[0]) except Exception: run_id = None if run_id is None and context_info: ctx_run = context_info.get("scheduled_job_run_id") or context_info.get("run_id") try: if ctx_run is not None: run_id = int(ctx_run) except (TypeError, ValueError): run_id = None try: if context_info.get("scheduled_ts") is not None: scheduled_ts_ctx = int(context_info.get("scheduled_ts")) except (TypeError, ValueError): scheduled_ts_ctx = None if run_id is not None: ts_now = _now_ts() try: if status.lower() == "running": cursor.execute( "UPDATE scheduled_job_runs SET status='Running', updated_at=? WHERE id=?", (ts_now, run_id), ) else: cursor.execute( """ UPDATE scheduled_job_runs SET status=?, finished_ts=COALESCE(finished_ts, ?), updated_at=? WHERE id=? """, (status, ts_now, ts_now, run_id), ) if scheduled_ts_ctx is not None: cursor.execute( "UPDATE scheduled_job_runs SET scheduled_ts=COALESCE(scheduled_ts, ?) WHERE id=?", (scheduled_ts_ctx, run_id), ) conn.commit() adapters.service_log( "scheduled_jobs", f"scheduled run update run_id={run_id} activity_id={job_id} status={status}", ) except Exception as exc: # pragma: no cover - defensive guard logger.debug( "quick_job_result failed to update scheduled_job_runs for job_id=%s run_id=%s: %s", job_id, run_id, exc, ) elif context_info: adapters.service_log( "scheduled_jobs", f"scheduled run update skipped (no run_id) activity_id={job_id} status={status} context={context_info}", level="WARNING", ) try: cursor.execute( "SELECT id, hostname, status FROM activity_history WHERE id=?", (job_id,), ) row = cursor.fetchone() except sqlite3.Error: row = None if row: hostname = (row[1] or "").strip() if hostname: broadcast_payload = { "activity_id": int(row[0]), "hostname": hostname, "status": row[2] or status, "change": "updated", "source": "quick_job", } adapters.service_log( "assemblies", f"quick_job_result processed job_id={job_id} status={status}", ) except Exception as exc: # pragma: no cover - defensive guard logger.warning( "quick_job_result handler error for job_id=%s: %s", job_id, exc, exc_info=True, ) finally: if cursor is not None: try: cursor.close() except Exception: pass if conn is not None: try: conn.close() except Exception: pass if broadcast_payload: try: socket_server.emit("device_activity_changed", broadcast_payload) except Exception as exc: # pragma: no cover - defensive guard logger.debug( "Failed to emit device_activity_changed for job_id=%s: %s", job_id, exc, ) @socket_server.on("tunnel_bridge_attach") def _tunnel_bridge_attach(data: Any) -> Any: """Placeholder operator bridge attach handler (no data channel yet).""" if not isinstance(data, dict): return {"error": "invalid_payload"} tunnel_id = str(data.get("tunnel_id") or "").strip() operator_id = str(data.get("operator_id") or "").strip() or None if not tunnel_id: return {"error": "tunnel_id_required"} try: tunnel_service.operator_attach(tunnel_id, operator_id) except ValueError as exc: return {"error": str(exc)} except Exception as exc: # pragma: no cover - defensive guard logger.debug("tunnel_bridge_attach failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) return {"error": "bridge_attach_failed"} return {"status": "ok", "tunnel_id": tunnel_id, "operator_id": operator_id or "-"} def _encode_frame(frame: TunnelFrame) -> str: return base64.b64encode(frame.encode()).decode("ascii") def _decode_frame_payload(raw: Any) -> TunnelFrame: if isinstance(raw, str): try: raw_bytes = base64.b64decode(raw) except Exception: raise ValueError("invalid_frame") elif isinstance(raw, (bytes, bytearray)): raw_bytes = bytes(raw) else: raise ValueError("invalid_frame") return decode_frame(raw_bytes) @socket_server.on("tunnel_operator_send") def _tunnel_operator_send(data: Any) -> Any: """Operator -> agent frame enqueue (placeholder queue).""" if not isinstance(data, dict): return {"error": "invalid_payload"} tunnel_id = str(data.get("tunnel_id") or "").strip() frame_raw = data.get("frame") if not tunnel_id or frame_raw is None: return {"error": "tunnel_id_and_frame_required"} try: frame = _decode_frame_payload(frame_raw) except Exception as exc: return {"error": str(exc)} bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id) if bridge is None: return {"error": "unknown_tunnel"} bridge.operator_to_agent(frame) return {"status": "ok"} @socket_server.on("tunnel_operator_poll") def _tunnel_operator_poll(data: Any) -> Any: """Operator polls queued frames from agent.""" tunnel_id = "" if isinstance(data, dict): tunnel_id = str(data.get("tunnel_id") or "").strip() if not tunnel_id: return {"error": "tunnel_id_required"} bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id) if bridge is None: return {"error": "unknown_tunnel"} frames = [] while True: frame = bridge.next_for_operator() if frame is None: break frames.append(_encode_frame(frame)) return {"frames": frames} # WebUI operator bridge namespace for browser clients tunnel_namespace = "/tunnel" _operator_sessions: Dict[str, str] = {} def _current_operator() -> Optional[str]: username = session.get("username") if username: return str(username) auth_header = (request.headers.get("Authorization") or "").strip() token = None if auth_header.lower().startswith("bearer "): token = auth_header.split(" ", 1)[1].strip() if not token: token = request.cookies.get("borealis_auth") return token or None @socket_server.on("join", namespace=tunnel_namespace) def _ws_tunnel_join(data: Any) -> Any: if not isinstance(data, dict): return {"error": "invalid_payload"} operator_id = _current_operator() if not operator_id: return {"error": "unauthorized"} tunnel_id = str(data.get("tunnel_id") or "").strip() if not tunnel_id: return {"error": "tunnel_id_required"} bridge = tunnel_service.get_bridge(tunnel_id) if bridge is None: return {"error": "unknown_tunnel"} try: tunnel_service.operator_attach(tunnel_id, operator_id) except Exception as exc: logger.debug("ws_tunnel_join failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) return {"error": "attach_failed"} sid = request.sid _operator_sessions[sid] = tunnel_id return {"status": "ok", "tunnel_id": tunnel_id} @socket_server.on("send", namespace=tunnel_namespace) def _ws_tunnel_send(data: Any) -> Any: sid = request.sid tunnel_id = _operator_sessions.get(sid) if not tunnel_id: return {"error": "not_joined"} if not isinstance(data, dict): return {"error": "invalid_payload"} frame_raw = data.get("frame") if frame_raw is None: return {"error": "frame_required"} try: frame = _decode_frame_payload(frame_raw) except Exception: return {"error": "invalid_frame"} bridge = tunnel_service.get_bridge(tunnel_id) if bridge is None: return {"error": "unknown_tunnel"} bridge.operator_to_agent(frame) return {"status": "ok"} @socket_server.on("poll", namespace=tunnel_namespace) def _ws_tunnel_poll() -> Any: sid = request.sid tunnel_id = _operator_sessions.get(sid) if not tunnel_id: return {"error": "not_joined"} bridge = tunnel_service.get_bridge(tunnel_id) if bridge is None: return {"error": "unknown_tunnel"} frames = [] while True: frame = bridge.next_for_operator() if frame is None: break frames.append(_encode_frame(frame)) return {"frames": frames} def _require_ps_server(): sid = request.sid tunnel_id = _operator_sessions.get(sid) if not tunnel_id: return None, None, {"error": "not_joined"} server = tunnel_service.ensure_protocol_server(tunnel_id) if server is None or not hasattr(server, "open_channel"): return None, tunnel_id, {"error": "ps_unsupported"} return server, tunnel_id, None @socket_server.on("ps_open", namespace=tunnel_namespace) def _ws_ps_open(data: Any) -> Any: server, tunnel_id, error = _require_ps_server() if server is None: return error cols = 120 rows = 32 if isinstance(data, dict): try: cols = int(data.get("cols", cols)) rows = int(data.get("rows", rows)) except Exception: pass cols = max(20, min(cols, 300)) rows = max(10, min(rows, 200)) try: server.open_channel(cols=cols, rows=rows) except Exception as exc: logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) return {"error": "ps_open_failed"} return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows} @socket_server.on("ps_send", namespace=tunnel_namespace) def _ws_ps_send(data: Any) -> Any: server, tunnel_id, error = _require_ps_server() if server is None: return error if data is None: return {"error": "payload_required"} text = data if isinstance(data, dict): text = data.get("data") if text is None: return {"error": "payload_required"} try: server.send_input(str(text)) except Exception as exc: logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) return {"error": "ps_send_failed"} return {"status": "ok"} @socket_server.on("ps_resize", namespace=tunnel_namespace) def _ws_ps_resize(data: Any) -> Any: server, tunnel_id, error = _require_ps_server() if server is None: return error cols = None rows = None if isinstance(data, dict): cols = data.get("cols") rows = data.get("rows") try: cols_int = int(cols) if cols is not None else 120 rows_int = int(rows) if rows is not None else 32 cols_int = max(20, min(cols_int, 300)) rows_int = max(10, min(rows_int, 200)) server.send_resize(cols_int, rows_int) return {"status": "ok", "cols": cols_int, "rows": rows_int} except Exception as exc: logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) return {"error": "ps_resize_failed"} @socket_server.on("ps_poll", namespace=tunnel_namespace) def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused server, tunnel_id, error = _require_ps_server() if server is None: return error try: output = server.drain_output() status = server.status() return {"output": output, "status": status} except Exception as exc: logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) return {"error": "ps_poll_failed"} @socket_server.on("disconnect", namespace=tunnel_namespace) def _ws_tunnel_disconnect(): sid = request.sid tunnel_id = _operator_sessions.pop(sid, None) if tunnel_id and tunnel_id not in _operator_sessions.values(): try: tunnel_service.stop_tunnel(tunnel_id, reason="operator_socket_disconnect") except Exception as exc: logger.debug("ws_tunnel_disconnect stop_tunnel failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)