diff --git a/Data/Engine/CODE_MIGRATION_TRACKER.md b/Data/Engine/CODE_MIGRATION_TRACKER.md index d20c2a7a..c21d8cba 100644 --- a/Data/Engine/CODE_MIGRATION_TRACKER.md +++ b/Data/Engine/CODE_MIGRATION_TRACKER.md @@ -37,7 +37,7 @@ Lastly, everytime that you complete a stage, you will create a pull request name - [x] Preserve TLS-aware URL generation and caching. - [ ] Add migration switch in the legacy server for WebUI delegation. - [x] Extend tests to cover critical WebUI routes. - - [ ] Port device API endpoints into Engine services (in progress). + - [ ] Port device API endpoints into Engine services (device + admin coverage in progress). - [ ] **Stage 7 — Plan WebSocket migration** - [ ] Extract Socket.IO handlers into Data/Engine/services/WebSocket. - [ ] Provide register_realtime hook for the Engine factory. diff --git a/Data/Engine/Unit_Tests/conftest.py b/Data/Engine/Unit_Tests/conftest.py index 6590d805..a89059f6 100644 --- a/Data/Engine/Unit_Tests/conftest.py +++ b/Data/Engine/Unit_Tests/conftest.py @@ -55,6 +55,7 @@ CREATE TABLE IF NOT EXISTS enrollment_install_codes ( id TEXT PRIMARY KEY, code TEXT UNIQUE, expires_at TEXT, + created_by_user_id TEXT, used_at TEXT, used_by_guid TEXT, max_uses INTEGER, @@ -94,17 +95,28 @@ CREATE TABLE IF NOT EXISTS device_list_views ( CREATE TABLE IF NOT EXISTS sites ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT, - description TEXT + description TEXT, + created_at INTEGER ); CREATE TABLE IF NOT EXISTS device_sites ( - device_hostname TEXT, + device_hostname TEXT PRIMARY KEY, site_id INTEGER, - PRIMARY KEY (device_hostname, site_id) + assigned_at INTEGER ); CREATE TABLE IF NOT EXISTS github_token ( id INTEGER PRIMARY KEY, token TEXT ); +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE, + display_name TEXT, + password_sha512 TEXT, + role TEXT, + last_login INTEGER, + created_at INTEGER, + updated_at INTEGER +); """ @@ -210,12 +222,54 @@ def engine_harness(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Iterator[ ), ) cur.execute( - "INSERT INTO sites (id, name, description) VALUES (?, ?, ?)", - (1, "Main Lab", "Primary integration site"), + "INSERT INTO sites (id, name, description, created_at) VALUES (?, ?, ?, ?)", + (1, "Main Lab", "Primary integration site", 1_700_000_000), ) cur.execute( - "INSERT INTO device_sites (device_hostname, site_id) VALUES (?, ?)", - ("test-device", 1), + "INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)", + ("test-device", 1, 1_700_000_500), + ) + cur.execute( + """ + INSERT INTO users (id, username, display_name, password_sha512, role, last_login, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (1, "admin", "Administrator", "test", "Admin", 0, 0, 0), + ) + cur.execute( + """ + INSERT INTO device_approvals ( + id, + approval_reference, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed, + enrollment_code_id, + status, + client_nonce, + server_nonce, + agent_pubkey_der, + created_at, + updated_at, + approved_by_user_id + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + "approval-1", + "APP-REF-1", + None, + "pending-device", + "aa:bb:cc:dd", + None, + "pending", + "client-nonce", + "server-nonce", + None, + "2025-01-01T00:00:00Z", + "2025-01-01T00:00:00Z", + None, + ), ) conn.commit() finally: diff --git a/Data/Engine/Unit_Tests/test_devices_api.py b/Data/Engine/Unit_Tests/test_devices_api.py index 1f8a7f79..ce5f6e0a 100644 --- a/Data/Engine/Unit_Tests/test_devices_api.py +++ b/Data/Engine/Unit_Tests/test_devices_api.py @@ -120,3 +120,67 @@ def test_agent_hash_list_permissions(engine_harness: EngineTestHarness) -> None: assert allowed.status_code == 200 agents = allowed.get_json()["agents"] assert agents and agents[0]["hostname"] == "test-device" + + +def test_sites_lifecycle(engine_harness: EngineTestHarness) -> None: + client = _client_with_admin_session(engine_harness) + create_resp = client.post( + "/api/sites", + json={"name": "Edge", "description": "Edge location"}, + ) + assert create_resp.status_code == 201 + site_id = create_resp.get_json()["id"] + + list_resp = client.get("/api/sites") + sites = list_resp.get_json()["sites"] + assert any(site["id"] == site_id for site in sites) + + assign_resp = client.post( + "/api/sites/assign", + json={"site_id": site_id, "hostnames": ["test-device"]}, + ) + assert assign_resp.status_code == 200 + + mapping_resp = client.get("/api/sites/device_map") + mapping = mapping_resp.get_json()["mapping"] + assert mapping["test-device"]["site_id"] == site_id + + rename_resp = client.post( + "/api/sites/rename", + json={"id": site_id, "new_name": "Edge-Renamed"}, + ) + assert rename_resp.status_code == 200 + assert rename_resp.get_json()["name"] == "Edge-Renamed" + + delete_resp = client.post("/api/sites/delete", json={"ids": [site_id]}) + assert delete_resp.status_code == 200 + + +def test_admin_enrollment_code_flow(engine_harness: EngineTestHarness) -> None: + client = _client_with_admin_session(engine_harness) + create_resp = client.post( + "/api/admin/enrollment-codes", + json={"ttl_hours": 1, "max_uses": 2}, + ) + assert create_resp.status_code == 201 + code_id = create_resp.get_json()["id"] + + list_resp = client.get("/api/admin/enrollment-codes") + codes = list_resp.get_json()["codes"] + assert any(code["id"] == code_id for code in codes) + + delete_resp = client.delete(f"/api/admin/enrollment-codes/{code_id}") + assert delete_resp.status_code == 200 + + +def test_admin_device_approvals(engine_harness: EngineTestHarness) -> None: + client = _client_with_admin_session(engine_harness) + list_resp = client.get("/api/admin/device-approvals") + approvals = list_resp.get_json()["approvals"] + assert approvals and approvals[0]["status"] == "pending" + + approve_resp = client.post( + "/api/admin/device-approvals/approval-1/approve", + json={"conflict_resolution": "overwrite"}, + ) + assert approve_resp.status_code == 200 diff --git a/Data/Engine/services/API/__init__.py b/Data/Engine/services/API/__init__.py index c7995509..bf520bbd 100644 --- a/Data/Engine/services/API/__init__.py +++ b/Data/Engine/services/API/__init__.py @@ -22,6 +22,7 @@ from Modules.tokens import routes as token_routes from ...server import EngineContext from .access_management.login import register_auth +from .devices.approval import register_admin_endpoints from .devices.management import register_management DEFAULT_API_GROUPS: Sequence[str] = ("auth", "tokens", "enrollment", "devices") @@ -180,11 +181,16 @@ def _register_enrollment(app: Flask, adapters: LegacyServiceAdapters) -> None: ) +def _register_devices(app: Flask, adapters: LegacyServiceAdapters) -> None: + register_management(app, adapters) + register_admin_endpoints(app, adapters) + + _GROUP_REGISTRARS: Mapping[str, Callable[[Flask, LegacyServiceAdapters], None]] = { "auth": register_auth, "tokens": _register_tokens, "enrollment": _register_enrollment, - "devices": register_management, + "devices": _register_devices, } diff --git a/Data/Engine/services/API/devices/approval.py b/Data/Engine/services/API/devices/approval.py index 2102559c..150d7d2d 100644 --- a/Data/Engine/services/API/devices/approval.py +++ b/Data/Engine/services/API/devices/approval.py @@ -1 +1,523 @@ -"Placeholder for API module devices/approval.py." +"""Admin-focused device enrollment and approval endpoints.""" +from __future__ import annotations + +import os +import secrets +import sqlite3 +import uuid +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +from flask import Blueprint, jsonify, request, session +from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer + +from Modules.guid_utils import normalize_guid + +if TYPE_CHECKING: # pragma: no cover - typing helper + from .. import LegacyServiceAdapters + + +VALID_TTL_HOURS = {1, 3, 6, 12, 24} + + +def _now() -> datetime: + return datetime.now(tz=timezone.utc) + + +def _iso(dt: datetime) -> str: + return dt.isoformat() + + +def _generate_install_code() -> str: + raw = secrets.token_hex(16).upper() + return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4)) + + +class AdminDeviceService: + """Utility wrapper for admin device APIs.""" + + def __init__(self, app, adapters: "LegacyServiceAdapters") -> None: + self.app = app + self.adapters = adapters + self.db_conn_factory = adapters.db_conn_factory + self.service_log = adapters.service_log + self.logger = adapters.context.logger + + def _db_conn(self) -> sqlite3.Connection: + return self.db_conn_factory() + + def _token_serializer(self) -> URLSafeTimedSerializer: + secret = self.app.secret_key or "borealis-dev-secret" + return URLSafeTimedSerializer(secret, salt="borealis-auth") + + def _current_user(self) -> Optional[Dict[str, Any]]: + username = session.get("username") + role = session.get("role") or "User" + if username: + return {"username": username, "role": role} + + token = None + auth_header = request.headers.get("Authorization") or "" + if auth_header.lower().startswith("bearer "): + token = auth_header.split(" ", 1)[1].strip() + if not token: + token = request.cookies.get("borealis_auth") + if not token: + return None + try: + data = self._token_serializer().loads( + token, + max_age=int(os.environ.get("BOREALIS_TOKEN_TTL_SECONDS", 60 * 60 * 24 * 30)), + ) + username = data.get("u") + role = data.get("r") or "User" + if username: + return {"username": username, "role": role} + except (BadSignature, SignatureExpired, Exception): + return None + return None + + def require_admin(self) -> Optional[Tuple[Dict[str, Any], int]]: + user = self._current_user() + if not user: + return {"error": "unauthorized"}, 401 + if (user.get("role") or "").lower() != "admin": + return {"error": "forbidden"}, 403 + return None + + def _lookup_user_id(self, cur: sqlite3.Cursor, username: str) -> Optional[str]: + if not username: + return None + cur.execute( + "SELECT id FROM users WHERE LOWER(username) = LOWER(?)", + (username,), + ) + row = cur.fetchone() + if row: + return str(row[0]) + return None + + def _hostname_conflict( + self, + cur: sqlite3.Cursor, + hostname: Optional[str], + pending_guid: Optional[str], + ) -> Optional[Dict[str, Any]]: + if not hostname: + return None + cur.execute( + """ + SELECT d.guid, d.ssl_key_fingerprint, ds.site_id, s.name + FROM devices d + LEFT JOIN device_sites ds ON ds.device_hostname = d.hostname + LEFT JOIN sites s ON s.id = ds.site_id + WHERE d.hostname = ? + """, + (hostname,), + ) + row = cur.fetchone() + if not row: + return None + existing_guid = normalize_guid(row[0]) + existing_fingerprint = (row[1] or "").strip().lower() + pending_norm = normalize_guid(pending_guid) + if existing_guid and pending_norm and existing_guid == pending_norm: + return None + site_id_raw = row[2] + try: + site_id = int(site_id_raw) if site_id_raw is not None else None + except Exception: + site_id = None + site_name = row[3] or "" + return { + "guid": existing_guid or None, + "ssl_key_fingerprint": existing_fingerprint or None, + "site_id": site_id, + "site_name": site_name, + } + + def _suggest_alternate_hostname( + self, + cur: sqlite3.Cursor, + hostname: Optional[str], + pending_guid: Optional[str], + ) -> Optional[str]: + base = (hostname or "").strip() + if not base: + return None + base = base[:253] + candidate = base + pending_norm = normalize_guid(pending_guid) + suffix = 1 + while True: + cur.execute( + "SELECT guid FROM devices WHERE hostname = ?", + (candidate,), + ) + row = cur.fetchone() + if not row: + return candidate + existing_guid = normalize_guid(row[0]) + if pending_norm and existing_guid == pending_norm: + return candidate + candidate = f"{base}-{suffix}" + suffix += 1 + if suffix > 50: + return pending_norm or candidate + + # ------------------------------------------------------------------ # + # Enrollment code management + # ------------------------------------------------------------------ # + + def list_enrollment_codes(self, status_filter: Optional[str]) -> Tuple[Dict[str, Any], int]: + conn = self._db_conn() + try: + cur = conn.cursor() + sql = """ + SELECT id, + code, + expires_at, + created_by_user_id, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at + FROM enrollment_install_codes + """ + params: List[str] = [] + now_iso = _iso(_now()) + if status_filter == "active": + sql += " WHERE use_count < max_uses AND expires_at > ?" + params.append(now_iso) + elif status_filter == "expired": + sql += " WHERE use_count < max_uses AND expires_at <= ?" + params.append(now_iso) + elif status_filter == "used": + sql += " WHERE use_count >= max_uses" + sql += " ORDER BY expires_at ASC" + cur.execute(sql, params) + rows = cur.fetchall() + finally: + conn.close() + + records = [] + for row in rows: + records.append( + { + "id": row[0], + "code": row[1], + "expires_at": row[2], + "created_by_user_id": row[3], + "used_at": row[4], + "used_by_guid": row[5], + "max_uses": row[6], + "use_count": row[7], + "last_used_at": row[8], + } + ) + return {"codes": records}, 200 + + def create_enrollment_code(self, ttl_hours: int, max_uses: int) -> Tuple[Dict[str, Any], int]: + if ttl_hours not in VALID_TTL_HOURS: + return {"error": "invalid_ttl"}, 400 + max_uses = max(1, min(int(max_uses or 1), 10)) + + user = self._current_user() or {} + username = user.get("username") or "" + + conn = self._db_conn() + try: + cur = conn.cursor() + created_by = self._lookup_user_id(cur, username) or username or "system" + code_value = _generate_install_code() + expires_at = _now() + timedelta(hours=ttl_hours) + record_id = str(uuid.uuid4()) + cur.execute( + """ + INSERT INTO enrollment_install_codes ( + id, code, expires_at, created_by_user_id, max_uses, use_count + ) + VALUES (?, ?, ?, ?, ?, 0) + """, + (record_id, code_value, _iso(expires_at), created_by, max_uses), + ) + conn.commit() + finally: + conn.close() + + self.service_log( + "server", + f"installer code created id={record_id} by={username} ttl={ttl_hours}h max_uses={max_uses}", + ) + return ( + { + "id": record_id, + "code": code_value, + "expires_at": _iso(expires_at), + "max_uses": max_uses, + "use_count": 0, + "last_used_at": None, + }, + 201, + ) + + def delete_enrollment_code(self, code_id: str) -> Tuple[Dict[str, Any], int]: + conn = self._db_conn() + try: + cur = conn.cursor() + cur.execute( + "DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0", + (code_id,), + ) + deleted = cur.rowcount + conn.commit() + finally: + conn.close() + + if not deleted: + return {"error": "not_found"}, 404 + self.service_log("server", f"installer code deleted id={code_id}") + return {"status": "deleted"}, 200 + + # ------------------------------------------------------------------ # + # Device approval helpers + # ------------------------------------------------------------------ # + + def list_device_approvals(self, status_filter: Optional[str]) -> Tuple[Dict[str, Any], int]: + approvals: List[Dict[str, Any]] = [] + conn = self._db_conn() + try: + cur = conn.cursor() + params: List[str] = [] + sql = """ + SELECT + da.id, + da.approval_reference, + da.guid, + da.hostname_claimed, + da.ssl_key_fingerprint_claimed, + da.enrollment_code_id, + da.status, + da.client_nonce, + da.server_nonce, + da.created_at, + da.updated_at, + da.approved_by_user_id, + u.username AS approved_by_username + FROM device_approvals AS da + LEFT JOIN users AS u + ON ( + CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT) + OR LOWER(da.approved_by_user_id) = LOWER(u.username) + ) + """ + status_norm = (status_filter or "").strip().lower() + if status_norm and status_norm != "all": + sql += " WHERE LOWER(da.status) = ?" + params.append(status_norm) + sql += " ORDER BY da.created_at ASC" + cur.execute(sql, params) + rows = cur.fetchall() + for row in rows: + record_guid = row[2] + hostname = row[3] + fingerprint_claimed = row[4] + claimed_fp_norm = (fingerprint_claimed or "").strip().lower() + conflict_raw = self._hostname_conflict(cur, hostname, record_guid) + fingerprint_match = False + requires_prompt = False + conflict = None + if conflict_raw: + conflict_fp = (conflict_raw.get("ssl_key_fingerprint") or "").strip().lower() + fingerprint_match = bool(conflict_fp and claimed_fp_norm) and conflict_fp == claimed_fp_norm + requires_prompt = not fingerprint_match + conflict = { + **conflict_raw, + "fingerprint_match": fingerprint_match, + "requires_prompt": requires_prompt, + } + alternate = ( + self._suggest_alternate_hostname(cur, hostname, record_guid) + if conflict_raw and requires_prompt + else None + ) + approvals.append( + { + "id": row[0], + "approval_reference": row[1], + "guid": record_guid, + "hostname_claimed": hostname, + "ssl_key_fingerprint_claimed": fingerprint_claimed, + "enrollment_code_id": row[5], + "status": row[6], + "client_nonce": row[7], + "server_nonce": row[8], + "created_at": row[9], + "updated_at": row[10], + "approved_by_user_id": row[11], + "hostname_conflict": conflict, + "alternate_hostname": alternate, + "conflict_requires_prompt": requires_prompt, + "fingerprint_match": fingerprint_match, + "approved_by_username": row[12], + } + ) + finally: + conn.close() + return {"approvals": approvals}, 200 + + def _set_approval_status( + self, + approval_id: str, + status: str, + *, + guid: Optional[str] = None, + resolution: Optional[str] = None, + ) -> Tuple[Dict[str, Any], int]: + user = self._current_user() or {} + username = user.get("username") or "" + + conn = self._db_conn() + try: + cur = conn.cursor() + cur.execute( + """ + SELECT status, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed + FROM device_approvals + WHERE id = ? + """, + (approval_id,), + ) + row = cur.fetchone() + if not row: + return {"error": "not_found"}, 404 + existing_status = (row[0] or "").strip().lower() + if existing_status != "pending": + return {"error": "approval_not_pending"}, 409 + stored_guid = row[1] + hostname_claimed = row[2] + fingerprint_claimed = (row[3] or "").strip().lower() + + guid_effective = normalize_guid(guid) if guid else normalize_guid(stored_guid) + resolution_effective = (resolution.strip().lower() if isinstance(resolution, str) else None) + + if status == "approved": + conflict = self._hostname_conflict(cur, hostname_claimed, guid_effective) + if conflict: + conflict_fp = (conflict.get("ssl_key_fingerprint") or "").strip().lower() + fingerprint_match = bool(conflict_fp and fingerprint_claimed) and conflict_fp == fingerprint_claimed + if fingerprint_match: + guid_effective = conflict.get("guid") or guid_effective + if not resolution_effective: + resolution_effective = "auto_merge_fingerprint" + elif resolution_effective == "overwrite": + guid_effective = conflict.get("guid") or guid_effective + elif resolution_effective == "coexist": + pass + else: + return { + "error": "conflict_resolution_required", + "hostname": hostname_claimed, + }, 409 + + guid_to_store = guid_effective or normalize_guid(stored_guid) or None + approved_by = self._lookup_user_id(cur, username) or username or "system" + cur.execute( + """ + UPDATE device_approvals + SET status = ?, + guid = ?, + approved_by_user_id = ?, + updated_at = ? + WHERE id = ? + """, + ( + status, + guid_to_store, + approved_by, + _iso(_now()), + approval_id, + ), + ) + conn.commit() + finally: + conn.close() + + resolution_note = f" ({resolution_effective})" if resolution_effective else "" + self.service_log("server", f"device approval {approval_id} -> {status}{resolution_note} by {username}") + payload: Dict[str, Any] = {"status": status} + if resolution_effective: + payload["conflict_resolution"] = resolution_effective + return payload, 200 + + def approve_device(self, approval_id: str, payload: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: + guid = (payload.get("guid") or "").strip() or None + resolution_raw = payload.get("conflict_resolution") + resolution = resolution_raw.strip() if isinstance(resolution_raw, str) else None + return self._set_approval_status(approval_id, "approved", guid=guid, resolution=resolution) + + def deny_device(self, approval_id: str) -> Tuple[Dict[str, Any], int]: + return self._set_approval_status(approval_id, "denied") + + +def register_admin_endpoints(app, adapters: "LegacyServiceAdapters") -> None: + """Register admin enrollment + approval endpoints.""" + + service = AdminDeviceService(app, adapters) + blueprint = Blueprint("device_admin", __name__) + + @blueprint.before_request + def _ensure_admin(): + requirement = service.require_admin() + if requirement: + payload, status = requirement + return jsonify(payload), status + return None + + @blueprint.route("/api/admin/enrollment-codes", methods=["GET"]) + def _admin_enrollment_codes(): + payload, status = service.list_enrollment_codes(request.args.get("status")) + return jsonify(payload), status + + @blueprint.route("/api/admin/enrollment-codes", methods=["POST"]) + def _admin_create_enrollment_code(): + data = request.get_json(force=True, silent=True) or {} + ttl_hours = int(data.get("ttl_hours") or 1) + max_uses_value = data.get("max_uses") + if max_uses_value is None: + max_uses_value = data.get("allowed_uses") + try: + max_uses = int(max_uses_value) if max_uses_value is not None else 2 + except Exception: + max_uses = 2 + payload, status = service.create_enrollment_code(ttl_hours, max_uses) + return jsonify(payload), status + + @blueprint.route("/api/admin/enrollment-codes/", methods=["DELETE"]) + def _admin_delete_enrollment_code(code_id: str): + payload, status = service.delete_enrollment_code(code_id) + return jsonify(payload), status + + @blueprint.route("/api/admin/device-approvals", methods=["GET"]) + def _admin_list_device_approvals(): + payload, status = service.list_device_approvals(request.args.get("status")) + return jsonify(payload), status + + @blueprint.route("/api/admin/device-approvals//approve", methods=["POST"]) + def _admin_approve_device(approval_id: str): + data = request.get_json(force=True, silent=True) or {} + payload, status = service.approve_device(approval_id, data) + return jsonify(payload), status + + @blueprint.route("/api/admin/device-approvals//deny", methods=["POST"]) + def _admin_deny_device(approval_id: str): + payload, status = service.deny_device(approval_id) + return jsonify(payload), status + + app.register_blueprint(blueprint) + adapters.context.logger.info("Engine registered API group 'devices.admin'.") + diff --git a/Data/Engine/services/API/devices/management.py b/Data/Engine/services/API/devices/management.py index 9434f1d7..38b14fd3 100644 --- a/Data/Engine/services/API/devices/management.py +++ b/Data/Engine/services/API/devices/management.py @@ -24,7 +24,7 @@ except ImportError: # pragma: no cover - fallback for minimal test environments """Stand-in exception when the requests module is unavailable.""" def get(self, *args: Any, **kwargs: Any) -> Any: - raise RuntimeError("The 'requests' library is required for repository hash lookups.") + raise self.RequestException("The 'requests' library is required for repository hash lookups.") requests = _RequestsStub() # type: ignore @@ -83,6 +83,16 @@ def _is_internal_request(remote_addr: Optional[str]) -> bool: return False +def _row_to_site(row: Tuple[Any, ...]) -> Dict[str, Any]: + return { + "id": row[0], + "name": row[1], + "description": row[2] or "", + "created_at": row[3] or 0, + "device_count": row[4] or 0, + } + + class RepositoryHashCache: """Lightweight GitHub head cache with on-disk persistence.""" @@ -338,6 +348,14 @@ class DeviceManagementService: return {"error": "unauthorized"}, 401 return None + def _require_admin(self) -> Optional[Tuple[Dict[str, Any], int]]: + user = self._current_user() + if not user: + return {"error": "unauthorized"}, 401 + if (user.get("role") or "").lower() != "admin": + return {"error": "forbidden"}, 403 + return None + def _build_device_payload( self, row: Tuple[Any, ...], @@ -737,6 +755,224 @@ class DeviceManagementService: finally: conn.close() + # ------------------------------------------------------------------ + # Site management helpers + # ------------------------------------------------------------------ + + def list_sites(self) -> Tuple[Dict[str, Any], int]: + conn = self._db_conn() + try: + cur = conn.cursor() + cur.execute( + """ + SELECT s.id, + s.name, + s.description, + s.created_at, + COALESCE(ds.cnt, 0) AS device_count + FROM sites AS s + LEFT JOIN ( + SELECT site_id, COUNT(*) AS cnt + FROM device_sites + GROUP BY site_id + ) AS ds ON ds.site_id = s.id + ORDER BY LOWER(s.name) ASC + """ + ) + rows = cur.fetchall() + sites = [_row_to_site(row) for row in rows] + return {"sites": sites}, 200 + except Exception as exc: + self.logger.debug("Failed to list sites", exc_info=True) + return {"error": str(exc)}, 500 + finally: + conn.close() + + def create_site(self, name: str, description: str) -> Tuple[Dict[str, Any], int]: + if not name: + return {"error": "name is required"}, 400 + now = int(time.time()) + conn = self._db_conn() + try: + cur = conn.cursor() + cur.execute( + "INSERT INTO sites(name, description, created_at) VALUES (?, ?, ?)", + (name, description, now), + ) + site_id = cur.lastrowid + conn.commit() + cur.execute( + "SELECT id, name, description, created_at, 0 FROM sites WHERE id = ?", + (site_id,), + ) + row = cur.fetchone() + if not row: + return {"error": "creation_failed"}, 500 + return _row_to_site(row), 201 + except sqlite3.IntegrityError: + conn.rollback() + return {"error": "name already exists"}, 409 + except Exception as exc: + conn.rollback() + self.logger.debug("Failed to create site", exc_info=True) + return {"error": str(exc)}, 500 + finally: + conn.close() + + def delete_sites(self, ids: List[Any]) -> Tuple[Dict[str, Any], int]: + if not isinstance(ids, list) or not all(isinstance(x, (int, str)) for x in ids): + return {"error": "ids must be a list"}, 400 + norm_ids: List[int] = [] + for value in ids: + try: + norm_ids.append(int(value)) + except Exception: + continue + if not norm_ids: + return {"status": "ok", "deleted": 0}, 200 + conn = self._db_conn() + try: + cur = conn.cursor() + placeholders = ",".join("?" * len(norm_ids)) + cur.execute( + f"DELETE FROM device_sites WHERE site_id IN ({placeholders})", + tuple(norm_ids), + ) + cur.execute( + f"DELETE FROM sites WHERE id IN ({placeholders})", + tuple(norm_ids), + ) + deleted = cur.rowcount + conn.commit() + return {"status": "ok", "deleted": deleted}, 200 + except Exception as exc: + conn.rollback() + self.logger.debug("Failed to delete sites", exc_info=True) + return {"error": str(exc)}, 500 + finally: + conn.close() + + def sites_device_map(self, hostnames: Optional[str]) -> Tuple[Dict[str, Any], int]: + filter_set: set[str] = set() + if hostnames: + for part in hostnames.split(","): + candidate = part.strip() + if candidate: + filter_set.add(candidate) + conn = self._db_conn() + try: + cur = conn.cursor() + if filter_set: + placeholders = ",".join("?" * len(filter_set)) + cur.execute( + f""" + SELECT ds.device_hostname, s.id, s.name + FROM device_sites ds + JOIN sites s ON s.id = ds.site_id + WHERE ds.device_hostname IN ({placeholders}) + """, + tuple(filter_set), + ) + else: + cur.execute( + """ + SELECT ds.device_hostname, s.id, s.name + FROM device_sites ds + JOIN sites s ON s.id = ds.site_id + """ + ) + mapping: Dict[str, Dict[str, Any]] = {} + for hostname, site_id, site_name in cur.fetchall(): + mapping[str(hostname)] = {"site_id": site_id, "site_name": site_name} + return {"mapping": mapping}, 200 + except Exception as exc: + self.logger.debug("Failed to build site device map", exc_info=True) + return {"error": str(exc)}, 500 + finally: + conn.close() + + def assign_devices(self, site_id: Any, hostnames: List[str]) -> Tuple[Dict[str, Any], int]: + try: + site_id_int = int(site_id) + except Exception: + return {"error": "invalid site_id"}, 400 + if not isinstance(hostnames, list) or not all(isinstance(h, str) and h.strip() for h in hostnames): + return {"error": "hostnames must be a list of strings"}, 400 + now = int(time.time()) + conn = self._db_conn() + try: + cur = conn.cursor() + cur.execute("SELECT 1 FROM sites WHERE id = ?", (site_id_int,)) + if not cur.fetchone(): + return {"error": "site not found"}, 404 + for hostname in hostnames: + hn = hostname.strip() + if not hn: + continue + cur.execute( + """ + INSERT INTO device_sites(device_hostname, site_id, assigned_at) + VALUES (?, ?, ?) + ON CONFLICT(device_hostname) + DO UPDATE SET site_id=excluded.site_id, assigned_at=excluded.assigned_at + """, + (hn, site_id_int, now), + ) + conn.commit() + return {"status": "ok"}, 200 + except Exception as exc: + conn.rollback() + self.logger.debug("Failed to assign devices to site", exc_info=True) + return {"error": str(exc)}, 500 + finally: + conn.close() + + def rename_site(self, site_id: Any, new_name: str) -> Tuple[Dict[str, Any], int]: + try: + site_id_int = int(site_id) + except Exception: + return {"error": "invalid id"}, 400 + if not new_name: + return {"error": "new_name is required"}, 400 + conn = self._db_conn() + try: + cur = conn.cursor() + cur.execute("UPDATE sites SET name = ? WHERE id = ?", (new_name, site_id_int)) + if cur.rowcount == 0: + conn.rollback() + return {"error": "site not found"}, 404 + conn.commit() + cur.execute( + """ + SELECT s.id, + s.name, + s.description, + s.created_at, + COALESCE(ds.cnt, 0) AS device_count + FROM sites AS s + LEFT JOIN ( + SELECT site_id, COUNT(*) AS cnt + FROM device_sites + GROUP BY site_id + ) ds ON ds.site_id = s.id + WHERE s.id = ? + """, + (site_id_int,), + ) + row = cur.fetchone() + if not row: + return {"error": "site not found"}, 404 + return _row_to_site(row), 200 + except sqlite3.IntegrityError: + conn.rollback() + return {"error": "name already exists"}, 409 + except Exception as exc: + conn.rollback() + self.logger.debug("Failed to rename site", exc_info=True) + return {"error": str(exc)}, 500 + finally: + conn.close() + def repo_current_hash(self) -> Tuple[Dict[str, Any], int]: repo = (request.args.get("repo") or "bunny-lab-io/Borealis").strip() branch = (request.args.get("branch") or "main").strip() @@ -882,6 +1118,59 @@ def register_management(app, adapters: "LegacyServiceAdapters") -> None: payload, status = service.delete_view(view_id) return jsonify(payload), status + @blueprint.route("/api/sites", methods=["GET"]) + def _sites_list(): + payload, status = service.list_sites() + return jsonify(payload), status + + @blueprint.route("/api/sites", methods=["POST"]) + def _sites_create(): + requirement = service._require_admin() + if requirement: + payload, status = requirement + return jsonify(payload), status + data = request.get_json(silent=True) or {} + name = (data.get("name") or "").strip() + description = (data.get("description") or "").strip() + payload, status = service.create_site(name, description) + return jsonify(payload), status + + @blueprint.route("/api/sites/delete", methods=["POST"]) + def _sites_delete(): + requirement = service._require_admin() + if requirement: + payload, status = requirement + return jsonify(payload), status + data = request.get_json(silent=True) or {} + ids = data.get("ids") or [] + payload, status = service.delete_sites(ids) + return jsonify(payload), status + + @blueprint.route("/api/sites/device_map", methods=["GET"]) + def _sites_device_map(): + payload, status = service.sites_device_map(request.args.get("hostnames")) + return jsonify(payload), status + + @blueprint.route("/api/sites/assign", methods=["POST"]) + def _sites_assign(): + requirement = service._require_admin() + if requirement: + payload, status = requirement + return jsonify(payload), status + data = request.get_json(silent=True) or {} + payload, status = service.assign_devices(data.get("site_id"), data.get("hostnames") or []) + return jsonify(payload), status + + @blueprint.route("/api/sites/rename", methods=["POST"]) + def _sites_rename(): + requirement = service._require_admin() + if requirement: + payload, status = requirement + return jsonify(payload), status + data = request.get_json(silent=True) or {} + payload, status = service.rename_site(data.get("id"), (data.get("new_name") or "").strip()) + return jsonify(payload), status + @blueprint.route("/api/repo/current_hash", methods=["GET"]) def _repo_current_hash(): payload, status = service.repo_current_hash()