Files
Borealis-Github-Replica/Data/Engine/services/WebSocket/__init__.py

493 lines
18 KiB
Python

# ======================================================
# Data\Engine\services\WebSocket\__init__.py
# Description: Socket.IO handlers for Engine runtime quick job updates and realtime notifications.
#
# API Endpoints (if applicable): None
# ======================================================
"""WebSocket service registration for the Borealis Engine runtime."""
from __future__ import annotations
import base64
import sqlite3
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, Optional
from flask import session, request
from flask_socketio import SocketIO
from ...database import initialise_engine_database
from ...server import EngineContext
from .Agent.ReverseTunnel import (
ReverseTunnelService,
TunnelBridge,
decode_frame,
TunnelFrame,
)
def _now_ts() -> int:
return int(time.time())
def _normalize_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, bytes):
try:
return value.decode("utf-8")
except Exception:
return value.decode("utf-8", errors="replace")
return str(value)
@dataclass
class EngineRealtimeAdapters:
context: EngineContext
db_conn_factory: Callable[[], sqlite3.Connection] = field(init=False)
service_log: Callable[[str, str, Optional[str]], None] = field(init=False)
def __post_init__(self) -> None:
from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load
initialise_engine_database(self.context.database_path, logger=self.context.logger)
self.db_conn_factory = _make_db_conn_factory(self.context.database_path)
log_file = str(
self.context.config.get("log_file")
or self.context.config.get("LOG_FILE")
or ""
).strip()
if log_file:
base = Path(log_file).resolve().parent
else:
base = Path(self.context.database_path).resolve().parent
self.service_log = _make_service_logger(base, self.context.logger)
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
"""Register Socket.IO event handlers for the Engine runtime."""
from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load
adapters = EngineRealtimeAdapters(context)
logger = context.logger.getChild("realtime.quick_jobs")
tunnel_service = getattr(context, "reverse_tunnel_service", None)
if tunnel_service is None:
tunnel_service = ReverseTunnelService(
context,
signer=None,
db_conn_factory=adapters.db_conn_factory,
socketio=socket_server,
)
tunnel_service.start()
setattr(context, "reverse_tunnel_service", tunnel_service)
@socket_server.on("quick_job_result")
def _handle_quick_job_result(data: Any) -> None:
if not isinstance(data, dict):
logger.debug("quick_job_result payload ignored (non-dict): %r", data)
return
job_id_raw = data.get("job_id")
try:
job_id = int(job_id_raw)
except (TypeError, ValueError):
logger.debug("quick_job_result missing valid job_id: %r", job_id_raw)
return
status = str(data.get("status") or "").strip() or "Failed"
stdout = _normalize_text(data.get("stdout"))
stderr = _normalize_text(data.get("stderr"))
conn: Optional[sqlite3.Connection] = None
cursor = None
broadcast_payload: Optional[Dict[str, Any]] = None
ctx_payload = data.get("context")
context_info: Optional[Dict[str, Any]] = ctx_payload if isinstance(ctx_payload, dict) else None
try:
conn = adapters.db_conn_factory()
cursor = conn.cursor()
cursor.execute(
"UPDATE activity_history SET status=?, stdout=?, stderr=? WHERE id=?",
(status, stdout, stderr, job_id),
)
if cursor.rowcount == 0:
logger.debug("quick_job_result missing activity_history row for job_id=%s", job_id)
conn.commit()
try:
cursor.execute(
"SELECT run_id FROM scheduled_job_run_activity WHERE activity_id=?",
(job_id,),
)
link = cursor.fetchone()
except sqlite3.Error:
link = None
run_id: Optional[int] = None
scheduled_ts_ctx: Optional[int] = None
if link:
try:
run_id = int(link[0])
except Exception:
run_id = None
if run_id is None and context_info:
ctx_run = context_info.get("scheduled_job_run_id") or context_info.get("run_id")
try:
if ctx_run is not None:
run_id = int(ctx_run)
except (TypeError, ValueError):
run_id = None
try:
if context_info.get("scheduled_ts") is not None:
scheduled_ts_ctx = int(context_info.get("scheduled_ts"))
except (TypeError, ValueError):
scheduled_ts_ctx = None
if run_id is not None:
ts_now = _now_ts()
try:
if status.lower() == "running":
cursor.execute(
"UPDATE scheduled_job_runs SET status='Running', updated_at=? WHERE id=?",
(ts_now, run_id),
)
else:
cursor.execute(
"""
UPDATE scheduled_job_runs
SET status=?,
finished_ts=COALESCE(finished_ts, ?),
updated_at=?
WHERE id=?
""",
(status, ts_now, ts_now, run_id),
)
if scheduled_ts_ctx is not None:
cursor.execute(
"UPDATE scheduled_job_runs SET scheduled_ts=COALESCE(scheduled_ts, ?) WHERE id=?",
(scheduled_ts_ctx, run_id),
)
conn.commit()
adapters.service_log(
"scheduled_jobs",
f"scheduled run update run_id={run_id} activity_id={job_id} status={status}",
)
except Exception as exc: # pragma: no cover - defensive guard
logger.debug(
"quick_job_result failed to update scheduled_job_runs for job_id=%s run_id=%s: %s",
job_id,
run_id,
exc,
)
elif context_info:
adapters.service_log(
"scheduled_jobs",
f"scheduled run update skipped (no run_id) activity_id={job_id} status={status} context={context_info}",
level="WARNING",
)
try:
cursor.execute(
"SELECT id, hostname, status FROM activity_history WHERE id=?",
(job_id,),
)
row = cursor.fetchone()
except sqlite3.Error:
row = None
if row:
hostname = (row[1] or "").strip()
if hostname:
broadcast_payload = {
"activity_id": int(row[0]),
"hostname": hostname,
"status": row[2] or status,
"change": "updated",
"source": "quick_job",
}
adapters.service_log(
"assemblies",
f"quick_job_result processed job_id={job_id} status={status}",
)
except Exception as exc: # pragma: no cover - defensive guard
logger.warning(
"quick_job_result handler error for job_id=%s: %s",
job_id,
exc,
exc_info=True,
)
finally:
if cursor is not None:
try:
cursor.close()
except Exception:
pass
if conn is not None:
try:
conn.close()
except Exception:
pass
if broadcast_payload:
try:
socket_server.emit("device_activity_changed", broadcast_payload)
except Exception as exc: # pragma: no cover - defensive guard
logger.debug(
"Failed to emit device_activity_changed for job_id=%s: %s",
job_id,
exc,
)
@socket_server.on("tunnel_bridge_attach")
def _tunnel_bridge_attach(data: Any) -> Any:
"""Placeholder operator bridge attach handler (no data channel yet)."""
if not isinstance(data, dict):
return {"error": "invalid_payload"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
operator_id = str(data.get("operator_id") or "").strip() or None
if not tunnel_id:
return {"error": "tunnel_id_required"}
try:
tunnel_service.operator_attach(tunnel_id, operator_id)
except ValueError as exc:
return {"error": str(exc)}
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("tunnel_bridge_attach failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "bridge_attach_failed"}
return {"status": "ok", "tunnel_id": tunnel_id, "operator_id": operator_id or "-"}
def _encode_frame(frame: TunnelFrame) -> str:
return base64.b64encode(frame.encode()).decode("ascii")
def _decode_frame_payload(raw: Any) -> TunnelFrame:
if isinstance(raw, str):
try:
raw_bytes = base64.b64decode(raw)
except Exception:
raise ValueError("invalid_frame")
elif isinstance(raw, (bytes, bytearray)):
raw_bytes = bytes(raw)
else:
raise ValueError("invalid_frame")
return decode_frame(raw_bytes)
@socket_server.on("tunnel_operator_send")
def _tunnel_operator_send(data: Any) -> Any:
"""Operator -> agent frame enqueue (placeholder queue)."""
if not isinstance(data, dict):
return {"error": "invalid_payload"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
frame_raw = data.get("frame")
if not tunnel_id or frame_raw is None:
return {"error": "tunnel_id_and_frame_required"}
try:
frame = _decode_frame_payload(frame_raw)
except Exception as exc:
return {"error": str(exc)}
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
bridge.operator_to_agent(frame)
return {"status": "ok"}
@socket_server.on("tunnel_operator_poll")
def _tunnel_operator_poll(data: Any) -> Any:
"""Operator polls queued frames from agent."""
tunnel_id = ""
if isinstance(data, dict):
tunnel_id = str(data.get("tunnel_id") or "").strip()
if not tunnel_id:
return {"error": "tunnel_id_required"}
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
frames = []
while True:
frame = bridge.next_for_operator()
if frame is None:
break
frames.append(_encode_frame(frame))
return {"frames": frames}
# WebUI operator bridge namespace for browser clients
tunnel_namespace = "/tunnel"
_operator_sessions: Dict[str, str] = {}
def _current_operator() -> Optional[str]:
username = session.get("username")
if username:
return str(username)
auth_header = (request.headers.get("Authorization") or "").strip()
token = None
if auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
if not token:
token = request.cookies.get("borealis_auth")
return token or None
@socket_server.on("join", namespace=tunnel_namespace)
def _ws_tunnel_join(data: Any) -> Any:
if not isinstance(data, dict):
return {"error": "invalid_payload"}
operator_id = _current_operator()
if not operator_id:
return {"error": "unauthorized"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
if not tunnel_id:
return {"error": "tunnel_id_required"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
try:
tunnel_service.operator_attach(tunnel_id, operator_id)
except Exception as exc:
logger.debug("ws_tunnel_join failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "attach_failed"}
sid = request.sid
_operator_sessions[sid] = tunnel_id
return {"status": "ok", "tunnel_id": tunnel_id}
@socket_server.on("send", namespace=tunnel_namespace)
def _ws_tunnel_send(data: Any) -> Any:
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return {"error": "not_joined"}
if not isinstance(data, dict):
return {"error": "invalid_payload"}
frame_raw = data.get("frame")
if frame_raw is None:
return {"error": "frame_required"}
try:
frame = _decode_frame_payload(frame_raw)
except Exception:
return {"error": "invalid_frame"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
bridge.operator_to_agent(frame)
return {"status": "ok"}
@socket_server.on("poll", namespace=tunnel_namespace)
def _ws_tunnel_poll() -> Any:
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return {"error": "not_joined"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
frames = []
while True:
frame = bridge.next_for_operator()
if frame is None:
break
frames.append(_encode_frame(frame))
return {"frames": frames}
def _require_ps_server():
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return None, None, {"error": "not_joined"}
server = tunnel_service.ensure_ps_server(tunnel_id)
if server is None:
return None, tunnel_id, {"error": "ps_unsupported"}
return server, tunnel_id, None
@socket_server.on("ps_open", namespace=tunnel_namespace)
def _ws_ps_open(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = 120
rows = 32
if isinstance(data, dict):
try:
cols = int(data.get("cols", cols))
rows = int(data.get("rows", rows))
except Exception:
pass
cols = max(20, min(cols, 300))
rows = max(10, min(rows, 200))
try:
server.open_channel(cols=cols, rows=rows)
except Exception as exc:
logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_open_failed"}
return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows}
@socket_server.on("ps_send", namespace=tunnel_namespace)
def _ws_ps_send(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
if data is None:
return {"error": "payload_required"}
text = data
if isinstance(data, dict):
text = data.get("data")
if text is None:
return {"error": "payload_required"}
try:
server.send_input(str(text))
except Exception as exc:
logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_send_failed"}
return {"status": "ok"}
@socket_server.on("ps_resize", namespace=tunnel_namespace)
def _ws_ps_resize(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = None
rows = None
if isinstance(data, dict):
cols = data.get("cols")
rows = data.get("rows")
try:
cols_int = int(cols) if cols is not None else 120
rows_int = int(rows) if rows is not None else 32
cols_int = max(20, min(cols_int, 300))
rows_int = max(10, min(rows_int, 200))
server.send_resize(cols_int, rows_int)
return {"status": "ok", "cols": cols_int, "rows": rows_int}
except Exception as exc:
logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_resize_failed"}
@socket_server.on("ps_poll", namespace=tunnel_namespace)
def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
try:
output = server.drain_output()
status = server.status()
return {"output": output, "status": status}
except Exception as exc:
logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_poll_failed"}
@socket_server.on("disconnect", namespace=tunnel_namespace)
def _ws_tunnel_disconnect():
sid = request.sid
_operator_sessions.pop(sid, None)