From d0fa6929b23211d9bb8c2cd13ca8ba0ba6c988fb Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Wed, 22 Oct 2025 23:26:06 -0600 Subject: [PATCH] Implement admin enrollment APIs --- Data/Engine/domain/device_auth.py | 7 + Data/Engine/domain/enrollment_admin.py | 206 +++++++++++ Data/Engine/interfaces/http/admin.py | 105 +++++- .../sqlite/enrollment_repository.py | 347 +++++++++++++++++- Data/Engine/repositories/sqlite/migrations.py | 70 ++++ .../repositories/sqlite/user_repository.py | 51 +++ Data/Engine/services/__init__.py | 5 + Data/Engine/services/container.py | 9 + Data/Engine/services/enrollment/__init__.py | 54 ++- .../services/enrollment/admin_service.py | 113 ++++++ .../tests/test_enrollment_admin_service.py | 122 ++++++ Data/Engine/tests/test_http_admin.py | 111 ++++++ 12 files changed, 1182 insertions(+), 18 deletions(-) create mode 100644 Data/Engine/domain/enrollment_admin.py create mode 100644 Data/Engine/services/enrollment/admin_service.py create mode 100644 Data/Engine/tests/test_enrollment_admin_service.py create mode 100644 Data/Engine/tests/test_http_admin.py diff --git a/Data/Engine/domain/device_auth.py b/Data/Engine/domain/device_auth.py index d377e52..b4d0c52 100644 --- a/Data/Engine/domain/device_auth.py +++ b/Data/Engine/domain/device_auth.py @@ -18,6 +18,7 @@ __all__ = [ "AccessTokenClaims", "DeviceAuthContext", "sanitize_service_context", + "normalize_guid", ] @@ -73,6 +74,12 @@ class DeviceGuid: return self.value +def normalize_guid(value: Optional[str]) -> str: + """Expose GUID normalization for administrative helpers.""" + + return _normalize_guid(value) + + @dataclass(frozen=True, slots=True) class DeviceFingerprint: """Normalized TLS key fingerprint associated with a device.""" diff --git a/Data/Engine/domain/enrollment_admin.py b/Data/Engine/domain/enrollment_admin.py new file mode 100644 index 0000000..8b5f32e --- /dev/null +++ b/Data/Engine/domain/enrollment_admin.py @@ -0,0 +1,206 @@ +"""Administrative enrollment domain models.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Mapping, Optional + +from Data.Engine.domain.device_auth import DeviceGuid, normalize_guid + +__all__ = [ + "EnrollmentCodeRecord", + "DeviceApprovalRecord", + "HostnameConflict", +] + + +def _parse_iso8601(value: Optional[str]) -> Optional[datetime]: + if not value: + return None + raw = str(value).strip() + if not raw: + return None + try: + dt = datetime.fromisoformat(raw) + except Exception as exc: # pragma: no cover - defensive parsing + raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt.astimezone(timezone.utc) + + +def _isoformat(value: Optional[datetime]) -> Optional[str]: + if value is None: + return None + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc).isoformat() + + +@dataclass(frozen=True, slots=True) +class EnrollmentCodeRecord: + """Installer code metadata exposed to administrative clients.""" + + record_id: str + code: str + expires_at: datetime + max_uses: int + use_count: int + created_by_user_id: Optional[str] + used_at: Optional[datetime] + used_by_guid: Optional[DeviceGuid] + last_used_at: Optional[datetime] + + @classmethod + def from_row(cls, row: Mapping[str, Any]) -> "EnrollmentCodeRecord": + record_id = str(row.get("id") or "").strip() + code = str(row.get("code") or "").strip() + if not record_id or not code: + raise ValueError("invalid enrollment install code record") + + used_by = row.get("used_by_guid") + used_by_guid = DeviceGuid(str(used_by)) if used_by else None + + return cls( + record_id=record_id, + code=code, + expires_at=_parse_iso8601(row.get("expires_at")) or datetime.now(tz=timezone.utc), + max_uses=int(row.get("max_uses") or 1), + use_count=int(row.get("use_count") or 0), + created_by_user_id=str(row.get("created_by_user_id") or "").strip() or None, + used_at=_parse_iso8601(row.get("used_at")), + used_by_guid=used_by_guid, + last_used_at=_parse_iso8601(row.get("last_used_at")), + ) + + def status(self, *, now: Optional[datetime] = None) -> str: + reference = now or datetime.now(tz=timezone.utc) + if self.use_count >= self.max_uses: + return "used" + if self.expires_at <= reference: + return "expired" + return "active" + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.record_id, + "code": self.code, + "expires_at": _isoformat(self.expires_at), + "max_uses": self.max_uses, + "use_count": self.use_count, + "created_by_user_id": self.created_by_user_id, + "used_at": _isoformat(self.used_at), + "used_by_guid": self.used_by_guid.value if self.used_by_guid else None, + "last_used_at": _isoformat(self.last_used_at), + "status": self.status(), + } + + +@dataclass(frozen=True, slots=True) +class HostnameConflict: + """Existing device details colliding with a pending approval.""" + + guid: Optional[str] + ssl_key_fingerprint: Optional[str] + site_id: Optional[int] + site_name: str + fingerprint_match: bool + requires_prompt: bool + + def to_dict(self) -> dict[str, Any]: + return { + "guid": self.guid, + "ssl_key_fingerprint": self.ssl_key_fingerprint, + "site_id": self.site_id, + "site_name": self.site_name, + "fingerprint_match": self.fingerprint_match, + "requires_prompt": self.requires_prompt, + } + + +@dataclass(frozen=True, slots=True) +class DeviceApprovalRecord: + """Administrative projection of a device approval entry.""" + + record_id: str + reference: str + status: str + claimed_hostname: str + claimed_fingerprint: str + created_at: datetime + updated_at: datetime + enrollment_code_id: Optional[str] + guid: Optional[str] + approved_by_user_id: Optional[str] + approved_by_username: Optional[str] + client_nonce: str + server_nonce: str + hostname_conflict: Optional[HostnameConflict] + alternate_hostname: Optional[str] + conflict_requires_prompt: bool + fingerprint_match: bool + + @classmethod + def from_row( + cls, + row: Mapping[str, Any], + *, + conflict: Optional[HostnameConflict] = None, + alternate_hostname: Optional[str] = None, + fingerprint_match: bool = False, + requires_prompt: bool = False, + ) -> "DeviceApprovalRecord": + record_id = str(row.get("id") or "").strip() + reference = str(row.get("approval_reference") or "").strip() + hostname = str(row.get("hostname_claimed") or "").strip() + fingerprint = str(row.get("ssl_key_fingerprint_claimed") or "").strip().lower() + if not record_id or not reference or not hostname or not fingerprint: + raise ValueError("invalid device approval record") + + guid_raw = normalize_guid(row.get("guid")) or None + + return cls( + record_id=record_id, + reference=reference, + status=str(row.get("status") or "pending").strip().lower(), + claimed_hostname=hostname, + claimed_fingerprint=fingerprint, + created_at=_parse_iso8601(row.get("created_at")) or datetime.now(tz=timezone.utc), + updated_at=_parse_iso8601(row.get("updated_at")) or datetime.now(tz=timezone.utc), + enrollment_code_id=str(row.get("enrollment_code_id") or "").strip() or None, + guid=guid_raw, + approved_by_user_id=str(row.get("approved_by_user_id") or "").strip() or None, + approved_by_username=str(row.get("approved_by_username") or "").strip() or None, + client_nonce=str(row.get("client_nonce") or "").strip(), + server_nonce=str(row.get("server_nonce") or "").strip(), + hostname_conflict=conflict, + alternate_hostname=alternate_hostname, + conflict_requires_prompt=requires_prompt, + fingerprint_match=fingerprint_match, + ) + + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "id": self.record_id, + "approval_reference": self.reference, + "status": self.status, + "hostname_claimed": self.claimed_hostname, + "ssl_key_fingerprint_claimed": self.claimed_fingerprint, + "created_at": _isoformat(self.created_at), + "updated_at": _isoformat(self.updated_at), + "enrollment_code_id": self.enrollment_code_id, + "guid": self.guid, + "approved_by_user_id": self.approved_by_user_id, + "approved_by_username": self.approved_by_username, + "client_nonce": self.client_nonce, + "server_nonce": self.server_nonce, + "conflict_requires_prompt": self.conflict_requires_prompt, + "fingerprint_match": self.fingerprint_match, + } + if self.hostname_conflict is not None: + payload["hostname_conflict"] = self.hostname_conflict.to_dict() + if self.alternate_hostname: + payload["alternate_hostname"] = self.alternate_hostname + return payload + diff --git a/Data/Engine/interfaces/http/admin.py b/Data/Engine/interfaces/http/admin.py index 2da2ec2..30d7fd9 100644 --- a/Data/Engine/interfaces/http/admin.py +++ b/Data/Engine/interfaces/http/admin.py @@ -1,8 +1,8 @@ -"""Administrative HTTP interface placeholders for the Engine.""" +"""Administrative HTTP endpoints for the Borealis Engine.""" from __future__ import annotations -from flask import Blueprint, Flask +from flask import Blueprint, Flask, current_app, jsonify, request, session from Data.Engine.services.container import EngineServiceContainer @@ -11,13 +11,106 @@ blueprint = Blueprint("engine_admin", __name__, url_prefix="/api/admin") def register(app: Flask, _services: EngineServiceContainer) -> None: - """Attach administrative routes to *app*. - - Concrete endpoints will be migrated in subsequent phases. - """ + """Attach administrative routes to *app*.""" if "engine_admin" not in app.blueprints: app.register_blueprint(blueprint) +def _services() -> EngineServiceContainer: + services = current_app.extensions.get("engine_services") + if services is None: # pragma: no cover - defensive + raise RuntimeError("engine services not initialized") + return services + + +def _admin_service(): + return _services().enrollment_admin_service + + +def _require_admin(): + username = session.get("username") + role = (session.get("role") or "").strip().lower() + if not isinstance(username, str) or not username: + return jsonify({"error": "not_authenticated"}), 401 + if role != "admin": + return jsonify({"error": "forbidden"}), 403 + return None + + +@blueprint.route("/enrollment-codes", methods=["GET"]) +def list_enrollment_codes() -> object: + guard = _require_admin() + if guard: + return guard + + status = request.args.get("status") + records = _admin_service().list_install_codes(status=status) + return jsonify({"codes": [record.to_dict() for record in records]}) + + +@blueprint.route("/enrollment-codes", methods=["POST"]) +def create_enrollment_code() -> object: + guard = _require_admin() + if guard: + return guard + + payload = request.get_json(silent=True) or {} + + ttl_value = payload.get("ttl_hours") + if ttl_value is None: + ttl_value = payload.get("ttl") or 1 + try: + ttl_hours = int(ttl_value) + except (TypeError, ValueError): + ttl_hours = 1 + + max_uses_value = payload.get("max_uses") + if max_uses_value is None: + max_uses_value = payload.get("allowed_uses", 2) + try: + max_uses = int(max_uses_value) + except (TypeError, ValueError): + max_uses = 2 + + creator = session.get("username") if isinstance(session.get("username"), str) else None + + try: + record = _admin_service().create_install_code( + ttl_hours=ttl_hours, + max_uses=max_uses, + created_by=creator, + ) + except ValueError as exc: + if str(exc) == "invalid_ttl": + return jsonify({"error": "invalid_ttl"}), 400 + raise + + response = jsonify(record.to_dict()) + response.status_code = 201 + return response + + +@blueprint.route("/enrollment-codes/", methods=["DELETE"]) +def delete_enrollment_code(code_id: str) -> object: + guard = _require_admin() + if guard: + return guard + + if not _admin_service().delete_install_code(code_id): + return jsonify({"error": "not_found"}), 404 + return jsonify({"status": "deleted"}) + + +@blueprint.route("/device-approvals", methods=["GET"]) +def list_device_approvals() -> object: + guard = _require_admin() + if guard: + return guard + + status = request.args.get("status") + records = _admin_service().list_device_approvals(status=status) + return jsonify({"approvals": [record.to_dict() for record in records]}) + + __all__ = ["register", "blueprint"] diff --git a/Data/Engine/repositories/sqlite/enrollment_repository.py b/Data/Engine/repositories/sqlite/enrollment_repository.py index a6549ec..5733af9 100644 --- a/Data/Engine/repositories/sqlite/enrollment_repository.py +++ b/Data/Engine/repositories/sqlite/enrollment_repository.py @@ -5,14 +5,19 @@ from __future__ import annotations import logging from contextlib import closing from datetime import datetime, timezone -from typing import Optional +from typing import Any, List, Optional, Tuple -from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid +from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid, normalize_guid from Data.Engine.domain.device_enrollment import ( EnrollmentApproval, EnrollmentApprovalStatus, EnrollmentCode, ) +from Data.Engine.domain.enrollment_admin import ( + DeviceApprovalRecord, + EnrollmentCodeRecord, + HostnameConflict, +) from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory __all__ = ["SQLiteEnrollmentRepository"] @@ -122,6 +127,158 @@ class SQLiteEnrollmentRepository: self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc) return None + def list_install_codes( + self, + *, + status: Optional[str] = None, + now: Optional[datetime] = None, + ) -> List[EnrollmentCodeRecord]: + reference = now or datetime.now(tz=timezone.utc) + status_filter = (status or "").strip().lower() + params: List[str] = [] + + sql = """ + SELECT id, + code, + expires_at, + created_by_user_id, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at + FROM enrollment_install_codes + """ + + if status_filter in {"active", "expired", "used"}: + sql += " WHERE " + if status_filter == "active": + sql += "use_count < max_uses AND expires_at > ?" + params.append(self._isoformat(reference)) + elif status_filter == "expired": + sql += "use_count < max_uses AND expires_at <= ?" + params.append(self._isoformat(reference)) + else: # used + sql += "use_count >= max_uses" + + sql += " ORDER BY expires_at ASC" + + rows: List[EnrollmentCodeRecord] = [] + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute(sql, params) + for raw in cur.fetchall(): + record = { + "id": raw[0], + "code": raw[1], + "expires_at": raw[2], + "created_by_user_id": raw[3], + "used_at": raw[4], + "used_by_guid": raw[5], + "max_uses": raw[6], + "use_count": raw[7], + "last_used_at": raw[8], + } + try: + rows.append(EnrollmentCodeRecord.from_row(record)) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("invalid enrollment install code row id=%s: %s", record.get("id"), exc) + return rows + + def get_install_code_record(self, record_id: str) -> Optional[EnrollmentCodeRecord]: + identifier = (record_id or "").strip() + if not identifier: + return None + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT id, + code, + expires_at, + created_by_user_id, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at + FROM enrollment_install_codes + WHERE id = ? + """, + (identifier,), + ) + row = cur.fetchone() + + if not row: + return None + + payload = { + "id": row[0], + "code": row[1], + "expires_at": row[2], + "created_by_user_id": row[3], + "used_at": row[4], + "used_by_guid": row[5], + "max_uses": row[6], + "use_count": row[7], + "last_used_at": row[8], + } + + try: + return EnrollmentCodeRecord.from_row(payload) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("invalid enrollment install code record id=%s: %s", identifier, exc) + return None + + def insert_install_code( + self, + *, + record_id: str, + code: str, + expires_at: datetime, + created_by: Optional[str], + max_uses: int, + ) -> EnrollmentCodeRecord: + expires_iso = self._isoformat(expires_at) + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO enrollment_install_codes ( + id, + code, + expires_at, + created_by_user_id, + max_uses, + use_count + ) VALUES (?, ?, ?, ?, ?, 0) + """, + (record_id, code, expires_iso, created_by, max_uses), + ) + conn.commit() + + record = self.get_install_code_record(record_id) + if record is None: + raise RuntimeError("failed to load install code after insert") + return record + + def delete_install_code_if_unused(self, record_id: str) -> bool: + identifier = (record_id or "").strip() + if not identifier: + return False + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + "DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0", + (identifier,), + ) + deleted = cur.rowcount > 0 + conn.commit() + return deleted + def update_install_code_usage( self, record_id: str, @@ -165,6 +322,100 @@ class SQLiteEnrollmentRepository: # ------------------------------------------------------------------ # Device approvals # ------------------------------------------------------------------ + def list_device_approvals( + self, + *, + status: Optional[str] = None, + ) -> List[DeviceApprovalRecord]: + status_filter = (status or "").strip().lower() + params: List[str] = [] + + sql = """ + SELECT + da.id, + da.approval_reference, + da.guid, + da.hostname_claimed, + da.ssl_key_fingerprint_claimed, + da.enrollment_code_id, + da.status, + da.client_nonce, + da.server_nonce, + da.created_at, + da.updated_at, + da.approved_by_user_id, + u.username AS approved_by_username + FROM device_approvals AS da + LEFT JOIN users AS u + ON ( + CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT) + OR LOWER(da.approved_by_user_id) = LOWER(u.username) + ) + """ + + if status_filter and status_filter not in {"all", "*"}: + sql += " WHERE LOWER(da.status) = ?" + params.append(status_filter) + + sql += " ORDER BY da.created_at ASC" + + approvals: List[DeviceApprovalRecord] = [] + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute(sql, params) + rows = cur.fetchall() + + for raw in rows: + record = { + "id": raw[0], + "approval_reference": raw[1], + "guid": raw[2], + "hostname_claimed": raw[3], + "ssl_key_fingerprint_claimed": raw[4], + "enrollment_code_id": raw[5], + "status": raw[6], + "client_nonce": raw[7], + "server_nonce": raw[8], + "created_at": raw[9], + "updated_at": raw[10], + "approved_by_user_id": raw[11], + "approved_by_username": raw[12], + } + + conflict, fingerprint_match, requires_prompt = self._compute_hostname_conflict( + conn, + record.get("hostname_claimed"), + record.get("guid"), + record.get("ssl_key_fingerprint_claimed") or "", + ) + + alternate = None + if conflict and requires_prompt: + alternate = self._suggest_alternate_hostname( + conn, + record.get("hostname_claimed"), + record.get("guid"), + ) + + try: + approvals.append( + DeviceApprovalRecord.from_row( + record, + conflict=conflict, + alternate_hostname=alternate, + fingerprint_match=fingerprint_match, + requires_prompt=requires_prompt, + ) + ) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning( + "invalid device approval record id=%s: %s", + record.get("id"), + exc, + ) + + return approvals + def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: """Load a device approval using its operator-visible reference.""" @@ -376,6 +627,98 @@ class SQLiteEnrollmentRepository: ) return None + def _compute_hostname_conflict( + self, + conn, + hostname: Optional[str], + pending_guid: Optional[str], + claimed_fp: str, + ) -> Tuple[Optional[HostnameConflict], bool, bool]: + normalized_host = (hostname or "").strip() + if not normalized_host: + return None, False, False + + try: + cur = conn.cursor() + cur.execute( + """ + SELECT d.guid, + d.ssl_key_fingerprint, + ds.site_id, + s.name + FROM devices AS d + LEFT JOIN device_sites AS ds ON ds.device_hostname = d.hostname + LEFT JOIN sites AS s ON s.id = ds.site_id + WHERE d.hostname = ? + """, + (normalized_host,), + ) + row = cur.fetchone() + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("failed to inspect hostname conflict for %s: %s", normalized_host, exc) + return None, False, False + + if not row: + return None, False, False + + existing_guid = normalize_guid(row[0]) + pending_norm = normalize_guid(pending_guid) + if existing_guid and pending_norm and existing_guid == pending_norm: + return None, False, False + + stored_fp = (row[1] or "").strip().lower() + claimed_fp_normalized = (claimed_fp or "").strip().lower() + fingerprint_match = bool(stored_fp and claimed_fp_normalized and stored_fp == claimed_fp_normalized) + + site_id = None + if row[2] is not None: + try: + site_id = int(row[2]) + except (TypeError, ValueError): # pragma: no cover - defensive + site_id = None + + site_name = str(row[3] or "").strip() + requires_prompt = not fingerprint_match + + conflict = HostnameConflict( + guid=existing_guid or None, + ssl_key_fingerprint=stored_fp or None, + site_id=site_id, + site_name=site_name, + fingerprint_match=fingerprint_match, + requires_prompt=requires_prompt, + ) + + return conflict, fingerprint_match, requires_prompt + + def _suggest_alternate_hostname( + self, + conn, + hostname: Optional[str], + pending_guid: Optional[str], + ) -> Optional[str]: + base = (hostname or "").strip() + if not base: + return None + base = base[:253] + candidate = base + pending_norm = normalize_guid(pending_guid) + suffix = 1 + + cur = conn.cursor() + while True: + cur.execute("SELECT guid FROM devices WHERE hostname = ?", (candidate,)) + row = cur.fetchone() + if not row: + return candidate + existing_guid = normalize_guid(row[0]) + if pending_norm and existing_guid == pending_norm: + return candidate + candidate = f"{base}-{suffix}" + suffix += 1 + if suffix > 50: + return pending_norm or candidate + @staticmethod def _isoformat(value: datetime) -> str: if value.tzinfo is None: diff --git a/Data/Engine/repositories/sqlite/migrations.py b/Data/Engine/repositories/sqlite/migrations.py index 34d3c77..535c78c 100644 --- a/Data/Engine/repositories/sqlite/migrations.py +++ b/Data/Engine/repositories/sqlite/migrations.py @@ -31,6 +31,9 @@ def apply_all(conn: sqlite3.Connection) -> None: _ensure_refresh_token_table(conn) _ensure_install_code_table(conn) _ensure_device_approval_table(conn) + _ensure_device_list_views_table(conn) + _ensure_sites_tables(conn) + _ensure_credentials_table(conn) _ensure_github_token_table(conn) _ensure_scheduled_jobs_table(conn) _ensure_scheduled_job_run_tables(conn) @@ -233,6 +236,73 @@ def _ensure_device_approval_table(conn: sqlite3.Connection) -> None: ) +def _ensure_device_list_views_table(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS device_list_views ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + columns_json TEXT NOT NULL, + filters_json TEXT, + created_at INTEGER, + updated_at INTEGER + ) + """ + ) + + +def _ensure_sites_tables(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS sites ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + description TEXT, + created_at INTEGER + ) + """ + ) + cur.execute( + """ + CREATE TABLE IF NOT EXISTS device_sites ( + device_hostname TEXT UNIQUE NOT NULL, + site_id INTEGER NOT NULL, + assigned_at INTEGER, + FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE CASCADE + ) + """ + ) + + +def _ensure_credentials_table(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS credentials ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + description TEXT, + site_id INTEGER, + credential_type TEXT NOT NULL DEFAULT 'machine', + connection_type TEXT NOT NULL DEFAULT 'ssh', + username TEXT, + password_encrypted BLOB, + private_key_encrypted BLOB, + private_key_passphrase_encrypted BLOB, + become_method TEXT, + become_username TEXT, + become_password_encrypted BLOB, + metadata_json TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE SET NULL + ) + """ + ) + + def _ensure_github_token_table(conn: sqlite3.Connection) -> None: cur = conn.cursor() cur.execute( diff --git a/Data/Engine/repositories/sqlite/user_repository.py b/Data/Engine/repositories/sqlite/user_repository.py index 9c61a4d..9c3002d 100644 --- a/Data/Engine/repositories/sqlite/user_repository.py +++ b/Data/Engine/repositories/sqlite/user_repository.py @@ -71,6 +71,57 @@ class SQLiteUserRepository: finally: conn.close() + def resolve_identifier(self, username: str) -> Optional[str]: + normalized = (username or "").strip() + if not normalized: + return None + + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute( + "SELECT id FROM users WHERE LOWER(username) = LOWER(?)", + (normalized,), + ) + row = cur.fetchone() + if not row: + return None + return str(row[0]) if row[0] is not None else None + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.error("failed to resolve identifier for %s: %s", username, exc) + return None + finally: + conn.close() + + def username_for_identifier(self, identifier: str) -> Optional[str]: + token = (identifier or "").strip() + if not token: + return None + + conn = self._connection_factory() + try: + cur = conn.cursor() + cur.execute( + """ + SELECT username + FROM users + WHERE CAST(id AS TEXT) = ? + OR LOWER(username) = LOWER(?) + LIMIT 1 + """, + (token, token), + ) + row = cur.fetchone() + if not row: + return None + username = str(row[0] or "").strip() + return username or None + except sqlite3.Error as exc: # pragma: no cover - defensive + self._log.error("failed to resolve username for %s: %s", identifier, exc) + return None + finally: + conn.close() + def list_accounts(self) -> list[OperatorAccount]: conn = self._connection_factory() try: diff --git a/Data/Engine/services/__init__.py b/Data/Engine/services/__init__.py index 3e216c7..9c59917 100644 --- a/Data/Engine/services/__init__.py +++ b/Data/Engine/services/__init__.py @@ -23,6 +23,7 @@ __all__ = [ "SchedulerService", "GitHubService", "GitHubTokenPayload", + "EnrollmentAdminService", ] _LAZY_TARGETS: Dict[str, Tuple[str, str]] = { @@ -43,6 +44,10 @@ _LAZY_TARGETS: Dict[str, Tuple[str, str]] = { "SchedulerService": ("Data.Engine.services.jobs.scheduler_service", "SchedulerService"), "GitHubService": ("Data.Engine.services.github.github_service", "GitHubService"), "GitHubTokenPayload": ("Data.Engine.services.github.github_service", "GitHubTokenPayload"), + "EnrollmentAdminService": ( + "Data.Engine.services.enrollment.admin_service", + "EnrollmentAdminService", + ), } diff --git a/Data/Engine/services/container.py b/Data/Engine/services/container.py index 621e02a..bbb731b 100644 --- a/Data/Engine/services/container.py +++ b/Data/Engine/services/container.py @@ -30,6 +30,7 @@ from Data.Engine.services.auth import ( ) from Data.Engine.services.crypto.signing import ScriptSigner, load_signer from Data.Engine.services.enrollment import EnrollmentService +from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService from Data.Engine.services.enrollment.nonce_cache import NonceCache from Data.Engine.services.github import GitHubService from Data.Engine.services.jobs import SchedulerService @@ -44,6 +45,7 @@ class EngineServiceContainer: device_auth: DeviceAuthService token_service: TokenService enrollment_service: EnrollmentService + enrollment_admin_service: EnrollmentAdminService jwt_service: JWTService dpop_validator: DPoPValidator agent_realtime: AgentRealtimeService @@ -93,6 +95,12 @@ def build_service_container( logger=log.getChild("enrollment"), ) + enrollment_admin_service = EnrollmentAdminService( + repository=enrollment_repo, + user_repository=user_repo, + logger=log.getChild("enrollment_admin"), + ) + device_auth = DeviceAuthService( device_repository=device_repo, jwt_service=jwt_service, @@ -139,6 +147,7 @@ def build_service_container( device_auth=device_auth, token_service=token_service, enrollment_service=enrollment_service, + enrollment_admin_service=enrollment_admin_service, jwt_service=jwt_service, dpop_validator=dpop_validator, agent_realtime=agent_realtime, diff --git a/Data/Engine/services/enrollment/__init__.py b/Data/Engine/services/enrollment/__init__.py index 063cd7b..7277d59 100644 --- a/Data/Engine/services/enrollment/__init__.py +++ b/Data/Engine/services/enrollment/__init__.py @@ -2,20 +2,54 @@ from __future__ import annotations -from .enrollment_service import ( - EnrollmentRequestResult, - EnrollmentService, - EnrollmentStatus, - EnrollmentTokenBundle, - PollingResult, -) -from Data.Engine.domain.device_enrollment import EnrollmentValidationError +from importlib import import_module +from typing import Any __all__ = [ - "EnrollmentRequestResult", "EnrollmentService", + "EnrollmentRequestResult", "EnrollmentStatus", "EnrollmentTokenBundle", - "EnrollmentValidationError", "PollingResult", + "EnrollmentValidationError", + "EnrollmentAdminService", ] + +_LAZY: dict[str, tuple[str, str]] = { + "EnrollmentService": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentService"), + "EnrollmentRequestResult": ( + "Data.Engine.services.enrollment.enrollment_service", + "EnrollmentRequestResult", + ), + "EnrollmentStatus": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentStatus"), + "EnrollmentTokenBundle": ( + "Data.Engine.services.enrollment.enrollment_service", + "EnrollmentTokenBundle", + ), + "PollingResult": ("Data.Engine.services.enrollment.enrollment_service", "PollingResult"), + "EnrollmentValidationError": ( + "Data.Engine.domain.device_enrollment", + "EnrollmentValidationError", + ), + "EnrollmentAdminService": ( + "Data.Engine.services.enrollment.admin_service", + "EnrollmentAdminService", + ), +} + + +def __getattr__(name: str) -> Any: + try: + module_name, attribute = _LAZY[name] + except KeyError as exc: # pragma: no cover - defensive + raise AttributeError(name) from exc + + module = import_module(module_name) + value = getattr(module, attribute) + globals()[name] = value + return value + + +def __dir__() -> list[str]: # pragma: no cover - interactive helper + return sorted(set(__all__)) + diff --git a/Data/Engine/services/enrollment/admin_service.py b/Data/Engine/services/enrollment/admin_service.py new file mode 100644 index 0000000..de8193f --- /dev/null +++ b/Data/Engine/services/enrollment/admin_service.py @@ -0,0 +1,113 @@ +"""Administrative helpers for enrollment workflows.""" + +from __future__ import annotations + +import logging +import secrets +import uuid +from datetime import datetime, timedelta, timezone +from typing import Callable, List, Optional + +from Data.Engine.domain.enrollment_admin import DeviceApprovalRecord, EnrollmentCodeRecord +from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository +from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository + +__all__ = ["EnrollmentAdminService"] + + +class EnrollmentAdminService: + """Expose administrative enrollment operations.""" + + _VALID_TTL_HOURS = {1, 3, 6, 12, 24} + + def __init__( + self, + *, + repository: SQLiteEnrollmentRepository, + user_repository: SQLiteUserRepository, + logger: Optional[logging.Logger] = None, + clock: Optional[Callable[[], datetime]] = None, + ) -> None: + self._repository = repository + self._users = user_repository + self._log = logger or logging.getLogger("borealis.engine.services.enrollment_admin") + self._clock = clock or (lambda: datetime.now(tz=timezone.utc)) + + # ------------------------------------------------------------------ + # Enrollment install codes + # ------------------------------------------------------------------ + def list_install_codes(self, *, status: Optional[str] = None) -> List[EnrollmentCodeRecord]: + return self._repository.list_install_codes(status=status, now=self._clock()) + + def create_install_code( + self, + *, + ttl_hours: int, + max_uses: int, + created_by: Optional[str], + ) -> EnrollmentCodeRecord: + if ttl_hours not in self._VALID_TTL_HOURS: + raise ValueError("invalid_ttl") + + normalized_max = self._normalize_max_uses(max_uses) + + now = self._clock() + expires_at = now + timedelta(hours=ttl_hours) + record_id = str(uuid.uuid4()) + code = self._generate_install_code() + + created_by_identifier = None + if created_by: + created_by_identifier = self._users.resolve_identifier(created_by) + if not created_by_identifier: + created_by_identifier = created_by.strip() or None + + record = self._repository.insert_install_code( + record_id=record_id, + code=code, + expires_at=expires_at, + created_by=created_by_identifier, + max_uses=normalized_max, + ) + + self._log.info( + "install code created id=%s ttl=%sh max_uses=%s", + record.record_id, + ttl_hours, + normalized_max, + ) + + return record + + def delete_install_code(self, record_id: str) -> bool: + deleted = self._repository.delete_install_code_if_unused(record_id) + if deleted: + self._log.info("install code deleted id=%s", record_id) + return deleted + + # ------------------------------------------------------------------ + # Device approvals + # ------------------------------------------------------------------ + def list_device_approvals(self, *, status: Optional[str] = None) -> List[DeviceApprovalRecord]: + return self._repository.list_device_approvals(status=status) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + @staticmethod + def _generate_install_code() -> str: + raw = secrets.token_hex(16).upper() + return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4)) + + @staticmethod + def _normalize_max_uses(value: int) -> int: + try: + count = int(value) + except Exception: + count = 2 + if count < 1: + return 1 + if count > 10: + return 10 + return count + diff --git a/Data/Engine/tests/test_enrollment_admin_service.py b/Data/Engine/tests/test_enrollment_admin_service.py new file mode 100644 index 0000000..9fb3f64 --- /dev/null +++ b/Data/Engine/tests/test_enrollment_admin_service.py @@ -0,0 +1,122 @@ +import base64 +import sqlite3 +from datetime import datetime, timezone + +import pytest + +from Data.Engine.repositories.sqlite import connection as sqlite_connection +from Data.Engine.repositories.sqlite import migrations as sqlite_migrations +from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository +from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository +from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService + + +def _build_service(tmp_path): + db_path = tmp_path / "admin.db" + conn = sqlite3.connect(db_path) + sqlite_migrations.apply_all(conn) + conn.close() + + factory = sqlite_connection.connection_factory(db_path) + enrollment_repo = SQLiteEnrollmentRepository(factory) + user_repo = SQLiteUserRepository(factory) + + fixed_now = datetime(2024, 1, 1, tzinfo=timezone.utc) + service = EnrollmentAdminService( + repository=enrollment_repo, + user_repository=user_repo, + clock=lambda: fixed_now, + ) + return service, factory, fixed_now + + +def test_create_and_list_install_codes(tmp_path): + service, factory, fixed_now = _build_service(tmp_path) + + record = service.create_install_code(ttl_hours=3, max_uses=5, created_by="admin") + assert record.code + assert record.max_uses == 5 + assert record.status(now=fixed_now) == "active" + + records = service.list_install_codes() + assert any(r.record_id == record.record_id for r in records) + + # Invalid TTL should raise + with pytest.raises(ValueError): + service.create_install_code(ttl_hours=2, max_uses=1, created_by=None) + + # Deleting should succeed and remove the record + assert service.delete_install_code(record.record_id) is True + remaining = service.list_install_codes() + assert all(r.record_id != record.record_id for r in remaining) + + +def test_list_device_approvals_includes_conflict(tmp_path): + service, factory, fixed_now = _build_service(tmp_path) + + conn = factory() + cur = conn.cursor() + + cur.execute( + "INSERT INTO sites (name, description, created_at) VALUES (?, ?, ?)", + ("HQ", "Primary site", int(fixed_now.timestamp())), + ) + site_id = cur.lastrowid + + cur.execute( + """ + INSERT INTO devices (guid, hostname, created_at, last_seen, ssl_key_fingerprint, status) + VALUES (?, ?, ?, ?, ?, 'active') + """, + ("11111111-1111-1111-1111-111111111111", "agent-one", int(fixed_now.timestamp()), int(fixed_now.timestamp()), "abc123",), + ) + cur.execute( + "INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)", + ("agent-one", site_id, int(fixed_now.timestamp())), + ) + + now_iso = fixed_now.isoformat() + cur.execute( + """ + INSERT INTO device_approvals ( + id, + approval_reference, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed, + enrollment_code_id, + status, + client_nonce, + server_nonce, + created_at, + updated_at, + approved_by_user_id, + agent_pubkey_der + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + "approval-1", + "REF123", + None, + "agent-one", + "abc123", + "code-1", + "pending", + base64.b64encode(b"client").decode(), + base64.b64encode(b"server").decode(), + now_iso, + now_iso, + None, + b"pubkey", + ), + ) + conn.commit() + conn.close() + + approvals = service.list_device_approvals() + assert len(approvals) == 1 + record = approvals[0] + assert record.hostname_conflict is not None + assert record.hostname_conflict.fingerprint_match is True + assert record.conflict_requires_prompt is False + diff --git a/Data/Engine/tests/test_http_admin.py b/Data/Engine/tests/test_http_admin.py new file mode 100644 index 0000000..f3e0cc4 --- /dev/null +++ b/Data/Engine/tests/test_http_admin.py @@ -0,0 +1,111 @@ +import base64 +import sqlite3 +from datetime import datetime, timezone + +from .test_http_auth import _login, prepared_app + + +def test_enrollment_codes_require_authentication(prepared_app): + client = prepared_app.test_client() + resp = client.get("/api/admin/enrollment-codes") + assert resp.status_code == 401 + + +def test_enrollment_code_workflow(prepared_app): + client = prepared_app.test_client() + _login(client) + + payload = {"ttl_hours": 3, "max_uses": 4} + resp = client.post("/api/admin/enrollment-codes", json=payload) + assert resp.status_code == 201 + created = resp.get_json() + assert created["max_uses"] == 4 + assert created["status"] == "active" + + resp = client.get("/api/admin/enrollment-codes") + assert resp.status_code == 200 + codes = resp.get_json().get("codes", []) + assert any(code["id"] == created["id"] for code in codes) + + resp = client.delete(f"/api/admin/enrollment-codes/{created['id']}") + assert resp.status_code == 200 + + +def test_device_approvals_listing(prepared_app, engine_settings): + client = prepared_app.test_client() + _login(client) + + conn = sqlite3.connect(engine_settings.database.path) + cur = conn.cursor() + + now = datetime.now(tz=timezone.utc) + cur.execute( + "INSERT INTO sites (name, description, created_at) VALUES (?, ?, ?)", + ("HQ", "Primary", int(now.timestamp())), + ) + site_id = cur.lastrowid + + cur.execute( + """ + INSERT INTO devices (guid, hostname, created_at, last_seen, ssl_key_fingerprint, status) + VALUES (?, ?, ?, ?, ?, 'active') + """, + ( + "22222222-2222-2222-2222-222222222222", + "approval-host", + int(now.timestamp()), + int(now.timestamp()), + "deadbeef", + ), + ) + cur.execute( + "INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)", + ("approval-host", site_id, int(now.timestamp())), + ) + + now_iso = now.isoformat() + cur.execute( + """ + INSERT INTO device_approvals ( + id, + approval_reference, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed, + enrollment_code_id, + status, + client_nonce, + server_nonce, + created_at, + updated_at, + approved_by_user_id, + agent_pubkey_der + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + "approval-http", + "REFHTTP", + None, + "approval-host", + "deadbeef", + "code-http", + "pending", + base64.b64encode(b"client").decode(), + base64.b64encode(b"server").decode(), + now_iso, + now_iso, + None, + b"pub", + ), + ) + conn.commit() + conn.close() + + resp = client.get("/api/admin/device-approvals") + assert resp.status_code == 200 + body = resp.get_json() + approvals = body.get("approvals", []) + assert any(a["id"] == "approval-http" for a in approvals) + record = next(a for a in approvals if a["id"] == "approval-http") + assert record.get("hostname_conflict", {}).get("fingerprint_match") is True +