From 102e77f676d33cc386481a1591efddaefc7d0918 Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Wed, 22 Oct 2025 19:07:56 -0600 Subject: [PATCH 01/12] Fix Engine static root fallback for legacy WebUI --- Data/Engine/config/environment.py | 5 +++++ Data/Engine/tests/test_config_environment.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/Data/Engine/config/environment.py b/Data/Engine/config/environment.py index 58ccf66..460b4ea 100644 --- a/Data/Engine/config/environment.py +++ b/Data/Engine/config/environment.py @@ -114,10 +114,15 @@ def _resolve_static_root(project_root: Path) -> Path: candidates = ( project_root / "Engine" / "web-interface" / "build", project_root / "Engine" / "web-interface" / "dist", + project_root / "Engine" / "web-interface", project_root / "Data" / "Engine" / "WebUI" / "build", + project_root / "Data" / "Engine" / "WebUI", project_root / "Data" / "Server" / "web-interface" / "build", + project_root / "Data" / "Server" / "web-interface", project_root / "Data" / "Server" / "WebUI" / "build", + project_root / "Data" / "Server" / "WebUI", project_root / "Data" / "WebUI" / "build", + project_root / "Data" / "WebUI", ) for path in candidates: resolved = path.resolve() diff --git a/Data/Engine/tests/test_config_environment.py b/Data/Engine/tests/test_config_environment.py index b5b46a1..c89ef01 100644 --- a/Data/Engine/tests/test_config_environment.py +++ b/Data/Engine/tests/test_config_environment.py @@ -42,3 +42,20 @@ def test_static_root_env_override(tmp_path, monkeypatch): monkeypatch.delenv("BOREALIS_STATIC_ROOT", raising=False) monkeypatch.delenv("BOREALIS_ROOT", raising=False) + + +def test_static_root_falls_back_to_legacy_source(tmp_path, monkeypatch): + """Legacy WebUI source should be served when no build assets exist.""" + + legacy_source = tmp_path / "Data" / "Server" / "WebUI" + legacy_source.mkdir(parents=True) + (legacy_source / "index.html").write_text("", encoding="utf-8") + + monkeypatch.setenv("BOREALIS_ROOT", str(tmp_path)) + monkeypatch.delenv("BOREALIS_STATIC_ROOT", raising=False) + + settings = load_environment() + + assert settings.flask.static_root == legacy_source.resolve() + + monkeypatch.delenv("BOREALIS_ROOT", raising=False) From f361c51a5e8a9e303d9f1707cfcc3dfc86ec36c3 Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Wed, 22 Oct 2025 19:23:38 -0600 Subject: [PATCH 02/12] Implement operator login service and fix static root --- Data/Engine/builders/__init__.py | 10 + Data/Engine/builders/operator_auth.py | 72 ++++++ Data/Engine/config/environment.py | 7 +- Data/Engine/domain/__init__.py | 8 + Data/Engine/domain/operator.py | 51 +++++ Data/Engine/interfaces/http/__init__.py | 3 +- Data/Engine/interfaces/http/auth.py | 165 ++++++++++++++ Data/Engine/repositories/sqlite/__init__.py | 2 + .../repositories/sqlite/user_repository.py | 123 +++++++++++ Data/Engine/requirements.txt | 2 + Data/Engine/services/auth/__init__.py | 14 ++ .../services/auth/operator_auth_service.py | 209 ++++++++++++++++++ Data/Engine/services/container.py | 10 + Data/Engine/tests/test_config_environment.py | 13 ++ .../tests/test_operator_auth_builders.py | 63 ++++++ .../tests/test_operator_auth_service.py | 197 +++++++++++++++++ 16 files changed, 947 insertions(+), 2 deletions(-) create mode 100644 Data/Engine/builders/operator_auth.py create mode 100644 Data/Engine/domain/operator.py create mode 100644 Data/Engine/interfaces/http/auth.py create mode 100644 Data/Engine/repositories/sqlite/user_repository.py create mode 100644 Data/Engine/services/auth/operator_auth_service.py create mode 100644 Data/Engine/tests/test_operator_auth_builders.py create mode 100644 Data/Engine/tests/test_operator_auth_service.py diff --git a/Data/Engine/builders/__init__.py b/Data/Engine/builders/__init__.py index 0f9b02a..6dd2fc9 100644 --- a/Data/Engine/builders/__init__.py +++ b/Data/Engine/builders/__init__.py @@ -8,12 +8,22 @@ from .device_auth import ( RefreshTokenRequest, RefreshTokenRequestBuilder, ) +from .operator_auth import ( + OperatorLoginRequest, + OperatorMFAVerificationRequest, + build_login_request, + build_mfa_request, +) __all__ = [ "DeviceAuthRequest", "DeviceAuthRequestBuilder", "RefreshTokenRequest", "RefreshTokenRequestBuilder", + "OperatorLoginRequest", + "OperatorMFAVerificationRequest", + "build_login_request", + "build_mfa_request", ] try: # pragma: no cover - optional dependency shim diff --git a/Data/Engine/builders/operator_auth.py b/Data/Engine/builders/operator_auth.py new file mode 100644 index 0000000..9153897 --- /dev/null +++ b/Data/Engine/builders/operator_auth.py @@ -0,0 +1,72 @@ +"""Builders for operator authentication payloads.""" + +from __future__ import annotations + +import hashlib +from dataclasses import dataclass +from typing import Mapping + + +@dataclass(frozen=True, slots=True) +class OperatorLoginRequest: + """Normalized operator login credentials.""" + + username: str + password_sha512: str + + +@dataclass(frozen=True, slots=True) +class OperatorMFAVerificationRequest: + """Normalized MFA verification payload.""" + + pending_token: str + code: str + + +def _sha512_hex(raw: str) -> str: + digest = hashlib.sha512() + digest.update(raw.encode("utf-8")) + return digest.hexdigest() + + +def build_login_request(payload: Mapping[str, object]) -> OperatorLoginRequest: + """Validate and normalize the login *payload*.""" + + username = str(payload.get("username") or "").strip() + password_sha512 = str(payload.get("password_sha512") or "").strip().lower() + password = payload.get("password") + + if not username: + raise ValueError("username is required") + + if password_sha512: + normalized_hash = password_sha512 + else: + if not isinstance(password, str) or not password: + raise ValueError("password is required") + normalized_hash = _sha512_hex(password) + + return OperatorLoginRequest(username=username, password_sha512=normalized_hash) + + +def build_mfa_request(payload: Mapping[str, object]) -> OperatorMFAVerificationRequest: + """Validate and normalize the MFA verification *payload*.""" + + pending_token = str(payload.get("pending_token") or "").strip() + raw_code = str(payload.get("code") or "").strip() + digits = "".join(ch for ch in raw_code if ch.isdigit()) + + if not pending_token: + raise ValueError("pending_token is required") + if len(digits) < 6: + raise ValueError("code must contain 6 digits") + + return OperatorMFAVerificationRequest(pending_token=pending_token, code=digits) + + +__all__ = [ + "OperatorLoginRequest", + "OperatorMFAVerificationRequest", + "build_login_request", + "build_mfa_request", +] diff --git a/Data/Engine/config/environment.py b/Data/Engine/config/environment.py index 460b4ea..14cde00 100644 --- a/Data/Engine/config/environment.py +++ b/Data/Engine/config/environment.py @@ -91,7 +91,12 @@ def _resolve_project_root() -> Path: candidate = os.getenv("BOREALIS_ROOT") if candidate: return Path(candidate).expanduser().resolve() - return Path(__file__).resolve().parents[2] + # ``environment.py`` lives under ``Data/Engine/config``. The project + # root is three levels above this module (the repository checkout). The + # previous implementation only walked up two levels which incorrectly + # treated ``Data/`` as the root, breaking all filesystem discovery logic + # that expects peers such as ``Data/Server`` to be available. + return Path(__file__).resolve().parents[3] def _resolve_database_path(project_root: Path) -> Path: diff --git a/Data/Engine/domain/__init__.py b/Data/Engine/domain/__init__.py index 077ce2f..8f36e8e 100644 --- a/Data/Engine/domain/__init__.py +++ b/Data/Engine/domain/__init__.py @@ -26,6 +26,11 @@ from .github import ( # noqa: F401 GitHubTokenStatus, RepoHeadSnapshot, ) +from .operator import ( # noqa: F401 + OperatorAccount, + OperatorLoginSuccess, + OperatorMFAChallenge, +) __all__ = [ "AccessTokenClaims", @@ -45,5 +50,8 @@ __all__ = [ "GitHubRepoRef", "GitHubTokenStatus", "RepoHeadSnapshot", + "OperatorAccount", + "OperatorLoginSuccess", + "OperatorMFAChallenge", "sanitize_service_context", ] diff --git a/Data/Engine/domain/operator.py b/Data/Engine/domain/operator.py new file mode 100644 index 0000000..6e0211c --- /dev/null +++ b/Data/Engine/domain/operator.py @@ -0,0 +1,51 @@ +"""Domain models for operator authentication.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal, Optional + + +@dataclass(frozen=True, slots=True) +class OperatorAccount: + """Snapshot of an operator account stored in SQLite.""" + + username: str + display_name: str + password_sha512: str + role: str + last_login: int + created_at: int + updated_at: int + mfa_enabled: bool + mfa_secret: Optional[str] + + +@dataclass(frozen=True, slots=True) +class OperatorLoginSuccess: + """Successful login payload for the caller.""" + + username: str + role: str + token: str + + +@dataclass(frozen=True, slots=True) +class OperatorMFAChallenge: + """Details describing an in-progress MFA challenge.""" + + username: str + role: str + stage: Literal["setup", "verify"] + pending_token: str + expires_at: int + secret: Optional[str] = None + otpauth_url: Optional[str] = None + qr_image: Optional[str] = None + + +__all__ = [ + "OperatorAccount", + "OperatorLoginSuccess", + "OperatorMFAChallenge", +] diff --git a/Data/Engine/interfaces/http/__init__.py b/Data/Engine/interfaces/http/__init__.py index ce80a82..e388b81 100644 --- a/Data/Engine/interfaces/http/__init__.py +++ b/Data/Engine/interfaces/http/__init__.py @@ -6,7 +6,7 @@ from flask import Flask from Data.Engine.services.container import EngineServiceContainer -from . import admin, agents, enrollment, github, health, job_management, tokens +from . import admin, agents, auth, enrollment, github, health, job_management, tokens _REGISTRARS = ( health.register, @@ -15,6 +15,7 @@ _REGISTRARS = ( tokens.register, job_management.register, github.register, + auth.register, admin.register, ) diff --git a/Data/Engine/interfaces/http/auth.py b/Data/Engine/interfaces/http/auth.py new file mode 100644 index 0000000..d91d9c7 --- /dev/null +++ b/Data/Engine/interfaces/http/auth.py @@ -0,0 +1,165 @@ +"""Operator authentication HTTP endpoints.""" + +from __future__ import annotations + +from typing import Any, Dict + +from flask import Blueprint, Flask, current_app, jsonify, request, session + +from Data.Engine.builders import build_login_request, build_mfa_request +from Data.Engine.domain import OperatorLoginSuccess, OperatorMFAChallenge +from Data.Engine.services.auth import ( + InvalidCredentialsError, + InvalidMFACodeError, + MFAUnavailableError, + MFASessionError, + OperatorAuthService, +) +from Data.Engine.services.container import EngineServiceContainer + + +def _service(container: EngineServiceContainer) -> OperatorAuthService: + return container.operator_auth_service + + +def register(app: Flask, services: EngineServiceContainer) -> None: + bp = Blueprint("auth", __name__) + + @bp.route("/api/auth/login", methods=["POST"]) + def login() -> Any: + payload = request.get_json(silent=True) or {} + try: + login_request = build_login_request(payload) + except ValueError as exc: + return jsonify({"error": str(exc)}), 400 + + service = _service(services) + + try: + result = service.authenticate(login_request) + except InvalidCredentialsError: + return jsonify({"error": "invalid username or password"}), 401 + except MFAUnavailableError as exc: + current_app.logger.error("mfa unavailable: %s", exc) + return jsonify({"error": str(exc)}), 500 + + session.pop("username", None) + session.pop("role", None) + + if isinstance(result, OperatorLoginSuccess): + session.pop("mfa_pending", None) + session["username"] = result.username + session["role"] = result.role or "User" + response = jsonify( + {"status": "ok", "username": result.username, "role": result.role, "token": result.token} + ) + _set_auth_cookie(response, result.token) + return response + + challenge = result + session["mfa_pending"] = { + "username": challenge.username, + "role": challenge.role, + "stage": challenge.stage, + "token": challenge.pending_token, + "expires": challenge.expires_at, + "secret": challenge.secret, + } + session.modified = True + + payload: Dict[str, Any] = { + "status": "mfa_required", + "stage": challenge.stage, + "pending_token": challenge.pending_token, + "username": challenge.username, + "role": challenge.role, + } + if challenge.stage == "setup": + if challenge.secret: + payload["secret"] = challenge.secret + if challenge.otpauth_url: + payload["otpauth_url"] = challenge.otpauth_url + if challenge.qr_image: + payload["qr_image"] = challenge.qr_image + return jsonify(payload) + + @bp.route("/api/auth/logout", methods=["POST"]) + def logout() -> Any: + session.clear() + response = jsonify({"status": "ok"}) + _set_auth_cookie(response, "", expires=0) + return response + + @bp.route("/api/auth/mfa/verify", methods=["POST"]) + def verify_mfa() -> Any: + pending = session.get("mfa_pending") + if not isinstance(pending, dict): + return jsonify({"error": "mfa_pending"}), 401 + + try: + request_payload = build_mfa_request(request.get_json(silent=True) or {}) + except ValueError as exc: + return jsonify({"error": str(exc)}), 400 + + challenge = OperatorMFAChallenge( + username=str(pending.get("username") or ""), + role=str(pending.get("role") or "User"), + stage=str(pending.get("stage") or "verify"), + pending_token=str(pending.get("token") or ""), + expires_at=int(pending.get("expires") or 0), + secret=str(pending.get("secret") or "") or None, + ) + + service = _service(services) + + try: + result = service.verify_mfa(challenge, request_payload) + except MFASessionError as exc: + error_key = str(exc) + status = 401 if error_key != "mfa_not_configured" else 403 + if error_key not in {"expired", "invalid_session", "mfa_not_configured"}: + error_key = "invalid_session" + session.pop("mfa_pending", None) + return jsonify({"error": error_key}), status + except InvalidMFACodeError as exc: + return jsonify({"error": str(exc) or "invalid_code"}), 401 + except MFAUnavailableError as exc: + current_app.logger.error("mfa unavailable: %s", exc) + return jsonify({"error": str(exc)}), 500 + except InvalidCredentialsError: + session.pop("mfa_pending", None) + return jsonify({"error": "invalid username or password"}), 401 + + session.pop("mfa_pending", None) + session["username"] = result.username + session["role"] = result.role or "User" + payload = { + "status": "ok", + "username": result.username, + "role": result.role, + "token": result.token, + } + response = jsonify(payload) + _set_auth_cookie(response, result.token) + return response + + app.register_blueprint(bp) + + +def _set_auth_cookie(response, value: str, *, expires: int | None = None) -> None: + same_site = current_app.config.get("SESSION_COOKIE_SAMESITE", "Lax") + secure = bool(current_app.config.get("SESSION_COOKIE_SECURE", False)) + domain = current_app.config.get("SESSION_COOKIE_DOMAIN", None) + response.set_cookie( + "borealis_auth", + value, + httponly=False, + samesite=same_site, + secure=secure, + domain=domain, + path="/", + expires=expires, + ) + + +__all__ = ["register"] diff --git a/Data/Engine/repositories/sqlite/__init__.py b/Data/Engine/repositories/sqlite/__init__.py index ceef224..869829f 100644 --- a/Data/Engine/repositories/sqlite/__init__.py +++ b/Data/Engine/repositories/sqlite/__init__.py @@ -26,6 +26,7 @@ try: # pragma: no cover - optional dependency shim from .github_repository import SQLiteGitHubRepository from .job_repository import SQLiteJobRepository from .token_repository import SQLiteRefreshTokenRepository + from .user_repository import SQLiteUserRepository except ModuleNotFoundError as exc: # pragma: no cover - triggered when auth deps missing def _missing_repo(*_args: object, **_kwargs: object) -> None: raise ModuleNotFoundError( @@ -44,4 +45,5 @@ else: "SQLiteJobRepository", "SQLiteEnrollmentRepository", "SQLiteGitHubRepository", + "SQLiteUserRepository", ] diff --git a/Data/Engine/repositories/sqlite/user_repository.py b/Data/Engine/repositories/sqlite/user_repository.py new file mode 100644 index 0000000..14708e5 --- /dev/null +++ b/Data/Engine/repositories/sqlite/user_repository.py @@ -0,0 +1,123 @@ +"""SQLite repository for operator accounts.""" + +from __future__ import annotations + +import logging +import sqlite3 +from dataclasses import dataclass +from typing import Optional + +from Data.Engine.domain import OperatorAccount + +from .connection import SQLiteConnectionFactory + + +@dataclass(frozen=True, slots=True) +class _UserRow: + id: str + username: str + display_name: str + password_sha512: str + role: str + last_login: int + created_at: int + updated_at: int + mfa_enabled: int + mfa_secret: str + + +class SQLiteUserRepository: + """Expose CRUD helpers for operator accounts stored in SQLite.""" + + def __init__( + self, + connection_factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._connection_factory = connection_factory + self._log = logger or logging.getLogger("borealis.engine.repositories.users") + + def fetch_by_username(self, username: str) -> Optional[OperatorAccount]: + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute( + """ + SELECT + id, + username, + display_name, + COALESCE(password_sha512, '') as password_sha512, + COALESCE(role, 'User') as role, + COALESCE(last_login, 0) as last_login, + COALESCE(created_at, 0) as created_at, + COALESCE(updated_at, 0) as updated_at, + COALESCE(mfa_enabled, 0) as mfa_enabled, + COALESCE(mfa_secret, '') as mfa_secret + FROM users + WHERE LOWER(username) = LOWER(?) + """, + (username,), + ) + row = cur.fetchone() + if not row: + return None + record = _UserRow(*row) + return OperatorAccount( + username=record.username, + display_name=record.display_name or record.username, + password_sha512=(record.password_sha512 or "").lower(), + role=record.role or "User", + last_login=int(record.last_login or 0), + created_at=int(record.created_at or 0), + updated_at=int(record.updated_at or 0), + mfa_enabled=bool(record.mfa_enabled), + mfa_secret=(record.mfa_secret or "") or None, + ) + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.error("failed to load user %s: %s", username, exc) + return None + finally: + conn.close() + + def update_last_login(self, username: str, timestamp: int) -> None: + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute( + """ + UPDATE users + SET last_login = ?, + updated_at = ? + WHERE LOWER(username) = LOWER(?) + """, + (timestamp, timestamp, username), + ) + conn.commit() + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.warning("failed to update last_login for %s: %s", username, exc) + finally: + conn.close() + + def store_mfa_secret(self, username: str, secret: str, *, timestamp: int) -> None: + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute( + """ + UPDATE users + SET mfa_secret = ?, + updated_at = ? + WHERE LOWER(username) = LOWER(?) + """, + (secret, timestamp, username), + ) + conn.commit() + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.warning("failed to persist MFA secret for %s: %s", username, exc) + finally: + conn.close() + + +__all__ = ["SQLiteUserRepository"] diff --git a/Data/Engine/requirements.txt b/Data/Engine/requirements.txt index 6f7693c..b1f10c7 100644 --- a/Data/Engine/requirements.txt +++ b/Data/Engine/requirements.txt @@ -9,3 +9,5 @@ requests # Auth & security PyJWT[crypto] cryptography +pyotp +qrcode diff --git a/Data/Engine/services/auth/__init__.py b/Data/Engine/services/auth/__init__.py index f24d072..98e66cd 100644 --- a/Data/Engine/services/auth/__init__.py +++ b/Data/Engine/services/auth/__init__.py @@ -11,6 +11,14 @@ from .token_service import ( TokenRefreshErrorCode, TokenService, ) +from .operator_auth_service import ( + InvalidCredentialsError, + InvalidMFACodeError, + MFAUnavailableError, + MFASessionError, + OperatorAuthError, + OperatorAuthService, +) __all__ = [ "DeviceAuthService", @@ -24,4 +32,10 @@ __all__ = [ "TokenRefreshError", "TokenRefreshErrorCode", "TokenService", + "OperatorAuthService", + "OperatorAuthError", + "InvalidCredentialsError", + "InvalidMFACodeError", + "MFAUnavailableError", + "MFASessionError", ] diff --git a/Data/Engine/services/auth/operator_auth_service.py b/Data/Engine/services/auth/operator_auth_service.py new file mode 100644 index 0000000..d3c1163 --- /dev/null +++ b/Data/Engine/services/auth/operator_auth_service.py @@ -0,0 +1,209 @@ +"""Operator authentication service.""" + +from __future__ import annotations + +import base64 +import io +import logging +import os +import time +import uuid +from typing import Optional + +try: # pragma: no cover - optional dependencies mirror legacy server behaviour + import pyotp # type: ignore +except Exception: # pragma: no cover - gracefully degrade when unavailable + pyotp = None # type: ignore + +try: # pragma: no cover - optional dependency + import qrcode # type: ignore +except Exception: # pragma: no cover - gracefully degrade when unavailable + qrcode = None # type: ignore + +from itsdangerous import URLSafeTimedSerializer + +from Data.Engine.builders.operator_auth import ( + OperatorLoginRequest, + OperatorMFAVerificationRequest, +) +from Data.Engine.domain import ( + OperatorAccount, + OperatorLoginSuccess, + OperatorMFAChallenge, +) +from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository + + +class OperatorAuthError(Exception): + """Base class for operator authentication errors.""" + + +class InvalidCredentialsError(OperatorAuthError): + """Raised when username/password verification fails.""" + + +class MFAUnavailableError(OperatorAuthError): + """Raised when MFA functionality is requested but dependencies are missing.""" + + +class InvalidMFACodeError(OperatorAuthError): + """Raised when the submitted MFA code is invalid.""" + + +class MFASessionError(OperatorAuthError): + """Raised when the MFA session state cannot be validated.""" + + +class OperatorAuthService: + """Authenticate operator accounts and manage MFA challenges.""" + + def __init__( + self, + repository: SQLiteUserRepository, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._repository = repository + self._log = logger or logging.getLogger("borealis.engine.services.operator_auth") + + def authenticate( + self, request: OperatorLoginRequest + ) -> OperatorLoginSuccess | OperatorMFAChallenge: + account = self._repository.fetch_by_username(request.username) + if not account: + raise InvalidCredentialsError("invalid username or password") + + if not self._password_matches(account, request.password_sha512): + raise InvalidCredentialsError("invalid username or password") + + if not account.mfa_enabled: + return self._finalize_login(account) + + stage = "verify" if account.mfa_secret else "setup" + return self._build_mfa_challenge(account, stage) + + def verify_mfa( + self, + challenge: OperatorMFAChallenge, + request: OperatorMFAVerificationRequest, + ) -> OperatorLoginSuccess: + now = int(time.time()) + if challenge.pending_token != request.pending_token: + raise MFASessionError("invalid_session") + if challenge.expires_at < now: + raise MFASessionError("expired") + + if challenge.stage == "setup": + secret = (challenge.secret or "").strip() + if not secret: + raise MFASessionError("mfa_not_configured") + totp = self._totp_for_secret(secret) + if not totp.verify(request.code, valid_window=1): + raise InvalidMFACodeError("invalid_code") + self._repository.store_mfa_secret(challenge.username, secret, timestamp=now) + else: + account = self._repository.fetch_by_username(challenge.username) + if not account or not account.mfa_secret: + raise MFASessionError("mfa_not_configured") + totp = self._totp_for_secret(account.mfa_secret) + if not totp.verify(request.code, valid_window=1): + raise InvalidMFACodeError("invalid_code") + + account = self._repository.fetch_by_username(challenge.username) + if not account: + raise InvalidCredentialsError("invalid username or password") + return self._finalize_login(account) + + def issue_token(self, username: str, role: str) -> str: + serializer = self._token_serializer() + payload = {"u": username, "r": role or "User", "ts": int(time.time())} + return serializer.dumps(payload) + + def _finalize_login(self, account: OperatorAccount) -> OperatorLoginSuccess: + now = int(time.time()) + self._repository.update_last_login(account.username, now) + token = self.issue_token(account.username, account.role) + return OperatorLoginSuccess(username=account.username, role=account.role, token=token) + + def _password_matches(self, account: OperatorAccount, provided_hash: str) -> bool: + expected = (account.password_sha512 or "").strip().lower() + candidate = (provided_hash or "").strip().lower() + return bool(expected and candidate and expected == candidate) + + def _build_mfa_challenge( + self, + account: OperatorAccount, + stage: str, + ) -> OperatorMFAChallenge: + now = int(time.time()) + pending_token = uuid.uuid4().hex + secret = None + otpauth_url = None + qr_image = None + + if stage == "setup": + secret = self._generate_totp_secret() + otpauth_url = self._totp_provisioning_uri(secret, account.username) + qr_image = self._totp_qr_data_uri(otpauth_url) if otpauth_url else None + + return OperatorMFAChallenge( + username=account.username, + role=account.role, + stage="verify" if stage == "verify" else "setup", + pending_token=pending_token, + expires_at=now + 300, + secret=secret, + otpauth_url=otpauth_url, + qr_image=qr_image, + ) + + def _token_serializer(self) -> URLSafeTimedSerializer: + secret = os.getenv("BOREALIS_FLASK_SECRET_KEY") or "change-me" + return URLSafeTimedSerializer(secret, salt="borealis-auth") + + def _generate_totp_secret(self) -> str: + if not pyotp: + raise MFAUnavailableError("pyotp is not installed; MFA unavailable") + return pyotp.random_base32() # type: ignore[no-any-return] + + def _totp_for_secret(self, secret: str): + if not pyotp: + raise MFAUnavailableError("pyotp is not installed; MFA unavailable") + normalized = secret.replace(" ", "").strip().upper() + if not normalized: + raise MFASessionError("mfa_not_configured") + return pyotp.TOTP(normalized, digits=6, interval=30) + + def _totp_provisioning_uri(self, secret: str, username: str) -> Optional[str]: + try: + totp = self._totp_for_secret(secret) + except OperatorAuthError: + return None + issuer = os.getenv("BOREALIS_MFA_ISSUER", "Borealis") + try: + return totp.provisioning_uri(name=username, issuer_name=issuer) + except Exception: # pragma: no cover - defensive + return None + + def _totp_qr_data_uri(self, payload: str) -> Optional[str]: + if not payload or qrcode is None: + return None + try: + img = qrcode.make(payload, box_size=6, border=4) + buf = io.BytesIO() + img.save(buf, format="PNG") + encoded = base64.b64encode(buf.getvalue()).decode("ascii") + return f"data:image/png;base64,{encoded}" + except Exception: # pragma: no cover - defensive + self._log.warning("failed to generate MFA QR code", exc_info=True) + return None + + +__all__ = [ + "OperatorAuthService", + "OperatorAuthError", + "InvalidCredentialsError", + "MFAUnavailableError", + "InvalidMFACodeError", + "MFASessionError", +] diff --git a/Data/Engine/services/container.py b/Data/Engine/services/container.py index 756c44f..714b686 100644 --- a/Data/Engine/services/container.py +++ b/Data/Engine/services/container.py @@ -17,10 +17,12 @@ from Data.Engine.repositories.sqlite import ( SQLiteGitHubRepository, SQLiteJobRepository, SQLiteRefreshTokenRepository, + SQLiteUserRepository, ) from Data.Engine.services.auth import ( DeviceAuthService, DPoPValidator, + OperatorAuthService, JWTService, TokenService, load_jwt_service, @@ -46,6 +48,7 @@ class EngineServiceContainer: agent_realtime: AgentRealtimeService scheduler_service: SchedulerService github_service: GitHubService + operator_auth_service: OperatorAuthService def build_service_container( @@ -61,6 +64,7 @@ def build_service_container( enrollment_repo = SQLiteEnrollmentRepository(db_factory, logger=log.getChild("enrollment")) job_repo = SQLiteJobRepository(db_factory, logger=log.getChild("jobs")) github_repo = SQLiteGitHubRepository(db_factory, logger=log.getChild("github_repo")) + user_repo = SQLiteUserRepository(db_factory, logger=log.getChild("users")) jwt_service = load_jwt_service() dpop_validator = DPoPValidator() @@ -106,6 +110,11 @@ def build_service_container( logger=log.getChild("scheduler"), ) + operator_auth_service = OperatorAuthService( + repository=user_repo, + logger=log.getChild("operator_auth"), + ) + github_provider = GitHubArtifactProvider( cache_file=settings.github.cache_file, default_repo=settings.github.default_repo, @@ -129,6 +138,7 @@ def build_service_container( agent_realtime=agent_realtime, scheduler_service=scheduler_service, github_service=github_service, + operator_auth_service=operator_auth_service, ) diff --git a/Data/Engine/tests/test_config_environment.py b/Data/Engine/tests/test_config_environment.py index c89ef01..03ff2ba 100644 --- a/Data/Engine/tests/test_config_environment.py +++ b/Data/Engine/tests/test_config_environment.py @@ -2,6 +2,8 @@ from __future__ import annotations +from pathlib import Path + from Data.Engine.config.environment import load_environment @@ -59,3 +61,14 @@ def test_static_root_falls_back_to_legacy_source(tmp_path, monkeypatch): assert settings.flask.static_root == legacy_source.resolve() monkeypatch.delenv("BOREALIS_ROOT", raising=False) + + +def test_resolve_project_root_defaults_to_repository(monkeypatch): + """The project root should resolve to the repository checkout.""" + + monkeypatch.delenv("BOREALIS_ROOT", raising=False) + from Data.Engine.config import environment as env_module + + expected = Path(env_module.__file__).resolve().parents[3] + + assert env_module._resolve_project_root() == expected diff --git a/Data/Engine/tests/test_operator_auth_builders.py b/Data/Engine/tests/test_operator_auth_builders.py new file mode 100644 index 0000000..80deb14 --- /dev/null +++ b/Data/Engine/tests/test_operator_auth_builders.py @@ -0,0 +1,63 @@ +"""Tests for operator authentication builders.""" + +from __future__ import annotations + +import pytest + +from Data.Engine.builders import ( + OperatorLoginRequest, + OperatorMFAVerificationRequest, + build_login_request, + build_mfa_request, +) + + +def test_build_login_request_uses_explicit_hash(): + payload = {"username": "Admin", "password_sha512": "abc123"} + + result = build_login_request(payload) + + assert isinstance(result, OperatorLoginRequest) + assert result.username == "Admin" + assert result.password_sha512 == "abc123" + + +def test_build_login_request_hashes_plain_password(): + payload = {"username": "user", "password": "secret"} + + result = build_login_request(payload) + + assert isinstance(result, OperatorLoginRequest) + assert result.username == "user" + assert result.password_sha512 + assert result.password_sha512 != "secret" + + +@pytest.mark.parametrize( + "payload", + [ + {"password": "secret"}, + {"username": ""}, + {"username": "user"}, + ], +) +def test_build_login_request_validation(payload): + with pytest.raises(ValueError): + build_login_request(payload) + + +def test_build_mfa_request_normalizes_code(): + payload = {"pending_token": "token", "code": "12 34-56"} + + result = build_mfa_request(payload) + + assert isinstance(result, OperatorMFAVerificationRequest) + assert result.pending_token == "token" + assert result.code == "123456" + + +def test_build_mfa_request_requires_token_and_code(): + with pytest.raises(ValueError): + build_mfa_request({"code": "123"}) + with pytest.raises(ValueError): + build_mfa_request({"pending_token": "token", "code": "12"}) diff --git a/Data/Engine/tests/test_operator_auth_service.py b/Data/Engine/tests/test_operator_auth_service.py new file mode 100644 index 0000000..921441e --- /dev/null +++ b/Data/Engine/tests/test_operator_auth_service.py @@ -0,0 +1,197 @@ +"""Tests for the operator authentication service.""" + +from __future__ import annotations + +import hashlib +import sqlite3 +from pathlib import Path +from typing import Callable + +import pytest + +pyotp = pytest.importorskip("pyotp") + +from Data.Engine.builders import ( + OperatorLoginRequest, + OperatorMFAVerificationRequest, +) +from Data.Engine.repositories.sqlite.connection import connection_factory +from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository +from Data.Engine.services.auth.operator_auth_service import ( + InvalidCredentialsError, + InvalidMFACodeError, + OperatorAuthService, +) + + +def _prepare_db(path: Path) -> Callable[[], sqlite3.Connection]: + conn = sqlite3.connect(path) + conn.execute( + """ + CREATE TABLE users ( + id TEXT PRIMARY KEY, + username TEXT, + display_name TEXT, + password_sha512 TEXT, + role TEXT, + last_login INTEGER, + created_at INTEGER, + updated_at INTEGER, + mfa_enabled INTEGER, + mfa_secret TEXT + ) + """ + ) + conn.commit() + conn.close() + return connection_factory(path) + + +def _insert_user( + factory: Callable[[], sqlite3.Connection], + *, + user_id: str, + username: str, + password_hash: str, + role: str = "Admin", + mfa_enabled: int = 0, + mfa_secret: str = "", +) -> None: + conn = factory() + conn.execute( + """ + INSERT INTO users ( + id, username, display_name, password_sha512, role, + last_login, created_at, updated_at, mfa_enabled, mfa_secret + ) VALUES (?, ?, ?, ?, ?, 0, 0, 0, ?, ?) + """, + (user_id, username, username, password_hash, role, mfa_enabled, mfa_secret), + ) + conn.commit() + conn.close() + + +def test_authenticate_success_updates_last_login(tmp_path): + db_path = tmp_path / "auth.db" + factory = _prepare_db(db_path) + password_hash = hashlib.sha512(b"password").hexdigest() + _insert_user(factory, user_id="1", username="admin", password_hash=password_hash) + + repo = SQLiteUserRepository(factory) + service = OperatorAuthService(repo) + + request = OperatorLoginRequest(username="admin", password_sha512=password_hash) + result = service.authenticate(request) + + assert result.username == "admin" + + conn = factory() + row = conn.execute("SELECT last_login FROM users WHERE username=?", ("admin",)).fetchone() + conn.close() + assert row[0] > 0 + + +def test_authenticate_invalid_credentials(tmp_path): + db_path = tmp_path / "auth.db" + factory = _prepare_db(db_path) + repo = SQLiteUserRepository(factory) + service = OperatorAuthService(repo) + + request = OperatorLoginRequest(username="missing", password_sha512="abc") + with pytest.raises(InvalidCredentialsError): + service.authenticate(request) + + +def test_mfa_verify_flow(tmp_path): + db_path = tmp_path / "auth.db" + factory = _prepare_db(db_path) + secret = pyotp.random_base32() + password_hash = hashlib.sha512(b"password").hexdigest() + _insert_user( + factory, + user_id="1", + username="admin", + password_hash=password_hash, + mfa_enabled=1, + mfa_secret=secret, + ) + + repo = SQLiteUserRepository(factory) + service = OperatorAuthService(repo) + login_request = OperatorLoginRequest(username="admin", password_sha512=password_hash) + + challenge = service.authenticate(login_request) + assert challenge.stage == "verify" + + totp = pyotp.TOTP(secret) + verify_request = OperatorMFAVerificationRequest( + pending_token=challenge.pending_token, + code=totp.now(), + ) + + result = service.verify_mfa(challenge, verify_request) + assert result.username == "admin" + + +def test_mfa_setup_flow_persists_secret(tmp_path): + db_path = tmp_path / "auth.db" + factory = _prepare_db(db_path) + password_hash = hashlib.sha512(b"password").hexdigest() + _insert_user( + factory, + user_id="1", + username="admin", + password_hash=password_hash, + mfa_enabled=1, + mfa_secret="", + ) + + repo = SQLiteUserRepository(factory) + service = OperatorAuthService(repo) + + challenge = service.authenticate(OperatorLoginRequest(username="admin", password_sha512=password_hash)) + assert challenge.stage == "setup" + assert challenge.secret + + totp = pyotp.TOTP(challenge.secret) + verify_request = OperatorMFAVerificationRequest( + pending_token=challenge.pending_token, + code=totp.now(), + ) + + result = service.verify_mfa(challenge, verify_request) + assert result.username == "admin" + + conn = factory() + stored_secret = conn.execute( + "SELECT mfa_secret FROM users WHERE username=?", ("admin",) + ).fetchone()[0] + conn.close() + assert stored_secret + + +def test_mfa_invalid_code_raises(tmp_path): + db_path = tmp_path / "auth.db" + factory = _prepare_db(db_path) + secret = pyotp.random_base32() + password_hash = hashlib.sha512(b"password").hexdigest() + _insert_user( + factory, + user_id="1", + username="admin", + password_hash=password_hash, + mfa_enabled=1, + mfa_secret=secret, + ) + + repo = SQLiteUserRepository(factory) + service = OperatorAuthService(repo) + challenge = service.authenticate(OperatorLoginRequest(username="admin", password_sha512=password_hash)) + + verify_request = OperatorMFAVerificationRequest( + pending_token=challenge.pending_token, + code="000000", + ) + + with pytest.raises(InvalidMFACodeError): + service.verify_mfa(challenge, verify_request) From 7a9feebde54f0f7cb812ba905ecfbe41602f04f4 Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Wed, 22 Oct 2025 19:37:47 -0600 Subject: [PATCH 03/12] Fix static asset fallback and seed default admin --- Data/Engine/bootstrapper.py | 4 + Data/Engine/config/environment.py | 4 + Data/Engine/repositories/sqlite/__init__.py | 3 +- Data/Engine/repositories/sqlite/migrations.py | 90 ++++++++++++++++++- Data/Engine/tests/test_config_environment.py | 17 ++++ Data/Engine/tests/test_sqlite_migrations.py | 51 +++++++++++ 6 files changed, 167 insertions(+), 2 deletions(-) diff --git a/Data/Engine/bootstrapper.py b/Data/Engine/bootstrapper.py index e16b272..d70a60a 100644 --- a/Data/Engine/bootstrapper.py +++ b/Data/Engine/bootstrapper.py @@ -66,6 +66,10 @@ def bootstrap() -> EngineRuntime: else: logger.info("migrations-skipped") + with sqlite_connection.connection_scope(settings.database_path) as conn: + sqlite_migrations.ensure_default_admin(conn) + logger.info("default-admin-ensured") + app = create_app(settings, db_factory=db_factory) services = build_service_container(settings, db_factory=db_factory, logger=logger.getChild("services")) app.extensions["engine_services"] = services diff --git a/Data/Engine/config/environment.py b/Data/Engine/config/environment.py index 14cde00..04211f8 100644 --- a/Data/Engine/config/environment.py +++ b/Data/Engine/config/environment.py @@ -122,6 +122,10 @@ def _resolve_static_root(project_root: Path) -> Path: project_root / "Engine" / "web-interface", project_root / "Data" / "Engine" / "WebUI" / "build", project_root / "Data" / "Engine" / "WebUI", + project_root / "Server" / "web-interface" / "build", + project_root / "Server" / "web-interface", + project_root / "Server" / "WebUI" / "build", + project_root / "Server" / "WebUI", project_root / "Data" / "Server" / "web-interface" / "build", project_root / "Data" / "Server" / "web-interface", project_root / "Data" / "Server" / "WebUI" / "build", diff --git a/Data/Engine/repositories/sqlite/__init__.py b/Data/Engine/repositories/sqlite/__init__.py index 869829f..8b44e59 100644 --- a/Data/Engine/repositories/sqlite/__init__.py +++ b/Data/Engine/repositories/sqlite/__init__.py @@ -9,7 +9,7 @@ from .connection import ( connection_factory, connection_scope, ) -from .migrations import apply_all +from .migrations import apply_all, ensure_default_admin __all__ = [ "SQLiteConnectionFactory", @@ -18,6 +18,7 @@ __all__ = [ "connection_factory", "connection_scope", "apply_all", + "ensure_default_admin", ] try: # pragma: no cover - optional dependency shim diff --git a/Data/Engine/repositories/sqlite/migrations.py b/Data/Engine/repositories/sqlite/migrations.py index 4dddca0..34d3c77 100644 --- a/Data/Engine/repositories/sqlite/migrations.py +++ b/Data/Engine/repositories/sqlite/migrations.py @@ -15,6 +15,10 @@ from typing import List, Optional, Sequence, Tuple DEVICE_TABLE = "devices" +_DEFAULT_ADMIN_USERNAME = "admin" +_DEFAULT_ADMIN_PASSWORD_SHA512 = ( + "e6c83b282aeb2e022844595721cc00bbda47cb24537c1779f9bb84f04039e1676e6ba8573e588da1052510e3aa0a32a9e55879ae22b0c2d62136fc0a3e85f8bb" +) def apply_all(conn: sqlite3.Connection) -> None: @@ -30,6 +34,8 @@ def apply_all(conn: sqlite3.Connection) -> None: _ensure_github_token_table(conn) _ensure_scheduled_jobs_table(conn) _ensure_scheduled_job_run_tables(conn) + _ensure_users_table(conn) + _ensure_default_admin(conn) conn.commit() @@ -504,4 +510,86 @@ def _normalized_guid(value: Optional[str]) -> str: return "" return str(value).strip() -__all__ = ["apply_all"] +def _ensure_users_table(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE NOT NULL, + display_name TEXT, + password_sha512 TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'Admin', + last_login INTEGER, + created_at INTEGER, + updated_at INTEGER, + mfa_enabled INTEGER NOT NULL DEFAULT 0, + mfa_secret TEXT + ) + """ + ) + + try: + cur.execute("PRAGMA table_info(users)") + columns = [row[1] for row in cur.fetchall()] + if "mfa_enabled" not in columns: + cur.execute("ALTER TABLE users ADD COLUMN mfa_enabled INTEGER NOT NULL DEFAULT 0") + if "mfa_secret" not in columns: + cur.execute("ALTER TABLE users ADD COLUMN mfa_secret TEXT") + except sqlite3.Error: + # Aligning the schema is best-effort; older deployments may lack ALTER + # TABLE privileges but can continue using existing columns. + pass + + +def _ensure_default_admin(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute("SELECT COUNT(*) FROM users WHERE LOWER(role)='admin'") + row = cur.fetchone() + if row and (row[0] or 0): + return + + now = int(datetime.now(timezone.utc).timestamp()) + cur.execute( + "SELECT COUNT(*) FROM users WHERE LOWER(username)=LOWER(?)", + (_DEFAULT_ADMIN_USERNAME,), + ) + existing = cur.fetchone() + if not existing or not (existing[0] or 0): + cur.execute( + """ + INSERT INTO users ( + username, display_name, password_sha512, role, + last_login, created_at, updated_at, mfa_enabled, mfa_secret + ) VALUES (?, ?, ?, 'Admin', 0, ?, ?, 0, NULL) + """, + ( + _DEFAULT_ADMIN_USERNAME, + "Administrator", + _DEFAULT_ADMIN_PASSWORD_SHA512, + now, + now, + ), + ) + else: + cur.execute( + """ + UPDATE users + SET role='Admin', + updated_at=? + WHERE LOWER(username)=LOWER(?) + AND LOWER(role)!='admin' + """, + (now, _DEFAULT_ADMIN_USERNAME), + ) + + +def ensure_default_admin(conn: sqlite3.Connection) -> None: + """Guarantee that at least one admin account exists.""" + + _ensure_users_table(conn) + _ensure_default_admin(conn) + conn.commit() + + +__all__ = ["apply_all", "ensure_default_admin"] diff --git a/Data/Engine/tests/test_config_environment.py b/Data/Engine/tests/test_config_environment.py index 03ff2ba..7631925 100644 --- a/Data/Engine/tests/test_config_environment.py +++ b/Data/Engine/tests/test_config_environment.py @@ -63,6 +63,23 @@ def test_static_root_falls_back_to_legacy_source(tmp_path, monkeypatch): monkeypatch.delenv("BOREALIS_ROOT", raising=False) +def test_static_root_considers_runtime_copy(tmp_path, monkeypatch): + """Runtime Server/WebUI copies should be considered when Data assets are missing.""" + + runtime_source = tmp_path / "Server" / "WebUI" + runtime_source.mkdir(parents=True) + (runtime_source / "index.html").write_text("runtime", encoding="utf-8") + + monkeypatch.setenv("BOREALIS_ROOT", str(tmp_path)) + monkeypatch.delenv("BOREALIS_STATIC_ROOT", raising=False) + + settings = load_environment() + + assert settings.flask.static_root == runtime_source.resolve() + + monkeypatch.delenv("BOREALIS_ROOT", raising=False) + + def test_resolve_project_root_defaults_to_repository(monkeypatch): """The project root should resolve to the repository checkout.""" diff --git a/Data/Engine/tests/test_sqlite_migrations.py b/Data/Engine/tests/test_sqlite_migrations.py index 6361616..56d4b1f 100644 --- a/Data/Engine/tests/test_sqlite_migrations.py +++ b/Data/Engine/tests/test_sqlite_migrations.py @@ -1,3 +1,4 @@ +import hashlib import sqlite3 import unittest @@ -24,6 +25,56 @@ class MigrationTests(unittest.TestCase): self.assertIn("scheduled_jobs", tables) self.assertIn("scheduled_job_runs", tables) self.assertIn("github_token", tables) + self.assertIn("users", tables) + + cursor.execute( + "SELECT username, role, password_sha512 FROM users WHERE LOWER(username)=LOWER(?)", + ("admin",), + ) + row = cursor.fetchone() + self.assertIsNotNone(row) + if row: + self.assertEqual(row[0], "admin") + self.assertEqual(row[1].lower(), "admin") + self.assertEqual(row[2], hashlib.sha512(b"Password").hexdigest()) + finally: + conn.close() + + def test_ensure_default_admin_promotes_existing_user(self) -> None: + conn = sqlite3.connect(":memory:") + try: + conn.execute( + """ + CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE NOT NULL, + display_name TEXT, + password_sha512 TEXT, + role TEXT, + last_login INTEGER, + created_at INTEGER, + updated_at INTEGER, + mfa_enabled INTEGER DEFAULT 0, + mfa_secret TEXT + ) + """ + ) + conn.execute( + "INSERT INTO users (username, display_name, password_sha512, role) VALUES (?, ?, ?, ?)", + ("admin", "Custom", "hash", "user"), + ) + conn.commit() + + migrations.ensure_default_admin(conn) + + cursor = conn.cursor() + cursor.execute( + "SELECT role, password_sha512 FROM users WHERE LOWER(username)=LOWER(?)", + ("admin",), + ) + role, password_hash = cursor.fetchone() + self.assertEqual(role.lower(), "admin") + self.assertEqual(password_hash, "hash") finally: conn.close() From da4cb501e0b7f4af389122dad6cfd282591740ac Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Wed, 22 Oct 2025 19:59:09 -0600 Subject: [PATCH 04/12] Bridge legacy Flask APIs through Engine fallback --- Data/Engine/bootstrapper.py | 15 +- Data/Engine/interfaces/http/__init__.py | 6 +- Data/Engine/server.py | 138 +++++++++++++++++- .../Engine/tests/test_server_legacy_bridge.py | 60 ++++++++ 4 files changed, 212 insertions(+), 7 deletions(-) create mode 100644 Data/Engine/tests/test_server_legacy_bridge.py diff --git a/Data/Engine/bootstrapper.py b/Data/Engine/bootstrapper.py index d70a60a..46d0194 100644 --- a/Data/Engine/bootstrapper.py +++ b/Data/Engine/bootstrapper.py @@ -18,7 +18,7 @@ from .interfaces import ( from .interfaces.eventlet_compat import apply_eventlet_patches from .repositories.sqlite import connection as sqlite_connection from .repositories.sqlite import migrations as sqlite_migrations -from .server import create_app +from .server import attach_legacy_bridge, create_app from .services.container import build_service_container from .services.crypto.certificates import ensure_certificate @@ -71,12 +71,19 @@ def bootstrap() -> EngineRuntime: logger.info("default-admin-ensured") app = create_app(settings, db_factory=db_factory) + attach_legacy_bridge(app, settings, logger=logger) services = build_service_container(settings, db_factory=db_factory, logger=logger.getChild("services")) app.extensions["engine_services"] = services register_http_interfaces(app, services) - socketio = create_socket_server(app, settings.socketio) - register_ws_interfaces(socketio, services) - services.scheduler_service.start(socketio) + + legacy_active = bool(app.config.get("ENGINE_LEGACY_BRIDGE_ACTIVE")) + if legacy_active: + socketio = None + logger.info("legacy-ws-deferred") + else: + socketio = create_socket_server(app, settings.socketio) + register_ws_interfaces(socketio, services) + services.scheduler_service.start(socketio) logger.info("bootstrap-complete") return EngineRuntime( app=app, diff --git a/Data/Engine/interfaces/http/__init__.py b/Data/Engine/interfaces/http/__init__.py index e388b81..c8a568a 100644 --- a/Data/Engine/interfaces/http/__init__.py +++ b/Data/Engine/interfaces/http/__init__.py @@ -26,7 +26,11 @@ def register_http_interfaces(app: Flask, services: EngineServiceContainer) -> No The implementation is intentionally minimal for the initial scaffolding. """ - for registrar in _REGISTRARS: + registrars = list(_REGISTRARS) + if app.config.get("ENGINE_LEGACY_BRIDGE_ACTIVE"): + registrars = [r for r in registrars if r is not job_management.register] + + for registrar in registrars: registrar(app, services) diff --git a/Data/Engine/server.py b/Data/Engine/server.py index 77fb8ea..3cb7376 100644 --- a/Data/Engine/server.py +++ b/Data/Engine/server.py @@ -2,8 +2,11 @@ from __future__ import annotations +import importlib +import logging +import os from pathlib import Path -from typing import Optional +from typing import Any, Iterable, Optional from flask import Flask, request, send_from_directory from flask_cors import CORS @@ -100,4 +103,135 @@ def create_app( return app -__all__ = ["create_app"] +def attach_legacy_bridge( + app: Flask, + settings: EngineSettings, + *, + logger: Optional[logging.Logger] = None, +) -> None: + """Attach the legacy Flask application as a fallback dispatcher. + + Borealis ships a mature API surface inside ``Data/Server/server.py``. The + Engine will eventually supersede it, but during the migration the React + frontend still expects the historical endpoints to exist. This helper + attempts to load the legacy application and wires it as a fallback WSGI + dispatcher so any route the Engine does not yet implement transparently + defers to the proven implementation. + """ + + log = logger or logging.getLogger("borealis.engine.legacy") + + if not _legacy_bridge_enabled(): + log.info("legacy-bridge-disabled") + return + + legacy = _load_legacy_app(settings, app, log) + if legacy is None: + log.warning("legacy-bridge-unavailable") + return + + app.config["ENGINE_LEGACY_BRIDGE_ACTIVE"] = True + app.wsgi_app = _FallbackDispatcher(app.wsgi_app, legacy.wsgi_app) # type: ignore[assignment] + app.extensions["legacy_flask_app"] = legacy + log.info("legacy-bridge-active") + + +def _legacy_bridge_enabled() -> bool: + raw = os.getenv("BOREALIS_ENGINE_ENABLE_LEGACY_BRIDGE", "1") + return raw.strip().lower() in {"1", "true", "yes", "on"} + + +def _load_legacy_app( + settings: EngineSettings, + engine_app: Flask, + logger: logging.Logger, +) -> Optional[Flask]: + try: + legacy_module = importlib.import_module("Data.Server.server") + except Exception as exc: # pragma: no cover - defensive + logger.exception("legacy-import-failed", exc_info=exc) + return None + + legacy_app = getattr(legacy_module, "app", None) + if not isinstance(legacy_app, Flask): + logger.error("legacy-app-missing") + return None + + # Align runtime configuration so both applications share database and + # session state. + try: + setattr(legacy_module, "DB_PATH", str(settings.database_path)) + except Exception: # pragma: no cover - defensive + logger.warning("legacy-db-path-sync-failed", extra={"path": str(settings.database_path)}) + + _synchronise_session_config(engine_app, legacy_app) + + return legacy_app + + +def _synchronise_session_config(engine_app: Flask, legacy_app: Flask) -> None: + legacy_app.secret_key = engine_app.config.get("SECRET_KEY", legacy_app.secret_key) + for key in ( + "SESSION_COOKIE_HTTPONLY", + "SESSION_COOKIE_SECURE", + "SESSION_COOKIE_SAMESITE", + "SESSION_COOKIE_DOMAIN", + ): + value = engine_app.config.get(key) + if value is not None: + legacy_app.config[key] = value + + +class _FallbackDispatcher: + """WSGI dispatcher that retries a secondary app when the primary 404s.""" + + __slots__ = ("_primary", "_fallback", "_retry_statuses") + + def __init__( + self, + primary: Any, + fallback: Any, + *, + retry_statuses: Iterable[int] = (404,), + ) -> None: + self._primary = primary + self._fallback = fallback + self._retry_statuses = {int(status) for status in retry_statuses} + + def __call__(self, environ: dict[str, Any], start_response: Any) -> Iterable[bytes]: + captured_body: list[bytes] = [] + captured_status: dict[str, Any] = {} + + def _capture_start_response(status: str, headers: list[tuple[str, str]], exc_info: Any = None): + captured_status["status"] = status + captured_status["headers"] = headers + captured_status["exc_info"] = exc_info + + def _write(data: bytes) -> None: + captured_body.append(data) + + return _write + + primary_iterable = self._primary(environ, _capture_start_response) + try: + for chunk in primary_iterable: + captured_body.append(chunk) + finally: + close = getattr(primary_iterable, "close", None) + if callable(close): + close() + + status_line = str(captured_status.get("status") or "500 Internal Server Error") + try: + status_code = int(status_line.split()[0]) + except Exception: # pragma: no cover - defensive + status_code = 500 + + if status_code not in self._retry_statuses: + start_response(status_line, captured_status.get("headers", []), captured_status.get("exc_info")) + return captured_body + + return self._fallback(environ, start_response) + + +__all__ = ["attach_legacy_bridge", "create_app"] diff --git a/Data/Engine/tests/test_server_legacy_bridge.py b/Data/Engine/tests/test_server_legacy_bridge.py new file mode 100644 index 0000000..1c3f77f --- /dev/null +++ b/Data/Engine/tests/test_server_legacy_bridge.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import Callable, Iterable + +import pytest + +pytest.importorskip("flask") + +from Data.Engine.server import _FallbackDispatcher + + +def _wsgi_app(status: str, body: bytes) -> Callable: + def _app(environ, start_response): # type: ignore[override] + start_response(status, [("Content-Type", "text/plain"), ("Content-Length", str(len(body)))]) + return [body] + + return _app + + +def _invoke(app: Callable, path: str = "/") -> tuple[str, bytes]: + status_holder: dict[str, str] = {} + body_parts: list[bytes] = [] + + def _start_response(status: str, headers: Iterable[tuple[str, str]], exc_info=None): # type: ignore[override] + status_holder["status"] = status + return body_parts.append + + environ = {"PATH_INFO": path, "REQUEST_METHOD": "GET", "wsgi.input": None} + result = app(environ, _start_response) + try: + for chunk in result: + body_parts.append(chunk) + finally: + close = getattr(result, "close", None) + if callable(close): + close() + + return status_holder.get("status", ""), b"".join(body_parts) + + +def test_fallback_dispatcher_primary_wins() -> None: + primary = _wsgi_app("200 OK", b"engine") + fallback = _wsgi_app("200 OK", b"legacy") + dispatcher = _FallbackDispatcher(primary, fallback) + + status, body = _invoke(dispatcher) + + assert status == "200 OK" + assert body == b"engine" + + +def test_fallback_dispatcher_uses_fallback_on_404() -> None: + primary = _wsgi_app("404 Not Found", b"missing") + fallback = _wsgi_app("200 OK", b"legacy") + dispatcher = _FallbackDispatcher(primary, fallback) + + status, body = _invoke(dispatcher) + + assert status == "200 OK" + assert body == b"legacy" From e1e63ec34698cad6d267efc4a1da8f4e2ca1354e Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Wed, 22 Oct 2025 20:18:09 -0600 Subject: [PATCH 05/12] Remove legacy bridge and expose auth session endpoint --- Data/Engine/bootstrapper.py | 14 +- Data/Engine/interfaces/http/__init__.py | 6 +- Data/Engine/interfaces/http/auth.py | 30 ++++ Data/Engine/server.py | 138 +----------------- .../services/auth/operator_auth_service.py | 29 +++- Data/Engine/tests/test_http_auth.py | 120 +++++++++++++++ .../Engine/tests/test_server_legacy_bridge.py | 60 -------- 7 files changed, 185 insertions(+), 212 deletions(-) create mode 100644 Data/Engine/tests/test_http_auth.py delete mode 100644 Data/Engine/tests/test_server_legacy_bridge.py diff --git a/Data/Engine/bootstrapper.py b/Data/Engine/bootstrapper.py index 46d0194..c613504 100644 --- a/Data/Engine/bootstrapper.py +++ b/Data/Engine/bootstrapper.py @@ -18,7 +18,7 @@ from .interfaces import ( from .interfaces.eventlet_compat import apply_eventlet_patches from .repositories.sqlite import connection as sqlite_connection from .repositories.sqlite import migrations as sqlite_migrations -from .server import attach_legacy_bridge, create_app +from .server import create_app from .services.container import build_service_container from .services.crypto.certificates import ensure_certificate @@ -71,19 +71,13 @@ def bootstrap() -> EngineRuntime: logger.info("default-admin-ensured") app = create_app(settings, db_factory=db_factory) - attach_legacy_bridge(app, settings, logger=logger) services = build_service_container(settings, db_factory=db_factory, logger=logger.getChild("services")) app.extensions["engine_services"] = services register_http_interfaces(app, services) - legacy_active = bool(app.config.get("ENGINE_LEGACY_BRIDGE_ACTIVE")) - if legacy_active: - socketio = None - logger.info("legacy-ws-deferred") - else: - socketio = create_socket_server(app, settings.socketio) - register_ws_interfaces(socketio, services) - services.scheduler_service.start(socketio) + socketio = create_socket_server(app, settings.socketio) + register_ws_interfaces(socketio, services) + services.scheduler_service.start(socketio) logger.info("bootstrap-complete") return EngineRuntime( app=app, diff --git a/Data/Engine/interfaces/http/__init__.py b/Data/Engine/interfaces/http/__init__.py index c8a568a..e388b81 100644 --- a/Data/Engine/interfaces/http/__init__.py +++ b/Data/Engine/interfaces/http/__init__.py @@ -26,11 +26,7 @@ def register_http_interfaces(app: Flask, services: EngineServiceContainer) -> No The implementation is intentionally minimal for the initial scaffolding. """ - registrars = list(_REGISTRARS) - if app.config.get("ENGINE_LEGACY_BRIDGE_ACTIVE"): - registrars = [r for r in registrars if r is not job_management.register] - - for registrar in registrars: + for registrar in _REGISTRARS: registrar(app, services) diff --git a/Data/Engine/interfaces/http/auth.py b/Data/Engine/interfaces/http/auth.py index d91d9c7..1ba5bdb 100644 --- a/Data/Engine/interfaces/http/auth.py +++ b/Data/Engine/interfaces/http/auth.py @@ -90,6 +90,36 @@ def register(app: Flask, services: EngineServiceContainer) -> None: _set_auth_cookie(response, "", expires=0) return response + @bp.route("/api/auth/me", methods=["GET"]) + def me() -> Any: + service = _service(services) + + account = None + username = session.get("username") + if isinstance(username, str) and username: + account = service.fetch_account(username) + + if account is None: + token = request.cookies.get("borealis_auth", "") + if not token: + auth_header = request.headers.get("Authorization", "") + if auth_header.lower().startswith("bearer "): + token = auth_header.split(None, 1)[1] + account = service.resolve_token(token) + if account is not None: + session["username"] = account.username + session["role"] = account.role or "User" + + if account is None: + return jsonify({"error": "not_authenticated"}), 401 + + payload = { + "username": account.username, + "display_name": account.display_name or account.username, + "role": account.role, + } + return jsonify(payload) + @bp.route("/api/auth/mfa/verify", methods=["POST"]) def verify_mfa() -> Any: pending = session.get("mfa_pending") diff --git a/Data/Engine/server.py b/Data/Engine/server.py index 3cb7376..77fb8ea 100644 --- a/Data/Engine/server.py +++ b/Data/Engine/server.py @@ -2,11 +2,8 @@ from __future__ import annotations -import importlib -import logging -import os from pathlib import Path -from typing import Any, Iterable, Optional +from typing import Optional from flask import Flask, request, send_from_directory from flask_cors import CORS @@ -103,135 +100,4 @@ def create_app( return app -def attach_legacy_bridge( - app: Flask, - settings: EngineSettings, - *, - logger: Optional[logging.Logger] = None, -) -> None: - """Attach the legacy Flask application as a fallback dispatcher. - - Borealis ships a mature API surface inside ``Data/Server/server.py``. The - Engine will eventually supersede it, but during the migration the React - frontend still expects the historical endpoints to exist. This helper - attempts to load the legacy application and wires it as a fallback WSGI - dispatcher so any route the Engine does not yet implement transparently - defers to the proven implementation. - """ - - log = logger or logging.getLogger("borealis.engine.legacy") - - if not _legacy_bridge_enabled(): - log.info("legacy-bridge-disabled") - return - - legacy = _load_legacy_app(settings, app, log) - if legacy is None: - log.warning("legacy-bridge-unavailable") - return - - app.config["ENGINE_LEGACY_BRIDGE_ACTIVE"] = True - app.wsgi_app = _FallbackDispatcher(app.wsgi_app, legacy.wsgi_app) # type: ignore[assignment] - app.extensions["legacy_flask_app"] = legacy - log.info("legacy-bridge-active") - - -def _legacy_bridge_enabled() -> bool: - raw = os.getenv("BOREALIS_ENGINE_ENABLE_LEGACY_BRIDGE", "1") - return raw.strip().lower() in {"1", "true", "yes", "on"} - - -def _load_legacy_app( - settings: EngineSettings, - engine_app: Flask, - logger: logging.Logger, -) -> Optional[Flask]: - try: - legacy_module = importlib.import_module("Data.Server.server") - except Exception as exc: # pragma: no cover - defensive - logger.exception("legacy-import-failed", exc_info=exc) - return None - - legacy_app = getattr(legacy_module, "app", None) - if not isinstance(legacy_app, Flask): - logger.error("legacy-app-missing") - return None - - # Align runtime configuration so both applications share database and - # session state. - try: - setattr(legacy_module, "DB_PATH", str(settings.database_path)) - except Exception: # pragma: no cover - defensive - logger.warning("legacy-db-path-sync-failed", extra={"path": str(settings.database_path)}) - - _synchronise_session_config(engine_app, legacy_app) - - return legacy_app - - -def _synchronise_session_config(engine_app: Flask, legacy_app: Flask) -> None: - legacy_app.secret_key = engine_app.config.get("SECRET_KEY", legacy_app.secret_key) - for key in ( - "SESSION_COOKIE_HTTPONLY", - "SESSION_COOKIE_SECURE", - "SESSION_COOKIE_SAMESITE", - "SESSION_COOKIE_DOMAIN", - ): - value = engine_app.config.get(key) - if value is not None: - legacy_app.config[key] = value - - -class _FallbackDispatcher: - """WSGI dispatcher that retries a secondary app when the primary 404s.""" - - __slots__ = ("_primary", "_fallback", "_retry_statuses") - - def __init__( - self, - primary: Any, - fallback: Any, - *, - retry_statuses: Iterable[int] = (404,), - ) -> None: - self._primary = primary - self._fallback = fallback - self._retry_statuses = {int(status) for status in retry_statuses} - - def __call__(self, environ: dict[str, Any], start_response: Any) -> Iterable[bytes]: - captured_body: list[bytes] = [] - captured_status: dict[str, Any] = {} - - def _capture_start_response(status: str, headers: list[tuple[str, str]], exc_info: Any = None): - captured_status["status"] = status - captured_status["headers"] = headers - captured_status["exc_info"] = exc_info - - def _write(data: bytes) -> None: - captured_body.append(data) - - return _write - - primary_iterable = self._primary(environ, _capture_start_response) - try: - for chunk in primary_iterable: - captured_body.append(chunk) - finally: - close = getattr(primary_iterable, "close", None) - if callable(close): - close() - - status_line = str(captured_status.get("status") or "500 Internal Server Error") - try: - status_code = int(status_line.split()[0]) - except Exception: # pragma: no cover - defensive - status_code = 500 - - if status_code not in self._retry_statuses: - start_response(status_line, captured_status.get("headers", []), captured_status.get("exc_info")) - return captured_body - - return self._fallback(environ, start_response) - - -__all__ = ["attach_legacy_bridge", "create_app"] +__all__ = ["create_app"] diff --git a/Data/Engine/services/auth/operator_auth_service.py b/Data/Engine/services/auth/operator_auth_service.py index d3c1163..d9256bf 100644 --- a/Data/Engine/services/auth/operator_auth_service.py +++ b/Data/Engine/services/auth/operator_auth_service.py @@ -20,7 +20,7 @@ try: # pragma: no cover - optional dependency except Exception: # pragma: no cover - gracefully degrade when unavailable qrcode = None # type: ignore -from itsdangerous import URLSafeTimedSerializer +from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer from Data.Engine.builders.operator_auth import ( OperatorLoginRequest, @@ -119,6 +119,33 @@ class OperatorAuthService: payload = {"u": username, "r": role or "User", "ts": int(time.time())} return serializer.dumps(payload) + def resolve_token(self, token: str, *, max_age: int = 30 * 24 * 3600) -> Optional[OperatorAccount]: + """Return the account associated with *token* if it is valid.""" + + token = (token or "").strip() + if not token: + return None + + serializer = self._token_serializer() + try: + payload = serializer.loads(token, max_age=max_age) + except (BadSignature, SignatureExpired): + return None + + username = str(payload.get("u") or "").strip() + if not username: + return None + + return self._repository.fetch_by_username(username) + + def fetch_account(self, username: str) -> Optional[OperatorAccount]: + """Return the operator account for *username* if it exists.""" + + username = (username or "").strip() + if not username: + return None + return self._repository.fetch_by_username(username) + def _finalize_login(self, account: OperatorAccount) -> OperatorLoginSuccess: now = int(time.time()) self._repository.update_last_login(account.username, now) diff --git a/Data/Engine/tests/test_http_auth.py b/Data/Engine/tests/test_http_auth.py new file mode 100644 index 0000000..b5fc39c --- /dev/null +++ b/Data/Engine/tests/test_http_auth.py @@ -0,0 +1,120 @@ +import hashlib +from pathlib import Path + +import pytest + +pytest.importorskip("flask") + +from Data.Engine.config.environment import ( + DatabaseSettings, + EngineSettings, + FlaskSettings, + GitHubSettings, + ServerSettings, + SocketIOSettings, +) +from Data.Engine.interfaces.http import register_http_interfaces +from Data.Engine.repositories.sqlite import connection as sqlite_connection +from Data.Engine.repositories.sqlite import migrations as sqlite_migrations +from Data.Engine.server import create_app +from Data.Engine.services.container import build_service_container + + +@pytest.fixture() +def engine_settings(tmp_path: Path) -> EngineSettings: + project_root = tmp_path + static_root = project_root / "static" + static_root.mkdir() + (static_root / "index.html").write_text("", encoding="utf-8") + + database_path = project_root / "database.db" + + return EngineSettings( + project_root=project_root, + debug=False, + database=DatabaseSettings(path=database_path, apply_migrations=False), + flask=FlaskSettings( + secret_key="test-key", + static_root=static_root, + cors_allowed_origins=("https://localhost",), + ), + socketio=SocketIOSettings(cors_allowed_origins=("https://localhost",)), + server=ServerSettings(host="127.0.0.1", port=5000), + github=GitHubSettings( + default_repo="owner/repo", + default_branch="main", + refresh_interval_seconds=60, + cache_root=project_root / "cache", + ), + ) + + +@pytest.fixture() +def prepared_app(engine_settings: EngineSettings): + settings = engine_settings + settings.github.cache_root.mkdir(exist_ok=True, parents=True) + + db_factory = sqlite_connection.connection_factory(settings.database.path) + with sqlite_connection.connection_scope(settings.database.path) as conn: + sqlite_migrations.apply_all(conn) + + app = create_app(settings, db_factory=db_factory) + services = build_service_container(settings, db_factory=db_factory) + app.extensions["engine_services"] = services + register_http_interfaces(app, services) + app.config.update(TESTING=True) + return app + + +def _login(client) -> dict: + payload = { + "username": "admin", + "password_sha512": hashlib.sha512("Password".encode()).hexdigest(), + } + resp = client.post("/api/auth/login", json=payload) + assert resp.status_code == 200 + data = resp.get_json() + assert isinstance(data, dict) + return data + + +def test_auth_me_returns_session_user(prepared_app): + client = prepared_app.test_client() + + _login(client) + resp = client.get("/api/auth/me") + assert resp.status_code == 200 + body = resp.get_json() + assert body == { + "username": "admin", + "display_name": "admin", + "role": "Admin", + } + + +def test_auth_me_uses_token_when_session_missing(prepared_app): + client = prepared_app.test_client() + login_data = _login(client) + token = login_data.get("token") + assert token + + # New client without session + other_client = prepared_app.test_client() + other_client.set_cookie(server_name="localhost", key="borealis_auth", value=token) + + resp = other_client.get("/api/auth/me") + assert resp.status_code == 200 + body = resp.get_json() + assert body == { + "username": "admin", + "display_name": "admin", + "role": "Admin", + } + + +def test_auth_me_requires_authentication(prepared_app): + client = prepared_app.test_client() + resp = client.get("/api/auth/me") + assert resp.status_code == 401 + body = resp.get_json() + assert body == {"error": "not_authenticated"} diff --git a/Data/Engine/tests/test_server_legacy_bridge.py b/Data/Engine/tests/test_server_legacy_bridge.py deleted file mode 100644 index 1c3f77f..0000000 --- a/Data/Engine/tests/test_server_legacy_bridge.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations - -from typing import Callable, Iterable - -import pytest - -pytest.importorskip("flask") - -from Data.Engine.server import _FallbackDispatcher - - -def _wsgi_app(status: str, body: bytes) -> Callable: - def _app(environ, start_response): # type: ignore[override] - start_response(status, [("Content-Type", "text/plain"), ("Content-Length", str(len(body)))]) - return [body] - - return _app - - -def _invoke(app: Callable, path: str = "/") -> tuple[str, bytes]: - status_holder: dict[str, str] = {} - body_parts: list[bytes] = [] - - def _start_response(status: str, headers: Iterable[tuple[str, str]], exc_info=None): # type: ignore[override] - status_holder["status"] = status - return body_parts.append - - environ = {"PATH_INFO": path, "REQUEST_METHOD": "GET", "wsgi.input": None} - result = app(environ, _start_response) - try: - for chunk in result: - body_parts.append(chunk) - finally: - close = getattr(result, "close", None) - if callable(close): - close() - - return status_holder.get("status", ""), b"".join(body_parts) - - -def test_fallback_dispatcher_primary_wins() -> None: - primary = _wsgi_app("200 OK", b"engine") - fallback = _wsgi_app("200 OK", b"legacy") - dispatcher = _FallbackDispatcher(primary, fallback) - - status, body = _invoke(dispatcher) - - assert status == "200 OK" - assert body == b"engine" - - -def test_fallback_dispatcher_uses_fallback_on_404() -> None: - primary = _wsgi_app("404 Not Found", b"missing") - fallback = _wsgi_app("200 OK", b"legacy") - dispatcher = _FallbackDispatcher(primary, fallback) - - status, body = _invoke(dispatcher) - - assert status == "200 OK" - assert body == b"legacy" From b8e3ea2a6272c013723ae630884da7d555dd364b Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Wed, 22 Oct 2025 20:57:09 -0600 Subject: [PATCH 06/12] Add operator account management API --- Data/Engine/interfaces/http/__init__.py | 3 +- Data/Engine/interfaces/http/users.py | 185 +++++++++++++++ .../repositories/sqlite/user_repository.py | 190 +++++++++++++++- Data/Engine/services/auth/__init__.py | 22 ++ .../services/auth/operator_account_service.py | 211 ++++++++++++++++++ Data/Engine/services/container.py | 7 + Data/Engine/tests/test_http_auth.py | 1 + Data/Engine/tests/test_http_users.py | 120 ++++++++++ .../tests/test_operator_account_service.py | 191 ++++++++++++++++ 9 files changed, 917 insertions(+), 13 deletions(-) create mode 100644 Data/Engine/interfaces/http/users.py create mode 100644 Data/Engine/services/auth/operator_account_service.py create mode 100644 Data/Engine/tests/test_http_users.py create mode 100644 Data/Engine/tests/test_operator_account_service.py diff --git a/Data/Engine/interfaces/http/__init__.py b/Data/Engine/interfaces/http/__init__.py index e388b81..fc88e26 100644 --- a/Data/Engine/interfaces/http/__init__.py +++ b/Data/Engine/interfaces/http/__init__.py @@ -6,7 +6,7 @@ from flask import Flask from Data.Engine.services.container import EngineServiceContainer -from . import admin, agents, auth, enrollment, github, health, job_management, tokens +from . import admin, agents, auth, enrollment, github, health, job_management, tokens, users _REGISTRARS = ( health.register, @@ -17,6 +17,7 @@ _REGISTRARS = ( github.register, auth.register, admin.register, + users.register, ) diff --git a/Data/Engine/interfaces/http/users.py b/Data/Engine/interfaces/http/users.py new file mode 100644 index 0000000..ae1d350 --- /dev/null +++ b/Data/Engine/interfaces/http/users.py @@ -0,0 +1,185 @@ +"""HTTP endpoints for operator account management.""" + +from __future__ import annotations + +from flask import Blueprint, Flask, jsonify, request, session + +from Data.Engine.services.auth import ( + AccountNotFoundError, + CannotModifySelfError, + InvalidPasswordHashError, + InvalidRoleError, + LastAdminError, + LastUserError, + OperatorAccountService, + UsernameAlreadyExistsError, +) +from Data.Engine.services.container import EngineServiceContainer + +blueprint = Blueprint("engine_users", __name__) + + +def register(app: Flask, services: EngineServiceContainer) -> None: + blueprint.services = services # type: ignore[attr-defined] + app.register_blueprint(blueprint) + + +def _services() -> EngineServiceContainer: + svc = getattr(blueprint, "services", None) + if svc is None: # pragma: no cover - defensive + raise RuntimeError("user blueprint not initialized") + return svc + + +def _accounts() -> OperatorAccountService: + return _services().operator_account_service + + +def _require_admin(): + username = session.get("username") + role = (session.get("role") or "").strip().lower() + if not isinstance(username, str) or not username: + return jsonify({"error": "not_authenticated"}), 401 + if role != "admin": + return jsonify({"error": "forbidden"}), 403 + return None + + +def _format_user(record) -> dict[str, object]: + return { + "username": record.username, + "display_name": record.display_name, + "role": record.role, + "last_login": record.last_login, + "created_at": record.created_at, + "updated_at": record.updated_at, + "mfa_enabled": 1 if record.mfa_enabled else 0, + } + + +@blueprint.route("/api/users", methods=["GET"]) +def list_users() -> object: + guard = _require_admin() + if guard: + return guard + + records = _accounts().list_accounts() + return jsonify({"users": [_format_user(record) for record in records]}) + + +@blueprint.route("/api/users", methods=["POST"]) +def create_user() -> object: + guard = _require_admin() + if guard: + return guard + + payload = request.get_json(silent=True) or {} + username = str(payload.get("username") or "").strip() + password_sha512 = str(payload.get("password_sha512") or "").strip() + role = str(payload.get("role") or "User") + display_name = str(payload.get("display_name") or username) + + try: + _accounts().create_account( + username=username, + password_sha512=password_sha512, + role=role, + display_name=display_name, + ) + except UsernameAlreadyExistsError as exc: + return jsonify({"error": str(exc)}), 409 + except (InvalidPasswordHashError, InvalidRoleError) as exc: + return jsonify({"error": str(exc)}), 400 + + return jsonify({"status": "ok"}) + + +@blueprint.route("/api/users/", methods=["DELETE"]) +def delete_user(username: str) -> object: + guard = _require_admin() + if guard: + return guard + + actor = session.get("username") if isinstance(session.get("username"), str) else None + + try: + _accounts().delete_account(username, actor=actor) + except CannotModifySelfError as exc: + return jsonify({"error": str(exc)}), 400 + except LastUserError as exc: + return jsonify({"error": str(exc)}), 400 + except LastAdminError as exc: + return jsonify({"error": str(exc)}), 400 + except AccountNotFoundError as exc: + return jsonify({"error": str(exc)}), 404 + + return jsonify({"status": "ok"}) + + +@blueprint.route("/api/users//reset_password", methods=["POST"]) +def reset_password(username: str) -> object: + guard = _require_admin() + if guard: + return guard + + payload = request.get_json(silent=True) or {} + password_sha512 = str(payload.get("password_sha512") or "").strip() + + try: + _accounts().reset_password(username, password_sha512) + except InvalidPasswordHashError as exc: + return jsonify({"error": str(exc)}), 400 + except AccountNotFoundError as exc: + return jsonify({"error": str(exc)}), 404 + + return jsonify({"status": "ok"}) + + +@blueprint.route("/api/users//role", methods=["POST"]) +def change_role(username: str) -> object: + guard = _require_admin() + if guard: + return guard + + payload = request.get_json(silent=True) or {} + role = str(payload.get("role") or "").strip() + actor = session.get("username") if isinstance(session.get("username"), str) else None + + try: + record = _accounts().change_role(username, role, actor=actor) + except InvalidRoleError as exc: + return jsonify({"error": str(exc)}), 400 + except LastAdminError as exc: + return jsonify({"error": str(exc)}), 400 + except AccountNotFoundError as exc: + return jsonify({"error": str(exc)}), 404 + + if actor and actor.strip().lower() == username.strip().lower(): + session["role"] = record.role + + return jsonify({"status": "ok"}) + + +@blueprint.route("/api/users//mfa", methods=["POST"]) +def update_mfa(username: str) -> object: + guard = _require_admin() + if guard: + return guard + + payload = request.get_json(silent=True) or {} + enabled = bool(payload.get("enabled", False)) + reset_secret = bool(payload.get("reset_secret", False)) + + try: + _accounts().update_mfa(username, enabled=enabled, reset_secret=reset_secret) + except AccountNotFoundError as exc: + return jsonify({"error": str(exc)}), 404 + + actor = session.get("username") if isinstance(session.get("username"), str) else None + if actor and actor.strip().lower() == username.strip().lower() and not enabled: + session.pop("mfa_pending", None) + + return jsonify({"status": "ok"}) + + +__all__ = ["register", "blueprint"] diff --git a/Data/Engine/repositories/sqlite/user_repository.py b/Data/Engine/repositories/sqlite/user_repository.py index 14708e5..9c61a4d 100644 --- a/Data/Engine/repositories/sqlite/user_repository.py +++ b/Data/Engine/repositories/sqlite/user_repository.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging import sqlite3 from dataclasses import dataclass -from typing import Optional +from typing import Iterable, Optional from Data.Engine.domain import OperatorAccount @@ -64,23 +64,175 @@ class SQLiteUserRepository: if not row: return None record = _UserRow(*row) - return OperatorAccount( - username=record.username, - display_name=record.display_name or record.username, - password_sha512=(record.password_sha512 or "").lower(), - role=record.role or "User", - last_login=int(record.last_login or 0), - created_at=int(record.created_at or 0), - updated_at=int(record.updated_at or 0), - mfa_enabled=bool(record.mfa_enabled), - mfa_secret=(record.mfa_secret or "") or None, - ) + return _row_to_account(record) except sqlite3.Error as exc: # pragma: no cover - defensive self._log.error("failed to load user %s: %s", username, exc) return None finally: conn.close() + def list_accounts(self) -> list[OperatorAccount]: + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute( + """ + SELECT + id, + username, + display_name, + COALESCE(password_sha512, '') as password_sha512, + COALESCE(role, 'User') as role, + COALESCE(last_login, 0) as last_login, + COALESCE(created_at, 0) as created_at, + COALESCE(updated_at, 0) as updated_at, + COALESCE(mfa_enabled, 0) as mfa_enabled, + COALESCE(mfa_secret, '') as mfa_secret + FROM users + ORDER BY LOWER(username) ASC + """ + ) + rows = [_UserRow(*row) for row in cur.fetchall()] + return [_row_to_account(row) for row in rows] + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.error("failed to enumerate users: %s", exc) + return [] + finally: + conn.close() + + def create_account( + self, + *, + username: str, + display_name: str, + password_sha512: str, + role: str, + timestamp: int, + ) -> None: + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO users ( + username, + display_name, + password_sha512, + role, + created_at, + updated_at + ) VALUES (?, ?, ?, ?, ?, ?) + """, + (username, display_name, password_sha512, role, timestamp, timestamp), + ) + conn.commit() + finally: + conn.close() + + def delete_account(self, username: str) -> bool: + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute("DELETE FROM users WHERE LOWER(username) = LOWER(?)", (username,)) + deleted = cur.rowcount > 0 + conn.commit() + return deleted + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.error("failed to delete user %s: %s", username, exc) + return False + finally: + conn.close() + + def update_password(self, username: str, password_sha512: str, *, timestamp: int) -> bool: + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute( + """ + UPDATE users + SET password_sha512 = ?, + updated_at = ? + WHERE LOWER(username) = LOWER(?) + """, + (password_sha512, timestamp, username), + ) + conn.commit() + return cur.rowcount > 0 + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.error("failed to update password for %s: %s", username, exc) + return False + finally: + conn.close() + + def update_role(self, username: str, role: str, *, timestamp: int) -> bool: + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute( + """ + UPDATE users + SET role = ?, + updated_at = ? + WHERE LOWER(username) = LOWER(?) + """, + (role, timestamp, username), + ) + conn.commit() + return cur.rowcount > 0 + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.error("failed to update role for %s: %s", username, exc) + return False + finally: + conn.close() + + def update_mfa( + self, + username: str, + *, + enabled: bool, + reset_secret: bool, + timestamp: int, + ) -> bool: + conn = self._connection_factory() + try: + cur = conn.cursor() + secret_clause = "mfa_secret = NULL" if reset_secret else None + assignments: list[str] = ["mfa_enabled = ?", "updated_at = ?"] + params: list[object] = [1 if enabled else 0, timestamp] + if secret_clause is not None: + assignments.append(secret_clause) + query = "UPDATE users SET " + ", ".join(assignments) + " WHERE LOWER(username) = LOWER(?)" + params.append(username) + cur.execute(query, tuple(params)) + conn.commit() + return cur.rowcount > 0 + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.error("failed to update MFA for %s: %s", username, exc) + return False + finally: + conn.close() + + def count_accounts(self) -> int: + return self._scalar("SELECT COUNT(*) FROM users", ()) + + def count_admins(self) -> int: + return self._scalar("SELECT COUNT(*) FROM users WHERE LOWER(role) = 'admin'", ()) + + def _scalar(self, query: str, params: Iterable[object]) -> int: + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute(query, tuple(params)) + row = cur.fetchone() + if not row: + return 0 + return int(row[0] or 0) + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.error("scalar query failed: %s", exc) + return 0 + finally: + conn.close() + def update_last_login(self, username: str, timestamp: int) -> None: conn = self._connection_factory() try: @@ -121,3 +273,17 @@ class SQLiteUserRepository: __all__ = ["SQLiteUserRepository"] + + +def _row_to_account(record: _UserRow) -> OperatorAccount: + return OperatorAccount( + username=record.username, + display_name=record.display_name or record.username, + password_sha512=(record.password_sha512 or "").lower(), + role=record.role or "User", + last_login=int(record.last_login or 0), + created_at=int(record.created_at or 0), + updated_at=int(record.updated_at or 0), + mfa_enabled=bool(record.mfa_enabled), + mfa_secret=(record.mfa_secret or "") or None, + ) diff --git a/Data/Engine/services/auth/__init__.py b/Data/Engine/services/auth/__init__.py index 98e66cd..103eb65 100644 --- a/Data/Engine/services/auth/__init__.py +++ b/Data/Engine/services/auth/__init__.py @@ -11,6 +11,18 @@ from .token_service import ( TokenRefreshErrorCode, TokenService, ) +from .operator_account_service import ( + AccountNotFoundError, + CannotModifySelfError, + InvalidPasswordHashError, + InvalidRoleError, + LastAdminError, + LastUserError, + OperatorAccountError, + OperatorAccountRecord, + OperatorAccountService, + UsernameAlreadyExistsError, +) from .operator_auth_service import ( InvalidCredentialsError, InvalidMFACodeError, @@ -32,6 +44,16 @@ __all__ = [ "TokenRefreshError", "TokenRefreshErrorCode", "TokenService", + "OperatorAccountService", + "OperatorAccountError", + "OperatorAccountRecord", + "UsernameAlreadyExistsError", + "AccountNotFoundError", + "LastAdminError", + "LastUserError", + "CannotModifySelfError", + "InvalidRoleError", + "InvalidPasswordHashError", "OperatorAuthService", "OperatorAuthError", "InvalidCredentialsError", diff --git a/Data/Engine/services/auth/operator_account_service.py b/Data/Engine/services/auth/operator_account_service.py new file mode 100644 index 0000000..b93d27f --- /dev/null +++ b/Data/Engine/services/auth/operator_account_service.py @@ -0,0 +1,211 @@ +"""Operator account management service.""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from typing import Optional + +from Data.Engine.domain import OperatorAccount +from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository + + +class OperatorAccountError(Exception): + """Base class for operator account management failures.""" + + +class UsernameAlreadyExistsError(OperatorAccountError): + """Raised when attempting to create an operator with a duplicate username.""" + + +class AccountNotFoundError(OperatorAccountError): + """Raised when the requested operator account cannot be located.""" + + +class LastAdminError(OperatorAccountError): + """Raised when attempting to demote or delete the last remaining admin.""" + + +class LastUserError(OperatorAccountError): + """Raised when attempting to delete the final operator account.""" + + +class CannotModifySelfError(OperatorAccountError): + """Raised when the caller attempts to delete themselves.""" + + +class InvalidRoleError(OperatorAccountError): + """Raised when a role value is invalid.""" + + +class InvalidPasswordHashError(OperatorAccountError): + """Raised when a password hash is malformed.""" + + +@dataclass(frozen=True, slots=True) +class OperatorAccountRecord: + username: str + display_name: str + role: str + last_login: int + created_at: int + updated_at: int + mfa_enabled: bool + + +class OperatorAccountService: + """High-level operations for managing operator accounts.""" + + def __init__( + self, + repository: SQLiteUserRepository, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._repository = repository + self._log = logger or logging.getLogger("borealis.engine.services.operator_accounts") + + def list_accounts(self) -> list[OperatorAccountRecord]: + return [_to_record(account) for account in self._repository.list_accounts()] + + def create_account( + self, + *, + username: str, + password_sha512: str, + role: str, + display_name: Optional[str] = None, + ) -> OperatorAccountRecord: + normalized_role = self._normalize_role(role) + username = (username or "").strip() + password_sha512 = (password_sha512 or "").strip().lower() + display_name = (display_name or username or "").strip() + + if not username or not password_sha512: + raise InvalidPasswordHashError("username and password are required") + if len(password_sha512) != 128: + raise InvalidPasswordHashError("password hash must be 128 hex characters") + + now = int(time.time()) + try: + self._repository.create_account( + username=username, + display_name=display_name or username, + password_sha512=password_sha512, + role=normalized_role, + timestamp=now, + ) + except Exception as exc: # pragma: no cover - sqlite integrity errors are deterministic + import sqlite3 + + if isinstance(exc, sqlite3.IntegrityError): + raise UsernameAlreadyExistsError("username already exists") from exc + raise + + account = self._repository.fetch_by_username(username) + if not account: # pragma: no cover - sanity guard + raise AccountNotFoundError("account creation failed") + return _to_record(account) + + def delete_account(self, username: str, *, actor: Optional[str] = None) -> None: + username = (username or "").strip() + if not username: + raise AccountNotFoundError("invalid username") + + if actor and actor.strip().lower() == username.lower(): + raise CannotModifySelfError("cannot delete yourself") + + total_accounts = self._repository.count_accounts() + if total_accounts <= 1: + raise LastUserError("cannot delete the last user") + + target = self._repository.fetch_by_username(username) + if not target: + raise AccountNotFoundError("user not found") + + if target.role.lower() == "admin" and self._repository.count_admins() <= 1: + raise LastAdminError("cannot delete the last admin") + + if not self._repository.delete_account(username): + raise AccountNotFoundError("user not found") + + def reset_password(self, username: str, password_sha512: str) -> None: + username = (username or "").strip() + password_sha512 = (password_sha512 or "").strip().lower() + if len(password_sha512) != 128: + raise InvalidPasswordHashError("invalid password hash") + + now = int(time.time()) + if not self._repository.update_password(username, password_sha512, timestamp=now): + raise AccountNotFoundError("user not found") + + def change_role(self, username: str, role: str, *, actor: Optional[str] = None) -> OperatorAccountRecord: + username = (username or "").strip() + normalized_role = self._normalize_role(role) + + account = self._repository.fetch_by_username(username) + if not account: + raise AccountNotFoundError("user not found") + + if account.role.lower() == "admin" and normalized_role.lower() != "admin": + if self._repository.count_admins() <= 1: + raise LastAdminError("cannot demote the last admin") + + now = int(time.time()) + if not self._repository.update_role(username, normalized_role, timestamp=now): + raise AccountNotFoundError("user not found") + + updated = self._repository.fetch_by_username(username) + if not updated: # pragma: no cover - guard + raise AccountNotFoundError("user not found") + + record = _to_record(updated) + if actor and actor.strip().lower() == username.lower(): + self._log.info("actor-role-updated", extra={"username": username, "role": record.role}) + return record + + def update_mfa(self, username: str, *, enabled: bool, reset_secret: bool) -> None: + username = (username or "").strip() + if not username: + raise AccountNotFoundError("invalid username") + + now = int(time.time()) + if not self._repository.update_mfa(username, enabled=enabled, reset_secret=reset_secret, timestamp=now): + raise AccountNotFoundError("user not found") + + def fetch_account(self, username: str) -> Optional[OperatorAccountRecord]: + account = self._repository.fetch_by_username(username) + return _to_record(account) if account else None + + def _normalize_role(self, role: str) -> str: + normalized = (role or "").strip().title() or "User" + if normalized not in {"User", "Admin"}: + raise InvalidRoleError("invalid role") + return normalized + + +def _to_record(account: OperatorAccount) -> OperatorAccountRecord: + return OperatorAccountRecord( + username=account.username, + display_name=account.display_name or account.username, + role=account.role or "User", + last_login=int(account.last_login or 0), + created_at=int(account.created_at or 0), + updated_at=int(account.updated_at or 0), + mfa_enabled=bool(account.mfa_enabled), + ) + + +__all__ = [ + "OperatorAccountService", + "OperatorAccountError", + "UsernameAlreadyExistsError", + "AccountNotFoundError", + "LastAdminError", + "LastUserError", + "CannotModifySelfError", + "InvalidRoleError", + "InvalidPasswordHashError", + "OperatorAccountRecord", +] diff --git a/Data/Engine/services/container.py b/Data/Engine/services/container.py index 714b686..621e02a 100644 --- a/Data/Engine/services/container.py +++ b/Data/Engine/services/container.py @@ -22,6 +22,7 @@ from Data.Engine.repositories.sqlite import ( from Data.Engine.services.auth import ( DeviceAuthService, DPoPValidator, + OperatorAccountService, OperatorAuthService, JWTService, TokenService, @@ -49,6 +50,7 @@ class EngineServiceContainer: scheduler_service: SchedulerService github_service: GitHubService operator_auth_service: OperatorAuthService + operator_account_service: OperatorAccountService def build_service_container( @@ -114,6 +116,10 @@ def build_service_container( repository=user_repo, logger=log.getChild("operator_auth"), ) + operator_account_service = OperatorAccountService( + repository=user_repo, + logger=log.getChild("operator_accounts"), + ) github_provider = GitHubArtifactProvider( cache_file=settings.github.cache_file, @@ -139,6 +145,7 @@ def build_service_container( scheduler_service=scheduler_service, github_service=github_service, operator_auth_service=operator_auth_service, + operator_account_service=operator_account_service, ) diff --git a/Data/Engine/tests/test_http_auth.py b/Data/Engine/tests/test_http_auth.py index b5fc39c..2f811c6 100644 --- a/Data/Engine/tests/test_http_auth.py +++ b/Data/Engine/tests/test_http_auth.py @@ -4,6 +4,7 @@ from pathlib import Path import pytest pytest.importorskip("flask") +pytest.importorskip("jwt") from Data.Engine.config.environment import ( DatabaseSettings, diff --git a/Data/Engine/tests/test_http_users.py b/Data/Engine/tests/test_http_users.py new file mode 100644 index 0000000..e30fadb --- /dev/null +++ b/Data/Engine/tests/test_http_users.py @@ -0,0 +1,120 @@ +"""HTTP integration tests for operator account endpoints.""" + +from __future__ import annotations + +import hashlib + +from .test_http_auth import _login, prepared_app + + +def test_list_users_requires_authentication(prepared_app): + client = prepared_app.test_client() + resp = client.get("/api/users") + assert resp.status_code == 401 + + +def test_list_users_returns_accounts(prepared_app): + client = prepared_app.test_client() + _login(client) + + resp = client.get("/api/users") + assert resp.status_code == 200 + payload = resp.get_json() + assert isinstance(payload, dict) + assert "users" in payload + assert any(user["username"] == "admin" for user in payload["users"]) + + +def test_create_user_validates_payload(prepared_app): + client = prepared_app.test_client() + _login(client) + + resp = client.post("/api/users", json={"username": "bob"}) + assert resp.status_code == 400 + + payload = { + "username": "bob", + "password_sha512": hashlib.sha512(b"pw").hexdigest(), + "role": "User", + } + resp = client.post("/api/users", json=payload) + assert resp.status_code == 200 + + # Duplicate username should conflict + resp = client.post("/api/users", json=payload) + assert resp.status_code == 409 + + +def test_delete_user_handles_edge_cases(prepared_app): + client = prepared_app.test_client() + _login(client) + + # cannot delete the only user + resp = client.delete("/api/users/admin") + assert resp.status_code == 400 + + # create another user then delete them successfully + payload = { + "username": "alice", + "password_sha512": hashlib.sha512(b"pw").hexdigest(), + "role": "User", + } + client.post("/api/users", json=payload) + + resp = client.delete("/api/users/alice") + assert resp.status_code == 200 + + +def test_delete_user_prevents_self_deletion(prepared_app): + client = prepared_app.test_client() + _login(client) + + payload = { + "username": "charlie", + "password_sha512": hashlib.sha512(b"pw").hexdigest(), + "role": "User", + } + client.post("/api/users", json=payload) + + resp = client.delete("/api/users/admin") + assert resp.status_code == 400 + + +def test_change_role_updates_session(prepared_app): + client = prepared_app.test_client() + _login(client) + + payload = { + "username": "backup", + "password_sha512": hashlib.sha512(b"pw").hexdigest(), + "role": "Admin", + } + client.post("/api/users", json=payload) + + resp = client.post("/api/users/backup/role", json={"role": "User"}) + assert resp.status_code == 200 + + resp = client.post("/api/users/admin/role", json={"role": "User"}) + assert resp.status_code == 400 + + +def test_reset_password_requires_valid_hash(prepared_app): + client = prepared_app.test_client() + _login(client) + + resp = client.post("/api/users/admin/reset_password", json={"password_sha512": "abc"}) + assert resp.status_code == 400 + + resp = client.post( + "/api/users/admin/reset_password", + json={"password_sha512": hashlib.sha512(b"new").hexdigest()}, + ) + assert resp.status_code == 200 + + +def test_update_mfa_returns_not_found_for_unknown_user(prepared_app): + client = prepared_app.test_client() + _login(client) + + resp = client.post("/api/users/missing/mfa", json={"enabled": True}) + assert resp.status_code == 404 diff --git a/Data/Engine/tests/test_operator_account_service.py b/Data/Engine/tests/test_operator_account_service.py new file mode 100644 index 0000000..1f0118a --- /dev/null +++ b/Data/Engine/tests/test_operator_account_service.py @@ -0,0 +1,191 @@ +"""Tests for the operator account management service.""" + +from __future__ import annotations + +import hashlib +import sqlite3 +from pathlib import Path +from typing import Callable + +import pytest + +pytest.importorskip("jwt") + +from Data.Engine.repositories.sqlite.connection import connection_factory +from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository +from Data.Engine.services.auth.operator_account_service import ( + AccountNotFoundError, + CannotModifySelfError, + InvalidPasswordHashError, + InvalidRoleError, + LastAdminError, + LastUserError, + OperatorAccountService, + UsernameAlreadyExistsError, +) + + +def _prepare_db(path: Path) -> Callable[[], sqlite3.Connection]: + conn = sqlite3.connect(path) + conn.execute( + """ + CREATE TABLE users ( + id TEXT PRIMARY KEY, + username TEXT UNIQUE, + display_name TEXT, + password_sha512 TEXT, + role TEXT, + last_login INTEGER, + created_at INTEGER, + updated_at INTEGER, + mfa_enabled INTEGER, + mfa_secret TEXT + ) + """ + ) + conn.commit() + conn.close() + return connection_factory(path) + + +def _insert_user( + factory: Callable[[], sqlite3.Connection], + *, + user_id: str, + username: str, + password_hash: str, + role: str = "Admin", + mfa_enabled: int = 0, + mfa_secret: str = "", +) -> None: + conn = factory() + conn.execute( + """ + INSERT INTO users ( + id, username, display_name, password_sha512, role, + last_login, created_at, updated_at, mfa_enabled, mfa_secret + ) VALUES (?, ?, ?, ?, ?, 0, 0, 0, ?, ?) + """, + (user_id, username, username, password_hash, role, mfa_enabled, mfa_secret), + ) + conn.commit() + conn.close() + + +def _service(factory: Callable[[], sqlite3.Connection]) -> OperatorAccountService: + repo = SQLiteUserRepository(factory) + return OperatorAccountService(repo) + + +def test_list_accounts_returns_users(tmp_path): + db = tmp_path / "users.db" + factory = _prepare_db(db) + password_hash = hashlib.sha512(b"password").hexdigest() + _insert_user(factory, user_id="1", username="admin", password_hash=password_hash) + + service = _service(factory) + records = service.list_accounts() + + assert len(records) == 1 + assert records[0].username == "admin" + assert records[0].role == "Admin" + + +def test_create_account_enforces_uniqueness(tmp_path): + db = tmp_path / "users.db" + factory = _prepare_db(db) + service = _service(factory) + password_hash = hashlib.sha512(b"pw").hexdigest() + + service.create_account(username="admin", password_sha512=password_hash, role="Admin") + + with pytest.raises(UsernameAlreadyExistsError): + service.create_account(username="admin", password_sha512=password_hash, role="Admin") + + +def test_create_account_validates_password_hash(tmp_path): + db = tmp_path / "users.db" + factory = _prepare_db(db) + service = _service(factory) + + with pytest.raises(InvalidPasswordHashError): + service.create_account(username="user", password_sha512="abc", role="User") + + +def test_delete_account_protects_last_user(tmp_path): + db = tmp_path / "users.db" + factory = _prepare_db(db) + password_hash = hashlib.sha512(b"pw").hexdigest() + _insert_user(factory, user_id="1", username="admin", password_hash=password_hash) + + service = _service(factory) + + with pytest.raises(LastUserError): + service.delete_account("admin") + + +def test_delete_account_prevents_self_deletion(tmp_path): + db = tmp_path / "users.db" + factory = _prepare_db(db) + password_hash = hashlib.sha512(b"pw").hexdigest() + _insert_user(factory, user_id="1", username="admin", password_hash=password_hash) + _insert_user(factory, user_id="2", username="user", password_hash=password_hash, role="User") + + service = _service(factory) + + with pytest.raises(CannotModifySelfError): + service.delete_account("admin", actor="admin") + + +def test_delete_account_prevents_last_admin_removal(tmp_path): + db = tmp_path / "users.db" + factory = _prepare_db(db) + password_hash = hashlib.sha512(b"pw").hexdigest() + _insert_user(factory, user_id="1", username="admin", password_hash=password_hash) + _insert_user(factory, user_id="2", username="user", password_hash=password_hash, role="User") + + service = _service(factory) + + with pytest.raises(LastAdminError): + service.delete_account("admin") + + +def test_change_role_demotes_only_when_valid(tmp_path): + db = tmp_path / "users.db" + factory = _prepare_db(db) + password_hash = hashlib.sha512(b"pw").hexdigest() + _insert_user(factory, user_id="1", username="admin", password_hash=password_hash) + _insert_user(factory, user_id="2", username="backup", password_hash=password_hash) + + service = _service(factory) + service.change_role("backup", "User") + + with pytest.raises(LastAdminError): + service.change_role("admin", "User") + + with pytest.raises(InvalidRoleError): + service.change_role("admin", "invalid") + + +def test_reset_password_validates_hash(tmp_path): + db = tmp_path / "users.db" + factory = _prepare_db(db) + password_hash = hashlib.sha512(b"pw").hexdigest() + _insert_user(factory, user_id="1", username="admin", password_hash=password_hash) + + service = _service(factory) + + with pytest.raises(InvalidPasswordHashError): + service.reset_password("admin", "abc") + + new_hash = hashlib.sha512(b"new").hexdigest() + service.reset_password("admin", new_hash) + + +def test_update_mfa_raises_for_unknown_user(tmp_path): + db = tmp_path / "users.db" + factory = _prepare_db(db) + service = _service(factory) + + with pytest.raises(AccountNotFoundError): + service.update_mfa("missing", enabled=True, reset_secret=False) From d0fa6929b23211d9bb8c2cd13ca8ba0ba6c988fb Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Wed, 22 Oct 2025 23:26:06 -0600 Subject: [PATCH 07/12] Implement admin enrollment APIs --- Data/Engine/domain/device_auth.py | 7 + Data/Engine/domain/enrollment_admin.py | 206 +++++++++++ Data/Engine/interfaces/http/admin.py | 105 +++++- .../sqlite/enrollment_repository.py | 347 +++++++++++++++++- Data/Engine/repositories/sqlite/migrations.py | 70 ++++ .../repositories/sqlite/user_repository.py | 51 +++ Data/Engine/services/__init__.py | 5 + Data/Engine/services/container.py | 9 + Data/Engine/services/enrollment/__init__.py | 54 ++- .../services/enrollment/admin_service.py | 113 ++++++ .../tests/test_enrollment_admin_service.py | 122 ++++++ Data/Engine/tests/test_http_admin.py | 111 ++++++ 12 files changed, 1182 insertions(+), 18 deletions(-) create mode 100644 Data/Engine/domain/enrollment_admin.py create mode 100644 Data/Engine/services/enrollment/admin_service.py create mode 100644 Data/Engine/tests/test_enrollment_admin_service.py create mode 100644 Data/Engine/tests/test_http_admin.py diff --git a/Data/Engine/domain/device_auth.py b/Data/Engine/domain/device_auth.py index d377e52..b4d0c52 100644 --- a/Data/Engine/domain/device_auth.py +++ b/Data/Engine/domain/device_auth.py @@ -18,6 +18,7 @@ __all__ = [ "AccessTokenClaims", "DeviceAuthContext", "sanitize_service_context", + "normalize_guid", ] @@ -73,6 +74,12 @@ class DeviceGuid: return self.value +def normalize_guid(value: Optional[str]) -> str: + """Expose GUID normalization for administrative helpers.""" + + return _normalize_guid(value) + + @dataclass(frozen=True, slots=True) class DeviceFingerprint: """Normalized TLS key fingerprint associated with a device.""" diff --git a/Data/Engine/domain/enrollment_admin.py b/Data/Engine/domain/enrollment_admin.py new file mode 100644 index 0000000..8b5f32e --- /dev/null +++ b/Data/Engine/domain/enrollment_admin.py @@ -0,0 +1,206 @@ +"""Administrative enrollment domain models.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Mapping, Optional + +from Data.Engine.domain.device_auth import DeviceGuid, normalize_guid + +__all__ = [ + "EnrollmentCodeRecord", + "DeviceApprovalRecord", + "HostnameConflict", +] + + +def _parse_iso8601(value: Optional[str]) -> Optional[datetime]: + if not value: + return None + raw = str(value).strip() + if not raw: + return None + try: + dt = datetime.fromisoformat(raw) + except Exception as exc: # pragma: no cover - defensive parsing + raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt.astimezone(timezone.utc) + + +def _isoformat(value: Optional[datetime]) -> Optional[str]: + if value is None: + return None + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc).isoformat() + + +@dataclass(frozen=True, slots=True) +class EnrollmentCodeRecord: + """Installer code metadata exposed to administrative clients.""" + + record_id: str + code: str + expires_at: datetime + max_uses: int + use_count: int + created_by_user_id: Optional[str] + used_at: Optional[datetime] + used_by_guid: Optional[DeviceGuid] + last_used_at: Optional[datetime] + + @classmethod + def from_row(cls, row: Mapping[str, Any]) -> "EnrollmentCodeRecord": + record_id = str(row.get("id") or "").strip() + code = str(row.get("code") or "").strip() + if not record_id or not code: + raise ValueError("invalid enrollment install code record") + + used_by = row.get("used_by_guid") + used_by_guid = DeviceGuid(str(used_by)) if used_by else None + + return cls( + record_id=record_id, + code=code, + expires_at=_parse_iso8601(row.get("expires_at")) or datetime.now(tz=timezone.utc), + max_uses=int(row.get("max_uses") or 1), + use_count=int(row.get("use_count") or 0), + created_by_user_id=str(row.get("created_by_user_id") or "").strip() or None, + used_at=_parse_iso8601(row.get("used_at")), + used_by_guid=used_by_guid, + last_used_at=_parse_iso8601(row.get("last_used_at")), + ) + + def status(self, *, now: Optional[datetime] = None) -> str: + reference = now or datetime.now(tz=timezone.utc) + if self.use_count >= self.max_uses: + return "used" + if self.expires_at <= reference: + return "expired" + return "active" + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.record_id, + "code": self.code, + "expires_at": _isoformat(self.expires_at), + "max_uses": self.max_uses, + "use_count": self.use_count, + "created_by_user_id": self.created_by_user_id, + "used_at": _isoformat(self.used_at), + "used_by_guid": self.used_by_guid.value if self.used_by_guid else None, + "last_used_at": _isoformat(self.last_used_at), + "status": self.status(), + } + + +@dataclass(frozen=True, slots=True) +class HostnameConflict: + """Existing device details colliding with a pending approval.""" + + guid: Optional[str] + ssl_key_fingerprint: Optional[str] + site_id: Optional[int] + site_name: str + fingerprint_match: bool + requires_prompt: bool + + def to_dict(self) -> dict[str, Any]: + return { + "guid": self.guid, + "ssl_key_fingerprint": self.ssl_key_fingerprint, + "site_id": self.site_id, + "site_name": self.site_name, + "fingerprint_match": self.fingerprint_match, + "requires_prompt": self.requires_prompt, + } + + +@dataclass(frozen=True, slots=True) +class DeviceApprovalRecord: + """Administrative projection of a device approval entry.""" + + record_id: str + reference: str + status: str + claimed_hostname: str + claimed_fingerprint: str + created_at: datetime + updated_at: datetime + enrollment_code_id: Optional[str] + guid: Optional[str] + approved_by_user_id: Optional[str] + approved_by_username: Optional[str] + client_nonce: str + server_nonce: str + hostname_conflict: Optional[HostnameConflict] + alternate_hostname: Optional[str] + conflict_requires_prompt: bool + fingerprint_match: bool + + @classmethod + def from_row( + cls, + row: Mapping[str, Any], + *, + conflict: Optional[HostnameConflict] = None, + alternate_hostname: Optional[str] = None, + fingerprint_match: bool = False, + requires_prompt: bool = False, + ) -> "DeviceApprovalRecord": + record_id = str(row.get("id") or "").strip() + reference = str(row.get("approval_reference") or "").strip() + hostname = str(row.get("hostname_claimed") or "").strip() + fingerprint = str(row.get("ssl_key_fingerprint_claimed") or "").strip().lower() + if not record_id or not reference or not hostname or not fingerprint: + raise ValueError("invalid device approval record") + + guid_raw = normalize_guid(row.get("guid")) or None + + return cls( + record_id=record_id, + reference=reference, + status=str(row.get("status") or "pending").strip().lower(), + claimed_hostname=hostname, + claimed_fingerprint=fingerprint, + created_at=_parse_iso8601(row.get("created_at")) or datetime.now(tz=timezone.utc), + updated_at=_parse_iso8601(row.get("updated_at")) or datetime.now(tz=timezone.utc), + enrollment_code_id=str(row.get("enrollment_code_id") or "").strip() or None, + guid=guid_raw, + approved_by_user_id=str(row.get("approved_by_user_id") or "").strip() or None, + approved_by_username=str(row.get("approved_by_username") or "").strip() or None, + client_nonce=str(row.get("client_nonce") or "").strip(), + server_nonce=str(row.get("server_nonce") or "").strip(), + hostname_conflict=conflict, + alternate_hostname=alternate_hostname, + conflict_requires_prompt=requires_prompt, + fingerprint_match=fingerprint_match, + ) + + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "id": self.record_id, + "approval_reference": self.reference, + "status": self.status, + "hostname_claimed": self.claimed_hostname, + "ssl_key_fingerprint_claimed": self.claimed_fingerprint, + "created_at": _isoformat(self.created_at), + "updated_at": _isoformat(self.updated_at), + "enrollment_code_id": self.enrollment_code_id, + "guid": self.guid, + "approved_by_user_id": self.approved_by_user_id, + "approved_by_username": self.approved_by_username, + "client_nonce": self.client_nonce, + "server_nonce": self.server_nonce, + "conflict_requires_prompt": self.conflict_requires_prompt, + "fingerprint_match": self.fingerprint_match, + } + if self.hostname_conflict is not None: + payload["hostname_conflict"] = self.hostname_conflict.to_dict() + if self.alternate_hostname: + payload["alternate_hostname"] = self.alternate_hostname + return payload + diff --git a/Data/Engine/interfaces/http/admin.py b/Data/Engine/interfaces/http/admin.py index 2da2ec2..30d7fd9 100644 --- a/Data/Engine/interfaces/http/admin.py +++ b/Data/Engine/interfaces/http/admin.py @@ -1,8 +1,8 @@ -"""Administrative HTTP interface placeholders for the Engine.""" +"""Administrative HTTP endpoints for the Borealis Engine.""" from __future__ import annotations -from flask import Blueprint, Flask +from flask import Blueprint, Flask, current_app, jsonify, request, session from Data.Engine.services.container import EngineServiceContainer @@ -11,13 +11,106 @@ blueprint = Blueprint("engine_admin", __name__, url_prefix="/api/admin") def register(app: Flask, _services: EngineServiceContainer) -> None: - """Attach administrative routes to *app*. - - Concrete endpoints will be migrated in subsequent phases. - """ + """Attach administrative routes to *app*.""" if "engine_admin" not in app.blueprints: app.register_blueprint(blueprint) +def _services() -> EngineServiceContainer: + services = current_app.extensions.get("engine_services") + if services is None: # pragma: no cover - defensive + raise RuntimeError("engine services not initialized") + return services + + +def _admin_service(): + return _services().enrollment_admin_service + + +def _require_admin(): + username = session.get("username") + role = (session.get("role") or "").strip().lower() + if not isinstance(username, str) or not username: + return jsonify({"error": "not_authenticated"}), 401 + if role != "admin": + return jsonify({"error": "forbidden"}), 403 + return None + + +@blueprint.route("/enrollment-codes", methods=["GET"]) +def list_enrollment_codes() -> object: + guard = _require_admin() + if guard: + return guard + + status = request.args.get("status") + records = _admin_service().list_install_codes(status=status) + return jsonify({"codes": [record.to_dict() for record in records]}) + + +@blueprint.route("/enrollment-codes", methods=["POST"]) +def create_enrollment_code() -> object: + guard = _require_admin() + if guard: + return guard + + payload = request.get_json(silent=True) or {} + + ttl_value = payload.get("ttl_hours") + if ttl_value is None: + ttl_value = payload.get("ttl") or 1 + try: + ttl_hours = int(ttl_value) + except (TypeError, ValueError): + ttl_hours = 1 + + max_uses_value = payload.get("max_uses") + if max_uses_value is None: + max_uses_value = payload.get("allowed_uses", 2) + try: + max_uses = int(max_uses_value) + except (TypeError, ValueError): + max_uses = 2 + + creator = session.get("username") if isinstance(session.get("username"), str) else None + + try: + record = _admin_service().create_install_code( + ttl_hours=ttl_hours, + max_uses=max_uses, + created_by=creator, + ) + except ValueError as exc: + if str(exc) == "invalid_ttl": + return jsonify({"error": "invalid_ttl"}), 400 + raise + + response = jsonify(record.to_dict()) + response.status_code = 201 + return response + + +@blueprint.route("/enrollment-codes/", methods=["DELETE"]) +def delete_enrollment_code(code_id: str) -> object: + guard = _require_admin() + if guard: + return guard + + if not _admin_service().delete_install_code(code_id): + return jsonify({"error": "not_found"}), 404 + return jsonify({"status": "deleted"}) + + +@blueprint.route("/device-approvals", methods=["GET"]) +def list_device_approvals() -> object: + guard = _require_admin() + if guard: + return guard + + status = request.args.get("status") + records = _admin_service().list_device_approvals(status=status) + return jsonify({"approvals": [record.to_dict() for record in records]}) + + __all__ = ["register", "blueprint"] diff --git a/Data/Engine/repositories/sqlite/enrollment_repository.py b/Data/Engine/repositories/sqlite/enrollment_repository.py index a6549ec..5733af9 100644 --- a/Data/Engine/repositories/sqlite/enrollment_repository.py +++ b/Data/Engine/repositories/sqlite/enrollment_repository.py @@ -5,14 +5,19 @@ from __future__ import annotations import logging from contextlib import closing from datetime import datetime, timezone -from typing import Optional +from typing import Any, List, Optional, Tuple -from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid +from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid, normalize_guid from Data.Engine.domain.device_enrollment import ( EnrollmentApproval, EnrollmentApprovalStatus, EnrollmentCode, ) +from Data.Engine.domain.enrollment_admin import ( + DeviceApprovalRecord, + EnrollmentCodeRecord, + HostnameConflict, +) from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory __all__ = ["SQLiteEnrollmentRepository"] @@ -122,6 +127,158 @@ class SQLiteEnrollmentRepository: self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc) return None + def list_install_codes( + self, + *, + status: Optional[str] = None, + now: Optional[datetime] = None, + ) -> List[EnrollmentCodeRecord]: + reference = now or datetime.now(tz=timezone.utc) + status_filter = (status or "").strip().lower() + params: List[str] = [] + + sql = """ + SELECT id, + code, + expires_at, + created_by_user_id, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at + FROM enrollment_install_codes + """ + + if status_filter in {"active", "expired", "used"}: + sql += " WHERE " + if status_filter == "active": + sql += "use_count < max_uses AND expires_at > ?" + params.append(self._isoformat(reference)) + elif status_filter == "expired": + sql += "use_count < max_uses AND expires_at <= ?" + params.append(self._isoformat(reference)) + else: # used + sql += "use_count >= max_uses" + + sql += " ORDER BY expires_at ASC" + + rows: List[EnrollmentCodeRecord] = [] + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute(sql, params) + for raw in cur.fetchall(): + record = { + "id": raw[0], + "code": raw[1], + "expires_at": raw[2], + "created_by_user_id": raw[3], + "used_at": raw[4], + "used_by_guid": raw[5], + "max_uses": raw[6], + "use_count": raw[7], + "last_used_at": raw[8], + } + try: + rows.append(EnrollmentCodeRecord.from_row(record)) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("invalid enrollment install code row id=%s: %s", record.get("id"), exc) + return rows + + def get_install_code_record(self, record_id: str) -> Optional[EnrollmentCodeRecord]: + identifier = (record_id or "").strip() + if not identifier: + return None + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT id, + code, + expires_at, + created_by_user_id, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at + FROM enrollment_install_codes + WHERE id = ? + """, + (identifier,), + ) + row = cur.fetchone() + + if not row: + return None + + payload = { + "id": row[0], + "code": row[1], + "expires_at": row[2], + "created_by_user_id": row[3], + "used_at": row[4], + "used_by_guid": row[5], + "max_uses": row[6], + "use_count": row[7], + "last_used_at": row[8], + } + + try: + return EnrollmentCodeRecord.from_row(payload) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("invalid enrollment install code record id=%s: %s", identifier, exc) + return None + + def insert_install_code( + self, + *, + record_id: str, + code: str, + expires_at: datetime, + created_by: Optional[str], + max_uses: int, + ) -> EnrollmentCodeRecord: + expires_iso = self._isoformat(expires_at) + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO enrollment_install_codes ( + id, + code, + expires_at, + created_by_user_id, + max_uses, + use_count + ) VALUES (?, ?, ?, ?, ?, 0) + """, + (record_id, code, expires_iso, created_by, max_uses), + ) + conn.commit() + + record = self.get_install_code_record(record_id) + if record is None: + raise RuntimeError("failed to load install code after insert") + return record + + def delete_install_code_if_unused(self, record_id: str) -> bool: + identifier = (record_id or "").strip() + if not identifier: + return False + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + "DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0", + (identifier,), + ) + deleted = cur.rowcount > 0 + conn.commit() + return deleted + def update_install_code_usage( self, record_id: str, @@ -165,6 +322,100 @@ class SQLiteEnrollmentRepository: # ------------------------------------------------------------------ # Device approvals # ------------------------------------------------------------------ + def list_device_approvals( + self, + *, + status: Optional[str] = None, + ) -> List[DeviceApprovalRecord]: + status_filter = (status or "").strip().lower() + params: List[str] = [] + + sql = """ + SELECT + da.id, + da.approval_reference, + da.guid, + da.hostname_claimed, + da.ssl_key_fingerprint_claimed, + da.enrollment_code_id, + da.status, + da.client_nonce, + da.server_nonce, + da.created_at, + da.updated_at, + da.approved_by_user_id, + u.username AS approved_by_username + FROM device_approvals AS da + LEFT JOIN users AS u + ON ( + CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT) + OR LOWER(da.approved_by_user_id) = LOWER(u.username) + ) + """ + + if status_filter and status_filter not in {"all", "*"}: + sql += " WHERE LOWER(da.status) = ?" + params.append(status_filter) + + sql += " ORDER BY da.created_at ASC" + + approvals: List[DeviceApprovalRecord] = [] + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute(sql, params) + rows = cur.fetchall() + + for raw in rows: + record = { + "id": raw[0], + "approval_reference": raw[1], + "guid": raw[2], + "hostname_claimed": raw[3], + "ssl_key_fingerprint_claimed": raw[4], + "enrollment_code_id": raw[5], + "status": raw[6], + "client_nonce": raw[7], + "server_nonce": raw[8], + "created_at": raw[9], + "updated_at": raw[10], + "approved_by_user_id": raw[11], + "approved_by_username": raw[12], + } + + conflict, fingerprint_match, requires_prompt = self._compute_hostname_conflict( + conn, + record.get("hostname_claimed"), + record.get("guid"), + record.get("ssl_key_fingerprint_claimed") or "", + ) + + alternate = None + if conflict and requires_prompt: + alternate = self._suggest_alternate_hostname( + conn, + record.get("hostname_claimed"), + record.get("guid"), + ) + + try: + approvals.append( + DeviceApprovalRecord.from_row( + record, + conflict=conflict, + alternate_hostname=alternate, + fingerprint_match=fingerprint_match, + requires_prompt=requires_prompt, + ) + ) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning( + "invalid device approval record id=%s: %s", + record.get("id"), + exc, + ) + + return approvals + def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: """Load a device approval using its operator-visible reference.""" @@ -376,6 +627,98 @@ class SQLiteEnrollmentRepository: ) return None + def _compute_hostname_conflict( + self, + conn, + hostname: Optional[str], + pending_guid: Optional[str], + claimed_fp: str, + ) -> Tuple[Optional[HostnameConflict], bool, bool]: + normalized_host = (hostname or "").strip() + if not normalized_host: + return None, False, False + + try: + cur = conn.cursor() + cur.execute( + """ + SELECT d.guid, + d.ssl_key_fingerprint, + ds.site_id, + s.name + FROM devices AS d + LEFT JOIN device_sites AS ds ON ds.device_hostname = d.hostname + LEFT JOIN sites AS s ON s.id = ds.site_id + WHERE d.hostname = ? + """, + (normalized_host,), + ) + row = cur.fetchone() + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("failed to inspect hostname conflict for %s: %s", normalized_host, exc) + return None, False, False + + if not row: + return None, False, False + + existing_guid = normalize_guid(row[0]) + pending_norm = normalize_guid(pending_guid) + if existing_guid and pending_norm and existing_guid == pending_norm: + return None, False, False + + stored_fp = (row[1] or "").strip().lower() + claimed_fp_normalized = (claimed_fp or "").strip().lower() + fingerprint_match = bool(stored_fp and claimed_fp_normalized and stored_fp == claimed_fp_normalized) + + site_id = None + if row[2] is not None: + try: + site_id = int(row[2]) + except (TypeError, ValueError): # pragma: no cover - defensive + site_id = None + + site_name = str(row[3] or "").strip() + requires_prompt = not fingerprint_match + + conflict = HostnameConflict( + guid=existing_guid or None, + ssl_key_fingerprint=stored_fp or None, + site_id=site_id, + site_name=site_name, + fingerprint_match=fingerprint_match, + requires_prompt=requires_prompt, + ) + + return conflict, fingerprint_match, requires_prompt + + def _suggest_alternate_hostname( + self, + conn, + hostname: Optional[str], + pending_guid: Optional[str], + ) -> Optional[str]: + base = (hostname or "").strip() + if not base: + return None + base = base[:253] + candidate = base + pending_norm = normalize_guid(pending_guid) + suffix = 1 + + cur = conn.cursor() + while True: + cur.execute("SELECT guid FROM devices WHERE hostname = ?", (candidate,)) + row = cur.fetchone() + if not row: + return candidate + existing_guid = normalize_guid(row[0]) + if pending_norm and existing_guid == pending_norm: + return candidate + candidate = f"{base}-{suffix}" + suffix += 1 + if suffix > 50: + return pending_norm or candidate + @staticmethod def _isoformat(value: datetime) -> str: if value.tzinfo is None: diff --git a/Data/Engine/repositories/sqlite/migrations.py b/Data/Engine/repositories/sqlite/migrations.py index 34d3c77..535c78c 100644 --- a/Data/Engine/repositories/sqlite/migrations.py +++ b/Data/Engine/repositories/sqlite/migrations.py @@ -31,6 +31,9 @@ def apply_all(conn: sqlite3.Connection) -> None: _ensure_refresh_token_table(conn) _ensure_install_code_table(conn) _ensure_device_approval_table(conn) + _ensure_device_list_views_table(conn) + _ensure_sites_tables(conn) + _ensure_credentials_table(conn) _ensure_github_token_table(conn) _ensure_scheduled_jobs_table(conn) _ensure_scheduled_job_run_tables(conn) @@ -233,6 +236,73 @@ def _ensure_device_approval_table(conn: sqlite3.Connection) -> None: ) +def _ensure_device_list_views_table(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS device_list_views ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + columns_json TEXT NOT NULL, + filters_json TEXT, + created_at INTEGER, + updated_at INTEGER + ) + """ + ) + + +def _ensure_sites_tables(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS sites ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + description TEXT, + created_at INTEGER + ) + """ + ) + cur.execute( + """ + CREATE TABLE IF NOT EXISTS device_sites ( + device_hostname TEXT UNIQUE NOT NULL, + site_id INTEGER NOT NULL, + assigned_at INTEGER, + FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE CASCADE + ) + """ + ) + + +def _ensure_credentials_table(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS credentials ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + description TEXT, + site_id INTEGER, + credential_type TEXT NOT NULL DEFAULT 'machine', + connection_type TEXT NOT NULL DEFAULT 'ssh', + username TEXT, + password_encrypted BLOB, + private_key_encrypted BLOB, + private_key_passphrase_encrypted BLOB, + become_method TEXT, + become_username TEXT, + become_password_encrypted BLOB, + metadata_json TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE SET NULL + ) + """ + ) + + def _ensure_github_token_table(conn: sqlite3.Connection) -> None: cur = conn.cursor() cur.execute( diff --git a/Data/Engine/repositories/sqlite/user_repository.py b/Data/Engine/repositories/sqlite/user_repository.py index 9c61a4d..9c3002d 100644 --- a/Data/Engine/repositories/sqlite/user_repository.py +++ b/Data/Engine/repositories/sqlite/user_repository.py @@ -71,6 +71,57 @@ class SQLiteUserRepository: finally: conn.close() + def resolve_identifier(self, username: str) -> Optional[str]: + normalized = (username or "").strip() + if not normalized: + return None + + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute( + "SELECT id FROM users WHERE LOWER(username) = LOWER(?)", + (normalized,), + ) + row = cur.fetchone() + if not row: + return None + return str(row[0]) if row[0] is not None else None + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.error("failed to resolve identifier for %s: %s", username, exc) + return None + finally: + conn.close() + + def username_for_identifier(self, identifier: str) -> Optional[str]: + token = (identifier or "").strip() + if not token: + return None + + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute( + """ + SELECT username + FROM users + WHERE CAST(id AS TEXT) = ? + OR LOWER(username) = LOWER(?) + LIMIT 1 + """, + (token, token), + ) + row = cur.fetchone() + if not row: + return None + username = str(row[0] or "").strip() + return username or None + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.error("failed to resolve username for %s: %s", identifier, exc) + return None + finally: + conn.close() + def list_accounts(self) -> list[OperatorAccount]: conn = self._connection_factory() try: diff --git a/Data/Engine/services/__init__.py b/Data/Engine/services/__init__.py index 3e216c7..9c59917 100644 --- a/Data/Engine/services/__init__.py +++ b/Data/Engine/services/__init__.py @@ -23,6 +23,7 @@ __all__ = [ "SchedulerService", "GitHubService", "GitHubTokenPayload", + "EnrollmentAdminService", ] _LAZY_TARGETS: Dict[str, Tuple[str, str]] = { @@ -43,6 +44,10 @@ _LAZY_TARGETS: Dict[str, Tuple[str, str]] = { "SchedulerService": ("Data.Engine.services.jobs.scheduler_service", "SchedulerService"), "GitHubService": ("Data.Engine.services.github.github_service", "GitHubService"), "GitHubTokenPayload": ("Data.Engine.services.github.github_service", "GitHubTokenPayload"), + "EnrollmentAdminService": ( + "Data.Engine.services.enrollment.admin_service", + "EnrollmentAdminService", + ), } diff --git a/Data/Engine/services/container.py b/Data/Engine/services/container.py index 621e02a..bbb731b 100644 --- a/Data/Engine/services/container.py +++ b/Data/Engine/services/container.py @@ -30,6 +30,7 @@ from Data.Engine.services.auth import ( ) from Data.Engine.services.crypto.signing import ScriptSigner, load_signer from Data.Engine.services.enrollment import EnrollmentService +from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService from Data.Engine.services.enrollment.nonce_cache import NonceCache from Data.Engine.services.github import GitHubService from Data.Engine.services.jobs import SchedulerService @@ -44,6 +45,7 @@ class EngineServiceContainer: device_auth: DeviceAuthService token_service: TokenService enrollment_service: EnrollmentService + enrollment_admin_service: EnrollmentAdminService jwt_service: JWTService dpop_validator: DPoPValidator agent_realtime: AgentRealtimeService @@ -93,6 +95,12 @@ def build_service_container( logger=log.getChild("enrollment"), ) + enrollment_admin_service = EnrollmentAdminService( + repository=enrollment_repo, + user_repository=user_repo, + logger=log.getChild("enrollment_admin"), + ) + device_auth = DeviceAuthService( device_repository=device_repo, jwt_service=jwt_service, @@ -139,6 +147,7 @@ def build_service_container( device_auth=device_auth, token_service=token_service, enrollment_service=enrollment_service, + enrollment_admin_service=enrollment_admin_service, jwt_service=jwt_service, dpop_validator=dpop_validator, agent_realtime=agent_realtime, diff --git a/Data/Engine/services/enrollment/__init__.py b/Data/Engine/services/enrollment/__init__.py index 063cd7b..7277d59 100644 --- a/Data/Engine/services/enrollment/__init__.py +++ b/Data/Engine/services/enrollment/__init__.py @@ -2,20 +2,54 @@ from __future__ import annotations -from .enrollment_service import ( - EnrollmentRequestResult, - EnrollmentService, - EnrollmentStatus, - EnrollmentTokenBundle, - PollingResult, -) -from Data.Engine.domain.device_enrollment import EnrollmentValidationError +from importlib import import_module +from typing import Any __all__ = [ - "EnrollmentRequestResult", "EnrollmentService", + "EnrollmentRequestResult", "EnrollmentStatus", "EnrollmentTokenBundle", - "EnrollmentValidationError", "PollingResult", + "EnrollmentValidationError", + "EnrollmentAdminService", ] + +_LAZY: dict[str, tuple[str, str]] = { + "EnrollmentService": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentService"), + "EnrollmentRequestResult": ( + "Data.Engine.services.enrollment.enrollment_service", + "EnrollmentRequestResult", + ), + "EnrollmentStatus": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentStatus"), + "EnrollmentTokenBundle": ( + "Data.Engine.services.enrollment.enrollment_service", + "EnrollmentTokenBundle", + ), + "PollingResult": ("Data.Engine.services.enrollment.enrollment_service", "PollingResult"), + "EnrollmentValidationError": ( + "Data.Engine.domain.device_enrollment", + "EnrollmentValidationError", + ), + "EnrollmentAdminService": ( + "Data.Engine.services.enrollment.admin_service", + "EnrollmentAdminService", + ), +} + + +def __getattr__(name: str) -> Any: + try: + module_name, attribute = _LAZY[name] + except KeyError as exc: # pragma: no cover - defensive + raise AttributeError(name) from exc + + module = import_module(module_name) + value = getattr(module, attribute) + globals()[name] = value + return value + + +def __dir__() -> list[str]: # pragma: no cover - interactive helper + return sorted(set(__all__)) + diff --git a/Data/Engine/services/enrollment/admin_service.py b/Data/Engine/services/enrollment/admin_service.py new file mode 100644 index 0000000..de8193f --- /dev/null +++ b/Data/Engine/services/enrollment/admin_service.py @@ -0,0 +1,113 @@ +"""Administrative helpers for enrollment workflows.""" + +from __future__ import annotations + +import logging +import secrets +import uuid +from datetime import datetime, timedelta, timezone +from typing import Callable, List, Optional + +from Data.Engine.domain.enrollment_admin import DeviceApprovalRecord, EnrollmentCodeRecord +from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository +from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository + +__all__ = ["EnrollmentAdminService"] + + +class EnrollmentAdminService: + """Expose administrative enrollment operations.""" + + _VALID_TTL_HOURS = {1, 3, 6, 12, 24} + + def __init__( + self, + *, + repository: SQLiteEnrollmentRepository, + user_repository: SQLiteUserRepository, + logger: Optional[logging.Logger] = None, + clock: Optional[Callable[[], datetime]] = None, + ) -> None: + self._repository = repository + self._users = user_repository + self._log = logger or logging.getLogger("borealis.engine.services.enrollment_admin") + self._clock = clock or (lambda: datetime.now(tz=timezone.utc)) + + # ------------------------------------------------------------------ + # Enrollment install codes + # ------------------------------------------------------------------ + def list_install_codes(self, *, status: Optional[str] = None) -> List[EnrollmentCodeRecord]: + return self._repository.list_install_codes(status=status, now=self._clock()) + + def create_install_code( + self, + *, + ttl_hours: int, + max_uses: int, + created_by: Optional[str], + ) -> EnrollmentCodeRecord: + if ttl_hours not in self._VALID_TTL_HOURS: + raise ValueError("invalid_ttl") + + normalized_max = self._normalize_max_uses(max_uses) + + now = self._clock() + expires_at = now + timedelta(hours=ttl_hours) + record_id = str(uuid.uuid4()) + code = self._generate_install_code() + + created_by_identifier = None + if created_by: + created_by_identifier = self._users.resolve_identifier(created_by) + if not created_by_identifier: + created_by_identifier = created_by.strip() or None + + record = self._repository.insert_install_code( + record_id=record_id, + code=code, + expires_at=expires_at, + created_by=created_by_identifier, + max_uses=normalized_max, + ) + + self._log.info( + "install code created id=%s ttl=%sh max_uses=%s", + record.record_id, + ttl_hours, + normalized_max, + ) + + return record + + def delete_install_code(self, record_id: str) -> bool: + deleted = self._repository.delete_install_code_if_unused(record_id) + if deleted: + self._log.info("install code deleted id=%s", record_id) + return deleted + + # ------------------------------------------------------------------ + # Device approvals + # ------------------------------------------------------------------ + def list_device_approvals(self, *, status: Optional[str] = None) -> List[DeviceApprovalRecord]: + return self._repository.list_device_approvals(status=status) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + @staticmethod + def _generate_install_code() -> str: + raw = secrets.token_hex(16).upper() + return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4)) + + @staticmethod + def _normalize_max_uses(value: int) -> int: + try: + count = int(value) + except Exception: + count = 2 + if count < 1: + return 1 + if count > 10: + return 10 + return count + diff --git a/Data/Engine/tests/test_enrollment_admin_service.py b/Data/Engine/tests/test_enrollment_admin_service.py new file mode 100644 index 0000000..9fb3f64 --- /dev/null +++ b/Data/Engine/tests/test_enrollment_admin_service.py @@ -0,0 +1,122 @@ +import base64 +import sqlite3 +from datetime import datetime, timezone + +import pytest + +from Data.Engine.repositories.sqlite import connection as sqlite_connection +from Data.Engine.repositories.sqlite import migrations as sqlite_migrations +from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository +from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository +from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService + + +def _build_service(tmp_path): + db_path = tmp_path / "admin.db" + conn = sqlite3.connect(db_path) + sqlite_migrations.apply_all(conn) + conn.close() + + factory = sqlite_connection.connection_factory(db_path) + enrollment_repo = SQLiteEnrollmentRepository(factory) + user_repo = SQLiteUserRepository(factory) + + fixed_now = datetime(2024, 1, 1, tzinfo=timezone.utc) + service = EnrollmentAdminService( + repository=enrollment_repo, + user_repository=user_repo, + clock=lambda: fixed_now, + ) + return service, factory, fixed_now + + +def test_create_and_list_install_codes(tmp_path): + service, factory, fixed_now = _build_service(tmp_path) + + record = service.create_install_code(ttl_hours=3, max_uses=5, created_by="admin") + assert record.code + assert record.max_uses == 5 + assert record.status(now=fixed_now) == "active" + + records = service.list_install_codes() + assert any(r.record_id == record.record_id for r in records) + + # Invalid TTL should raise + with pytest.raises(ValueError): + service.create_install_code(ttl_hours=2, max_uses=1, created_by=None) + + # Deleting should succeed and remove the record + assert service.delete_install_code(record.record_id) is True + remaining = service.list_install_codes() + assert all(r.record_id != record.record_id for r in remaining) + + +def test_list_device_approvals_includes_conflict(tmp_path): + service, factory, fixed_now = _build_service(tmp_path) + + conn = factory() + cur = conn.cursor() + + cur.execute( + "INSERT INTO sites (name, description, created_at) VALUES (?, ?, ?)", + ("HQ", "Primary site", int(fixed_now.timestamp())), + ) + site_id = cur.lastrowid + + cur.execute( + """ + INSERT INTO devices (guid, hostname, created_at, last_seen, ssl_key_fingerprint, status) + VALUES (?, ?, ?, ?, ?, 'active') + """, + ("11111111-1111-1111-1111-111111111111", "agent-one", int(fixed_now.timestamp()), int(fixed_now.timestamp()), "abc123",), + ) + cur.execute( + "INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)", + ("agent-one", site_id, int(fixed_now.timestamp())), + ) + + now_iso = fixed_now.isoformat() + cur.execute( + """ + INSERT INTO device_approvals ( + id, + approval_reference, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed, + enrollment_code_id, + status, + client_nonce, + server_nonce, + created_at, + updated_at, + approved_by_user_id, + agent_pubkey_der + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + "approval-1", + "REF123", + None, + "agent-one", + "abc123", + "code-1", + "pending", + base64.b64encode(b"client").decode(), + base64.b64encode(b"server").decode(), + now_iso, + now_iso, + None, + b"pubkey", + ), + ) + conn.commit() + conn.close() + + approvals = service.list_device_approvals() + assert len(approvals) == 1 + record = approvals[0] + assert record.hostname_conflict is not None + assert record.hostname_conflict.fingerprint_match is True + assert record.conflict_requires_prompt is False + diff --git a/Data/Engine/tests/test_http_admin.py b/Data/Engine/tests/test_http_admin.py new file mode 100644 index 0000000..f3e0cc4 --- /dev/null +++ b/Data/Engine/tests/test_http_admin.py @@ -0,0 +1,111 @@ +import base64 +import sqlite3 +from datetime import datetime, timezone + +from .test_http_auth import _login, prepared_app + + +def test_enrollment_codes_require_authentication(prepared_app): + client = prepared_app.test_client() + resp = client.get("/api/admin/enrollment-codes") + assert resp.status_code == 401 + + +def test_enrollment_code_workflow(prepared_app): + client = prepared_app.test_client() + _login(client) + + payload = {"ttl_hours": 3, "max_uses": 4} + resp = client.post("/api/admin/enrollment-codes", json=payload) + assert resp.status_code == 201 + created = resp.get_json() + assert created["max_uses"] == 4 + assert created["status"] == "active" + + resp = client.get("/api/admin/enrollment-codes") + assert resp.status_code == 200 + codes = resp.get_json().get("codes", []) + assert any(code["id"] == created["id"] for code in codes) + + resp = client.delete(f"/api/admin/enrollment-codes/{created['id']}") + assert resp.status_code == 200 + + +def test_device_approvals_listing(prepared_app, engine_settings): + client = prepared_app.test_client() + _login(client) + + conn = sqlite3.connect(engine_settings.database.path) + cur = conn.cursor() + + now = datetime.now(tz=timezone.utc) + cur.execute( + "INSERT INTO sites (name, description, created_at) VALUES (?, ?, ?)", + ("HQ", "Primary", int(now.timestamp())), + ) + site_id = cur.lastrowid + + cur.execute( + """ + INSERT INTO devices (guid, hostname, created_at, last_seen, ssl_key_fingerprint, status) + VALUES (?, ?, ?, ?, ?, 'active') + """, + ( + "22222222-2222-2222-2222-222222222222", + "approval-host", + int(now.timestamp()), + int(now.timestamp()), + "deadbeef", + ), + ) + cur.execute( + "INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)", + ("approval-host", site_id, int(now.timestamp())), + ) + + now_iso = now.isoformat() + cur.execute( + """ + INSERT INTO device_approvals ( + id, + approval_reference, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed, + enrollment_code_id, + status, + client_nonce, + server_nonce, + created_at, + updated_at, + approved_by_user_id, + agent_pubkey_der + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + "approval-http", + "REFHTTP", + None, + "approval-host", + "deadbeef", + "code-http", + "pending", + base64.b64encode(b"client").decode(), + base64.b64encode(b"server").decode(), + now_iso, + now_iso, + None, + b"pub", + ), + ) + conn.commit() + conn.close() + + resp = client.get("/api/admin/device-approvals") + assert resp.status_code == 200 + body = resp.get_json() + approvals = body.get("approvals", []) + assert any(a["id"] == "approval-http" for a in approvals) + record = next(a for a in approvals if a["id"] == "approval-http") + assert record.get("hostname_conflict", {}).get("fingerprint_match") is True + From 4bc529aaf460dcbb9fea9d545c742efddd4c267c Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Wed, 22 Oct 2025 23:43:16 -0600 Subject: [PATCH 08/12] Port core API routes for sites and devices --- Data/Engine/domain/device_views.py | 28 ++ Data/Engine/domain/devices.py | 291 +++++++++++++++++ Data/Engine/domain/sites.py | 43 +++ Data/Engine/interfaces/http/__init__.py | 18 +- Data/Engine/interfaces/http/credentials.py | 70 ++++ Data/Engine/interfaces/http/devices.py | 301 ++++++++++++++++++ Data/Engine/interfaces/http/sites.py | 112 +++++++ Data/Engine/repositories/sqlite/__init__.py | 12 + .../sqlite/credential_repository.py | 103 ++++++ .../sqlite/device_inventory_repository.py | 253 +++++++++++++++ .../sqlite/device_view_repository.py | 143 +++++++++ .../repositories/sqlite/site_repository.py | 189 +++++++++++ Data/Engine/services/__init__.py | 17 + Data/Engine/services/container.py | 42 +++ Data/Engine/services/credentials/__init__.py | 3 + .../credentials/credential_service.py | 29 ++ Data/Engine/services/devices/__init__.py | 4 + .../devices/device_inventory_service.py | 178 +++++++++++ .../services/devices/device_view_service.py | 73 +++++ Data/Engine/services/sites/__init__.py | 3 + Data/Engine/services/sites/site_service.py | 73 +++++ Data/Engine/tests/test_http_sites_devices.py | 108 +++++++ 22 files changed, 2092 insertions(+), 1 deletion(-) create mode 100644 Data/Engine/domain/device_views.py create mode 100644 Data/Engine/domain/devices.py create mode 100644 Data/Engine/domain/sites.py create mode 100644 Data/Engine/interfaces/http/credentials.py create mode 100644 Data/Engine/interfaces/http/devices.py create mode 100644 Data/Engine/interfaces/http/sites.py create mode 100644 Data/Engine/repositories/sqlite/credential_repository.py create mode 100644 Data/Engine/repositories/sqlite/device_inventory_repository.py create mode 100644 Data/Engine/repositories/sqlite/device_view_repository.py create mode 100644 Data/Engine/repositories/sqlite/site_repository.py create mode 100644 Data/Engine/services/credentials/__init__.py create mode 100644 Data/Engine/services/credentials/credential_service.py create mode 100644 Data/Engine/services/devices/__init__.py create mode 100644 Data/Engine/services/devices/device_inventory_service.py create mode 100644 Data/Engine/services/devices/device_view_service.py create mode 100644 Data/Engine/services/sites/__init__.py create mode 100644 Data/Engine/services/sites/site_service.py create mode 100644 Data/Engine/tests/test_http_sites_devices.py diff --git a/Data/Engine/domain/device_views.py b/Data/Engine/domain/device_views.py new file mode 100644 index 0000000..a208692 --- /dev/null +++ b/Data/Engine/domain/device_views.py @@ -0,0 +1,28 @@ +"""Domain objects for saved device list views.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List + +__all__ = ["DeviceListView"] + + +@dataclass(frozen=True, slots=True) +class DeviceListView: + id: int + name: str + columns: List[str] + filters: Dict[str, object] + created_at: int + updated_at: int + + def to_dict(self) -> Dict[str, object]: + return { + "id": self.id, + "name": self.name, + "columns": self.columns, + "filters": self.filters, + "created_at": self.created_at, + "updated_at": self.updated_at, + } diff --git a/Data/Engine/domain/devices.py b/Data/Engine/domain/devices.py new file mode 100644 index 0000000..5c292c2 --- /dev/null +++ b/Data/Engine/domain/devices.py @@ -0,0 +1,291 @@ +"""Device domain helpers mirroring the legacy server payloads.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Dict, List, Mapping, Optional, Sequence + +__all__ = [ + "DEVICE_TABLE_COLUMNS", + "DEVICE_TABLE", + "DeviceSnapshot", + "assemble_device_snapshot", + "row_to_device_dict", + "serialize_device_json", + "clean_device_str", + "coerce_int", + "ts_to_iso", + "device_column_sql", + "ts_to_human", +] + + +DEVICE_TABLE = "devices" + +DEVICE_JSON_LIST_FIELDS: Mapping[str, List[Any]] = { + "memory": [], + "network": [], + "software": [], + "storage": [], +} + +DEVICE_JSON_OBJECT_FIELDS: Mapping[str, Dict[str, Any]] = { + "cpu": {}, +} + +DEVICE_TABLE_COLUMNS: Sequence[str] = ( + "guid", + "hostname", + "description", + "created_at", + "agent_hash", + "memory", + "network", + "software", + "storage", + "cpu", + "device_type", + "domain", + "external_ip", + "internal_ip", + "last_reboot", + "last_seen", + "last_user", + "operating_system", + "uptime", + "agent_id", + "ansible_ee_ver", + "connection_type", + "connection_endpoint", + "ssl_key_fingerprint", + "token_version", + "status", + "key_added_at", +) + + +@dataclass(frozen=True) +class DeviceSnapshot: + hostname: str + description: str + created_at: int + created_at_iso: str + agent_hash: str + agent_guid: str + guid: str + memory: List[Dict[str, Any]] + network: List[Dict[str, Any]] + software: List[Dict[str, Any]] + storage: List[Dict[str, Any]] + cpu: Dict[str, Any] + device_type: str + domain: str + external_ip: str + internal_ip: str + last_reboot: str + last_seen: int + last_seen_iso: str + last_user: str + operating_system: str + uptime: int + agent_id: str + connection_type: str + connection_endpoint: str + details: Dict[str, Any] + summary: Dict[str, Any] + + def to_dict(self) -> Dict[str, Any]: + return { + "hostname": self.hostname, + "description": self.description, + "created_at": self.created_at, + "created_at_iso": self.created_at_iso, + "agent_hash": self.agent_hash, + "agent_guid": self.agent_guid, + "guid": self.guid, + "memory": self.memory, + "network": self.network, + "software": self.software, + "storage": self.storage, + "cpu": self.cpu, + "device_type": self.device_type, + "domain": self.domain, + "external_ip": self.external_ip, + "internal_ip": self.internal_ip, + "last_reboot": self.last_reboot, + "last_seen": self.last_seen, + "last_seen_iso": self.last_seen_iso, + "last_user": self.last_user, + "operating_system": self.operating_system, + "uptime": self.uptime, + "agent_id": self.agent_id, + "connection_type": self.connection_type, + "connection_endpoint": self.connection_endpoint, + "details": self.details, + "summary": self.summary, + } + + +def ts_to_iso(ts: Optional[int]) -> str: + if not ts: + return "" + try: + return datetime.fromtimestamp(int(ts), timezone.utc).isoformat() + except Exception: + return "" + + +def _ts_to_human(ts: Optional[int]) -> str: + if not ts: + return "" + try: + return datetime.utcfromtimestamp(int(ts)).strftime("%Y-%m-%d %H:%M:%S") + except Exception: + return "" + + +def _parse_device_json(raw: Optional[str], default: Any) -> Any: + if raw is None: + return json.loads(json.dumps(default)) if isinstance(default, (list, dict)) else default + try: + data = json.loads(raw) + except Exception: + data = None + if isinstance(default, list): + if isinstance(data, list): + return data + return [] + if isinstance(default, dict): + if isinstance(data, dict): + return data + return {} + return default + + +def serialize_device_json(value: Any, default: Any) -> str: + candidate = value + if candidate is None: + candidate = default + if not isinstance(candidate, (list, dict)): + candidate = default + try: + return json.dumps(candidate) + except Exception: + try: + return json.dumps(default) + except Exception: + return "{}" if isinstance(default, dict) else "[]" + + +def clean_device_str(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, (int, float)) and not isinstance(value, bool): + text = str(value) + elif isinstance(value, str): + text = value + else: + try: + text = str(value) + except Exception: + return None + text = text.strip() + return text or None + + +def coerce_int(value: Any) -> Optional[int]: + if value is None: + return None + try: + if isinstance(value, str) and value.strip() == "": + return None + return int(float(value)) + except (ValueError, TypeError): + return None + + +def row_to_device_dict(row: Sequence[Any], columns: Sequence[str]) -> Dict[str, Any]: + return {columns[idx]: row[idx] for idx in range(min(len(row), len(columns)))} + + +def assemble_device_snapshot(record: Mapping[str, Any]) -> Dict[str, Any]: + summary = { + "hostname": record.get("hostname") or "", + "description": record.get("description") or "", + "device_type": record.get("device_type") or "", + "domain": record.get("domain") or "", + "external_ip": record.get("external_ip") or "", + "internal_ip": record.get("internal_ip") or "", + "last_reboot": record.get("last_reboot") or "", + "last_seen": record.get("last_seen") or 0, + "last_user": record.get("last_user") or "", + "operating_system": record.get("operating_system") or "", + "uptime": record.get("uptime") or 0, + "agent_id": record.get("agent_id") or "", + "agent_hash": record.get("agent_hash") or "", + "agent_guid": record.get("guid") or record.get("agent_guid") or "", + "connection_type": record.get("connection_type") or "", + "connection_endpoint": record.get("connection_endpoint") or "", + "created_at": record.get("created_at") or 0, + } + + created_ts = coerce_int(summary.get("created_at")) or 0 + last_seen_ts = coerce_int(summary.get("last_seen")) or 0 + uptime_val = coerce_int(summary.get("uptime")) or 0 + + parsed_lists = { + key: _parse_device_json(record.get(key), default) + for key, default in DEVICE_JSON_LIST_FIELDS.items() + } + cpu_obj = _parse_device_json(record.get("cpu"), DEVICE_JSON_OBJECT_FIELDS["cpu"]) + + details = { + "memory": parsed_lists["memory"], + "network": parsed_lists["network"], + "software": parsed_lists["software"], + "storage": parsed_lists["storage"], + "cpu": cpu_obj, + } + + payload: Dict[str, Any] = { + "hostname": summary["hostname"], + "description": summary.get("description", ""), + "created_at": created_ts, + "created_at_iso": ts_to_iso(created_ts), + "agent_hash": summary.get("agent_hash", ""), + "agent_guid": summary.get("agent_guid", ""), + "guid": summary.get("agent_guid", ""), + "memory": parsed_lists["memory"], + "network": parsed_lists["network"], + "software": parsed_lists["software"], + "storage": parsed_lists["storage"], + "cpu": cpu_obj, + "device_type": summary.get("device_type", ""), + "domain": summary.get("domain", ""), + "external_ip": summary.get("external_ip", ""), + "internal_ip": summary.get("internal_ip", ""), + "last_reboot": summary.get("last_reboot", ""), + "last_seen": last_seen_ts, + "last_seen_iso": ts_to_iso(last_seen_ts), + "last_user": summary.get("last_user", ""), + "operating_system": summary.get("operating_system", ""), + "uptime": uptime_val, + "agent_id": summary.get("agent_id", ""), + "connection_type": summary.get("connection_type", ""), + "connection_endpoint": summary.get("connection_endpoint", ""), + "details": details, + "summary": summary, + } + return payload + + +def device_column_sql(alias: Optional[str] = None) -> str: + if alias: + return ", ".join(f"{alias}.{col}" for col in DEVICE_TABLE_COLUMNS) + return ", ".join(DEVICE_TABLE_COLUMNS) + + +def ts_to_human(ts: Optional[int]) -> str: + return _ts_to_human(ts) diff --git a/Data/Engine/domain/sites.py b/Data/Engine/domain/sites.py new file mode 100644 index 0000000..7b1c2c0 --- /dev/null +++ b/Data/Engine/domain/sites.py @@ -0,0 +1,43 @@ +"""Domain models for operator site management.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Optional + +__all__ = ["SiteSummary", "SiteDeviceMapping"] + + +@dataclass(frozen=True, slots=True) +class SiteSummary: + """Representation of a site record including device counts.""" + + id: int + name: str + description: str + created_at: int + device_count: int + + def to_dict(self) -> Dict[str, object]: + return { + "id": self.id, + "name": self.name, + "description": self.description, + "created_at": self.created_at, + "device_count": self.device_count, + } + + +@dataclass(frozen=True, slots=True) +class SiteDeviceMapping: + """Mapping entry describing which site a device belongs to.""" + + hostname: str + site_id: Optional[int] + site_name: str + + def to_dict(self) -> Dict[str, object]: + return { + "site_id": self.site_id, + "site_name": self.site_name, + } diff --git a/Data/Engine/interfaces/http/__init__.py b/Data/Engine/interfaces/http/__init__.py index fc88e26..47d62fa 100644 --- a/Data/Engine/interfaces/http/__init__.py +++ b/Data/Engine/interfaces/http/__init__.py @@ -6,7 +6,20 @@ from flask import Flask from Data.Engine.services.container import EngineServiceContainer -from . import admin, agents, auth, enrollment, github, health, job_management, tokens, users +from . import ( + admin, + agents, + auth, + enrollment, + github, + health, + job_management, + tokens, + users, + sites, + devices, + credentials, +) _REGISTRARS = ( health.register, @@ -18,6 +31,9 @@ _REGISTRARS = ( auth.register, admin.register, users.register, + sites.register, + devices.register, + credentials.register, ) diff --git a/Data/Engine/interfaces/http/credentials.py b/Data/Engine/interfaces/http/credentials.py new file mode 100644 index 0000000..6e65719 --- /dev/null +++ b/Data/Engine/interfaces/http/credentials.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from flask import Blueprint, Flask, current_app, jsonify, request, session + +from Data.Engine.services.container import EngineServiceContainer + +blueprint = Blueprint("engine_credentials", __name__) + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + if "engine_credentials" not in app.blueprints: + app.register_blueprint(blueprint) + + +def _services() -> EngineServiceContainer: + services = current_app.extensions.get("engine_services") + if services is None: # pragma: no cover - defensive + raise RuntimeError("engine services not initialized") + return services + + +def _credentials_service(): + return _services().credential_service + + +def _require_admin(): + username = session.get("username") + role = (session.get("role") or "").strip().lower() + if not isinstance(username, str) or not username: + return jsonify({"error": "not_authenticated"}), 401 + if role != "admin": + return jsonify({"error": "forbidden"}), 403 + return None + + +@blueprint.route("/api/credentials", methods=["GET"]) +def list_credentials() -> object: + guard = _require_admin() + if guard: + return guard + + site_id_param = request.args.get("site_id") + connection_type = (request.args.get("connection_type") or "").strip() or None + try: + site_id = int(site_id_param) if site_id_param not in (None, "") else None + except (TypeError, ValueError): + site_id = None + + records = _credentials_service().list_credentials( + site_id=site_id, + connection_type=connection_type, + ) + return jsonify({"credentials": records}) + + +@blueprint.route("/api/credentials", methods=["POST"]) +def create_credential() -> object: # pragma: no cover - placeholder + return jsonify({"error": "not implemented"}), 501 + + +@blueprint.route("/api/credentials/", methods=["GET", "PUT", "DELETE"]) +def credential_detail(credential_id: int) -> object: # pragma: no cover - placeholder + if request.method == "GET": + return jsonify({"error": "not implemented"}), 501 + if request.method == "DELETE": + return jsonify({"error": "not implemented"}), 501 + return jsonify({"error": "not implemented"}), 501 + + +__all__ = ["register", "blueprint"] diff --git a/Data/Engine/interfaces/http/devices.py b/Data/Engine/interfaces/http/devices.py new file mode 100644 index 0000000..e618aa8 --- /dev/null +++ b/Data/Engine/interfaces/http/devices.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +from ipaddress import ip_address + +from flask import Blueprint, Flask, current_app, jsonify, request, session + +from Data.Engine.services.container import EngineServiceContainer +from Data.Engine.services.devices import RemoteDeviceError + +blueprint = Blueprint("engine_devices", __name__) + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + if "engine_devices" not in app.blueprints: + app.register_blueprint(blueprint) + + +def _services() -> EngineServiceContainer: + services = current_app.extensions.get("engine_services") + if services is None: # pragma: no cover - defensive + raise RuntimeError("engine services not initialized") + return services + + +def _inventory(): + return _services().device_inventory + + +def _views(): + return _services().device_view_service + + +def _require_admin(): + username = session.get("username") + role = (session.get("role") or "").strip().lower() + if not isinstance(username, str) or not username: + return jsonify({"error": "not_authenticated"}), 401 + if role != "admin": + return jsonify({"error": "forbidden"}), 403 + return None + + +def _is_internal_request(req: request) -> bool: + remote = (req.remote_addr or "").strip() + if not remote: + return False + try: + return ip_address(remote).is_loopback + except ValueError: + return remote in {"localhost"} + + +@blueprint.route("/api/devices", methods=["GET"]) +def list_devices() -> object: + devices = _inventory().list_devices() + return jsonify({"devices": devices}) + + +@blueprint.route("/api/devices/", methods=["GET"]) +def get_device_by_guid(guid: str) -> object: + device = _inventory().get_device_by_guid(guid) + if not device: + return jsonify({"error": "not found"}), 404 + return jsonify(device) + + +@blueprint.route("/api/agent_devices", methods=["GET"]) +def list_agent_devices() -> object: + guard = _require_admin() + if guard: + return guard + devices = _inventory().list_agent_devices() + return jsonify({"devices": devices}) + + +@blueprint.route("/api/ssh_devices", methods=["GET", "POST"]) +def ssh_devices() -> object: + return _remote_devices_endpoint("ssh") + + +@blueprint.route("/api/winrm_devices", methods=["GET", "POST"]) +def winrm_devices() -> object: + return _remote_devices_endpoint("winrm") + + +@blueprint.route("/api/ssh_devices/", methods=["PUT", "DELETE"]) +def ssh_device_detail(hostname: str) -> object: + return _remote_device_detail("ssh", hostname) + + +@blueprint.route("/api/winrm_devices/", methods=["PUT", "DELETE"]) +def winrm_device_detail(hostname: str) -> object: + return _remote_device_detail("winrm", hostname) + + +@blueprint.route("/api/agent/hash_list", methods=["GET"]) +def agent_hash_list() -> object: + if not _is_internal_request(request): + remote_addr = (request.remote_addr or "unknown").strip() or "unknown" + current_app.logger.warning( + "/api/agent/hash_list denied non-local request from %s", remote_addr + ) + return jsonify({"error": "forbidden"}), 403 + try: + records = _inventory().collect_agent_hash_records() + except Exception as exc: # pragma: no cover - defensive logging + current_app.logger.exception("/api/agent/hash_list error: %s", exc) + return jsonify({"error": "internal error"}), 500 + return jsonify({"agents": records}) + + +@blueprint.route("/api/device_list_views", methods=["GET"]) +def list_device_list_views() -> object: + views = _views().list_views() + return jsonify({"views": [view.to_dict() for view in views]}) + + +@blueprint.route("/api/device_list_views/", methods=["GET"]) +def get_device_list_view(view_id: int) -> object: + view = _views().get_view(view_id) + if not view: + return jsonify({"error": "not found"}), 404 + return jsonify(view.to_dict()) + + +@blueprint.route("/api/device_list_views", methods=["POST"]) +def create_device_list_view() -> object: + payload = request.get_json(silent=True) or {} + name = (payload.get("name") or "").strip() + columns = payload.get("columns") or [] + filters = payload.get("filters") or {} + + if not name: + return jsonify({"error": "name is required"}), 400 + if name.lower() == "default view": + return jsonify({"error": "reserved name"}), 400 + if not isinstance(columns, list) or not all(isinstance(x, str) for x in columns): + return jsonify({"error": "columns must be a list of strings"}), 400 + if not isinstance(filters, dict): + return jsonify({"error": "filters must be an object"}), 400 + + try: + view = _views().create_view(name, columns, filters) + except ValueError as exc: + if str(exc) == "duplicate": + return jsonify({"error": "name already exists"}), 409 + raise + response = jsonify(view.to_dict()) + response.status_code = 201 + return response + + +@blueprint.route("/api/device_list_views/", methods=["PUT"]) +def update_device_list_view(view_id: int) -> object: + payload = request.get_json(silent=True) or {} + updates: dict = {} + if "name" in payload: + name_val = payload.get("name") + if name_val is None: + return jsonify({"error": "name cannot be empty"}), 400 + normalized = (str(name_val) or "").strip() + if not normalized: + return jsonify({"error": "name cannot be empty"}), 400 + if normalized.lower() == "default view": + return jsonify({"error": "reserved name"}), 400 + updates["name"] = normalized + if "columns" in payload: + columns_val = payload.get("columns") + if not isinstance(columns_val, list) or not all(isinstance(x, str) for x in columns_val): + return jsonify({"error": "columns must be a list of strings"}), 400 + updates["columns"] = columns_val + if "filters" in payload: + filters_val = payload.get("filters") + if filters_val is not None and not isinstance(filters_val, dict): + return jsonify({"error": "filters must be an object"}), 400 + if filters_val is not None: + updates["filters"] = filters_val + if not updates: + return jsonify({"error": "no fields to update"}), 400 + + try: + view = _views().update_view( + view_id, + name=updates.get("name"), + columns=updates.get("columns"), + filters=updates.get("filters"), + ) + except ValueError as exc: + code = str(exc) + if code == "duplicate": + return jsonify({"error": "name already exists"}), 409 + if code == "missing_name": + return jsonify({"error": "name cannot be empty"}), 400 + if code == "reserved": + return jsonify({"error": "reserved name"}), 400 + return jsonify({"error": "invalid payload"}), 400 + except LookupError: + return jsonify({"error": "not found"}), 404 + return jsonify(view.to_dict()) + + +@blueprint.route("/api/device_list_views/", methods=["DELETE"]) +def delete_device_list_view(view_id: int) -> object: + if not _views().delete_view(view_id): + return jsonify({"error": "not found"}), 404 + return jsonify({"status": "ok"}) + + +def _remote_devices_endpoint(connection_type: str) -> object: + guard = _require_admin() + if guard: + return guard + if request.method == "GET": + devices = _inventory().list_remote_devices(connection_type) + return jsonify({"devices": devices}) + + payload = request.get_json(silent=True) or {} + hostname = (payload.get("hostname") or "").strip() + address = ( + payload.get("address") + or payload.get("connection_endpoint") + or payload.get("endpoint") + or payload.get("host") + ) + description = payload.get("description") + os_hint = payload.get("operating_system") or payload.get("os") + + if not hostname: + return jsonify({"error": "hostname is required"}), 400 + if not (address or "").strip(): + return jsonify({"error": "address is required"}), 400 + + try: + device = _inventory().upsert_remote_device( + connection_type, + hostname, + address, + description, + os_hint, + ensure_existing_type=None, + ) + except RemoteDeviceError as exc: + status = 409 if exc.code in {"conflict", "address_required"} else 500 + if exc.code == "conflict": + return jsonify({"error": str(exc)}), 409 + if exc.code == "address_required": + return jsonify({"error": "address is required"}), 400 + return jsonify({"error": str(exc)}), status + return jsonify({"device": device}), 201 + + +def _remote_device_detail(connection_type: str, hostname: str) -> object: + guard = _require_admin() + if guard: + return guard + normalized_host = (hostname or "").strip() + if not normalized_host: + return jsonify({"error": "invalid hostname"}), 400 + + if request.method == "DELETE": + try: + _inventory().delete_remote_device(connection_type, normalized_host) + except RemoteDeviceError as exc: + if exc.code == "not_found": + return jsonify({"error": "device not found"}), 404 + if exc.code == "invalid_hostname": + return jsonify({"error": "invalid hostname"}), 400 + return jsonify({"error": str(exc)}), 500 + return jsonify({"status": "ok"}) + + payload = request.get_json(silent=True) or {} + address = ( + payload.get("address") + or payload.get("connection_endpoint") + or payload.get("endpoint") + ) + description = payload.get("description") + os_hint = payload.get("operating_system") or payload.get("os") + + if address is None and description is None and os_hint is None: + return jsonify({"error": "no fields to update"}), 400 + + try: + device = _inventory().upsert_remote_device( + connection_type, + normalized_host, + address if address is not None else "", + description, + os_hint, + ensure_existing_type=connection_type, + ) + except RemoteDeviceError as exc: + if exc.code == "not_found": + return jsonify({"error": "device not found"}), 404 + if exc.code == "address_required": + return jsonify({"error": "address is required"}), 400 + return jsonify({"error": str(exc)}), 500 + return jsonify({"device": device}) + + +__all__ = ["register", "blueprint"] diff --git a/Data/Engine/interfaces/http/sites.py b/Data/Engine/interfaces/http/sites.py new file mode 100644 index 0000000..20b82fc --- /dev/null +++ b/Data/Engine/interfaces/http/sites.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from flask import Blueprint, Flask, current_app, jsonify, request + +from Data.Engine.services.container import EngineServiceContainer + +blueprint = Blueprint("engine_sites", __name__) + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + if "engine_sites" not in app.blueprints: + app.register_blueprint(blueprint) + + +def _services() -> EngineServiceContainer: + services = current_app.extensions.get("engine_services") + if services is None: # pragma: no cover - defensive + raise RuntimeError("engine services not initialized") + return services + + +def _site_service(): + return _services().site_service + + +@blueprint.route("/api/sites", methods=["GET"]) +def list_sites() -> object: + records = _site_service().list_sites() + return jsonify({"sites": [record.to_dict() for record in records]}) + + +@blueprint.route("/api/sites", methods=["POST"]) +def create_site() -> object: + payload = request.get_json(silent=True) or {} + name = payload.get("name") + description = payload.get("description") + try: + record = _site_service().create_site(name or "", description or "") + except ValueError as exc: + if str(exc) == "missing_name": + return jsonify({"error": "name is required"}), 400 + if str(exc) == "duplicate": + return jsonify({"error": "name already exists"}), 409 + raise + response = jsonify(record.to_dict()) + response.status_code = 201 + return response + + +@blueprint.route("/api/sites/delete", methods=["POST"]) +def delete_sites() -> object: + payload = request.get_json(silent=True) or {} + ids = payload.get("ids") or [] + if not isinstance(ids, list): + return jsonify({"error": "ids must be a list"}), 400 + deleted = _site_service().delete_sites(ids) + return jsonify({"status": "ok", "deleted": deleted}) + + +@blueprint.route("/api/sites/device_map", methods=["GET"]) +def sites_device_map() -> object: + host_param = (request.args.get("hostnames") or "").strip() + filter_set = [] + if host_param: + for part in host_param.split(","): + normalized = part.strip() + if normalized: + filter_set.append(normalized) + mapping = _site_service().map_devices(filter_set or None) + return jsonify({"mapping": {hostname: entry.to_dict() for hostname, entry in mapping.items()}}) + + +@blueprint.route("/api/sites/assign", methods=["POST"]) +def assign_devices_to_site() -> object: + payload = request.get_json(silent=True) or {} + site_id = payload.get("site_id") + hostnames = payload.get("hostnames") or [] + if not isinstance(hostnames, list): + return jsonify({"error": "hostnames must be a list of strings"}), 400 + try: + _site_service().assign_devices(site_id, hostnames) + except ValueError as exc: + message = str(exc) + if message == "invalid_site_id": + return jsonify({"error": "invalid site_id"}), 400 + if message == "invalid_hostnames": + return jsonify({"error": "hostnames must be a list of strings"}), 400 + raise + except LookupError: + return jsonify({"error": "site not found"}), 404 + return jsonify({"status": "ok"}) + + +@blueprint.route("/api/sites/rename", methods=["POST"]) +def rename_site() -> object: + payload = request.get_json(silent=True) or {} + site_id = payload.get("id") + new_name = payload.get("new_name") or "" + try: + record = _site_service().rename_site(site_id, new_name) + except ValueError as exc: + if str(exc) == "missing_name": + return jsonify({"error": "new_name is required"}), 400 + if str(exc) == "duplicate": + return jsonify({"error": "name already exists"}), 409 + raise + except LookupError: + return jsonify({"error": "site not found"}), 404 + return jsonify(record.to_dict()) + + +__all__ = ["register", "blueprint"] diff --git a/Data/Engine/repositories/sqlite/__init__.py b/Data/Engine/repositories/sqlite/__init__.py index 8b44e59..98e6296 100644 --- a/Data/Engine/repositories/sqlite/__init__.py +++ b/Data/Engine/repositories/sqlite/__init__.py @@ -24,8 +24,12 @@ __all__ = [ try: # pragma: no cover - optional dependency shim from .device_repository import SQLiteDeviceRepository from .enrollment_repository import SQLiteEnrollmentRepository + from .device_inventory_repository import SQLiteDeviceInventoryRepository + from .device_view_repository import SQLiteDeviceViewRepository + from .credential_repository import SQLiteCredentialRepository from .github_repository import SQLiteGitHubRepository from .job_repository import SQLiteJobRepository + from .site_repository import SQLiteSiteRepository from .token_repository import SQLiteRefreshTokenRepository from .user_repository import SQLiteUserRepository except ModuleNotFoundError as exc: # pragma: no cover - triggered when auth deps missing @@ -36,8 +40,12 @@ except ModuleNotFoundError as exc: # pragma: no cover - triggered when auth dep SQLiteDeviceRepository = _missing_repo # type: ignore[assignment] SQLiteEnrollmentRepository = _missing_repo # type: ignore[assignment] + SQLiteDeviceInventoryRepository = _missing_repo # type: ignore[assignment] + SQLiteDeviceViewRepository = _missing_repo # type: ignore[assignment] + SQLiteCredentialRepository = _missing_repo # type: ignore[assignment] SQLiteGitHubRepository = _missing_repo # type: ignore[assignment] SQLiteJobRepository = _missing_repo # type: ignore[assignment] + SQLiteSiteRepository = _missing_repo # type: ignore[assignment] SQLiteRefreshTokenRepository = _missing_repo # type: ignore[assignment] else: __all__ += [ @@ -45,6 +53,10 @@ else: "SQLiteRefreshTokenRepository", "SQLiteJobRepository", "SQLiteEnrollmentRepository", + "SQLiteDeviceInventoryRepository", + "SQLiteDeviceViewRepository", + "SQLiteCredentialRepository", "SQLiteGitHubRepository", "SQLiteUserRepository", + "SQLiteSiteRepository", ] diff --git a/Data/Engine/repositories/sqlite/credential_repository.py b/Data/Engine/repositories/sqlite/credential_repository.py new file mode 100644 index 0000000..bde1c67 --- /dev/null +++ b/Data/Engine/repositories/sqlite/credential_repository.py @@ -0,0 +1,103 @@ +"""SQLite access for operator credential metadata.""" + +from __future__ import annotations + +import json +import logging +import sqlite3 +from contextlib import closing +from typing import Dict, List, Optional + +from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory + +__all__ = ["SQLiteCredentialRepository"] + + +class SQLiteCredentialRepository: + def __init__( + self, + connection_factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._connections = connection_factory + self._log = logger or logging.getLogger("borealis.engine.repositories.credentials") + + def list_credentials( + self, + *, + site_id: Optional[int] = None, + connection_type: Optional[str] = None, + ) -> List[Dict[str, object]]: + sql = """ + SELECT c.id, + c.name, + c.description, + c.credential_type, + c.connection_type, + c.username, + c.site_id, + s.name AS site_name, + c.become_method, + c.become_username, + c.metadata_json, + c.created_at, + c.updated_at, + c.password_encrypted, + c.private_key_encrypted, + c.private_key_passphrase_encrypted, + c.become_password_encrypted + FROM credentials c + LEFT JOIN sites s ON s.id = c.site_id + """ + clauses: List[str] = [] + params: List[object] = [] + if site_id is not None: + clauses.append("c.site_id = ?") + params.append(site_id) + if connection_type: + clauses.append("LOWER(c.connection_type) = LOWER(?)") + params.append(connection_type) + if clauses: + sql += " WHERE " + " AND ".join(clauses) + sql += " ORDER BY LOWER(c.name) ASC" + + with closing(self._connections()) as conn: + conn.row_factory = sqlite3.Row # type: ignore[attr-defined] + cur = conn.cursor() + cur.execute(sql, params) + rows = cur.fetchall() + + results: List[Dict[str, object]] = [] + for row in rows: + metadata_json = row["metadata_json"] if "metadata_json" in row.keys() else None + metadata = {} + if metadata_json: + try: + candidate = json.loads(metadata_json) + if isinstance(candidate, dict): + metadata = candidate + except Exception: + metadata = {} + results.append( + { + "id": row["id"], + "name": row["name"], + "description": row["description"] or "", + "credential_type": row["credential_type"] or "machine", + "connection_type": row["connection_type"] or "ssh", + "site_id": row["site_id"], + "site_name": row["site_name"], + "username": row["username"] or "", + "become_method": row["become_method"] or "", + "become_username": row["become_username"] or "", + "metadata": metadata, + "created_at": int(row["created_at"] or 0), + "updated_at": int(row["updated_at"] or 0), + "has_password": bool(row["password_encrypted"]), + "has_private_key": bool(row["private_key_encrypted"]), + "has_private_key_passphrase": bool(row["private_key_passphrase_encrypted"]), + "has_become_password": bool(row["become_password_encrypted"]), + } + ) + return results diff --git a/Data/Engine/repositories/sqlite/device_inventory_repository.py b/Data/Engine/repositories/sqlite/device_inventory_repository.py new file mode 100644 index 0000000..8ae5767 --- /dev/null +++ b/Data/Engine/repositories/sqlite/device_inventory_repository.py @@ -0,0 +1,253 @@ +"""Device inventory operations backed by SQLite.""" + +from __future__ import annotations + +import logging +import sqlite3 +import time +from contextlib import closing +from typing import Any, Dict, List, Optional, Tuple + +from Data.Engine.domain.devices import ( + DEVICE_TABLE, + DEVICE_TABLE_COLUMNS, + assemble_device_snapshot, + clean_device_str, + coerce_int, + device_column_sql, + row_to_device_dict, + serialize_device_json, +) +from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory + +__all__ = ["SQLiteDeviceInventoryRepository"] + + +class SQLiteDeviceInventoryRepository: + def __init__( + self, + connection_factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._connections = connection_factory + self._log = logger or logging.getLogger("borealis.engine.repositories.device_inventory") + + def fetch_devices( + self, + *, + connection_type: Optional[str] = None, + hostname: Optional[str] = None, + only_agents: bool = False, + ) -> List[Dict[str, Any]]: + sql = f""" + SELECT {device_column_sql('d')}, s.id, s.name, s.description + FROM {DEVICE_TABLE} d + LEFT JOIN device_sites ds ON ds.device_hostname = d.hostname + LEFT JOIN sites s ON s.id = ds.site_id + """ + clauses: List[str] = [] + params: List[Any] = [] + if connection_type: + clauses.append("LOWER(d.connection_type) = LOWER(?)") + params.append(connection_type) + if hostname: + clauses.append("LOWER(d.hostname) = LOWER(?)") + params.append(hostname.lower()) + if only_agents: + clauses.append("(d.connection_type IS NULL OR TRIM(d.connection_type) = '')") + if clauses: + sql += " WHERE " + " AND ".join(clauses) + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute(sql, params) + rows = cur.fetchall() + + now = time.time() + devices: List[Dict[str, Any]] = [] + for row in rows: + core = row[: len(DEVICE_TABLE_COLUMNS)] + site_id, site_name, site_description = row[len(DEVICE_TABLE_COLUMNS) :] + record = row_to_device_dict(core, DEVICE_TABLE_COLUMNS) + snapshot = assemble_device_snapshot(record) + summary = snapshot.get("summary", {}) + last_seen = snapshot.get("last_seen") or 0 + status = "Offline" + try: + if last_seen and (now - float(last_seen)) <= 300: + status = "Online" + except Exception: + pass + devices.append( + { + **snapshot, + "site_id": site_id, + "site_name": site_name or "", + "site_description": site_description or "", + "status": status, + } + ) + return devices + + def load_snapshot(self, *, hostname: Optional[str] = None, guid: Optional[str] = None) -> Optional[Dict[str, Any]]: + if not hostname and not guid: + return None + sql = None + params: Tuple[Any, ...] + if hostname: + sql = f"SELECT {device_column_sql()} FROM {DEVICE_TABLE} WHERE hostname = ?" + params = (hostname,) + else: + sql = f"SELECT {device_column_sql()} FROM {DEVICE_TABLE} WHERE LOWER(guid) = LOWER(?)" + params = (guid,) + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute(sql, params) + row = cur.fetchone() + if not row: + return None + record = row_to_device_dict(row, DEVICE_TABLE_COLUMNS) + return assemble_device_snapshot(record) + + def upsert_device( + self, + hostname: str, + description: Optional[str], + merged_details: Dict[str, Any], + created_at: Optional[int], + *, + agent_hash: Optional[str] = None, + guid: Optional[str] = None, + ) -> None: + if not hostname: + return + + column_values = self._extract_device_columns(merged_details or {}) + normalized_description = description if description is not None else "" + try: + normalized_description = str(normalized_description) + except Exception: + normalized_description = "" + + normalized_hash = clean_device_str(agent_hash) or None + normalized_guid = clean_device_str(guid) or None + created_ts = coerce_int(created_at) or int(time.time()) + + sql = f""" + INSERT INTO {DEVICE_TABLE}( + hostname, + description, + created_at, + agent_hash, + guid, + memory, + network, + software, + storage, + cpu, + device_type, + domain, + external_ip, + internal_ip, + last_reboot, + last_seen, + last_user, + operating_system, + uptime, + agent_id, + ansible_ee_ver, + connection_type, + connection_endpoint + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) + ON CONFLICT(hostname) DO UPDATE SET + description=excluded.description, + created_at=COALESCE({DEVICE_TABLE}.created_at, excluded.created_at), + agent_hash=COALESCE(NULLIF(excluded.agent_hash, ''), {DEVICE_TABLE}.agent_hash), + guid=COALESCE(NULLIF(excluded.guid, ''), {DEVICE_TABLE}.guid), + memory=excluded.memory, + network=excluded.network, + software=excluded.software, + storage=excluded.storage, + cpu=excluded.cpu, + device_type=COALESCE(NULLIF(excluded.device_type, ''), {DEVICE_TABLE}.device_type), + domain=COALESCE(NULLIF(excluded.domain, ''), {DEVICE_TABLE}.domain), + external_ip=COALESCE(NULLIF(excluded.external_ip, ''), {DEVICE_TABLE}.external_ip), + internal_ip=COALESCE(NULLIF(excluded.internal_ip, ''), {DEVICE_TABLE}.internal_ip), + last_reboot=COALESCE(NULLIF(excluded.last_reboot, ''), {DEVICE_TABLE}.last_reboot), + last_seen=COALESCE(NULLIF(excluded.last_seen, 0), {DEVICE_TABLE}.last_seen), + last_user=COALESCE(NULLIF(excluded.last_user, ''), {DEVICE_TABLE}.last_user), + operating_system=COALESCE(NULLIF(excluded.operating_system, ''), {DEVICE_TABLE}.operating_system), + uptime=COALESCE(NULLIF(excluded.uptime, 0), {DEVICE_TABLE}.uptime), + agent_id=COALESCE(NULLIF(excluded.agent_id, ''), {DEVICE_TABLE}.agent_id), + ansible_ee_ver=COALESCE(NULLIF(excluded.ansible_ee_ver, ''), {DEVICE_TABLE}.ansible_ee_ver), + connection_type=COALESCE(NULLIF(excluded.connection_type, ''), {DEVICE_TABLE}.connection_type), + connection_endpoint=COALESCE(NULLIF(excluded.connection_endpoint, ''), {DEVICE_TABLE}.connection_endpoint) + """ + + params: List[Any] = [ + hostname, + normalized_description, + created_ts, + normalized_hash, + normalized_guid, + column_values.get("memory"), + column_values.get("network"), + column_values.get("software"), + column_values.get("storage"), + column_values.get("cpu"), + column_values.get("device_type"), + column_values.get("domain"), + column_values.get("external_ip"), + column_values.get("internal_ip"), + column_values.get("last_reboot"), + column_values.get("last_seen"), + column_values.get("last_user"), + column_values.get("operating_system"), + column_values.get("uptime"), + column_values.get("agent_id"), + column_values.get("ansible_ee_ver"), + column_values.get("connection_type"), + column_values.get("connection_endpoint"), + ] + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute(sql, params) + conn.commit() + + def delete_device_by_hostname(self, hostname: str) -> None: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute("DELETE FROM device_sites WHERE device_hostname = ?", (hostname,)) + cur.execute(f"DELETE FROM {DEVICE_TABLE} WHERE hostname = ?", (hostname,)) + conn.commit() + + def _extract_device_columns(self, details: Dict[str, Any]) -> Dict[str, Any]: + summary = details.get("summary") or {} + payload: Dict[str, Any] = {} + for field in ("memory", "network", "software", "storage"): + payload[field] = serialize_device_json(details.get(field), []) + payload["cpu"] = serialize_device_json(summary.get("cpu") or details.get("cpu"), {}) + payload["device_type"] = clean_device_str(summary.get("device_type") or summary.get("type")) + payload["domain"] = clean_device_str(summary.get("domain")) + payload["external_ip"] = clean_device_str(summary.get("external_ip") or summary.get("public_ip")) + payload["internal_ip"] = clean_device_str(summary.get("internal_ip") or summary.get("private_ip")) + payload["last_reboot"] = clean_device_str(summary.get("last_reboot") or summary.get("last_boot")) + payload["last_seen"] = coerce_int(summary.get("last_seen")) + payload["last_user"] = clean_device_str( + summary.get("last_user") + or summary.get("last_user_name") + or summary.get("logged_in_user") + ) + payload["operating_system"] = clean_device_str( + summary.get("operating_system") or summary.get("os") + ) + payload["uptime"] = coerce_int(summary.get("uptime")) + payload["agent_id"] = clean_device_str(summary.get("agent_id")) + payload["ansible_ee_ver"] = clean_device_str(summary.get("ansible_ee_ver")) + payload["connection_type"] = clean_device_str(summary.get("connection_type")) + payload["connection_endpoint"] = clean_device_str( + summary.get("connection_endpoint") or summary.get("endpoint") + ) + return payload diff --git a/Data/Engine/repositories/sqlite/device_view_repository.py b/Data/Engine/repositories/sqlite/device_view_repository.py new file mode 100644 index 0000000..f579f1e --- /dev/null +++ b/Data/Engine/repositories/sqlite/device_view_repository.py @@ -0,0 +1,143 @@ +"""SQLite persistence for device list views.""" + +from __future__ import annotations + +import json +import logging +import sqlite3 +import time +from contextlib import closing +from typing import Dict, Iterable, List, Optional + +from Data.Engine.domain.device_views import DeviceListView +from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory + +__all__ = ["SQLiteDeviceViewRepository"] + + +class SQLiteDeviceViewRepository: + def __init__( + self, + connection_factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._connections = connection_factory + self._log = logger or logging.getLogger("borealis.engine.repositories.device_views") + + def list_views(self) -> List[DeviceListView]: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + "SELECT id, name, columns_json, filters_json, created_at, updated_at\n" + " FROM device_list_views ORDER BY name COLLATE NOCASE ASC" + ) + rows = cur.fetchall() + return [self._row_to_view(row) for row in rows] + + def get_view(self, view_id: int) -> Optional[DeviceListView]: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + "SELECT id, name, columns_json, filters_json, created_at, updated_at\n" + " FROM device_list_views WHERE id = ?", + (view_id,), + ) + row = cur.fetchone() + return self._row_to_view(row) if row else None + + def create_view(self, name: str, columns: List[str], filters: Dict[str, object]) -> DeviceListView: + now = int(time.time()) + with closing(self._connections()) as conn: + cur = conn.cursor() + try: + cur.execute( + "INSERT INTO device_list_views(name, columns_json, filters_json, created_at, updated_at)\n" + "VALUES (?, ?, ?, ?, ?)", + (name, json.dumps(columns), json.dumps(filters), now, now), + ) + except sqlite3.IntegrityError as exc: + raise ValueError("duplicate") from exc + view_id = cur.lastrowid + conn.commit() + cur.execute( + "SELECT id, name, columns_json, filters_json, created_at, updated_at FROM device_list_views WHERE id = ?", + (view_id,), + ) + row = cur.fetchone() + if not row: + raise RuntimeError("view missing after insert") + return self._row_to_view(row) + + def update_view( + self, + view_id: int, + *, + name: Optional[str] = None, + columns: Optional[List[str]] = None, + filters: Optional[Dict[str, object]] = None, + ) -> DeviceListView: + fields: List[str] = [] + params: List[object] = [] + if name is not None: + fields.append("name = ?") + params.append(name) + if columns is not None: + fields.append("columns_json = ?") + params.append(json.dumps(columns)) + if filters is not None: + fields.append("filters_json = ?") + params.append(json.dumps(filters)) + fields.append("updated_at = ?") + params.append(int(time.time())) + params.append(view_id) + + with closing(self._connections()) as conn: + cur = conn.cursor() + try: + cur.execute( + f"UPDATE device_list_views SET {', '.join(fields)} WHERE id = ?", + params, + ) + except sqlite3.IntegrityError as exc: + raise ValueError("duplicate") from exc + if cur.rowcount == 0: + raise LookupError("not_found") + conn.commit() + cur.execute( + "SELECT id, name, columns_json, filters_json, created_at, updated_at FROM device_list_views WHERE id = ?", + (view_id,), + ) + row = cur.fetchone() + if not row: + raise LookupError("not_found") + return self._row_to_view(row) + + def delete_view(self, view_id: int) -> bool: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute("DELETE FROM device_list_views WHERE id = ?", (view_id,)) + deleted = cur.rowcount + conn.commit() + return bool(deleted) + + def _row_to_view(self, row: Optional[Iterable[object]]) -> DeviceListView: + if row is None: + raise ValueError("row required") + view_id, name, columns_json, filters_json, created_at, updated_at = row + try: + columns = json.loads(columns_json or "[]") + except Exception: + columns = [] + try: + filters = json.loads(filters_json or "{}") + except Exception: + filters = {} + return DeviceListView( + id=int(view_id), + name=str(name or ""), + columns=list(columns) if isinstance(columns, list) else [], + filters=dict(filters) if isinstance(filters, dict) else {}, + created_at=int(created_at or 0), + updated_at=int(updated_at or 0), + ) diff --git a/Data/Engine/repositories/sqlite/site_repository.py b/Data/Engine/repositories/sqlite/site_repository.py new file mode 100644 index 0000000..25a9967 --- /dev/null +++ b/Data/Engine/repositories/sqlite/site_repository.py @@ -0,0 +1,189 @@ +"""SQLite persistence for site management.""" + +from __future__ import annotations + +import logging +import sqlite3 +import time +from contextlib import closing +from typing import Dict, Iterable, List, Optional, Sequence + +from Data.Engine.domain.sites import SiteDeviceMapping, SiteSummary +from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory + +__all__ = ["SQLiteSiteRepository"] + + +class SQLiteSiteRepository: + """Repository exposing site CRUD and device assignment helpers.""" + + def __init__( + self, + connection_factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._connections = connection_factory + self._log = logger or logging.getLogger("borealis.engine.repositories.sites") + + def list_sites(self) -> List[SiteSummary]: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT s.id, s.name, s.description, s.created_at, + COALESCE(ds.cnt, 0) AS device_count + FROM sites s + LEFT JOIN ( + SELECT site_id, COUNT(*) AS cnt + FROM device_sites + GROUP BY site_id + ) ds + ON ds.site_id = s.id + ORDER BY LOWER(s.name) ASC + """ + ) + rows = cur.fetchall() + return [self._row_to_site(row) for row in rows] + + def create_site(self, name: str, description: str) -> SiteSummary: + now = int(time.time()) + with closing(self._connections()) as conn: + cur = conn.cursor() + try: + cur.execute( + "INSERT INTO sites(name, description, created_at) VALUES (?, ?, ?)", + (name, description, now), + ) + except sqlite3.IntegrityError as exc: + raise ValueError("duplicate") from exc + site_id = cur.lastrowid + conn.commit() + + cur.execute( + "SELECT id, name, description, created_at, 0 FROM sites WHERE id = ?", + (site_id,), + ) + row = cur.fetchone() + if not row: + raise RuntimeError("site not found after insert") + return self._row_to_site(row) + + def delete_sites(self, ids: Sequence[int]) -> int: + if not ids: + return 0 + with closing(self._connections()) as conn: + cur = conn.cursor() + placeholders = ",".join("?" for _ in ids) + try: + cur.execute( + f"DELETE FROM device_sites WHERE site_id IN ({placeholders})", + tuple(ids), + ) + cur.execute( + f"DELETE FROM sites WHERE id IN ({placeholders})", + tuple(ids), + ) + except sqlite3.DatabaseError as exc: + conn.rollback() + raise + deleted = cur.rowcount + conn.commit() + return deleted + + def rename_site(self, site_id: int, new_name: str) -> SiteSummary: + with closing(self._connections()) as conn: + cur = conn.cursor() + try: + cur.execute("UPDATE sites SET name = ? WHERE id = ?", (new_name, site_id)) + except sqlite3.IntegrityError as exc: + raise ValueError("duplicate") from exc + if cur.rowcount == 0: + raise LookupError("not_found") + conn.commit() + cur.execute( + """ + SELECT s.id, s.name, s.description, s.created_at, + COALESCE(ds.cnt, 0) AS device_count + FROM sites s + LEFT JOIN ( + SELECT site_id, COUNT(*) AS cnt + FROM device_sites + GROUP BY site_id + ) ds + ON ds.site_id = s.id + WHERE s.id = ? + """, + (site_id,), + ) + row = cur.fetchone() + if not row: + raise LookupError("not_found") + return self._row_to_site(row) + + def map_devices(self, hostnames: Optional[Iterable[str]] = None) -> Dict[str, SiteDeviceMapping]: + with closing(self._connections()) as conn: + cur = conn.cursor() + if hostnames: + normalized = [hn.strip() for hn in hostnames if hn and hn.strip()] + if not normalized: + return {} + placeholders = ",".join("?" for _ in normalized) + cur.execute( + f""" + SELECT ds.device_hostname, s.id, s.name + FROM device_sites ds + INNER JOIN sites s ON s.id = ds.site_id + WHERE ds.device_hostname IN ({placeholders}) + """, + tuple(normalized), + ) + else: + cur.execute( + """ + SELECT ds.device_hostname, s.id, s.name + FROM device_sites ds + INNER JOIN sites s ON s.id = ds.site_id + """ + ) + rows = cur.fetchall() + mapping: Dict[str, SiteDeviceMapping] = {} + for hostname, site_id, site_name in rows: + mapping[str(hostname)] = SiteDeviceMapping( + hostname=str(hostname), + site_id=int(site_id) if site_id is not None else None, + site_name=str(site_name or ""), + ) + return mapping + + def assign_devices(self, site_id: int, hostnames: Sequence[str]) -> None: + now = int(time.time()) + normalized = [hn.strip() for hn in hostnames if isinstance(hn, str) and hn.strip()] + if not normalized: + return + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute("SELECT 1 FROM sites WHERE id = ?", (site_id,)) + if not cur.fetchone(): + raise LookupError("not_found") + for hostname in normalized: + cur.execute( + """ + INSERT INTO device_sites(device_hostname, site_id, assigned_at) + VALUES (?, ?, ?) + ON CONFLICT(device_hostname) + DO UPDATE SET site_id = excluded.site_id, + assigned_at = excluded.assigned_at + """, + (hostname, site_id, now), + ) + conn.commit() + + def _row_to_site(self, row: Sequence[object]) -> SiteSummary: + return SiteSummary( + id=int(row[0]), + name=str(row[1] or ""), + description=str(row[2] or ""), + created_at=int(row[3] or 0), + device_count=int(row[4] or 0), + ) diff --git a/Data/Engine/services/__init__.py b/Data/Engine/services/__init__.py index 9c59917..22d8e14 100644 --- a/Data/Engine/services/__init__.py +++ b/Data/Engine/services/__init__.py @@ -24,6 +24,10 @@ __all__ = [ "GitHubService", "GitHubTokenPayload", "EnrollmentAdminService", + "SiteService", + "DeviceInventoryService", + "DeviceViewService", + "CredentialService", ] _LAZY_TARGETS: Dict[str, Tuple[str, str]] = { @@ -48,6 +52,19 @@ _LAZY_TARGETS: Dict[str, Tuple[str, str]] = { "Data.Engine.services.enrollment.admin_service", "EnrollmentAdminService", ), + "SiteService": ("Data.Engine.services.sites.site_service", "SiteService"), + "DeviceInventoryService": ( + "Data.Engine.services.devices.device_inventory_service", + "DeviceInventoryService", + ), + "DeviceViewService": ( + "Data.Engine.services.devices.device_view_service", + "DeviceViewService", + ), + "CredentialService": ( + "Data.Engine.services.credentials.credential_service", + "CredentialService", + ), } diff --git a/Data/Engine/services/container.py b/Data/Engine/services/container.py index bbb731b..a544b7c 100644 --- a/Data/Engine/services/container.py +++ b/Data/Engine/services/container.py @@ -13,10 +13,14 @@ from Data.Engine.integrations.github import GitHubArtifactProvider from Data.Engine.repositories.sqlite import ( SQLiteConnectionFactory, SQLiteDeviceRepository, + SQLiteDeviceInventoryRepository, + SQLiteDeviceViewRepository, + SQLiteCredentialRepository, SQLiteEnrollmentRepository, SQLiteGitHubRepository, SQLiteJobRepository, SQLiteRefreshTokenRepository, + SQLiteSiteRepository, SQLiteUserRepository, ) from Data.Engine.services.auth import ( @@ -32,10 +36,14 @@ from Data.Engine.services.crypto.signing import ScriptSigner, load_signer from Data.Engine.services.enrollment import EnrollmentService from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService from Data.Engine.services.enrollment.nonce_cache import NonceCache +from Data.Engine.services.devices import DeviceInventoryService +from Data.Engine.services.devices import DeviceViewService +from Data.Engine.services.credentials import CredentialService from Data.Engine.services.github import GitHubService from Data.Engine.services.jobs import SchedulerService from Data.Engine.services.rate_limit import SlidingWindowRateLimiter from Data.Engine.services.realtime import AgentRealtimeService +from Data.Engine.services.sites import SiteService __all__ = ["EngineServiceContainer", "build_service_container"] @@ -43,9 +51,13 @@ __all__ = ["EngineServiceContainer", "build_service_container"] @dataclass(frozen=True, slots=True) class EngineServiceContainer: device_auth: DeviceAuthService + device_inventory: DeviceInventoryService + device_view_service: DeviceViewService + credential_service: CredentialService token_service: TokenService enrollment_service: EnrollmentService enrollment_admin_service: EnrollmentAdminService + site_service: SiteService jwt_service: JWTService dpop_validator: DPoPValidator agent_realtime: AgentRealtimeService @@ -64,10 +76,20 @@ def build_service_container( log = logger or logging.getLogger("borealis.engine.services") device_repo = SQLiteDeviceRepository(db_factory, logger=log.getChild("devices")) + device_inventory_repo = SQLiteDeviceInventoryRepository( + db_factory, logger=log.getChild("devices.inventory") + ) + device_view_repo = SQLiteDeviceViewRepository( + db_factory, logger=log.getChild("devices.views") + ) + credential_repo = SQLiteCredentialRepository( + db_factory, logger=log.getChild("credentials.repo") + ) token_repo = SQLiteRefreshTokenRepository(db_factory, logger=log.getChild("tokens")) enrollment_repo = SQLiteEnrollmentRepository(db_factory, logger=log.getChild("enrollment")) job_repo = SQLiteJobRepository(db_factory, logger=log.getChild("jobs")) github_repo = SQLiteGitHubRepository(db_factory, logger=log.getChild("github_repo")) + site_repo = SQLiteSiteRepository(db_factory, logger=log.getChild("sites.repo")) user_repo = SQLiteUserRepository(db_factory, logger=log.getChild("users")) jwt_service = load_jwt_service() @@ -128,6 +150,22 @@ def build_service_container( repository=user_repo, logger=log.getChild("operator_accounts"), ) + device_inventory = DeviceInventoryService( + repository=device_inventory_repo, + logger=log.getChild("device_inventory"), + ) + device_view_service = DeviceViewService( + repository=device_view_repo, + logger=log.getChild("device_views"), + ) + credential_service = CredentialService( + repository=credential_repo, + logger=log.getChild("credentials"), + ) + site_service = SiteService( + repository=site_repo, + logger=log.getChild("sites"), + ) github_provider = GitHubArtifactProvider( cache_file=settings.github.cache_file, @@ -155,6 +193,10 @@ def build_service_container( github_service=github_service, operator_auth_service=operator_auth_service, operator_account_service=operator_account_service, + device_inventory=device_inventory, + device_view_service=device_view_service, + credential_service=credential_service, + site_service=site_service, ) diff --git a/Data/Engine/services/credentials/__init__.py b/Data/Engine/services/credentials/__init__.py new file mode 100644 index 0000000..1d6ce8b --- /dev/null +++ b/Data/Engine/services/credentials/__init__.py @@ -0,0 +1,3 @@ +from .credential_service import CredentialService + +__all__ = ["CredentialService"] diff --git a/Data/Engine/services/credentials/credential_service.py b/Data/Engine/services/credentials/credential_service.py new file mode 100644 index 0000000..e141293 --- /dev/null +++ b/Data/Engine/services/credentials/credential_service.py @@ -0,0 +1,29 @@ +"""Expose read access to stored credentials.""" + +from __future__ import annotations + +import logging +from typing import List, Optional + +from Data.Engine.repositories.sqlite.credential_repository import SQLiteCredentialRepository + +__all__ = ["CredentialService"] + + +class CredentialService: + def __init__( + self, + repository: SQLiteCredentialRepository, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._repo = repository + self._log = logger or logging.getLogger("borealis.engine.services.credentials") + + def list_credentials( + self, + *, + site_id: Optional[int] = None, + connection_type: Optional[str] = None, + ) -> List[dict]: + return self._repo.list_credentials(site_id=site_id, connection_type=connection_type) diff --git a/Data/Engine/services/devices/__init__.py b/Data/Engine/services/devices/__init__.py new file mode 100644 index 0000000..d659909 --- /dev/null +++ b/Data/Engine/services/devices/__init__.py @@ -0,0 +1,4 @@ +from .device_inventory_service import DeviceInventoryService, RemoteDeviceError +from .device_view_service import DeviceViewService + +__all__ = ["DeviceInventoryService", "RemoteDeviceError", "DeviceViewService"] diff --git a/Data/Engine/services/devices/device_inventory_service.py b/Data/Engine/services/devices/device_inventory_service.py new file mode 100644 index 0000000..031e789 --- /dev/null +++ b/Data/Engine/services/devices/device_inventory_service.py @@ -0,0 +1,178 @@ +"""Mirrors the legacy device inventory HTTP behaviour.""" + +from __future__ import annotations + +import logging +import sqlite3 +from typing import Dict, List, Optional + +from Data.Engine.repositories.sqlite.device_inventory_repository import ( + SQLiteDeviceInventoryRepository, +) + +__all__ = ["DeviceInventoryService", "RemoteDeviceError"] + + +class RemoteDeviceError(Exception): + def __init__(self, code: str, message: Optional[str] = None) -> None: + super().__init__(message or code) + self.code = code + + +class DeviceInventoryService: + def __init__( + self, + repository: SQLiteDeviceInventoryRepository, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._repo = repository + self._log = logger or logging.getLogger("borealis.engine.services.devices") + + def list_devices(self) -> List[Dict[str, object]]: + return self._repo.fetch_devices() + + def list_agent_devices(self) -> List[Dict[str, object]]: + return self._repo.fetch_devices(only_agents=True) + + def list_remote_devices(self, connection_type: str) -> List[Dict[str, object]]: + return self._repo.fetch_devices(connection_type=connection_type) + + def get_device_by_guid(self, guid: str) -> Optional[Dict[str, object]]: + snapshot = self._repo.load_snapshot(guid=guid) + if not snapshot: + return None + devices = self._repo.fetch_devices(hostname=snapshot.get("hostname")) + return devices[0] if devices else None + + def collect_agent_hash_records(self) -> List[Dict[str, object]]: + records: List[Dict[str, object]] = [] + key_to_index: Dict[str, int] = {} + + for device in self._repo.fetch_devices(): + summary = device.get("summary", {}) if isinstance(device, dict) else {} + agent_id = (summary.get("agent_id") or "").strip() + agent_guid = (summary.get("agent_guid") or "").strip() + hostname = (summary.get("hostname") or device.get("hostname") or "").strip() + agent_hash = (summary.get("agent_hash") or device.get("agent_hash") or "").strip() + + keys: List[str] = [] + if agent_id: + keys.append(f"id:{agent_id.lower()}") + if agent_guid: + keys.append(f"guid:{agent_guid.lower()}") + if hostname: + keys.append(f"host:{hostname.lower()}") + + payload = { + "agent_id": agent_id or None, + "agent_guid": agent_guid or None, + "hostname": hostname or None, + "agent_hash": agent_hash or None, + "source": "database", + } + + if not keys: + records.append(payload) + continue + + existing_index = None + for key in keys: + if key in key_to_index: + existing_index = key_to_index[key] + break + + if existing_index is None: + existing_index = len(records) + records.append(payload) + for key in keys: + key_to_index[key] = existing_index + continue + + merged = records[existing_index] + for key in ("agent_id", "agent_guid", "hostname", "agent_hash"): + if not merged.get(key) and payload.get(key): + merged[key] = payload[key] + + return records + + def upsert_remote_device( + self, + connection_type: str, + hostname: str, + address: Optional[str], + description: Optional[str], + os_hint: Optional[str], + *, + ensure_existing_type: Optional[str], + ) -> Dict[str, object]: + normalized_type = (connection_type or "").strip().lower() + if not normalized_type: + raise RemoteDeviceError("invalid_type", "connection type required") + normalized_host = (hostname or "").strip() + if not normalized_host: + raise RemoteDeviceError("invalid_hostname", "hostname is required") + + existing = self._repo.load_snapshot(hostname=normalized_host) + existing_type = (existing or {}).get("summary", {}).get("connection_type") or "" + existing_type = existing_type.strip().lower() + + if ensure_existing_type and existing_type != ensure_existing_type.lower(): + raise RemoteDeviceError("not_found", "device not found") + if ensure_existing_type is None and existing_type and existing_type != normalized_type: + raise RemoteDeviceError("conflict", "device already exists with different connection type") + + created_ts = None + if existing: + created_ts = existing.get("summary", {}).get("created_at") + + endpoint = (address or "").strip() or (existing or {}).get("summary", {}).get("connection_endpoint") or "" + if not endpoint: + raise RemoteDeviceError("address_required", "address is required") + + description_val = description if description is not None else (existing or {}).get("summary", {}).get("description") + os_value = os_hint or (existing or {}).get("summary", {}).get("operating_system") + os_value = (os_value or "").strip() + + device_type_label = "SSH Remote" if normalized_type == "ssh" else "WinRM Remote" + + summary_payload = { + "connection_type": normalized_type, + "connection_endpoint": endpoint, + "internal_ip": endpoint, + "external_ip": endpoint, + "device_type": device_type_label, + "operating_system": os_value or "", + "last_seen": 0, + "description": (description_val or ""), + } + + try: + self._repo.upsert_device( + normalized_host, + description_val, + {"summary": summary_payload}, + created_ts, + ) + except sqlite3.DatabaseError as exc: # type: ignore[name-defined] + raise RemoteDeviceError("storage_error", str(exc)) from exc + except Exception as exc: # pragma: no cover - defensive + raise RemoteDeviceError("storage_error", str(exc)) from exc + + devices = self._repo.fetch_devices(hostname=normalized_host) + if not devices: + raise RemoteDeviceError("reload_failed", "failed to load device after upsert") + return devices[0] + + def delete_remote_device(self, connection_type: str, hostname: str) -> None: + normalized_host = (hostname or "").strip() + if not normalized_host: + raise RemoteDeviceError("invalid_hostname", "invalid hostname") + existing = self._repo.load_snapshot(hostname=normalized_host) + if not existing: + raise RemoteDeviceError("not_found", "device not found") + existing_type = (existing.get("summary", {}) or {}).get("connection_type") or "" + if (existing_type or "").strip().lower() != (connection_type or "").strip().lower(): + raise RemoteDeviceError("not_found", "device not found") + self._repo.delete_device_by_hostname(normalized_host) + diff --git a/Data/Engine/services/devices/device_view_service.py b/Data/Engine/services/devices/device_view_service.py new file mode 100644 index 0000000..fc4c70f --- /dev/null +++ b/Data/Engine/services/devices/device_view_service.py @@ -0,0 +1,73 @@ +"""Service exposing CRUD for saved device list views.""" + +from __future__ import annotations + +import logging +from typing import List, Optional + +from Data.Engine.domain.device_views import DeviceListView +from Data.Engine.repositories.sqlite.device_view_repository import SQLiteDeviceViewRepository + +__all__ = ["DeviceViewService"] + + +class DeviceViewService: + def __init__( + self, + repository: SQLiteDeviceViewRepository, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._repo = repository + self._log = logger or logging.getLogger("borealis.engine.services.device_views") + + def list_views(self) -> List[DeviceListView]: + return self._repo.list_views() + + def get_view(self, view_id: int) -> Optional[DeviceListView]: + return self._repo.get_view(view_id) + + def create_view(self, name: str, columns: List[str], filters: dict) -> DeviceListView: + normalized_name = (name or "").strip() + if not normalized_name: + raise ValueError("missing_name") + if normalized_name.lower() == "default view": + raise ValueError("reserved") + return self._repo.create_view(normalized_name, list(columns), dict(filters)) + + def update_view( + self, + view_id: int, + *, + name: Optional[str] = None, + columns: Optional[List[str]] = None, + filters: Optional[dict] = None, + ) -> DeviceListView: + updates: dict = {} + if name is not None: + normalized = (name or "").strip() + if not normalized: + raise ValueError("missing_name") + if normalized.lower() == "default view": + raise ValueError("reserved") + updates["name"] = normalized + if columns is not None: + if not isinstance(columns, list) or not all(isinstance(col, str) for col in columns): + raise ValueError("invalid_columns") + updates["columns"] = list(columns) + if filters is not None: + if not isinstance(filters, dict): + raise ValueError("invalid_filters") + updates["filters"] = dict(filters) + if not updates: + raise ValueError("no_fields") + return self._repo.update_view( + view_id, + name=updates.get("name"), + columns=updates.get("columns"), + filters=updates.get("filters"), + ) + + def delete_view(self, view_id: int) -> bool: + return self._repo.delete_view(view_id) + diff --git a/Data/Engine/services/sites/__init__.py b/Data/Engine/services/sites/__init__.py new file mode 100644 index 0000000..6285f25 --- /dev/null +++ b/Data/Engine/services/sites/__init__.py @@ -0,0 +1,3 @@ +from .site_service import SiteService + +__all__ = ["SiteService"] diff --git a/Data/Engine/services/sites/site_service.py b/Data/Engine/services/sites/site_service.py new file mode 100644 index 0000000..1694096 --- /dev/null +++ b/Data/Engine/services/sites/site_service.py @@ -0,0 +1,73 @@ +"""Site management service that mirrors the legacy Flask behaviour.""" + +from __future__ import annotations + +import logging +from typing import Dict, Iterable, List, Optional + +from Data.Engine.domain.sites import SiteDeviceMapping, SiteSummary +from Data.Engine.repositories.sqlite.site_repository import SQLiteSiteRepository + +__all__ = ["SiteService"] + + +class SiteService: + def __init__(self, repository: SQLiteSiteRepository, *, logger: Optional[logging.Logger] = None) -> None: + self._repo = repository + self._log = logger or logging.getLogger("borealis.engine.services.sites") + + def list_sites(self) -> List[SiteSummary]: + return self._repo.list_sites() + + def create_site(self, name: str, description: str) -> SiteSummary: + normalized_name = (name or "").strip() + normalized_description = (description or "").strip() + if not normalized_name: + raise ValueError("missing_name") + try: + return self._repo.create_site(normalized_name, normalized_description) + except ValueError as exc: + if str(exc) == "duplicate": + raise ValueError("duplicate") from exc + raise + + def delete_sites(self, ids: Iterable[int]) -> int: + normalized = [] + for value in ids: + try: + normalized.append(int(value)) + except Exception: + continue + if not normalized: + return 0 + return self._repo.delete_sites(tuple(normalized)) + + def rename_site(self, site_id: int, new_name: str) -> SiteSummary: + normalized_name = (new_name or "").strip() + if not normalized_name: + raise ValueError("missing_name") + try: + return self._repo.rename_site(int(site_id), normalized_name) + except ValueError as exc: + if str(exc) == "duplicate": + raise ValueError("duplicate") from exc + raise + + def map_devices(self, hostnames: Optional[Iterable[str]] = None) -> Dict[str, SiteDeviceMapping]: + return self._repo.map_devices(hostnames) + + def assign_devices(self, site_id: int, hostnames: Iterable[str]) -> None: + try: + numeric_id = int(site_id) + except Exception as exc: + raise ValueError("invalid_site_id") from exc + normalized = [hn for hn in hostnames if isinstance(hn, str) and hn.strip()] + if not normalized: + raise ValueError("invalid_hostnames") + try: + self._repo.assign_devices(numeric_id, normalized) + except LookupError as exc: + if str(exc) == "not_found": + raise LookupError("not_found") from exc + raise + diff --git a/Data/Engine/tests/test_http_sites_devices.py b/Data/Engine/tests/test_http_sites_devices.py new file mode 100644 index 0000000..486d82c --- /dev/null +++ b/Data/Engine/tests/test_http_sites_devices.py @@ -0,0 +1,108 @@ +import sqlite3 +from datetime import datetime, timezone + +import pytest + +pytest.importorskip("flask") + +from .test_http_auth import _login, prepared_app, engine_settings + + +def _ensure_admin_session(client): + _login(client) + + +def test_sites_crud_flow(prepared_app): + client = prepared_app.test_client() + _ensure_admin_session(client) + + resp = client.get("/api/sites") + assert resp.status_code == 200 + assert resp.get_json() == {"sites": []} + + create = client.post("/api/sites", json={"name": "HQ", "description": "Primary"}) + assert create.status_code == 201 + created = create.get_json() + assert created["name"] == "HQ" + + listing = client.get("/api/sites") + sites = listing.get_json()["sites"] + assert len(sites) == 1 + + resp = client.post("/api/sites/assign", json={"site_id": created["id"], "hostnames": ["device-1"]}) + assert resp.status_code == 200 + + mapping = client.get("/api/sites/device_map?hostnames=device-1") + data = mapping.get_json()["mapping"] + assert data["device-1"]["site_id"] == created["id"] + + rename = client.post("/api/sites/rename", json={"id": created["id"], "new_name": "Main"}) + assert rename.status_code == 200 + assert rename.get_json()["name"] == "Main" + + delete = client.post("/api/sites/delete", json={"ids": [created["id"]]}) + assert delete.status_code == 200 + assert delete.get_json()["deleted"] == 1 + + +def test_devices_listing(prepared_app, engine_settings): + client = prepared_app.test_client() + _ensure_admin_session(client) + + now = datetime.now(tz=timezone.utc) + conn = sqlite3.connect(engine_settings.database.path) + cur = conn.cursor() + cur.execute( + """ + INSERT INTO devices ( + guid, + hostname, + description, + created_at, + agent_hash, + last_seen, + connection_type, + connection_endpoint + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + "11111111-1111-1111-1111-111111111111", + "test-device", + "Test Device", + int(now.timestamp()), + "hashvalue", + int(now.timestamp()), + "", + "", + ), + ) + conn.commit() + conn.close() + + resp = client.get("/api/devices") + assert resp.status_code == 200 + devices = resp.get_json()["devices"] + assert any(device["hostname"] == "test-device" for device in devices) + + +def test_agent_hash_list_requires_local_request(prepared_app): + client = prepared_app.test_client() + _ensure_admin_session(client) + + resp = client.get("/api/agent/hash_list", environ_overrides={"REMOTE_ADDR": "203.0.113.5"}) + assert resp.status_code == 403 + + resp = client.get("/api/agent/hash_list", environ_overrides={"REMOTE_ADDR": "127.0.0.1"}) + assert resp.status_code == 200 + assert resp.get_json() == {"agents": []} + + +def test_credentials_list_requires_admin(prepared_app): + client = prepared_app.test_client() + resp = client.get("/api/credentials") + assert resp.status_code == 401 + + _ensure_admin_session(client) + resp = client.get("/api/credentials") + assert resp.status_code == 200 + assert resp.get_json() == {"credentials": []} From 82210408caaed9e359c19c57338efd11677d862a Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Thu, 23 Oct 2025 00:31:48 -0600 Subject: [PATCH 09/12] Add assembly endpoints and approval flows --- Data/Engine/interfaces/http/__init__.py | 4 + Data/Engine/interfaces/http/admin.py | 57 ++ Data/Engine/interfaces/http/assemblies.py | 182 +++++ Data/Engine/interfaces/http/server_info.py | 53 ++ Data/Engine/services/__init__.py | 20 + Data/Engine/services/assemblies/__init__.py | 10 + .../services/assemblies/assembly_service.py | 715 ++++++++++++++++++ Data/Engine/services/container.py | 8 + .../services/enrollment/admin_service.py | 134 +++- Data/Engine/tests/test_http_admin.py | 242 ++++++ Data/Engine/tests/test_http_assemblies.py | 86 +++ 11 files changed, 1510 insertions(+), 1 deletion(-) create mode 100644 Data/Engine/interfaces/http/assemblies.py create mode 100644 Data/Engine/interfaces/http/server_info.py create mode 100644 Data/Engine/services/assemblies/__init__.py create mode 100644 Data/Engine/services/assemblies/assembly_service.py create mode 100644 Data/Engine/tests/test_http_assemblies.py diff --git a/Data/Engine/interfaces/http/__init__.py b/Data/Engine/interfaces/http/__init__.py index 47d62fa..43bfc9c 100644 --- a/Data/Engine/interfaces/http/__init__.py +++ b/Data/Engine/interfaces/http/__init__.py @@ -19,6 +19,8 @@ from . import ( sites, devices, credentials, + assemblies, + server_info, ) _REGISTRARS = ( @@ -34,6 +36,8 @@ _REGISTRARS = ( sites.register, devices.register, credentials.register, + assemblies.register, + server_info.register, ) diff --git a/Data/Engine/interfaces/http/admin.py b/Data/Engine/interfaces/http/admin.py index 30d7fd9..8a52a68 100644 --- a/Data/Engine/interfaces/http/admin.py +++ b/Data/Engine/interfaces/http/admin.py @@ -113,4 +113,61 @@ def list_device_approvals() -> object: return jsonify({"approvals": [record.to_dict() for record in records]}) +@blueprint.route("/device-approvals//approve", methods=["POST"]) +def approve_device_approval(approval_id: str) -> object: + guard = _require_admin() + if guard: + return guard + + payload = request.get_json(silent=True) or {} + guid = payload.get("guid") + resolution_raw = payload.get("conflict_resolution") or payload.get("resolution") + resolution = resolution_raw.strip().lower() if isinstance(resolution_raw, str) else None + + actor = session.get("username") if isinstance(session.get("username"), str) else None + + try: + result = _admin_service().approve_device_approval( + approval_id, + actor=actor, + guid=guid, + conflict_resolution=resolution, + ) + except LookupError: + return jsonify({"error": "not_found"}), 404 + except ValueError as exc: + code = str(exc) + if code == "approval_not_pending": + return jsonify({"error": "approval_not_pending"}), 409 + if code == "conflict_resolution_required": + return jsonify({"error": "conflict_resolution_required"}), 409 + if code == "invalid_guid": + return jsonify({"error": "invalid_guid"}), 400 + raise + + response = jsonify(result.to_dict()) + response.status_code = 200 + return response + + +@blueprint.route("/device-approvals//deny", methods=["POST"]) +def deny_device_approval(approval_id: str) -> object: + guard = _require_admin() + if guard: + return guard + + actor = session.get("username") if isinstance(session.get("username"), str) else None + + try: + result = _admin_service().deny_device_approval(approval_id, actor=actor) + except LookupError: + return jsonify({"error": "not_found"}), 404 + except ValueError as exc: + if str(exc) == "approval_not_pending": + return jsonify({"error": "approval_not_pending"}), 409 + raise + + return jsonify(result.to_dict()) + + __all__ = ["register", "blueprint"] diff --git a/Data/Engine/interfaces/http/assemblies.py b/Data/Engine/interfaces/http/assemblies.py new file mode 100644 index 0000000..7108860 --- /dev/null +++ b/Data/Engine/interfaces/http/assemblies.py @@ -0,0 +1,182 @@ +"""HTTP endpoints for assembly management.""" + +from __future__ import annotations + +from flask import Blueprint, Flask, current_app, jsonify, request + +from Data.Engine.services.container import EngineServiceContainer + +blueprint = Blueprint("engine_assemblies", __name__) + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + if "engine_assemblies" not in app.blueprints: + app.register_blueprint(blueprint) + + +def _services() -> EngineServiceContainer: + services = current_app.extensions.get("engine_services") + if services is None: # pragma: no cover - defensive + raise RuntimeError("engine services not initialized") + return services + + +def _assembly_service(): + return _services().assembly_service + + +def _value_error_response(exc: ValueError): + code = str(exc) + if code == "invalid_island": + return jsonify({"error": "invalid island"}), 400 + if code == "path_required": + return jsonify({"error": "path required"}), 400 + if code == "invalid_kind": + return jsonify({"error": "invalid kind"}), 400 + if code == "invalid_destination": + return jsonify({"error": "invalid destination"}), 400 + if code == "invalid_path": + return jsonify({"error": "invalid path"}), 400 + if code == "cannot_delete_root": + return jsonify({"error": "cannot delete root"}), 400 + return jsonify({"error": code or "invalid request"}), 400 + + +def _not_found_response(exc: FileNotFoundError): + code = str(exc) + if code == "file_not_found": + return jsonify({"error": "file not found"}), 404 + if code == "folder_not_found": + return jsonify({"error": "folder not found"}), 404 + return jsonify({"error": "not found"}), 404 + + +@blueprint.route("/api/assembly/list", methods=["GET"]) +def list_assemblies() -> object: + island = (request.args.get("island") or "").strip() + try: + listing = _assembly_service().list_items(island) + except ValueError as exc: + return _value_error_response(exc) + return jsonify(listing.to_dict()) + + +@blueprint.route("/api/assembly/load", methods=["GET"]) +def load_assembly() -> object: + island = (request.args.get("island") or "").strip() + rel_path = (request.args.get("path") or "").strip() + try: + result = _assembly_service().load_item(island, rel_path) + except ValueError as exc: + return _value_error_response(exc) + except FileNotFoundError as exc: + return _not_found_response(exc) + return jsonify(result.to_dict()) + + +@blueprint.route("/api/assembly/create", methods=["POST"]) +def create_assembly() -> object: + payload = request.get_json(silent=True) or {} + island = (payload.get("island") or "").strip() + kind = (payload.get("kind") or "").strip().lower() + rel_path = (payload.get("path") or "").strip() + content = payload.get("content") + item_type = payload.get("type") + try: + result = _assembly_service().create_item( + island, + kind=kind, + rel_path=rel_path, + content=content, + item_type=item_type if isinstance(item_type, str) else None, + ) + except ValueError as exc: + return _value_error_response(exc) + return jsonify(result.to_dict()) + + +@blueprint.route("/api/assembly/edit", methods=["POST"]) +def edit_assembly() -> object: + payload = request.get_json(silent=True) or {} + island = (payload.get("island") or "").strip() + rel_path = (payload.get("path") or "").strip() + content = payload.get("content") + item_type = payload.get("type") + try: + result = _assembly_service().edit_item( + island, + rel_path=rel_path, + content=content, + item_type=item_type if isinstance(item_type, str) else None, + ) + except ValueError as exc: + return _value_error_response(exc) + except FileNotFoundError as exc: + return _not_found_response(exc) + return jsonify(result.to_dict()) + + +@blueprint.route("/api/assembly/rename", methods=["POST"]) +def rename_assembly() -> object: + payload = request.get_json(silent=True) or {} + island = (payload.get("island") or "").strip() + kind = (payload.get("kind") or "").strip().lower() + rel_path = (payload.get("path") or "").strip() + new_name = (payload.get("new_name") or "").strip() + item_type = payload.get("type") + try: + result = _assembly_service().rename_item( + island, + kind=kind, + rel_path=rel_path, + new_name=new_name, + item_type=item_type if isinstance(item_type, str) else None, + ) + except ValueError as exc: + return _value_error_response(exc) + except FileNotFoundError as exc: + return _not_found_response(exc) + return jsonify(result.to_dict()) + + +@blueprint.route("/api/assembly/move", methods=["POST"]) +def move_assembly() -> object: + payload = request.get_json(silent=True) or {} + island = (payload.get("island") or "").strip() + rel_path = (payload.get("path") or "").strip() + new_path = (payload.get("new_path") or "").strip() + kind = (payload.get("kind") or "").strip().lower() + try: + result = _assembly_service().move_item( + island, + rel_path=rel_path, + new_path=new_path, + kind=kind, + ) + except ValueError as exc: + return _value_error_response(exc) + except FileNotFoundError as exc: + return _not_found_response(exc) + return jsonify(result.to_dict()) + + +@blueprint.route("/api/assembly/delete", methods=["POST"]) +def delete_assembly() -> object: + payload = request.get_json(silent=True) or {} + island = (payload.get("island") or "").strip() + rel_path = (payload.get("path") or "").strip() + kind = (payload.get("kind") or "").strip().lower() + try: + result = _assembly_service().delete_item( + island, + rel_path=rel_path, + kind=kind, + ) + except ValueError as exc: + return _value_error_response(exc) + except FileNotFoundError as exc: + return _not_found_response(exc) + return jsonify(result.to_dict()) + + +__all__ = ["register", "blueprint"] diff --git a/Data/Engine/interfaces/http/server_info.py b/Data/Engine/interfaces/http/server_info.py new file mode 100644 index 0000000..840c53d --- /dev/null +++ b/Data/Engine/interfaces/http/server_info.py @@ -0,0 +1,53 @@ +"""Server metadata endpoints.""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from flask import Blueprint, Flask, jsonify + +from Data.Engine.services.container import EngineServiceContainer + +blueprint = Blueprint("engine_server_info", __name__) + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + if "engine_server_info" not in app.blueprints: + app.register_blueprint(blueprint) + + +@blueprint.route("/api/server/time", methods=["GET"]) +def server_time() -> object: + now_local = datetime.now().astimezone() + now_utc = datetime.now(timezone.utc) + tzinfo = now_local.tzinfo + offset = tzinfo.utcoffset(now_local) if tzinfo else None + + def _ordinal(n: int) -> str: + if 11 <= (n % 100) <= 13: + suffix = "th" + else: + suffix = {1: "st", 2: "nd", 3: "rd"}.get(n % 10, "th") + return f"{n}{suffix}" + + month = now_local.strftime("%B") + day_disp = _ordinal(now_local.day) + year = now_local.strftime("%Y") + hour24 = now_local.hour + hour12 = hour24 % 12 or 12 + minute = now_local.minute + ampm = "AM" if hour24 < 12 else "PM" + display = f"{month} {day_disp} {year} @ {hour12}:{minute:02d}{ampm}" + + payload = { + "epoch": int(now_local.timestamp()), + "iso": now_local.isoformat(), + "utc_iso": now_utc.isoformat().replace("+00:00", "Z"), + "timezone": str(tzinfo) if tzinfo else "", + "offset_seconds": int(offset.total_seconds()) if offset else 0, + "display": display, + } + return jsonify(payload) + + +__all__ = ["register", "blueprint"] diff --git a/Data/Engine/services/__init__.py b/Data/Engine/services/__init__.py index 22d8e14..8a78a85 100644 --- a/Data/Engine/services/__init__.py +++ b/Data/Engine/services/__init__.py @@ -28,6 +28,10 @@ __all__ = [ "DeviceInventoryService", "DeviceViewService", "CredentialService", + "AssemblyService", + "AssemblyListing", + "AssemblyLoadResult", + "AssemblyMutationResult", ] _LAZY_TARGETS: Dict[str, Tuple[str, str]] = { @@ -65,6 +69,22 @@ _LAZY_TARGETS: Dict[str, Tuple[str, str]] = { "Data.Engine.services.credentials.credential_service", "CredentialService", ), + "AssemblyService": ( + "Data.Engine.services.assemblies.assembly_service", + "AssemblyService", + ), + "AssemblyListing": ( + "Data.Engine.services.assemblies.assembly_service", + "AssemblyListing", + ), + "AssemblyLoadResult": ( + "Data.Engine.services.assemblies.assembly_service", + "AssemblyLoadResult", + ), + "AssemblyMutationResult": ( + "Data.Engine.services.assemblies.assembly_service", + "AssemblyMutationResult", + ), } diff --git a/Data/Engine/services/assemblies/__init__.py b/Data/Engine/services/assemblies/__init__.py new file mode 100644 index 0000000..a49adf0 --- /dev/null +++ b/Data/Engine/services/assemblies/__init__.py @@ -0,0 +1,10 @@ +"""Assembly management services.""" + +from .assembly_service import AssemblyService, AssemblyMutationResult, AssemblyLoadResult, AssemblyListing + +__all__ = [ + "AssemblyService", + "AssemblyMutationResult", + "AssemblyLoadResult", + "AssemblyListing", +] diff --git a/Data/Engine/services/assemblies/assembly_service.py b/Data/Engine/services/assemblies/assembly_service.py new file mode 100644 index 0000000..f5250f6 --- /dev/null +++ b/Data/Engine/services/assemblies/assembly_service.py @@ -0,0 +1,715 @@ +"""Filesystem-backed assembly management service.""" + +from __future__ import annotations + +import base64 +import json +import logging +import os +import re +import shutil +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +__all__ = [ + "AssemblyService", + "AssemblyListing", + "AssemblyLoadResult", + "AssemblyMutationResult", +] + + +@dataclass(frozen=True, slots=True) +class AssemblyListing: + """Listing payload for an assembly island.""" + + root: Path + items: List[Dict[str, Any]] + folders: List[str] + + def to_dict(self) -> dict[str, Any]: + return { + "root": str(self.root), + "items": self.items, + "folders": self.folders, + } + + +@dataclass(frozen=True, slots=True) +class AssemblyLoadResult: + """Container describing a loaded assembly artifact.""" + + payload: Dict[str, Any] + + def to_dict(self) -> Dict[str, Any]: + return dict(self.payload) + + +@dataclass(frozen=True, slots=True) +class AssemblyMutationResult: + """Mutation acknowledgement for create/edit/rename operations.""" + + status: str = "ok" + rel_path: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"status": self.status} + if self.rel_path: + payload["rel_path"] = self.rel_path + return payload + + +class AssemblyService: + """Provide CRUD helpers for workflow/script/ansible assemblies.""" + + _ISLAND_DIR_MAP = { + "workflows": "Workflows", + "workflow": "Workflows", + "scripts": "Scripts", + "script": "Scripts", + "ansible": "Ansible_Playbooks", + "ansible_playbooks": "Ansible_Playbooks", + "ansible-playbooks": "Ansible_Playbooks", + "playbooks": "Ansible_Playbooks", + } + + _SCRIPT_EXTENSIONS = (".json", ".ps1", ".bat", ".sh") + _ANSIBLE_EXTENSIONS = (".json", ".yml") + + def __init__(self, *, root: Path, logger: Optional[logging.Logger] = None) -> None: + self._root = root.resolve() + self._log = logger or logging.getLogger("borealis.engine.services.assemblies") + try: + self._root.mkdir(parents=True, exist_ok=True) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("failed to ensure assemblies root %s: %s", self._root, exc) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def list_items(self, island: str) -> AssemblyListing: + root = self._resolve_island_root(island) + root.mkdir(parents=True, exist_ok=True) + + items: List[Dict[str, Any]] = [] + folders: List[str] = [] + + isl = (island or "").strip().lower() + if isl in {"workflows", "workflow"}: + for dirpath, dirnames, filenames in os.walk(root): + rel_root = os.path.relpath(dirpath, root) + if rel_root != ".": + folders.append(rel_root.replace(os.sep, "/")) + for fname in filenames: + if not fname.lower().endswith(".json"): + continue + abs_path = Path(dirpath) / fname + rel_path = abs_path.relative_to(root).as_posix() + try: + mtime = abs_path.stat().st_mtime + except OSError: + mtime = 0.0 + obj = self._safe_read_json(abs_path) + tab = self._extract_tab_name(obj) + items.append( + { + "file_name": fname, + "rel_path": rel_path, + "type": "workflow", + "tab_name": tab, + "last_edited": time.strftime( + "%Y-%m-%dT%H:%M:%S", time.localtime(mtime) + ), + "last_edited_epoch": mtime, + } + ) + elif isl in {"scripts", "script"}: + for dirpath, dirnames, filenames in os.walk(root): + rel_root = os.path.relpath(dirpath, root) + if rel_root != ".": + folders.append(rel_root.replace(os.sep, "/")) + for fname in filenames: + if not fname.lower().endswith(self._SCRIPT_EXTENSIONS): + continue + abs_path = Path(dirpath) / fname + rel_path = abs_path.relative_to(root).as_posix() + try: + mtime = abs_path.stat().st_mtime + except OSError: + mtime = 0.0 + script_type = self._detect_script_type(abs_path) + doc = self._load_assembly_document(abs_path, "scripts", script_type) + items.append( + { + "file_name": fname, + "rel_path": rel_path, + "type": doc.get("type", script_type), + "name": doc.get("name"), + "category": doc.get("category"), + "description": doc.get("description"), + "last_edited": time.strftime( + "%Y-%m-%dT%H:%M:%S", time.localtime(mtime) + ), + "last_edited_epoch": mtime, + } + ) + elif isl in { + "ansible", + "ansible_playbooks", + "ansible-playbooks", + "playbooks", + }: + for dirpath, dirnames, filenames in os.walk(root): + rel_root = os.path.relpath(dirpath, root) + if rel_root != ".": + folders.append(rel_root.replace(os.sep, "/")) + for fname in filenames: + if not fname.lower().endswith(self._ANSIBLE_EXTENSIONS): + continue + abs_path = Path(dirpath) / fname + rel_path = abs_path.relative_to(root).as_posix() + try: + mtime = abs_path.stat().st_mtime + except OSError: + mtime = 0.0 + script_type = self._detect_script_type(abs_path) + doc = self._load_assembly_document(abs_path, "ansible", script_type) + items.append( + { + "file_name": fname, + "rel_path": rel_path, + "type": doc.get("type", "ansible"), + "name": doc.get("name"), + "category": doc.get("category"), + "description": doc.get("description"), + "last_edited": time.strftime( + "%Y-%m-%dT%H:%M:%S", time.localtime(mtime) + ), + "last_edited_epoch": mtime, + } + ) + else: + raise ValueError("invalid_island") + + items.sort(key=lambda entry: entry.get("last_edited_epoch", 0.0), reverse=True) + return AssemblyListing(root=root, items=items, folders=folders) + + def load_item(self, island: str, rel_path: str) -> AssemblyLoadResult: + root, abs_path, _ = self._resolve_assembly_path(island, rel_path) + if not abs_path.is_file(): + raise FileNotFoundError("file_not_found") + + isl = (island or "").strip().lower() + if isl in {"workflows", "workflow"}: + payload = self._safe_read_json(abs_path) + return AssemblyLoadResult(payload=payload) + + doc = self._load_assembly_document(abs_path, island) + rel = abs_path.relative_to(root).as_posix() + payload = { + "file_name": abs_path.name, + "rel_path": rel, + "type": doc.get("type"), + "assembly": doc, + "content": doc.get("script"), + } + return AssemblyLoadResult(payload=payload) + + def create_item( + self, + island: str, + *, + kind: str, + rel_path: str, + content: Any, + item_type: Optional[str] = None, + ) -> AssemblyMutationResult: + root, abs_path, rel_norm = self._resolve_assembly_path(island, rel_path) + if not rel_norm: + raise ValueError("path_required") + + normalized_kind = (kind or "").strip().lower() + if normalized_kind == "folder": + abs_path.mkdir(parents=True, exist_ok=True) + return AssemblyMutationResult() + if normalized_kind != "file": + raise ValueError("invalid_kind") + + target_path = abs_path + if not target_path.suffix: + target_path = target_path.with_suffix( + self._default_ext_for_island(island, item_type or "") + ) + target_path.parent.mkdir(parents=True, exist_ok=True) + + isl = (island or "").strip().lower() + if isl in {"workflows", "workflow"}: + payload = self._ensure_workflow_document(content) + base_name = target_path.stem + payload.setdefault("tab_name", base_name) + self._write_json(target_path, payload) + else: + document = self._normalize_assembly_document( + content, + self._default_type_for_island(island, item_type or ""), + target_path.stem, + ) + self._write_json(target_path, self._prepare_assembly_storage(document)) + + rel_new = target_path.relative_to(root).as_posix() + return AssemblyMutationResult(rel_path=rel_new) + + def edit_item( + self, + island: str, + *, + rel_path: str, + content: Any, + item_type: Optional[str] = None, + ) -> AssemblyMutationResult: + root, abs_path, _ = self._resolve_assembly_path(island, rel_path) + if not abs_path.exists(): + raise FileNotFoundError("file_not_found") + + target_path = abs_path + if not target_path.suffix: + target_path = target_path.with_suffix( + self._default_ext_for_island(island, item_type or "") + ) + + isl = (island or "").strip().lower() + if isl in {"workflows", "workflow"}: + payload = self._ensure_workflow_document(content) + self._write_json(target_path, payload) + else: + document = self._normalize_assembly_document( + content, + self._default_type_for_island(island, item_type or ""), + target_path.stem, + ) + self._write_json(target_path, self._prepare_assembly_storage(document)) + + if target_path != abs_path and abs_path.exists(): + try: + abs_path.unlink() + except OSError: # pragma: no cover - best effort cleanup + pass + + rel_new = target_path.relative_to(root).as_posix() + return AssemblyMutationResult(rel_path=rel_new) + + def rename_item( + self, + island: str, + *, + kind: str, + rel_path: str, + new_name: str, + item_type: Optional[str] = None, + ) -> AssemblyMutationResult: + root, old_path, _ = self._resolve_assembly_path(island, rel_path) + + normalized_kind = (kind or "").strip().lower() + if normalized_kind not in {"file", "folder"}: + raise ValueError("invalid_kind") + + if normalized_kind == "folder": + if not old_path.is_dir(): + raise FileNotFoundError("folder_not_found") + destination = old_path.parent / new_name + else: + if not old_path.is_file(): + raise FileNotFoundError("file_not_found") + candidate = Path(new_name) + if not candidate.suffix: + candidate = candidate.with_suffix( + self._default_ext_for_island(island, item_type or "") + ) + destination = old_path.parent / candidate.name + + destination = destination.resolve() + if not str(destination).startswith(str(root)): + raise ValueError("invalid_destination") + + old_path.rename(destination) + + isl = (island or "").strip().lower() + if normalized_kind == "file" and isl in {"workflows", "workflow"}: + try: + obj = self._safe_read_json(destination) + base_name = destination.stem + for key in ["tabName", "tab_name", "name", "title"]: + if key in obj: + obj[key] = base_name + obj.setdefault("tab_name", base_name) + self._write_json(destination, obj) + except Exception: # pragma: no cover - best effort update + self._log.debug("failed to normalize workflow metadata for %s", destination) + + rel_new = destination.relative_to(root).as_posix() + return AssemblyMutationResult(rel_path=rel_new) + + def move_item( + self, + island: str, + *, + rel_path: str, + new_path: str, + kind: Optional[str] = None, + ) -> AssemblyMutationResult: + root, old_path, _ = self._resolve_assembly_path(island, rel_path) + _, dest_path, _ = self._resolve_assembly_path(island, new_path) + + normalized_kind = (kind or "").strip().lower() + if normalized_kind == "folder": + if not old_path.is_dir(): + raise FileNotFoundError("folder_not_found") + else: + if not old_path.exists(): + raise FileNotFoundError("file_not_found") + + dest_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(old_path), str(dest_path)) + return AssemblyMutationResult() + + def delete_item( + self, + island: str, + *, + rel_path: str, + kind: str, + ) -> AssemblyMutationResult: + _, abs_path, rel_norm = self._resolve_assembly_path(island, rel_path) + if not rel_norm: + raise ValueError("cannot_delete_root") + + normalized_kind = (kind or "").strip().lower() + if normalized_kind == "folder": + if not abs_path.is_dir(): + raise FileNotFoundError("folder_not_found") + shutil.rmtree(abs_path) + elif normalized_kind == "file": + if not abs_path.is_file(): + raise FileNotFoundError("file_not_found") + abs_path.unlink() + else: + raise ValueError("invalid_kind") + + return AssemblyMutationResult() + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _resolve_island_root(self, island: str) -> Path: + key = (island or "").strip().lower() + subdir = self._ISLAND_DIR_MAP.get(key) + if not subdir: + raise ValueError("invalid_island") + root = (self._root / subdir).resolve() + root.mkdir(parents=True, exist_ok=True) + return root + + def _resolve_assembly_path(self, island: str, rel_path: str) -> Tuple[Path, Path, str]: + root = self._resolve_island_root(island) + rel_norm = self._normalize_relpath(rel_path) + abs_path = (root / rel_norm).resolve() + if not str(abs_path).startswith(str(root)): + raise ValueError("invalid_path") + return root, abs_path, rel_norm + + @staticmethod + def _normalize_relpath(value: str) -> str: + return (value or "").replace("\\", "/").strip("/") + + @staticmethod + def _default_ext_for_island(island: str, item_type: str) -> str: + isl = (island or "").strip().lower() + if isl in {"workflows", "workflow"}: + return ".json" + if isl in {"ansible", "ansible_playbooks", "ansible-playbooks", "playbooks"}: + return ".json" + if isl in {"scripts", "script"}: + return ".json" + typ = (item_type or "").strip().lower() + if typ in {"bash", "batch", "powershell"}: + return ".json" + return ".json" + + @staticmethod + def _default_type_for_island(island: str, item_type: str) -> str: + isl = (island or "").strip().lower() + if isl in {"ansible", "ansible_playbooks", "ansible-playbooks", "playbooks"}: + return "ansible" + typ = (item_type or "").strip().lower() + if typ in {"powershell", "batch", "bash", "ansible"}: + return typ + return "powershell" + + @staticmethod + def _empty_assembly_document(default_type: str) -> Dict[str, Any]: + return { + "version": 1, + "name": "", + "description": "", + "category": "application" if default_type.lower() == "ansible" else "script", + "type": default_type or "powershell", + "script": "", + "timeout_seconds": 3600, + "sites": {"mode": "all", "values": []}, + "variables": [], + "files": [], + } + + @staticmethod + def _decode_base64_text(value: Any) -> Optional[str]: + if not isinstance(value, str): + return None + stripped = value.strip() + if not stripped: + return "" + try: + cleaned = re.sub(r"\s+", "", stripped) + except Exception: + cleaned = stripped + try: + decoded = base64.b64decode(cleaned, validate=True) + except Exception: + return None + try: + return decoded.decode("utf-8") + except Exception: + return decoded.decode("utf-8", errors="replace") + + def _decode_script_content(self, value: Any, encoding_hint: str = "") -> str: + encoding = (encoding_hint or "").strip().lower() + if isinstance(value, str): + if encoding in {"base64", "b64", "base-64"}: + decoded = self._decode_base64_text(value) + if decoded is not None: + return decoded.replace("\r\n", "\n") + decoded = self._decode_base64_text(value) + if decoded is not None: + return decoded.replace("\r\n", "\n") + return value.replace("\r\n", "\n") + return "" + + @staticmethod + def _encode_script_content(script_text: Any) -> str: + if not isinstance(script_text, str): + if script_text is None: + script_text = "" + else: + script_text = str(script_text) + normalized = script_text.replace("\r\n", "\n") + if not normalized: + return "" + encoded = base64.b64encode(normalized.encode("utf-8")) + return encoded.decode("ascii") + + def _prepare_assembly_storage(self, document: Dict[str, Any]) -> Dict[str, Any]: + stored: Dict[str, Any] = {} + for key, value in (document or {}).items(): + if key == "script": + stored[key] = self._encode_script_content(value) + else: + stored[key] = value + stored["script_encoding"] = "base64" + return stored + + def _normalize_assembly_document( + self, + obj: Any, + default_type: str, + base_name: str, + ) -> Dict[str, Any]: + doc = self._empty_assembly_document(default_type) + if not isinstance(obj, dict): + obj = {} + base = (base_name or "assembly").strip() + doc["name"] = str(obj.get("name") or obj.get("display_name") or base) + doc["description"] = str(obj.get("description") or "") + category = str(obj.get("category") or doc["category"]).strip().lower() + if category in {"script", "application"}: + doc["category"] = category + typ = str(obj.get("type") or obj.get("script_type") or default_type or "powershell").strip().lower() + if typ in {"powershell", "batch", "bash", "ansible"}: + doc["type"] = typ + script_val = obj.get("script") + content_val = obj.get("content") + script_lines = obj.get("script_lines") + if isinstance(script_lines, list): + try: + doc["script"] = "\n".join(str(line) for line in script_lines) + except Exception: + doc["script"] = "" + elif isinstance(script_val, str): + doc["script"] = script_val + elif isinstance(content_val, str): + doc["script"] = content_val + encoding_hint = str( + obj.get("script_encoding") or obj.get("scriptEncoding") or "" + ).strip().lower() + doc["script"] = self._decode_script_content(doc.get("script"), encoding_hint) + if encoding_hint in {"base64", "b64", "base-64"}: + doc["script_encoding"] = "base64" + else: + probe_source = "" + if isinstance(script_val, str) and script_val: + probe_source = script_val + elif isinstance(content_val, str) and content_val: + probe_source = content_val + decoded_probe = self._decode_base64_text(probe_source) if probe_source else None + if decoded_probe is not None: + doc["script_encoding"] = "base64" + doc["script"] = decoded_probe.replace("\r\n", "\n") + else: + doc["script_encoding"] = "plain" + timeout_val = obj.get("timeout_seconds", obj.get("timeout")) + if timeout_val is not None: + try: + doc["timeout_seconds"] = max(0, int(timeout_val)) + except Exception: + pass + sites = obj.get("sites") if isinstance(obj.get("sites"), dict) else {} + values = sites.get("values") if isinstance(sites.get("values"), list) else [] + mode = str(sites.get("mode") or ("specific" if values else "all")).strip().lower() + if mode not in {"all", "specific"}: + mode = "all" + doc["sites"] = { + "mode": mode, + "values": [ + str(v).strip() + for v in values + if isinstance(v, (str, int, float)) and str(v).strip() + ], + } + vars_in = obj.get("variables") if isinstance(obj.get("variables"), list) else [] + doc_vars: List[Dict[str, Any]] = [] + for entry in vars_in: + if not isinstance(entry, dict): + continue + name = str(entry.get("name") or entry.get("key") or "").strip() + if not name: + continue + vtype = str(entry.get("type") or "string").strip().lower() + if vtype not in {"string", "number", "boolean", "credential"}: + vtype = "string" + default_val = entry.get("default", entry.get("default_value")) + doc_vars.append( + { + "name": name, + "label": str(entry.get("label") or ""), + "type": vtype, + "default": default_val, + "required": bool(entry.get("required")), + "description": str(entry.get("description") or ""), + } + ) + doc["variables"] = doc_vars + files_in = obj.get("files") if isinstance(obj.get("files"), list) else [] + doc_files: List[Dict[str, Any]] = [] + for record in files_in: + if not isinstance(record, dict): + continue + fname = record.get("file_name") or record.get("name") + data = record.get("data") + if not fname or not isinstance(data, str): + continue + size_val = record.get("size") + try: + size_int = int(size_val) + except Exception: + size_int = 0 + doc_files.append( + { + "file_name": str(fname), + "size": size_int, + "mime_type": str(record.get("mime_type") or record.get("mimeType") or ""), + "data": data, + } + ) + doc["files"] = doc_files + try: + doc["version"] = int(obj.get("version") or doc["version"]) + except Exception: + pass + return doc + + def _load_assembly_document( + self, + abs_path: Path, + island: str, + type_hint: str = "", + ) -> Dict[str, Any]: + base_name = abs_path.stem + default_type = self._default_type_for_island(island, type_hint) + if abs_path.suffix.lower() == ".json": + data = self._safe_read_json(abs_path) + return self._normalize_assembly_document(data, default_type, base_name) + try: + content = abs_path.read_text(encoding="utf-8", errors="replace") + except Exception: + content = "" + document = self._empty_assembly_document(default_type) + document["name"] = base_name + document["script"] = (content or "").replace("\r\n", "\n") + if default_type == "ansible": + document["category"] = "application" + return document + + @staticmethod + def _safe_read_json(path: Path) -> Dict[str, Any]: + try: + return json.loads(path.read_text(encoding="utf-8")) + except Exception: + return {} + + @staticmethod + def _extract_tab_name(obj: Dict[str, Any]) -> str: + if not isinstance(obj, dict): + return "" + for key in ["tabName", "tab_name", "name", "title"]: + value = obj.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return "" + + def _detect_script_type(self, path: Path) -> str: + lower = path.name.lower() + if lower.endswith(".json") and path.is_file(): + obj = self._safe_read_json(path) + if isinstance(obj, dict): + typ = str( + obj.get("type") or obj.get("script_type") or "" + ).strip().lower() + if typ in {"powershell", "batch", "bash", "ansible"}: + return typ + return "powershell" + if lower.endswith(".yml"): + return "ansible" + if lower.endswith(".ps1"): + return "powershell" + if lower.endswith(".bat"): + return "batch" + if lower.endswith(".sh"): + return "bash" + return "unknown" + + @staticmethod + def _ensure_workflow_document(content: Any) -> Dict[str, Any]: + payload = content + if isinstance(payload, str): + try: + payload = json.loads(payload) + except Exception: + payload = {} + if not isinstance(payload, dict): + payload = {} + return payload + + @staticmethod + def _write_json(path: Path, payload: Dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2), encoding="utf-8") diff --git a/Data/Engine/services/container.py b/Data/Engine/services/container.py index a544b7c..4b8df94 100644 --- a/Data/Engine/services/container.py +++ b/Data/Engine/services/container.py @@ -44,6 +44,7 @@ from Data.Engine.services.jobs import SchedulerService from Data.Engine.services.rate_limit import SlidingWindowRateLimiter from Data.Engine.services.realtime import AgentRealtimeService from Data.Engine.services.sites import SiteService +from Data.Engine.services.assemblies import AssemblyService __all__ = ["EngineServiceContainer", "build_service_container"] @@ -65,6 +66,7 @@ class EngineServiceContainer: github_service: GitHubService operator_auth_service: OperatorAuthService operator_account_service: OperatorAccountService + assembly_service: AssemblyService def build_service_container( @@ -167,6 +169,11 @@ def build_service_container( logger=log.getChild("sites"), ) + assembly_service = AssemblyService( + root=settings.project_root / "Assemblies", + logger=log.getChild("assemblies"), + ) + github_provider = GitHubArtifactProvider( cache_file=settings.github.cache_file, default_repo=settings.github.default_repo, @@ -197,6 +204,7 @@ def build_service_container( device_view_service=device_view_service, credential_service=credential_service, site_service=site_service, + assembly_service=assembly_service, ) diff --git a/Data/Engine/services/enrollment/admin_service.py b/Data/Engine/services/enrollment/admin_service.py index de8193f..b506205 100644 --- a/Data/Engine/services/enrollment/admin_service.py +++ b/Data/Engine/services/enrollment/admin_service.py @@ -8,11 +8,29 @@ import uuid from datetime import datetime, timedelta, timezone from typing import Callable, List, Optional +from dataclasses import dataclass + +from Data.Engine.domain.device_auth import DeviceGuid, normalize_guid +from Data.Engine.domain.device_enrollment import EnrollmentApprovalStatus from Data.Engine.domain.enrollment_admin import DeviceApprovalRecord, EnrollmentCodeRecord from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository -__all__ = ["EnrollmentAdminService"] +__all__ = ["EnrollmentAdminService", "DeviceApprovalActionResult"] + + +@dataclass(frozen=True, slots=True) +class DeviceApprovalActionResult: + """Outcome metadata returned after mutating an approval.""" + + status: str + conflict_resolution: Optional[str] = None + + def to_dict(self) -> dict[str, str]: + payload = {"status": self.status} + if self.conflict_resolution: + payload["conflict_resolution"] = self.conflict_resolution + return payload class EnrollmentAdminService: @@ -91,6 +109,36 @@ class EnrollmentAdminService: def list_device_approvals(self, *, status: Optional[str] = None) -> List[DeviceApprovalRecord]: return self._repository.list_device_approvals(status=status) + def approve_device_approval( + self, + record_id: str, + *, + actor: Optional[str], + guid: Optional[str] = None, + conflict_resolution: Optional[str] = None, + ) -> DeviceApprovalActionResult: + return self._set_device_approval_status( + record_id, + EnrollmentApprovalStatus.APPROVED, + actor=actor, + guid=guid, + conflict_resolution=conflict_resolution, + ) + + def deny_device_approval( + self, + record_id: str, + *, + actor: Optional[str], + ) -> DeviceApprovalActionResult: + return self._set_device_approval_status( + record_id, + EnrollmentApprovalStatus.DENIED, + actor=actor, + guid=None, + conflict_resolution=None, + ) + # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ @@ -111,3 +159,87 @@ class EnrollmentAdminService: return 10 return count + def _set_device_approval_status( + self, + record_id: str, + status: EnrollmentApprovalStatus, + *, + actor: Optional[str], + guid: Optional[str], + conflict_resolution: Optional[str], + ) -> DeviceApprovalActionResult: + approval = self._repository.fetch_device_approval(record_id) + if approval is None: + raise LookupError("not_found") + + if approval.status is not EnrollmentApprovalStatus.PENDING: + raise ValueError("approval_not_pending") + + normalized_guid = normalize_guid(guid) or (approval.guid.value if approval.guid else "") + resolution_normalized = (conflict_resolution or "").strip().lower() or None + + fingerprint_match = False + conflict_guid: Optional[str] = None + + if status is EnrollmentApprovalStatus.APPROVED: + pending_records = self._repository.list_device_approvals(status="pending") + current_record = next( + (record for record in pending_records if record.record_id == approval.record_id), + None, + ) + + conflict = current_record.hostname_conflict if current_record else None + if conflict: + conflict_guid = normalize_guid(conflict.guid) + fingerprint_match = bool(conflict.fingerprint_match) + + if fingerprint_match: + normalized_guid = conflict_guid or normalized_guid or "" + if resolution_normalized is None: + resolution_normalized = "auto_merge_fingerprint" + elif resolution_normalized == "overwrite": + normalized_guid = conflict_guid or normalized_guid or "" + elif resolution_normalized == "coexist": + pass + else: + raise ValueError("conflict_resolution_required") + + if normalized_guid: + try: + guid_value = DeviceGuid(normalized_guid) + except ValueError as exc: + raise ValueError("invalid_guid") from exc + else: + guid_value = None + + actor_identifier = None + if actor: + actor_identifier = self._users.resolve_identifier(actor) + if not actor_identifier: + actor_identifier = actor.strip() or None + if not actor_identifier: + actor_identifier = "system" + + self._repository.update_device_approval_status( + approval.record_id, + status=status, + updated_at=self._clock(), + approved_by=actor_identifier, + guid=guid_value, + ) + + if status is EnrollmentApprovalStatus.APPROVED: + self._log.info( + "device approval %s approved resolution=%s guid=%s", + approval.record_id, + resolution_normalized or "", + guid_value.value if guid_value else normalized_guid or "", + ) + else: + self._log.info("device approval %s denied", approval.record_id) + + return DeviceApprovalActionResult( + status=status.value, + conflict_resolution=resolution_normalized, + ) + diff --git a/Data/Engine/tests/test_http_admin.py b/Data/Engine/tests/test_http_admin.py index f3e0cc4..aea1b61 100644 --- a/Data/Engine/tests/test_http_admin.py +++ b/Data/Engine/tests/test_http_admin.py @@ -109,3 +109,245 @@ def test_device_approvals_listing(prepared_app, engine_settings): record = next(a for a in approvals if a["id"] == "approval-http") assert record.get("hostname_conflict", {}).get("fingerprint_match") is True + +def test_device_approval_requires_resolution(prepared_app, engine_settings): + client = prepared_app.test_client() + _login(client) + + now = datetime.now(tz=timezone.utc) + conn = sqlite3.connect(engine_settings.database.path) + cur = conn.cursor() + + cur.execute( + """ + INSERT INTO devices ( + guid, + hostname, + created_at, + last_seen, + ssl_key_fingerprint, + status + ) VALUES (?, ?, ?, ?, ?, 'active') + """, + ( + "33333333-3333-3333-3333-333333333333", + "conflict-host", + int(now.timestamp()), + int(now.timestamp()), + "existingfp", + ), + ) + + now_iso = now.isoformat() + cur.execute( + """ + INSERT INTO device_approvals ( + id, + approval_reference, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed, + enrollment_code_id, + status, + client_nonce, + server_nonce, + created_at, + updated_at, + approved_by_user_id, + agent_pubkey_der + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + "approval-conflict", + "REF-CONFLICT", + None, + "conflict-host", + "newfinger", + "code-conflict", + "pending", + base64.b64encode(b"client").decode(), + base64.b64encode(b"server").decode(), + now_iso, + now_iso, + None, + b"pub", + ), + ) + conn.commit() + conn.close() + + resp = client.post("/api/admin/device-approvals/approval-conflict/approve", json={}) + assert resp.status_code == 409 + assert resp.get_json().get("error") == "conflict_resolution_required" + + resp = client.post( + "/api/admin/device-approvals/approval-conflict/approve", + json={"conflict_resolution": "overwrite"}, + ) + assert resp.status_code == 200 + body = resp.get_json() + assert body == {"status": "approved", "conflict_resolution": "overwrite"} + + conn = sqlite3.connect(engine_settings.database.path) + cur = conn.cursor() + cur.execute( + "SELECT status, guid, approved_by_user_id FROM device_approvals WHERE id = ?", + ("approval-conflict",), + ) + row = cur.fetchone() + conn.close() + assert row[0] == "approved" + assert row[1] == "33333333-3333-3333-3333-333333333333" + assert row[2] + + resp = client.post( + "/api/admin/device-approvals/approval-conflict/approve", + json={"conflict_resolution": "overwrite"}, + ) + assert resp.status_code == 409 + assert resp.get_json().get("error") == "approval_not_pending" + + +def test_device_approval_auto_merge(prepared_app, engine_settings): + client = prepared_app.test_client() + _login(client) + + now = datetime.now(tz=timezone.utc) + conn = sqlite3.connect(engine_settings.database.path) + cur = conn.cursor() + + cur.execute( + """ + INSERT INTO devices ( + guid, + hostname, + created_at, + last_seen, + ssl_key_fingerprint, + status + ) VALUES (?, ?, ?, ?, ?, 'active') + """, + ( + "44444444-4444-4444-4444-444444444444", + "merge-host", + int(now.timestamp()), + int(now.timestamp()), + "deadbeef", + ), + ) + + now_iso = now.isoformat() + cur.execute( + """ + INSERT INTO device_approvals ( + id, + approval_reference, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed, + enrollment_code_id, + status, + client_nonce, + server_nonce, + created_at, + updated_at, + approved_by_user_id, + agent_pubkey_der + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + "approval-merge", + "REF-MERGE", + None, + "merge-host", + "deadbeef", + "code-merge", + "pending", + base64.b64encode(b"client").decode(), + base64.b64encode(b"server").decode(), + now_iso, + now_iso, + None, + b"pub", + ), + ) + conn.commit() + conn.close() + + resp = client.post("/api/admin/device-approvals/approval-merge/approve", json={}) + assert resp.status_code == 200 + body = resp.get_json() + assert body.get("status") == "approved" + assert body.get("conflict_resolution") == "auto_merge_fingerprint" + + conn = sqlite3.connect(engine_settings.database.path) + cur = conn.cursor() + cur.execute( + "SELECT guid, status FROM device_approvals WHERE id = ?", + ("approval-merge",), + ) + row = cur.fetchone() + conn.close() + assert row[1] == "approved" + assert row[0] == "44444444-4444-4444-4444-444444444444" + + +def test_device_approval_deny(prepared_app, engine_settings): + client = prepared_app.test_client() + _login(client) + + now = datetime.now(tz=timezone.utc) + conn = sqlite3.connect(engine_settings.database.path) + cur = conn.cursor() + + now_iso = now.isoformat() + cur.execute( + """ + INSERT INTO device_approvals ( + id, + approval_reference, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed, + enrollment_code_id, + status, + client_nonce, + server_nonce, + created_at, + updated_at, + approved_by_user_id, + agent_pubkey_der + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + "approval-deny", + "REF-DENY", + None, + "deny-host", + "cafebabe", + "code-deny", + "pending", + base64.b64encode(b"client").decode(), + base64.b64encode(b"server").decode(), + now_iso, + now_iso, + None, + b"pub", + ), + ) + conn.commit() + conn.close() + + resp = client.post("/api/admin/device-approvals/approval-deny/deny", json={}) + assert resp.status_code == 200 + assert resp.get_json() == {"status": "denied"} + + conn = sqlite3.connect(engine_settings.database.path) + cur = conn.cursor() + cur.execute( + "SELECT status FROM device_approvals WHERE id = ?", + ("approval-deny",), + ) + row = cur.fetchone() + conn.close() + assert row[0] == "denied" diff --git a/Data/Engine/tests/test_http_assemblies.py b/Data/Engine/tests/test_http_assemblies.py new file mode 100644 index 0000000..81fd57e --- /dev/null +++ b/Data/Engine/tests/test_http_assemblies.py @@ -0,0 +1,86 @@ +import pytest + +pytest.importorskip("flask") + +from .test_http_auth import _login, prepared_app + + +def test_assembly_crud_flow(prepared_app, engine_settings): + client = prepared_app.test_client() + _login(client) + + resp = client.post( + "/api/assembly/create", + json={"island": "scripts", "kind": "folder", "path": "Utilities"}, + ) + assert resp.status_code == 200 + + resp = client.post( + "/api/assembly/create", + json={ + "island": "scripts", + "kind": "file", + "path": "Utilities/sample", + "content": {"name": "Sample", "script": "Write-Output 'Hello'", "type": "powershell"}, + }, + ) + assert resp.status_code == 200 + body = resp.get_json() + rel_path = body.get("rel_path") + assert rel_path and rel_path.endswith(".json") + + resp = client.get("/api/assembly/list?island=scripts") + assert resp.status_code == 200 + listing = resp.get_json() + assert any(item["rel_path"] == rel_path for item in listing.get("items", [])) + + resp = client.get(f"/api/assembly/load?island=scripts&path={rel_path}") + assert resp.status_code == 200 + loaded = resp.get_json() + assert loaded.get("assembly", {}).get("name") == "Sample" + + resp = client.post( + "/api/assembly/rename", + json={ + "island": "scripts", + "kind": "file", + "path": rel_path, + "new_name": "renamed", + }, + ) + assert resp.status_code == 200 + renamed_rel = resp.get_json().get("rel_path") + assert renamed_rel and renamed_rel.endswith(".json") + + resp = client.post( + "/api/assembly/move", + json={ + "island": "scripts", + "path": renamed_rel, + "new_path": "Utilities/Nested/renamed.json", + "kind": "file", + }, + ) + assert resp.status_code == 200 + + resp = client.post( + "/api/assembly/delete", + json={ + "island": "scripts", + "path": "Utilities/Nested/renamed.json", + "kind": "file", + }, + ) + assert resp.status_code == 200 + + resp = client.get("/api/assembly/list?island=scripts") + remaining = resp.get_json().get("items", []) + assert all(item["rel_path"] != "Utilities/Nested/renamed.json" for item in remaining) + + +def test_server_time_endpoint(prepared_app): + client = prepared_app.test_client() + resp = client.get("/api/server/time") + assert resp.status_code == 200 + body = resp.get_json() + assert set(["epoch", "iso", "utc_iso", "timezone", "offset_seconds", "display"]).issubset(body) From fddf0230e2e61753e0fc8a576a5580290d662138 Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Thu, 23 Oct 2025 01:01:15 -0600 Subject: [PATCH 10/12] Add agent REST endpoints and heartbeat handling --- Data/Engine/interfaces/http/__init__.py | 2 + Data/Engine/interfaces/http/agent.py | 113 +++++++++ Data/Engine/services/container.py | 6 +- .../devices/device_inventory_service.py | 128 +++++++++- Data/Engine/tests/test_http_agent.py | 234 ++++++++++++++++++ 5 files changed, 480 insertions(+), 3 deletions(-) create mode 100644 Data/Engine/interfaces/http/agent.py create mode 100644 Data/Engine/tests/test_http_agent.py diff --git a/Data/Engine/interfaces/http/__init__.py b/Data/Engine/interfaces/http/__init__.py index 43bfc9c..a428f1b 100644 --- a/Data/Engine/interfaces/http/__init__.py +++ b/Data/Engine/interfaces/http/__init__.py @@ -8,6 +8,7 @@ from Data.Engine.services.container import EngineServiceContainer from . import ( admin, + agent, agents, auth, enrollment, @@ -25,6 +26,7 @@ from . import ( _REGISTRARS = ( health.register, + agent.register, agents.register, enrollment.register, tokens.register, diff --git a/Data/Engine/interfaces/http/agent.py b/Data/Engine/interfaces/http/agent.py new file mode 100644 index 0000000..1d415db --- /dev/null +++ b/Data/Engine/interfaces/http/agent.py @@ -0,0 +1,113 @@ +"""Agent REST endpoints for device communication.""" + +from __future__ import annotations + +import math +from functools import wraps +from typing import Any, Callable, Optional, TypeVar, cast + +from flask import Blueprint, Flask, current_app, g, jsonify, request + +from Data.Engine.builders.device_auth import DeviceAuthRequestBuilder +from Data.Engine.domain.device_auth import DeviceAuthContext, DeviceAuthFailure +from Data.Engine.services.container import EngineServiceContainer +from Data.Engine.services.devices.device_inventory_service import DeviceHeartbeatError + +AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context" + +blueprint = Blueprint("engine_agent", __name__) + +F = TypeVar("F", bound=Callable[..., Any]) + + +def _services() -> EngineServiceContainer: + return cast(EngineServiceContainer, current_app.extensions["engine_services"]) + + +def require_device_auth(func: F) -> F: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any): + services = _services() + builder = ( + DeviceAuthRequestBuilder() + .with_authorization(request.headers.get("Authorization")) + .with_http_method(request.method) + .with_htu(request.url) + .with_service_context(request.headers.get(AGENT_CONTEXT_HEADER)) + .with_dpop_proof(request.headers.get("DPoP")) + ) + try: + auth_request = builder.build() + context = services.device_auth.authenticate(auth_request, path=request.path) + except DeviceAuthFailure as exc: + payload = exc.to_dict() + response = jsonify(payload) + if exc.retry_after is not None: + response.headers["Retry-After"] = str(int(math.ceil(exc.retry_after))) + return response, exc.http_status + + g.device_auth = context + try: + return func(*args, **kwargs) + finally: + g.pop("device_auth", None) + + return cast(F, wrapper) + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + if "engine_agent" not in app.blueprints: + app.register_blueprint(blueprint) + + +@blueprint.route("/api/agent/heartbeat", methods=["POST"]) +@require_device_auth +def heartbeat() -> Any: + services = _services() + payload = request.get_json(force=True, silent=True) or {} + context = cast(DeviceAuthContext, g.device_auth) + + try: + services.device_inventory.record_heartbeat(context=context, payload=payload) + except DeviceHeartbeatError as exc: + error_payload = {"error": exc.code} + if exc.code == "device_not_registered": + return jsonify(error_payload), 404 + if exc.code == "storage_conflict": + return jsonify(error_payload), 409 + current_app.logger.exception( + "device-heartbeat-error guid=%s code=%s", context.identity.guid.value, exc.code + ) + return jsonify(error_payload), 500 + + return jsonify({"status": "ok", "poll_after_ms": 15000}) + + +@blueprint.route("/api/agent/script/request", methods=["POST"]) +@require_device_auth +def script_request() -> Any: + services = _services() + context = cast(DeviceAuthContext, g.device_auth) + + signing_key: Optional[str] = None + signer = services.script_signer + if signer is not None: + try: + signing_key = signer.public_base64_spki() + except Exception as exc: # pragma: no cover - defensive logging + current_app.logger.warning("script-signer-unavailable: %s", exc) + + status = "quarantined" if context.is_quarantined else "idle" + poll_after = 60000 if context.is_quarantined else 30000 + + response = { + "status": status, + "poll_after_ms": poll_after, + "sig_alg": "ed25519", + } + if signing_key: + response["signing_key"] = signing_key + return jsonify(response) + + +__all__ = ["register", "blueprint", "heartbeat", "script_request", "require_device_auth"] diff --git a/Data/Engine/services/container.py b/Data/Engine/services/container.py index 4b8df94..9a0d06b 100644 --- a/Data/Engine/services/container.py +++ b/Data/Engine/services/container.py @@ -67,6 +67,7 @@ class EngineServiceContainer: operator_auth_service: OperatorAuthService operator_account_service: OperatorAccountService assembly_service: AssemblyService + script_signer: Optional[ScriptSigner] def build_service_container( @@ -106,6 +107,8 @@ def build_service_container( logger=log.getChild("token_service"), ) + script_signer = _load_script_signer(log) + enrollment_service = EnrollmentService( device_repository=device_repo, enrollment_repository=enrollment_repo, @@ -115,7 +118,7 @@ def build_service_container( ip_rate_limiter=SlidingWindowRateLimiter(), fingerprint_rate_limiter=SlidingWindowRateLimiter(), nonce_cache=NonceCache(), - script_signer=_load_script_signer(log), + script_signer=script_signer, logger=log.getChild("enrollment"), ) @@ -205,6 +208,7 @@ def build_service_container( credential_service=credential_service, site_service=site_service, assembly_service=assembly_service, + script_signer=script_signer, ) diff --git a/Data/Engine/services/devices/device_inventory_service.py b/Data/Engine/services/devices/device_inventory_service.py index 031e789..ec146f0 100644 --- a/Data/Engine/services/devices/device_inventory_service.py +++ b/Data/Engine/services/devices/device_inventory_service.py @@ -4,13 +4,17 @@ from __future__ import annotations import logging import sqlite3 -from typing import Dict, List, Optional +import time +from collections.abc import Mapping +from typing import Any, Dict, List, Optional from Data.Engine.repositories.sqlite.device_inventory_repository import ( SQLiteDeviceInventoryRepository, ) +from Data.Engine.domain.device_auth import DeviceAuthContext +from Data.Engine.domain.devices import clean_device_str, coerce_int -__all__ = ["DeviceInventoryService", "RemoteDeviceError"] +__all__ = ["DeviceInventoryService", "RemoteDeviceError", "DeviceHeartbeatError"] class RemoteDeviceError(Exception): @@ -19,6 +23,12 @@ class RemoteDeviceError(Exception): self.code = code +class DeviceHeartbeatError(Exception): + def __init__(self, code: str, message: Optional[str] = None) -> None: + super().__init__(message or code) + self.code = code + + class DeviceInventoryService: def __init__( self, @@ -176,3 +186,117 @@ class DeviceInventoryService: raise RemoteDeviceError("not_found", "device not found") self._repo.delete_device_by_hostname(normalized_host) + # ------------------------------------------------------------------ + # Agent heartbeats + # ------------------------------------------------------------------ + def record_heartbeat( + self, + *, + context: DeviceAuthContext, + payload: Mapping[str, Any], + ) -> None: + guid = context.identity.guid.value + snapshot = self._repo.load_snapshot(guid=guid) + if not snapshot: + raise DeviceHeartbeatError("device_not_registered", "device not registered") + + summary = dict(snapshot.get("summary") or {}) + details = dict(snapshot.get("details") or {}) + + now_ts = int(time.time()) + summary["last_seen"] = now_ts + summary["agent_guid"] = guid + + existing_hostname = clean_device_str(summary.get("hostname")) or clean_device_str( + snapshot.get("hostname") + ) + incoming_hostname = clean_device_str(payload.get("hostname")) + raw_metrics = payload.get("metrics") + metrics = raw_metrics if isinstance(raw_metrics, Mapping) else {} + metrics_hostname = clean_device_str(metrics.get("hostname")) if metrics else None + hostname = incoming_hostname or metrics_hostname or existing_hostname + if not hostname: + hostname = f"RECOVERED-{guid[:12]}" + summary["hostname"] = hostname + + if metrics: + last_user = metrics.get("last_user") or metrics.get("username") + if last_user: + cleaned_user = clean_device_str(last_user) + if cleaned_user: + summary["last_user"] = cleaned_user + operating_system = metrics.get("operating_system") + if operating_system: + cleaned_os = clean_device_str(operating_system) + if cleaned_os: + summary["operating_system"] = cleaned_os + uptime = metrics.get("uptime") + if uptime is not None: + coerced = coerce_int(uptime) + if coerced is not None: + summary["uptime"] = coerced + agent_id = metrics.get("agent_id") + if agent_id: + cleaned_agent = clean_device_str(agent_id) + if cleaned_agent: + summary["agent_id"] = cleaned_agent + + for field in ("external_ip", "internal_ip", "device_type"): + value = payload.get(field) + cleaned = clean_device_str(value) + if cleaned: + summary[field] = cleaned + + summary.setdefault("description", summary.get("description") or "") + created_at = coerce_int(summary.get("created_at")) + if created_at is None: + created_at = coerce_int(snapshot.get("created_at")) + if created_at is None: + created_at = now_ts + summary["created_at"] = created_at + + raw_inventory = payload.get("inventory") + inventory = raw_inventory if isinstance(raw_inventory, Mapping) else {} + memory = inventory.get("memory") if isinstance(inventory.get("memory"), list) else details.get("memory") + network = inventory.get("network") if isinstance(inventory.get("network"), list) else details.get("network") + software = ( + inventory.get("software") if isinstance(inventory.get("software"), list) else details.get("software") + ) + storage = inventory.get("storage") if isinstance(inventory.get("storage"), list) else details.get("storage") + cpu = inventory.get("cpu") if isinstance(inventory.get("cpu"), Mapping) else details.get("cpu") + + merged_details: Dict[str, Any] = { + "summary": summary, + "memory": memory, + "network": network, + "software": software, + "storage": storage, + "cpu": cpu, + } + + try: + self._repo.upsert_device( + summary["hostname"], + summary.get("description"), + merged_details, + summary.get("created_at"), + agent_hash=clean_device_str(summary.get("agent_hash")), + guid=guid, + ) + except sqlite3.IntegrityError as exc: + self._log.warning( + "device-heartbeat-conflict guid=%s hostname=%s error=%s", + guid, + summary["hostname"], + exc, + ) + raise DeviceHeartbeatError("storage_conflict", str(exc)) from exc + except Exception as exc: # pragma: no cover - defensive + self._log.exception( + "device-heartbeat-failure guid=%s hostname=%s", + guid, + summary["hostname"], + exc_info=exc, + ) + raise DeviceHeartbeatError("storage_error", "failed to persist heartbeat") from exc + diff --git a/Data/Engine/tests/test_http_agent.py b/Data/Engine/tests/test_http_agent.py new file mode 100644 index 0000000..8ca499e --- /dev/null +++ b/Data/Engine/tests/test_http_agent.py @@ -0,0 +1,234 @@ +import pytest + +pytest.importorskip("jwt") + +import json +import sqlite3 +import time +from datetime import datetime, timezone +from pathlib import Path + +from Data.Engine.config.environment import ( + DatabaseSettings, + EngineSettings, + FlaskSettings, + GitHubSettings, + ServerSettings, + SocketIOSettings, +) +from Data.Engine.domain.device_auth import ( + AccessTokenClaims, + DeviceAuthContext, + DeviceFingerprint, + DeviceGuid, + DeviceIdentity, + DeviceStatus, +) +from Data.Engine.interfaces.http import register_http_interfaces +from Data.Engine.repositories.sqlite import connection as sqlite_connection +from Data.Engine.repositories.sqlite import migrations as sqlite_migrations +from Data.Engine.server import create_app +from Data.Engine.services.container import build_service_container + + +@pytest.fixture() +def engine_settings(tmp_path: Path) -> EngineSettings: + project_root = tmp_path + static_root = project_root / "static" + static_root.mkdir() + (static_root / "index.html").write_text("", encoding="utf-8") + + database_path = project_root / "database.db" + + return EngineSettings( + project_root=project_root, + debug=False, + database=DatabaseSettings(path=database_path, apply_migrations=False), + flask=FlaskSettings( + secret_key="test-key", + static_root=static_root, + cors_allowed_origins=("https://localhost",), + ), + socketio=SocketIOSettings(cors_allowed_origins=("https://localhost",)), + server=ServerSettings(host="127.0.0.1", port=5000), + github=GitHubSettings( + default_repo="owner/repo", + default_branch="main", + refresh_interval_seconds=60, + cache_root=project_root / "cache", + ), + ) + + +@pytest.fixture() +def prepared_app(engine_settings: EngineSettings): + settings = engine_settings + settings.github.cache_root.mkdir(exist_ok=True, parents=True) + + db_factory = sqlite_connection.connection_factory(settings.database.path) + with sqlite_connection.connection_scope(settings.database.path) as conn: + sqlite_migrations.apply_all(conn) + + app = create_app(settings, db_factory=db_factory) + services = build_service_container(settings, db_factory=db_factory) + app.extensions["engine_services"] = services + register_http_interfaces(app, services) + app.config.update(TESTING=True) + return app + + +def _insert_device(app, guid: str, fingerprint: str, hostname: str) -> None: + db_path = Path(app.config["ENGINE_DATABASE_PATH"]) + now = int(time.time()) + with sqlite3.connect(db_path) as conn: + conn.execute( + """ + INSERT INTO devices ( + guid, + hostname, + created_at, + last_seen, + ssl_key_fingerprint, + token_version, + status, + key_added_at + ) VALUES (?, ?, ?, ?, ?, ?, 'active', ?) + """, + ( + guid, + hostname, + now, + now, + fingerprint.lower(), + 1, + datetime.now(timezone.utc).isoformat(), + ), + ) + conn.commit() + + +def _build_context(guid: str, fingerprint: str, *, status: DeviceStatus = DeviceStatus.ACTIVE) -> DeviceAuthContext: + now = int(time.time()) + claims = AccessTokenClaims( + subject="device", + guid=DeviceGuid(guid), + fingerprint=DeviceFingerprint(fingerprint), + token_version=1, + issued_at=now, + not_before=now, + expires_at=now + 600, + raw={"sub": "device"}, + ) + identity = DeviceIdentity(DeviceGuid(guid), DeviceFingerprint(fingerprint)) + return DeviceAuthContext( + identity=identity, + access_token="token", + claims=claims, + status=status, + service_context="SYSTEM", + ) + + +def test_heartbeat_updates_device(prepared_app, monkeypatch): + client = prepared_app.test_client() + guid = "DE305D54-75B4-431B-ADB2-EB6B9E546014" + fingerprint = "aa:bb:cc" + hostname = "device-heartbeat" + _insert_device(prepared_app, guid, fingerprint, hostname) + + services = prepared_app.extensions["engine_services"] + context = _build_context(guid, fingerprint) + monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context) + + payload = { + "hostname": hostname, + "inventory": {"memory": [{"total": "16GB"}], "cpu": {"cores": 8}}, + "metrics": {"operating_system": "Windows", "last_user": "Admin", "uptime": 120}, + "external_ip": "1.2.3.4", + } + + start = int(time.time()) + resp = client.post( + "/api/agent/heartbeat", + json=payload, + headers={"Authorization": "Bearer token"}, + ) + assert resp.status_code == 200 + body = resp.get_json() + assert body == {"status": "ok", "poll_after_ms": 15000} + + db_path = Path(prepared_app.config["ENGINE_DATABASE_PATH"]) + with sqlite3.connect(db_path) as conn: + row = conn.execute( + "SELECT last_seen, external_ip, memory, cpu FROM devices WHERE guid = ?", + (guid,), + ).fetchone() + + assert row is not None + last_seen, external_ip, memory_json, cpu_json = row + assert last_seen >= start + assert external_ip == "1.2.3.4" + assert json.loads(memory_json)[0]["total"] == "16GB" + assert json.loads(cpu_json)["cores"] == 8 + + +def test_heartbeat_returns_404_when_device_missing(prepared_app, monkeypatch): + client = prepared_app.test_client() + guid = "9E295C27-8339-40C8-AD1A-6ED95C164A4A" + fingerprint = "11:22:33" + services = prepared_app.extensions["engine_services"] + context = _build_context(guid, fingerprint) + monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context) + + resp = client.post( + "/api/agent/heartbeat", + json={"hostname": "missing-device"}, + headers={"Authorization": "Bearer token"}, + ) + assert resp.status_code == 404 + assert resp.get_json() == {"error": "device_not_registered"} + + +def test_script_request_reports_status_and_signing_key(prepared_app, monkeypatch): + client = prepared_app.test_client() + guid = "2F8D76C0-38D4-4700-B247-3E90C03A67D7" + fingerprint = "44:55:66" + hostname = "device-script" + _insert_device(prepared_app, guid, fingerprint, hostname) + + services = prepared_app.extensions["engine_services"] + context = _build_context(guid, fingerprint) + monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context) + + class DummySigner: + def public_base64_spki(self) -> str: + return "PUBKEY" + + object.__setattr__(services, "script_signer", DummySigner()) + + resp = client.post( + "/api/agent/script/request", + json={"guid": guid}, + headers={"Authorization": "Bearer token"}, + ) + assert resp.status_code == 200 + body = resp.get_json() + assert body == { + "status": "idle", + "poll_after_ms": 30000, + "sig_alg": "ed25519", + "signing_key": "PUBKEY", + } + + quarantined_context = _build_context(guid, fingerprint, status=DeviceStatus.QUARANTINED) + monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: quarantined_context) + + resp = client.post( + "/api/agent/script/request", + json={}, + headers={"Authorization": "Bearer token"}, + ) + assert resp.status_code == 200 + assert resp.get_json()["status"] == "quarantined" + assert resp.get_json()["poll_after_ms"] == 60000 + From 40cab79f2186645320e1ea50372c4614890d5fd9 Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Thu, 23 Oct 2025 01:51:27 -0600 Subject: [PATCH 11/12] Restore agent detail ingestion and device description updates --- Data/Engine/domain/devices.py | 4 + Data/Engine/interfaces/http/agent.py | 39 +++- Data/Engine/interfaces/http/devices.py | 20 +- .../sqlite/device_inventory_repository.py | 59 ++++- Data/Engine/services/devices/__init__.py | 15 +- .../devices/device_inventory_service.py | 205 +++++++++++++++++- Data/Engine/tests/test_http_agent.py | 98 +++++++++ Data/Engine/tests/test_http_sites_devices.py | 45 +++- 8 files changed, 473 insertions(+), 12 deletions(-) diff --git a/Data/Engine/domain/devices.py b/Data/Engine/domain/devices.py index 5c292c2..b369169 100644 --- a/Data/Engine/domain/devices.py +++ b/Data/Engine/domain/devices.py @@ -228,6 +228,10 @@ def assemble_device_snapshot(record: Mapping[str, Any]) -> Dict[str, Any]: "agent_guid": record.get("guid") or record.get("agent_guid") or "", "connection_type": record.get("connection_type") or "", "connection_endpoint": record.get("connection_endpoint") or "", + "ssl_key_fingerprint": record.get("ssl_key_fingerprint") or "", + "status": record.get("status") or "", + "token_version": record.get("token_version") or 0, + "key_added_at": record.get("key_added_at") or "", "created_at": record.get("created_at") or 0, } diff --git a/Data/Engine/interfaces/http/agent.py b/Data/Engine/interfaces/http/agent.py index 1d415db..811a939 100644 --- a/Data/Engine/interfaces/http/agent.py +++ b/Data/Engine/interfaces/http/agent.py @@ -11,7 +11,10 @@ from flask import Blueprint, Flask, current_app, g, jsonify, request from Data.Engine.builders.device_auth import DeviceAuthRequestBuilder from Data.Engine.domain.device_auth import DeviceAuthContext, DeviceAuthFailure from Data.Engine.services.container import EngineServiceContainer -from Data.Engine.services.devices.device_inventory_service import DeviceHeartbeatError +from Data.Engine.services.devices.device_inventory_service import ( + DeviceDetailsError, + DeviceHeartbeatError, +) AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context" @@ -110,4 +113,36 @@ def script_request() -> Any: return jsonify(response) -__all__ = ["register", "blueprint", "heartbeat", "script_request", "require_device_auth"] +@blueprint.route("/api/agent/details", methods=["POST"]) +@require_device_auth +def save_details() -> Any: + services = _services() + payload = request.get_json(force=True, silent=True) or {} + context = cast(DeviceAuthContext, g.device_auth) + + try: + services.device_inventory.save_agent_details(context=context, payload=payload) + except DeviceDetailsError as exc: + error_payload = {"error": exc.code} + if exc.code == "invalid_payload": + return jsonify(error_payload), 400 + if exc.code in {"fingerprint_mismatch", "guid_mismatch"}: + return jsonify(error_payload), 403 + if exc.code == "device_not_registered": + return jsonify(error_payload), 404 + current_app.logger.exception( + "device-details-error guid=%s code=%s", context.identity.guid.value, exc.code + ) + return jsonify(error_payload), 500 + + return jsonify({"status": "ok"}) + + +__all__ = [ + "register", + "blueprint", + "heartbeat", + "script_request", + "save_details", + "require_device_auth", +] diff --git a/Data/Engine/interfaces/http/devices.py b/Data/Engine/interfaces/http/devices.py index e618aa8..4c10c2c 100644 --- a/Data/Engine/interfaces/http/devices.py +++ b/Data/Engine/interfaces/http/devices.py @@ -5,7 +5,7 @@ from ipaddress import ip_address from flask import Blueprint, Flask, current_app, jsonify, request, session from Data.Engine.services.container import EngineServiceContainer -from Data.Engine.services.devices import RemoteDeviceError +from Data.Engine.services.devices import DeviceDescriptionError, RemoteDeviceError blueprint = Blueprint("engine_devices", __name__) @@ -64,6 +64,24 @@ def get_device_by_guid(guid: str) -> object: return jsonify(device) +@blueprint.route("/api/device/description/", methods=["POST"]) +def set_device_description(hostname: str) -> object: + payload = request.get_json(silent=True) or {} + description = payload.get("description") + try: + _inventory().update_device_description(hostname, description) + except DeviceDescriptionError as exc: + if exc.code == "invalid_hostname": + return jsonify({"error": "invalid hostname"}), 400 + if exc.code == "not_found": + return jsonify({"error": "not found"}), 404 + current_app.logger.exception( + "device-description-error host=%s code=%s", hostname, exc.code + ) + return jsonify({"error": "internal error"}), 500 + return jsonify({"status": "ok"}) + + @blueprint.route("/api/agent_devices", methods=["GET"]) def list_agent_devices() -> object: guard = _require_admin() diff --git a/Data/Engine/repositories/sqlite/device_inventory_repository.py b/Data/Engine/repositories/sqlite/device_inventory_repository.py index 8ae5767..9a50a9e 100644 --- a/Data/Engine/repositories/sqlite/device_inventory_repository.py +++ b/Data/Engine/repositories/sqlite/device_inventory_repository.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging import sqlite3 import time +import uuid from contextlib import closing from typing import Any, Dict, List, Optional, Tuple @@ -158,8 +159,12 @@ class SQLiteDeviceInventoryRepository: agent_id, ansible_ee_ver, connection_type, - connection_endpoint - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) + connection_endpoint, + ssl_key_fingerprint, + token_version, + status, + key_added_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) ON CONFLICT(hostname) DO UPDATE SET description=excluded.description, created_at=COALESCE({DEVICE_TABLE}.created_at, excluded.created_at), @@ -182,7 +187,11 @@ class SQLiteDeviceInventoryRepository: agent_id=COALESCE(NULLIF(excluded.agent_id, ''), {DEVICE_TABLE}.agent_id), ansible_ee_ver=COALESCE(NULLIF(excluded.ansible_ee_ver, ''), {DEVICE_TABLE}.ansible_ee_ver), connection_type=COALESCE(NULLIF(excluded.connection_type, ''), {DEVICE_TABLE}.connection_type), - connection_endpoint=COALESCE(NULLIF(excluded.connection_endpoint, ''), {DEVICE_TABLE}.connection_endpoint) + connection_endpoint=COALESCE(NULLIF(excluded.connection_endpoint, ''), {DEVICE_TABLE}.connection_endpoint), + ssl_key_fingerprint=COALESCE(NULLIF(excluded.ssl_key_fingerprint, ''), {DEVICE_TABLE}.ssl_key_fingerprint), + token_version=COALESCE(NULLIF(excluded.token_version, 0), {DEVICE_TABLE}.token_version), + status=COALESCE(NULLIF(excluded.status, ''), {DEVICE_TABLE}.status), + key_added_at=COALESCE(NULLIF(excluded.key_added_at, ''), {DEVICE_TABLE}.key_added_at) """ params: List[Any] = [ @@ -209,6 +218,10 @@ class SQLiteDeviceInventoryRepository: column_values.get("ansible_ee_ver"), column_values.get("connection_type"), column_values.get("connection_endpoint"), + column_values.get("ssl_key_fingerprint"), + column_values.get("token_version"), + column_values.get("status"), + column_values.get("key_added_at"), ] with closing(self._connections()) as conn: @@ -223,6 +236,42 @@ class SQLiteDeviceInventoryRepository: cur.execute(f"DELETE FROM {DEVICE_TABLE} WHERE hostname = ?", (hostname,)) conn.commit() + def record_device_fingerprint(self, guid: Optional[str], fingerprint: Optional[str], added_at: str) -> None: + normalized_guid = clean_device_str(guid) + normalized_fp = clean_device_str(fingerprint) + if not normalized_guid or not normalized_fp: + return + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at) + VALUES (?, ?, ?, ?) + """, + (str(uuid.uuid4()), normalized_guid, normalized_fp.lower(), added_at), + ) + cur.execute( + """ + UPDATE device_keys + SET retired_at = ? + WHERE guid = ? + AND ssl_key_fingerprint != ? + AND retired_at IS NULL + """, + (added_at, normalized_guid, normalized_fp.lower()), + ) + cur.execute( + """ + UPDATE devices + SET ssl_key_fingerprint = COALESCE(LOWER(?), ssl_key_fingerprint), + key_added_at = COALESCE(key_added_at, ?) + WHERE LOWER(guid) = LOWER(?) + """, + (normalized_fp, added_at, normalized_guid), + ) + conn.commit() + def _extract_device_columns(self, details: Dict[str, Any]) -> Dict[str, Any]: summary = details.get("summary") or {} payload: Dict[str, Any] = {} @@ -250,4 +299,8 @@ class SQLiteDeviceInventoryRepository: payload["connection_endpoint"] = clean_device_str( summary.get("connection_endpoint") or summary.get("endpoint") ) + payload["ssl_key_fingerprint"] = clean_device_str(summary.get("ssl_key_fingerprint")) + payload["token_version"] = coerce_int(summary.get("token_version")) or 0 + payload["status"] = clean_device_str(summary.get("status")) + payload["key_added_at"] = clean_device_str(summary.get("key_added_at")) return payload diff --git a/Data/Engine/services/devices/__init__.py b/Data/Engine/services/devices/__init__.py index d659909..ddef61e 100644 --- a/Data/Engine/services/devices/__init__.py +++ b/Data/Engine/services/devices/__init__.py @@ -1,4 +1,15 @@ -from .device_inventory_service import DeviceInventoryService, RemoteDeviceError +from .device_inventory_service import ( + DeviceDescriptionError, + DeviceDetailsError, + DeviceInventoryService, + RemoteDeviceError, +) from .device_view_service import DeviceViewService -__all__ = ["DeviceInventoryService", "RemoteDeviceError", "DeviceViewService"] +__all__ = [ + "DeviceInventoryService", + "RemoteDeviceError", + "DeviceViewService", + "DeviceDetailsError", + "DeviceDescriptionError", +] diff --git a/Data/Engine/services/devices/device_inventory_service.py b/Data/Engine/services/devices/device_inventory_service.py index ec146f0..9252494 100644 --- a/Data/Engine/services/devices/device_inventory_service.py +++ b/Data/Engine/services/devices/device_inventory_service.py @@ -2,19 +2,27 @@ from __future__ import annotations +import json import logging import sqlite3 import time +from datetime import datetime, timezone from collections.abc import Mapping from typing import Any, Dict, List, Optional from Data.Engine.repositories.sqlite.device_inventory_repository import ( SQLiteDeviceInventoryRepository, ) -from Data.Engine.domain.device_auth import DeviceAuthContext +from Data.Engine.domain.device_auth import DeviceAuthContext, normalize_guid from Data.Engine.domain.devices import clean_device_str, coerce_int -__all__ = ["DeviceInventoryService", "RemoteDeviceError", "DeviceHeartbeatError"] +__all__ = [ + "DeviceInventoryService", + "RemoteDeviceError", + "DeviceHeartbeatError", + "DeviceDetailsError", + "DeviceDescriptionError", +] class RemoteDeviceError(Exception): @@ -29,6 +37,18 @@ class DeviceHeartbeatError(Exception): self.code = code +class DeviceDetailsError(Exception): + def __init__(self, code: str, message: Optional[str] = None) -> None: + super().__init__(message or code) + self.code = code + + +class DeviceDescriptionError(Exception): + def __init__(self, code: str, message: Optional[str] = None) -> None: + super().__init__(message or code) + self.code = code + + class DeviceInventoryService: def __init__( self, @@ -220,7 +240,7 @@ class DeviceInventoryService: summary["hostname"] = hostname if metrics: - last_user = metrics.get("last_user") or metrics.get("username") + last_user = metrics.get("last_user") if last_user: cleaned_user = clean_device_str(last_user) if cleaned_user: @@ -300,3 +320,182 @@ class DeviceInventoryService: ) raise DeviceHeartbeatError("storage_error", "failed to persist heartbeat") from exc + # ------------------------------------------------------------------ + # Agent details + # ------------------------------------------------------------------ + @staticmethod + def _is_empty(value: Any) -> bool: + return value in (None, "", [], {}) + + @classmethod + def _deep_merge_preserve(cls, prev: Dict[str, Any], incoming: Dict[str, Any]) -> Dict[str, Any]: + merged: Dict[str, Any] = dict(prev or {}) + for key, value in (incoming or {}).items(): + if isinstance(value, Mapping): + existing = merged.get(key) + if not isinstance(existing, Mapping): + existing = {} + merged[key] = cls._deep_merge_preserve(dict(existing), dict(value)) + elif isinstance(value, list): + if value: + merged[key] = value + else: + if cls._is_empty(value): + continue + merged[key] = value + return merged + + def save_agent_details( + self, + *, + context: DeviceAuthContext, + payload: Mapping[str, Any], + ) -> None: + hostname = clean_device_str(payload.get("hostname")) + details_raw = payload.get("details") + agent_id = clean_device_str(payload.get("agent_id")) + agent_hash = clean_device_str(payload.get("agent_hash")) + + if not isinstance(details_raw, Mapping): + raise DeviceDetailsError("invalid_payload", "details object required") + + details_dict: Dict[str, Any] + try: + details_dict = json.loads(json.dumps(details_raw)) + except Exception: + details_dict = dict(details_raw) + + incoming_summary = dict(details_dict.get("summary") or {}) + if not hostname: + hostname = clean_device_str(incoming_summary.get("hostname")) + if not hostname: + raise DeviceDetailsError("invalid_payload", "hostname required") + + snapshot = self._repo.load_snapshot(hostname=hostname) + if not snapshot: + snapshot = {} + + previous_details = snapshot.get("details") + if isinstance(previous_details, Mapping): + try: + prev_details = json.loads(json.dumps(previous_details)) + except Exception: + prev_details = dict(previous_details) + else: + prev_details = {} + + prev_summary = dict(prev_details.get("summary") or {}) + + existing_guid = clean_device_str(snapshot.get("guid") or snapshot.get("summary", {}).get("agent_guid")) + normalized_existing_guid = normalize_guid(existing_guid) + auth_guid = context.identity.guid.value + + if normalized_existing_guid and normalized_existing_guid != auth_guid: + raise DeviceDetailsError("guid_mismatch", "device guid mismatch") + + fingerprint = context.identity.fingerprint.value.lower() + stored_fp = clean_device_str(snapshot.get("summary", {}).get("ssl_key_fingerprint")) + if stored_fp and stored_fp.lower() != fingerprint: + raise DeviceDetailsError("fingerprint_mismatch", "device fingerprint mismatch") + + incoming_summary.setdefault("hostname", hostname) + if agent_id and not incoming_summary.get("agent_id"): + incoming_summary["agent_id"] = agent_id + if agent_hash: + incoming_summary["agent_hash"] = agent_hash + incoming_summary["agent_guid"] = auth_guid + if fingerprint: + incoming_summary["ssl_key_fingerprint"] = fingerprint + if not incoming_summary.get("last_seen") and prev_summary.get("last_seen"): + incoming_summary["last_seen"] = prev_summary.get("last_seen") + + details_dict["summary"] = incoming_summary + merged_details = self._deep_merge_preserve(prev_details, details_dict) + merged_summary = merged_details.setdefault("summary", {}) + + if not merged_summary.get("last_user") and prev_summary.get("last_user"): + merged_summary["last_user"] = prev_summary.get("last_user") + + created_at = coerce_int(merged_summary.get("created_at")) + if created_at is None: + created_at = coerce_int(snapshot.get("created_at")) + if created_at is None: + created_at = int(time.time()) + merged_summary["created_at"] = created_at + + if fingerprint: + merged_summary["ssl_key_fingerprint"] = fingerprint + if not merged_summary.get("key_added_at"): + merged_summary["key_added_at"] = datetime.now(timezone.utc).isoformat() + if merged_summary.get("token_version") is None: + merged_summary["token_version"] = 1 + if not merged_summary.get("status") and snapshot.get("summary", {}).get("status"): + merged_summary["status"] = snapshot.get("summary", {}).get("status") + + description = clean_device_str(merged_summary.get("description")) + existing_description = snapshot.get("description") if snapshot else "" + description_to_store = description if description is not None else (existing_description or "") + + existing_hash = clean_device_str(snapshot.get("agent_hash") or snapshot.get("summary", {}).get("agent_hash")) + effective_hash = agent_hash or existing_hash + + try: + self._repo.upsert_device( + hostname, + description_to_store, + merged_details, + created_at, + agent_hash=effective_hash, + guid=auth_guid, + ) + except sqlite3.DatabaseError as exc: + raise DeviceDetailsError("storage_error", str(exc)) from exc + + added_at = merged_summary.get("key_added_at") or datetime.now(timezone.utc).isoformat() + self._repo.record_device_fingerprint(auth_guid, fingerprint, added_at) + + # ------------------------------------------------------------------ + # Description management + # ------------------------------------------------------------------ + def update_device_description(self, hostname: str, description: Optional[str]) -> None: + normalized_host = clean_device_str(hostname) + if not normalized_host: + raise DeviceDescriptionError("invalid_hostname", "invalid hostname") + + snapshot = self._repo.load_snapshot(hostname=normalized_host) + if not snapshot: + raise DeviceDescriptionError("not_found", "device not found") + + details = snapshot.get("details") + if isinstance(details, Mapping): + try: + existing = json.loads(json.dumps(details)) + except Exception: + existing = dict(details) + else: + existing = {} + + summary = dict(existing.get("summary") or {}) + summary["description"] = description or "" + existing["summary"] = summary + + created_at = coerce_int(summary.get("created_at")) + if created_at is None: + created_at = coerce_int(snapshot.get("created_at")) + if created_at is None: + created_at = int(time.time()) + + agent_hash = clean_device_str(summary.get("agent_hash") or snapshot.get("agent_hash")) + guid = clean_device_str(summary.get("agent_guid") or snapshot.get("guid")) + + try: + self._repo.upsert_device( + normalized_host, + description or (snapshot.get("description") or ""), + existing, + created_at, + agent_hash=agent_hash, + guid=guid, + ) + except sqlite3.DatabaseError as exc: + raise DeviceDescriptionError("storage_error", str(exc)) from exc diff --git a/Data/Engine/tests/test_http_agent.py b/Data/Engine/tests/test_http_agent.py index 8ca499e..885ca6c 100644 --- a/Data/Engine/tests/test_http_agent.py +++ b/Data/Engine/tests/test_http_agent.py @@ -232,3 +232,101 @@ def test_script_request_reports_status_and_signing_key(prepared_app, monkeypatch assert resp.get_json()["status"] == "quarantined" assert resp.get_json()["poll_after_ms"] == 60000 + +def test_agent_details_persists_inventory(prepared_app, monkeypatch): + client = prepared_app.test_client() + guid = "5C9D76E4-4C5A-4A5D-9B5D-1C2E3F4A5B6C" + fingerprint = "aa:bb:cc:dd" + hostname = "device-details" + _insert_device(prepared_app, guid, fingerprint, hostname) + + services = prepared_app.extensions["engine_services"] + context = _build_context(guid, fingerprint) + monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context) + + payload = { + "hostname": hostname, + "agent_id": "AGENT-01", + "agent_hash": "hash-value", + "details": { + "summary": { + "hostname": hostname, + "device_type": "Laptop", + "last_user": "BUNNY-LAB\\nicole.rappe", + "operating_system": "Windows 11", + "description": "Primary workstation", + }, + "memory": [{"slot": "DIMM0", "capacity": 17179869184}], + "storage": [{"model": "NVMe", "size": 512}], + "network": [{"adapter": "Ethernet", "ips": ["192.168.1.50"]}], + }, + } + + resp = client.post( + "/api/agent/details", + json=payload, + headers={"Authorization": "Bearer token"}, + ) + + assert resp.status_code == 200 + assert resp.get_json() == {"status": "ok"} + + db_path = Path(prepared_app.config["ENGINE_DATABASE_PATH"]) + with sqlite3.connect(db_path) as conn: + row = conn.execute( + """ + SELECT device_type, last_user, memory, storage, network, description + FROM devices + WHERE guid = ? + """, + (guid,), + ).fetchone() + + assert row is not None + device_type, last_user, memory_json, storage_json, network_json, description = row + assert device_type == "Laptop" + assert last_user == "BUNNY-LAB\\nicole.rappe" + assert description == "Primary workstation" + assert json.loads(memory_json)[0]["capacity"] == 17179869184 + assert json.loads(storage_json)[0]["model"] == "NVMe" + assert json.loads(network_json)[0]["ips"][0] == "192.168.1.50" + + +def test_heartbeat_preserves_last_user_from_details(prepared_app, monkeypatch): + client = prepared_app.test_client() + guid = "7E8F90A1-B2C3-4D5E-8F90-A1B2C3D4E5F6" + fingerprint = "11:22:33:44" + hostname = "device-preserve" + _insert_device(prepared_app, guid, fingerprint, hostname) + + services = prepared_app.extensions["engine_services"] + context = _build_context(guid, fingerprint) + monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context) + + client.post( + "/api/agent/details", + json={ + "hostname": hostname, + "details": { + "summary": {"hostname": hostname, "last_user": "BUNNY-LAB\\nicole.rappe"} + }, + }, + headers={"Authorization": "Bearer token"}, + ) + + client.post( + "/api/agent/heartbeat", + json={"hostname": hostname, "metrics": {"uptime": 120}}, + headers={"Authorization": "Bearer token"}, + ) + + db_path = Path(prepared_app.config["ENGINE_DATABASE_PATH"]) + with sqlite3.connect(db_path) as conn: + row = conn.execute( + "SELECT last_user FROM devices WHERE guid = ?", + (guid,), + ).fetchone() + + assert row is not None + assert row[0] == "BUNNY-LAB\\nicole.rappe" + diff --git a/Data/Engine/tests/test_http_sites_devices.py b/Data/Engine/tests/test_http_sites_devices.py index 486d82c..0925449 100644 --- a/Data/Engine/tests/test_http_sites_devices.py +++ b/Data/Engine/tests/test_http_sites_devices.py @@ -1,5 +1,6 @@ -import sqlite3 from datetime import datetime, timezone +import sqlite3 +import time import pytest @@ -106,3 +107,45 @@ def test_credentials_list_requires_admin(prepared_app): resp = client.get("/api/credentials") assert resp.status_code == 200 assert resp.get_json() == {"credentials": []} + + +def test_device_description_update(prepared_app, engine_settings): + client = prepared_app.test_client() + hostname = "device-desc" + guid = "A3D3F1E5-9B8C-4C6F-80F1-4D5E6F7A8B9C" + + now = int(time.time()) + conn = sqlite3.connect(engine_settings.database.path) + cur = conn.cursor() + cur.execute( + """ + INSERT INTO devices ( + guid, + hostname, + description, + created_at, + last_seen + ) VALUES (?, ?, '', ?, ?) + """, + (guid, hostname, now, now), + ) + conn.commit() + conn.close() + + resp = client.post( + f"/api/device/description/{hostname}", + json={"description": "Primary workstation"}, + ) + + assert resp.status_code == 200 + assert resp.get_json() == {"status": "ok"} + + conn = sqlite3.connect(engine_settings.database.path) + row = conn.execute( + "SELECT description FROM devices WHERE hostname = ?", + (hostname,), + ).fetchone() + conn.close() + + assert row is not None + assert row[0] == "Primary workstation" From 0a9a626c5665efe4d39c5399325199bd64d0f299 Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Thu, 23 Oct 2025 04:36:24 -0600 Subject: [PATCH 12/12] Restore device summary fields and assembly data flow --- Data/Engine/domain/devices.py | 86 ++++++++++++------- .../sqlite/device_inventory_repository.py | 56 +++++++++--- .../devices/device_inventory_service.py | 14 ++- Data/Engine/tests/test_http_agent.py | 53 ++++++++++++ 4 files changed, 166 insertions(+), 43 deletions(-) diff --git a/Data/Engine/domain/devices.py b/Data/Engine/domain/devices.py index b369169..0264f9a 100644 --- a/Data/Engine/domain/devices.py +++ b/Data/Engine/domain/devices.py @@ -7,6 +7,8 @@ from dataclasses import dataclass from datetime import datetime, timezone from typing import Any, Dict, List, Mapping, Optional, Sequence +from Data.Engine.domain.device_auth import normalize_guid + __all__ = [ "DEVICE_TABLE_COLUMNS", "DEVICE_TABLE", @@ -91,8 +93,13 @@ class DeviceSnapshot: operating_system: str uptime: int agent_id: str + ansible_ee_ver: str connection_type: str connection_endpoint: str + ssl_key_fingerprint: str + token_version: int + status: str + key_added_at: str details: Dict[str, Any] summary: Dict[str, Any] @@ -121,8 +128,13 @@ class DeviceSnapshot: "operating_system": self.operating_system, "uptime": self.uptime, "agent_id": self.agent_id, + "ansible_ee_ver": self.ansible_ee_ver, "connection_type": self.connection_type, "connection_endpoint": self.connection_endpoint, + "ssl_key_fingerprint": self.ssl_key_fingerprint, + "token_version": self.token_version, + "status": self.status, + "key_added_at": self.key_added_at, "details": self.details, "summary": self.summary, } @@ -211,33 +223,16 @@ def row_to_device_dict(row: Sequence[Any], columns: Sequence[str]) -> Dict[str, def assemble_device_snapshot(record: Mapping[str, Any]) -> Dict[str, Any]: - summary = { - "hostname": record.get("hostname") or "", - "description": record.get("description") or "", - "device_type": record.get("device_type") or "", - "domain": record.get("domain") or "", - "external_ip": record.get("external_ip") or "", - "internal_ip": record.get("internal_ip") or "", - "last_reboot": record.get("last_reboot") or "", - "last_seen": record.get("last_seen") or 0, - "last_user": record.get("last_user") or "", - "operating_system": record.get("operating_system") or "", - "uptime": record.get("uptime") or 0, - "agent_id": record.get("agent_id") or "", - "agent_hash": record.get("agent_hash") or "", - "agent_guid": record.get("guid") or record.get("agent_guid") or "", - "connection_type": record.get("connection_type") or "", - "connection_endpoint": record.get("connection_endpoint") or "", - "ssl_key_fingerprint": record.get("ssl_key_fingerprint") or "", - "status": record.get("status") or "", - "token_version": record.get("token_version") or 0, - "key_added_at": record.get("key_added_at") or "", - "created_at": record.get("created_at") or 0, - } + hostname = clean_device_str(record.get("hostname")) or "" + description = clean_device_str(record.get("description")) or "" + agent_hash = clean_device_str(record.get("agent_hash")) or "" + raw_guid = clean_device_str(record.get("guid")) + normalized_guid = normalize_guid(raw_guid) - created_ts = coerce_int(summary.get("created_at")) or 0 - last_seen_ts = coerce_int(summary.get("last_seen")) or 0 - uptime_val = coerce_int(summary.get("uptime")) or 0 + created_ts = coerce_int(record.get("created_at")) or 0 + last_seen_ts = coerce_int(record.get("last_seen")) or 0 + uptime_val = coerce_int(record.get("uptime")) or 0 + token_version = coerce_int(record.get("token_version")) or 0 parsed_lists = { key: _parse_device_json(record.get(key), default) @@ -245,20 +240,48 @@ def assemble_device_snapshot(record: Mapping[str, Any]) -> Dict[str, Any]: } cpu_obj = _parse_device_json(record.get("cpu"), DEVICE_JSON_OBJECT_FIELDS["cpu"]) + summary: Dict[str, Any] = { + "hostname": hostname, + "description": description, + "agent_hash": agent_hash, + "agent_guid": normalized_guid or "", + "agent_id": clean_device_str(record.get("agent_id")) or "", + "device_type": clean_device_str(record.get("device_type")) or "", + "domain": clean_device_str(record.get("domain")) or "", + "external_ip": clean_device_str(record.get("external_ip")) or "", + "internal_ip": clean_device_str(record.get("internal_ip")) or "", + "last_reboot": clean_device_str(record.get("last_reboot")) or "", + "last_seen": last_seen_ts, + "last_user": clean_device_str(record.get("last_user")) or "", + "operating_system": clean_device_str(record.get("operating_system")) or "", + "uptime": uptime_val, + "uptime_sec": uptime_val, + "ansible_ee_ver": clean_device_str(record.get("ansible_ee_ver")) or "", + "connection_type": clean_device_str(record.get("connection_type")) or "", + "connection_endpoint": clean_device_str(record.get("connection_endpoint")) or "", + "ssl_key_fingerprint": clean_device_str(record.get("ssl_key_fingerprint")) or "", + "status": clean_device_str(record.get("status")) or "", + "token_version": token_version, + "key_added_at": clean_device_str(record.get("key_added_at")) or "", + "created_at": created_ts, + "created": ts_to_human(created_ts), + } + details = { "memory": parsed_lists["memory"], "network": parsed_lists["network"], "software": parsed_lists["software"], "storage": parsed_lists["storage"], "cpu": cpu_obj, + "summary": dict(summary), } payload: Dict[str, Any] = { - "hostname": summary["hostname"], - "description": summary.get("description", ""), + "hostname": hostname, + "description": description, "created_at": created_ts, "created_at_iso": ts_to_iso(created_ts), - "agent_hash": summary.get("agent_hash", ""), + "agent_hash": agent_hash, "agent_guid": summary.get("agent_guid", ""), "guid": summary.get("agent_guid", ""), "memory": parsed_lists["memory"], @@ -277,8 +300,13 @@ def assemble_device_snapshot(record: Mapping[str, Any]) -> Dict[str, Any]: "operating_system": summary.get("operating_system", ""), "uptime": uptime_val, "agent_id": summary.get("agent_id", ""), + "ansible_ee_ver": summary.get("ansible_ee_ver", ""), "connection_type": summary.get("connection_type", ""), "connection_endpoint": summary.get("connection_endpoint", ""), + "ssl_key_fingerprint": summary.get("ssl_key_fingerprint", ""), + "token_version": summary.get("token_version", 0), + "status": summary.get("status", ""), + "key_added_at": summary.get("key_added_at", ""), "details": details, "summary": summary, } diff --git a/Data/Engine/repositories/sqlite/device_inventory_repository.py b/Data/Engine/repositories/sqlite/device_inventory_repository.py index 9a50a9e..6aa839d 100644 --- a/Data/Engine/repositories/sqlite/device_inventory_repository.py +++ b/Data/Engine/repositories/sqlite/device_inventory_repository.py @@ -278,28 +278,60 @@ class SQLiteDeviceInventoryRepository: for field in ("memory", "network", "software", "storage"): payload[field] = serialize_device_json(details.get(field), []) payload["cpu"] = serialize_device_json(summary.get("cpu") or details.get("cpu"), {}) - payload["device_type"] = clean_device_str(summary.get("device_type") or summary.get("type")) - payload["domain"] = clean_device_str(summary.get("domain")) - payload["external_ip"] = clean_device_str(summary.get("external_ip") or summary.get("public_ip")) - payload["internal_ip"] = clean_device_str(summary.get("internal_ip") or summary.get("private_ip")) - payload["last_reboot"] = clean_device_str(summary.get("last_reboot") or summary.get("last_boot")) - payload["last_seen"] = coerce_int(summary.get("last_seen")) + payload["device_type"] = clean_device_str( + summary.get("device_type") + or summary.get("type") + or summary.get("device_class") + ) + payload["domain"] = clean_device_str( + summary.get("domain") or summary.get("domain_name") + ) + payload["external_ip"] = clean_device_str( + summary.get("external_ip") or summary.get("public_ip") + ) + payload["internal_ip"] = clean_device_str( + summary.get("internal_ip") or summary.get("private_ip") + ) + payload["last_reboot"] = clean_device_str( + summary.get("last_reboot") or summary.get("last_boot") + ) + payload["last_seen"] = coerce_int( + summary.get("last_seen") or summary.get("last_seen_epoch") + ) payload["last_user"] = clean_device_str( summary.get("last_user") or summary.get("last_user_name") or summary.get("logged_in_user") + or summary.get("username") + or summary.get("user") ) payload["operating_system"] = clean_device_str( - summary.get("operating_system") or summary.get("os") + summary.get("operating_system") + or summary.get("agent_operating_system") + or summary.get("os") ) - payload["uptime"] = coerce_int(summary.get("uptime")) + uptime_value = ( + summary.get("uptime_sec") + or summary.get("uptime_seconds") + or summary.get("uptime") + ) + payload["uptime"] = coerce_int(uptime_value) payload["agent_id"] = clean_device_str(summary.get("agent_id")) payload["ansible_ee_ver"] = clean_device_str(summary.get("ansible_ee_ver")) - payload["connection_type"] = clean_device_str(summary.get("connection_type")) - payload["connection_endpoint"] = clean_device_str( - summary.get("connection_endpoint") or summary.get("endpoint") + payload["connection_type"] = clean_device_str( + summary.get("connection_type") or summary.get("remote_type") + ) + payload["connection_endpoint"] = clean_device_str( + summary.get("connection_endpoint") + or summary.get("endpoint") + or summary.get("connection_address") + or summary.get("address") + or summary.get("external_ip") + or summary.get("internal_ip") + ) + payload["ssl_key_fingerprint"] = clean_device_str( + summary.get("ssl_key_fingerprint") ) - payload["ssl_key_fingerprint"] = clean_device_str(summary.get("ssl_key_fingerprint")) payload["token_version"] = coerce_int(summary.get("token_version")) or 0 payload["status"] = clean_device_str(summary.get("status")) payload["key_added_at"] = clean_device_str(summary.get("key_added_at")) diff --git a/Data/Engine/services/devices/device_inventory_service.py b/Data/Engine/services/devices/device_inventory_service.py index 9252494..e06208e 100644 --- a/Data/Engine/services/devices/device_inventory_service.py +++ b/Data/Engine/services/devices/device_inventory_service.py @@ -14,7 +14,7 @@ from Data.Engine.repositories.sqlite.device_inventory_repository import ( SQLiteDeviceInventoryRepository, ) from Data.Engine.domain.device_auth import DeviceAuthContext, normalize_guid -from Data.Engine.domain.devices import clean_device_str, coerce_int +from Data.Engine.domain.devices import clean_device_str, coerce_int, ts_to_human __all__ = [ "DeviceInventoryService", @@ -240,7 +240,7 @@ class DeviceInventoryService: summary["hostname"] = hostname if metrics: - last_user = metrics.get("last_user") + last_user = metrics.get("last_user") or metrics.get("username") or metrics.get("user") if last_user: cleaned_user = clean_device_str(last_user) if cleaned_user: @@ -422,6 +422,8 @@ class DeviceInventoryService: if created_at is None: created_at = int(time.time()) merged_summary["created_at"] = created_at + if not merged_summary.get("created"): + merged_summary["created"] = ts_to_human(created_at) if fingerprint: merged_summary["ssl_key_fingerprint"] = fingerprint @@ -431,6 +433,14 @@ class DeviceInventoryService: merged_summary["token_version"] = 1 if not merged_summary.get("status") and snapshot.get("summary", {}).get("status"): merged_summary["status"] = snapshot.get("summary", {}).get("status") + uptime_val = merged_summary.get("uptime") + if merged_summary.get("uptime_sec") is None and uptime_val is not None: + coerced = coerce_int(uptime_val) + if coerced is not None: + merged_summary["uptime_sec"] = coerced + merged_summary.setdefault("uptime_seconds", coerced) + if merged_summary.get("uptime_seconds") is None and merged_summary.get("uptime_sec") is not None: + merged_summary["uptime_seconds"] = merged_summary.get("uptime_sec") description = clean_device_str(merged_summary.get("description")) existing_description = snapshot.get("description") if snapshot else "" diff --git a/Data/Engine/tests/test_http_agent.py b/Data/Engine/tests/test_http_agent.py index 885ca6c..0d16e9f 100644 --- a/Data/Engine/tests/test_http_agent.py +++ b/Data/Engine/tests/test_http_agent.py @@ -255,10 +255,14 @@ def test_agent_details_persists_inventory(prepared_app, monkeypatch): "last_user": "BUNNY-LAB\\nicole.rappe", "operating_system": "Windows 11", "description": "Primary workstation", + "last_reboot": "2025-10-01 10:00:00", + "uptime": 3600, }, "memory": [{"slot": "DIMM0", "capacity": 17179869184}], "storage": [{"model": "NVMe", "size": 512}], "network": [{"adapter": "Ethernet", "ips": ["192.168.1.50"]}], + "software": [{"name": "Borealis Agent", "version": "2.0"}], + "cpu": {"name": "Intel Core i7", "logical_cores": 8, "base_clock_ghz": 3.4}, }, } @@ -291,6 +295,26 @@ def test_agent_details_persists_inventory(prepared_app, monkeypatch): assert json.loads(storage_json)[0]["model"] == "NVMe" assert json.loads(network_json)[0]["ips"][0] == "192.168.1.50" + resp = client.get("/api/devices") + assert resp.status_code == 200 + listing = resp.get_json() + device = next((dev for dev in listing.get("devices", []) if dev["hostname"] == hostname), None) + assert device is not None + summary = device["summary"] + details = device["details"] + + assert summary["device_type"] == "Laptop" + assert summary["last_user"] == "BUNNY-LAB\\nicole.rappe" + assert summary["created"] + assert summary.get("uptime_sec") == 3600 + assert details["summary"]["device_type"] == "Laptop" + assert details["summary"]["last_reboot"] == "2025-10-01 10:00:00" + assert details["summary"]["created"] == summary["created"] + assert details["software"][0]["name"] == "Borealis Agent" + assert device["storage"][0]["model"] == "NVMe" + assert device["memory"][0]["capacity"] == 17179869184 + assert device["cpu"]["name"] == "Intel Core i7" + def test_heartbeat_preserves_last_user_from_details(prepared_app, monkeypatch): client = prepared_app.test_client() @@ -330,3 +354,32 @@ def test_heartbeat_preserves_last_user_from_details(prepared_app, monkeypatch): assert row is not None assert row[0] == "BUNNY-LAB\\nicole.rappe" + +def test_heartbeat_uses_username_when_last_user_missing(prepared_app, monkeypatch): + client = prepared_app.test_client() + guid = "802A4E5F-1B2C-4D5E-8F90-A1B2C3D4E5F7" + fingerprint = "55:66:77:88" + hostname = "device-username" + _insert_device(prepared_app, guid, fingerprint, hostname) + + services = prepared_app.extensions["engine_services"] + context = _build_context(guid, fingerprint) + monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context) + + resp = client.post( + "/api/agent/heartbeat", + json={"hostname": hostname, "metrics": {"username": "BUNNY-LAB\\alice.smith"}}, + headers={"Authorization": "Bearer token"}, + ) + assert resp.status_code == 200 + + db_path = Path(prepared_app.config["ENGINE_DATABASE_PATH"]) + with sqlite3.connect(db_path) as conn: + row = conn.execute( + "SELECT last_user FROM devices WHERE guid = ?", + (guid,), + ).fetchone() + + assert row is not None + assert row[0] == "BUNNY-LAB\\alice.smith" +