mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-16 10:05:48 -07:00
493 lines
18 KiB
Python
493 lines
18 KiB
Python
# ======================================================
|
|
# Data\Engine\services\WebSocket\__init__.py
|
|
# Description: Socket.IO handlers for Engine runtime quick job updates and realtime notifications.
|
|
#
|
|
# API Endpoints (if applicable): None
|
|
# ======================================================
|
|
|
|
"""WebSocket service registration for the Borealis Engine runtime."""
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import sqlite3
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
from flask import session, request
|
|
from flask_socketio import SocketIO
|
|
|
|
from ...database import initialise_engine_database
|
|
from ...server import EngineContext
|
|
from .Agent.ReverseTunnel import (
|
|
ReverseTunnelService,
|
|
TunnelBridge,
|
|
decode_frame,
|
|
TunnelFrame,
|
|
)
|
|
|
|
|
|
def _now_ts() -> int:
|
|
return int(time.time())
|
|
|
|
|
|
def _normalize_text(value: Any) -> str:
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, bytes):
|
|
try:
|
|
return value.decode("utf-8")
|
|
except Exception:
|
|
return value.decode("utf-8", errors="replace")
|
|
return str(value)
|
|
|
|
|
|
@dataclass
|
|
class EngineRealtimeAdapters:
|
|
context: EngineContext
|
|
db_conn_factory: Callable[[], sqlite3.Connection] = field(init=False)
|
|
service_log: Callable[[str, str, Optional[str]], None] = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load
|
|
|
|
initialise_engine_database(self.context.database_path, logger=self.context.logger)
|
|
self.db_conn_factory = _make_db_conn_factory(self.context.database_path)
|
|
|
|
log_file = str(
|
|
self.context.config.get("log_file")
|
|
or self.context.config.get("LOG_FILE")
|
|
or ""
|
|
).strip()
|
|
if log_file:
|
|
base = Path(log_file).resolve().parent
|
|
else:
|
|
base = Path(self.context.database_path).resolve().parent
|
|
self.service_log = _make_service_logger(base, self.context.logger)
|
|
|
|
|
|
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
|
"""Register Socket.IO event handlers for the Engine runtime."""
|
|
|
|
from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load
|
|
|
|
adapters = EngineRealtimeAdapters(context)
|
|
logger = context.logger.getChild("realtime.quick_jobs")
|
|
tunnel_service = getattr(context, "reverse_tunnel_service", None)
|
|
if tunnel_service is None:
|
|
tunnel_service = ReverseTunnelService(
|
|
context,
|
|
signer=None,
|
|
db_conn_factory=adapters.db_conn_factory,
|
|
socketio=socket_server,
|
|
)
|
|
tunnel_service.start()
|
|
setattr(context, "reverse_tunnel_service", tunnel_service)
|
|
|
|
@socket_server.on("quick_job_result")
|
|
def _handle_quick_job_result(data: Any) -> None:
|
|
if not isinstance(data, dict):
|
|
logger.debug("quick_job_result payload ignored (non-dict): %r", data)
|
|
return
|
|
|
|
job_id_raw = data.get("job_id")
|
|
try:
|
|
job_id = int(job_id_raw)
|
|
except (TypeError, ValueError):
|
|
logger.debug("quick_job_result missing valid job_id: %r", job_id_raw)
|
|
return
|
|
|
|
status = str(data.get("status") or "").strip() or "Failed"
|
|
stdout = _normalize_text(data.get("stdout"))
|
|
stderr = _normalize_text(data.get("stderr"))
|
|
|
|
conn: Optional[sqlite3.Connection] = None
|
|
cursor = None
|
|
broadcast_payload: Optional[Dict[str, Any]] = None
|
|
|
|
ctx_payload = data.get("context")
|
|
context_info: Optional[Dict[str, Any]] = ctx_payload if isinstance(ctx_payload, dict) else None
|
|
|
|
try:
|
|
conn = adapters.db_conn_factory()
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE activity_history SET status=?, stdout=?, stderr=? WHERE id=?",
|
|
(status, stdout, stderr, job_id),
|
|
)
|
|
if cursor.rowcount == 0:
|
|
logger.debug("quick_job_result missing activity_history row for job_id=%s", job_id)
|
|
conn.commit()
|
|
|
|
try:
|
|
cursor.execute(
|
|
"SELECT run_id FROM scheduled_job_run_activity WHERE activity_id=?",
|
|
(job_id,),
|
|
)
|
|
link = cursor.fetchone()
|
|
except sqlite3.Error:
|
|
link = None
|
|
|
|
run_id: Optional[int] = None
|
|
scheduled_ts_ctx: Optional[int] = None
|
|
if link:
|
|
try:
|
|
run_id = int(link[0])
|
|
except Exception:
|
|
run_id = None
|
|
|
|
if run_id is None and context_info:
|
|
ctx_run = context_info.get("scheduled_job_run_id") or context_info.get("run_id")
|
|
try:
|
|
if ctx_run is not None:
|
|
run_id = int(ctx_run)
|
|
except (TypeError, ValueError):
|
|
run_id = None
|
|
try:
|
|
if context_info.get("scheduled_ts") is not None:
|
|
scheduled_ts_ctx = int(context_info.get("scheduled_ts"))
|
|
except (TypeError, ValueError):
|
|
scheduled_ts_ctx = None
|
|
|
|
if run_id is not None:
|
|
ts_now = _now_ts()
|
|
try:
|
|
if status.lower() == "running":
|
|
cursor.execute(
|
|
"UPDATE scheduled_job_runs SET status='Running', updated_at=? WHERE id=?",
|
|
(ts_now, run_id),
|
|
)
|
|
else:
|
|
cursor.execute(
|
|
"""
|
|
UPDATE scheduled_job_runs
|
|
SET status=?,
|
|
finished_ts=COALESCE(finished_ts, ?),
|
|
updated_at=?
|
|
WHERE id=?
|
|
""",
|
|
(status, ts_now, ts_now, run_id),
|
|
)
|
|
if scheduled_ts_ctx is not None:
|
|
cursor.execute(
|
|
"UPDATE scheduled_job_runs SET scheduled_ts=COALESCE(scheduled_ts, ?) WHERE id=?",
|
|
(scheduled_ts_ctx, run_id),
|
|
)
|
|
conn.commit()
|
|
adapters.service_log(
|
|
"scheduled_jobs",
|
|
f"scheduled run update run_id={run_id} activity_id={job_id} status={status}",
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive guard
|
|
logger.debug(
|
|
"quick_job_result failed to update scheduled_job_runs for job_id=%s run_id=%s: %s",
|
|
job_id,
|
|
run_id,
|
|
exc,
|
|
)
|
|
elif context_info:
|
|
adapters.service_log(
|
|
"scheduled_jobs",
|
|
f"scheduled run update skipped (no run_id) activity_id={job_id} status={status} context={context_info}",
|
|
level="WARNING",
|
|
)
|
|
|
|
try:
|
|
cursor.execute(
|
|
"SELECT id, hostname, status FROM activity_history WHERE id=?",
|
|
(job_id,),
|
|
)
|
|
row = cursor.fetchone()
|
|
except sqlite3.Error:
|
|
row = None
|
|
|
|
if row:
|
|
hostname = (row[1] or "").strip()
|
|
if hostname:
|
|
broadcast_payload = {
|
|
"activity_id": int(row[0]),
|
|
"hostname": hostname,
|
|
"status": row[2] or status,
|
|
"change": "updated",
|
|
"source": "quick_job",
|
|
}
|
|
|
|
adapters.service_log(
|
|
"assemblies",
|
|
f"quick_job_result processed job_id={job_id} status={status}",
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive guard
|
|
logger.warning(
|
|
"quick_job_result handler error for job_id=%s: %s",
|
|
job_id,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
finally:
|
|
if cursor is not None:
|
|
try:
|
|
cursor.close()
|
|
except Exception:
|
|
pass
|
|
if conn is not None:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
if broadcast_payload:
|
|
try:
|
|
socket_server.emit("device_activity_changed", broadcast_payload)
|
|
except Exception as exc: # pragma: no cover - defensive guard
|
|
logger.debug(
|
|
"Failed to emit device_activity_changed for job_id=%s: %s",
|
|
job_id,
|
|
exc,
|
|
)
|
|
|
|
@socket_server.on("tunnel_bridge_attach")
|
|
def _tunnel_bridge_attach(data: Any) -> Any:
|
|
"""Placeholder operator bridge attach handler (no data channel yet)."""
|
|
|
|
if not isinstance(data, dict):
|
|
return {"error": "invalid_payload"}
|
|
|
|
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
|
operator_id = str(data.get("operator_id") or "").strip() or None
|
|
if not tunnel_id:
|
|
return {"error": "tunnel_id_required"}
|
|
|
|
try:
|
|
tunnel_service.operator_attach(tunnel_id, operator_id)
|
|
except ValueError as exc:
|
|
return {"error": str(exc)}
|
|
except Exception as exc: # pragma: no cover - defensive guard
|
|
logger.debug("tunnel_bridge_attach failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
|
return {"error": "bridge_attach_failed"}
|
|
|
|
return {"status": "ok", "tunnel_id": tunnel_id, "operator_id": operator_id or "-"}
|
|
|
|
def _encode_frame(frame: TunnelFrame) -> str:
|
|
return base64.b64encode(frame.encode()).decode("ascii")
|
|
|
|
def _decode_frame_payload(raw: Any) -> TunnelFrame:
|
|
if isinstance(raw, str):
|
|
try:
|
|
raw_bytes = base64.b64decode(raw)
|
|
except Exception:
|
|
raise ValueError("invalid_frame")
|
|
elif isinstance(raw, (bytes, bytearray)):
|
|
raw_bytes = bytes(raw)
|
|
else:
|
|
raise ValueError("invalid_frame")
|
|
return decode_frame(raw_bytes)
|
|
|
|
@socket_server.on("tunnel_operator_send")
|
|
def _tunnel_operator_send(data: Any) -> Any:
|
|
"""Operator -> agent frame enqueue (placeholder queue)."""
|
|
|
|
if not isinstance(data, dict):
|
|
return {"error": "invalid_payload"}
|
|
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
|
frame_raw = data.get("frame")
|
|
if not tunnel_id or frame_raw is None:
|
|
return {"error": "tunnel_id_and_frame_required"}
|
|
try:
|
|
frame = _decode_frame_payload(frame_raw)
|
|
except Exception as exc:
|
|
return {"error": str(exc)}
|
|
|
|
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
|
|
if bridge is None:
|
|
return {"error": "unknown_tunnel"}
|
|
bridge.operator_to_agent(frame)
|
|
return {"status": "ok"}
|
|
|
|
@socket_server.on("tunnel_operator_poll")
|
|
def _tunnel_operator_poll(data: Any) -> Any:
|
|
"""Operator polls queued frames from agent."""
|
|
|
|
tunnel_id = ""
|
|
if isinstance(data, dict):
|
|
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
|
if not tunnel_id:
|
|
return {"error": "tunnel_id_required"}
|
|
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
|
|
if bridge is None:
|
|
return {"error": "unknown_tunnel"}
|
|
|
|
frames = []
|
|
while True:
|
|
frame = bridge.next_for_operator()
|
|
if frame is None:
|
|
break
|
|
frames.append(_encode_frame(frame))
|
|
return {"frames": frames}
|
|
|
|
# WebUI operator bridge namespace for browser clients
|
|
tunnel_namespace = "/tunnel"
|
|
_operator_sessions: Dict[str, str] = {}
|
|
|
|
def _current_operator() -> Optional[str]:
|
|
username = session.get("username")
|
|
if username:
|
|
return str(username)
|
|
auth_header = (request.headers.get("Authorization") or "").strip()
|
|
token = None
|
|
if auth_header.lower().startswith("bearer "):
|
|
token = auth_header.split(" ", 1)[1].strip()
|
|
if not token:
|
|
token = request.cookies.get("borealis_auth")
|
|
return token or None
|
|
|
|
@socket_server.on("join", namespace=tunnel_namespace)
|
|
def _ws_tunnel_join(data: Any) -> Any:
|
|
if not isinstance(data, dict):
|
|
return {"error": "invalid_payload"}
|
|
operator_id = _current_operator()
|
|
if not operator_id:
|
|
return {"error": "unauthorized"}
|
|
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
|
if not tunnel_id:
|
|
return {"error": "tunnel_id_required"}
|
|
bridge = tunnel_service.get_bridge(tunnel_id)
|
|
if bridge is None:
|
|
return {"error": "unknown_tunnel"}
|
|
try:
|
|
tunnel_service.operator_attach(tunnel_id, operator_id)
|
|
except Exception as exc:
|
|
logger.debug("ws_tunnel_join failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
|
return {"error": "attach_failed"}
|
|
sid = request.sid
|
|
_operator_sessions[sid] = tunnel_id
|
|
return {"status": "ok", "tunnel_id": tunnel_id}
|
|
|
|
@socket_server.on("send", namespace=tunnel_namespace)
|
|
def _ws_tunnel_send(data: Any) -> Any:
|
|
sid = request.sid
|
|
tunnel_id = _operator_sessions.get(sid)
|
|
if not tunnel_id:
|
|
return {"error": "not_joined"}
|
|
if not isinstance(data, dict):
|
|
return {"error": "invalid_payload"}
|
|
frame_raw = data.get("frame")
|
|
if frame_raw is None:
|
|
return {"error": "frame_required"}
|
|
try:
|
|
frame = _decode_frame_payload(frame_raw)
|
|
except Exception:
|
|
return {"error": "invalid_frame"}
|
|
bridge = tunnel_service.get_bridge(tunnel_id)
|
|
if bridge is None:
|
|
return {"error": "unknown_tunnel"}
|
|
bridge.operator_to_agent(frame)
|
|
return {"status": "ok"}
|
|
|
|
@socket_server.on("poll", namespace=tunnel_namespace)
|
|
def _ws_tunnel_poll() -> Any:
|
|
sid = request.sid
|
|
tunnel_id = _operator_sessions.get(sid)
|
|
if not tunnel_id:
|
|
return {"error": "not_joined"}
|
|
bridge = tunnel_service.get_bridge(tunnel_id)
|
|
if bridge is None:
|
|
return {"error": "unknown_tunnel"}
|
|
frames = []
|
|
while True:
|
|
frame = bridge.next_for_operator()
|
|
if frame is None:
|
|
break
|
|
frames.append(_encode_frame(frame))
|
|
return {"frames": frames}
|
|
|
|
def _require_ps_server():
|
|
sid = request.sid
|
|
tunnel_id = _operator_sessions.get(sid)
|
|
if not tunnel_id:
|
|
return None, None, {"error": "not_joined"}
|
|
server = tunnel_service.ensure_ps_server(tunnel_id)
|
|
if server is None:
|
|
return None, tunnel_id, {"error": "ps_unsupported"}
|
|
return server, tunnel_id, None
|
|
|
|
@socket_server.on("ps_open", namespace=tunnel_namespace)
|
|
def _ws_ps_open(data: Any) -> Any:
|
|
server, tunnel_id, error = _require_ps_server()
|
|
if server is None:
|
|
return error
|
|
cols = 120
|
|
rows = 32
|
|
if isinstance(data, dict):
|
|
try:
|
|
cols = int(data.get("cols", cols))
|
|
rows = int(data.get("rows", rows))
|
|
except Exception:
|
|
pass
|
|
cols = max(20, min(cols, 300))
|
|
rows = max(10, min(rows, 200))
|
|
try:
|
|
server.open_channel(cols=cols, rows=rows)
|
|
except Exception as exc:
|
|
logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
|
return {"error": "ps_open_failed"}
|
|
return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows}
|
|
|
|
@socket_server.on("ps_send", namespace=tunnel_namespace)
|
|
def _ws_ps_send(data: Any) -> Any:
|
|
server, tunnel_id, error = _require_ps_server()
|
|
if server is None:
|
|
return error
|
|
if data is None:
|
|
return {"error": "payload_required"}
|
|
text = data
|
|
if isinstance(data, dict):
|
|
text = data.get("data")
|
|
if text is None:
|
|
return {"error": "payload_required"}
|
|
try:
|
|
server.send_input(str(text))
|
|
except Exception as exc:
|
|
logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
|
return {"error": "ps_send_failed"}
|
|
return {"status": "ok"}
|
|
|
|
@socket_server.on("ps_resize", namespace=tunnel_namespace)
|
|
def _ws_ps_resize(data: Any) -> Any:
|
|
server, tunnel_id, error = _require_ps_server()
|
|
if server is None:
|
|
return error
|
|
cols = None
|
|
rows = None
|
|
if isinstance(data, dict):
|
|
cols = data.get("cols")
|
|
rows = data.get("rows")
|
|
try:
|
|
cols_int = int(cols) if cols is not None else 120
|
|
rows_int = int(rows) if rows is not None else 32
|
|
cols_int = max(20, min(cols_int, 300))
|
|
rows_int = max(10, min(rows_int, 200))
|
|
server.send_resize(cols_int, rows_int)
|
|
return {"status": "ok", "cols": cols_int, "rows": rows_int}
|
|
except Exception as exc:
|
|
logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
|
return {"error": "ps_resize_failed"}
|
|
|
|
@socket_server.on("ps_poll", namespace=tunnel_namespace)
|
|
def _ws_ps_poll() -> Any:
|
|
server, tunnel_id, error = _require_ps_server()
|
|
if server is None:
|
|
return error
|
|
try:
|
|
output = server.drain_output()
|
|
status = server.status()
|
|
return {"output": output, "status": status}
|
|
except Exception as exc:
|
|
logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
|
return {"error": "ps_poll_failed"}
|
|
|
|
@socket_server.on("disconnect", namespace=tunnel_namespace)
|
|
def _ws_tunnel_disconnect():
|
|
sid = request.sid
|
|
_operator_sessions.pop(sid, None)
|