mirror of
				https://github.com/bunny-lab-io/Borealis.git
				synced 2025-10-26 15:41:58 -06:00 
			
		
		
		
	Implement admin enrollment APIs
This commit is contained in:
		| @@ -18,6 +18,7 @@ __all__ = [ | ||||
|     "AccessTokenClaims", | ||||
|     "DeviceAuthContext", | ||||
|     "sanitize_service_context", | ||||
|     "normalize_guid", | ||||
| ] | ||||
|  | ||||
|  | ||||
| @@ -73,6 +74,12 @@ class DeviceGuid: | ||||
|         return self.value | ||||
|  | ||||
|  | ||||
| def normalize_guid(value: Optional[str]) -> str: | ||||
|     """Expose GUID normalization for administrative helpers.""" | ||||
|  | ||||
|     return _normalize_guid(value) | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True, slots=True) | ||||
| class DeviceFingerprint: | ||||
|     """Normalized TLS key fingerprint associated with a device.""" | ||||
|   | ||||
							
								
								
									
										206
									
								
								Data/Engine/domain/enrollment_admin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								Data/Engine/domain/enrollment_admin.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,206 @@ | ||||
| """Administrative enrollment domain models.""" | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| from dataclasses import dataclass | ||||
| from datetime import datetime, timezone | ||||
| from typing import Any, Mapping, Optional | ||||
|  | ||||
| from Data.Engine.domain.device_auth import DeviceGuid, normalize_guid | ||||
|  | ||||
| __all__ = [ | ||||
|     "EnrollmentCodeRecord", | ||||
|     "DeviceApprovalRecord", | ||||
|     "HostnameConflict", | ||||
| ] | ||||
|  | ||||
|  | ||||
| def _parse_iso8601(value: Optional[str]) -> Optional[datetime]: | ||||
|     if not value: | ||||
|         return None | ||||
|     raw = str(value).strip() | ||||
|     if not raw: | ||||
|         return None | ||||
|     try: | ||||
|         dt = datetime.fromisoformat(raw) | ||||
|     except Exception as exc:  # pragma: no cover - defensive parsing | ||||
|         raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc | ||||
|     if dt.tzinfo is None: | ||||
|         return dt.replace(tzinfo=timezone.utc) | ||||
|     return dt.astimezone(timezone.utc) | ||||
|  | ||||
|  | ||||
| def _isoformat(value: Optional[datetime]) -> Optional[str]: | ||||
|     if value is None: | ||||
|         return None | ||||
|     if value.tzinfo is None: | ||||
|         value = value.replace(tzinfo=timezone.utc) | ||||
|     return value.astimezone(timezone.utc).isoformat() | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True, slots=True) | ||||
| class EnrollmentCodeRecord: | ||||
|     """Installer code metadata exposed to administrative clients.""" | ||||
|  | ||||
|     record_id: str | ||||
|     code: str | ||||
|     expires_at: datetime | ||||
|     max_uses: int | ||||
|     use_count: int | ||||
|     created_by_user_id: Optional[str] | ||||
|     used_at: Optional[datetime] | ||||
|     used_by_guid: Optional[DeviceGuid] | ||||
|     last_used_at: Optional[datetime] | ||||
|  | ||||
|     @classmethod | ||||
|     def from_row(cls, row: Mapping[str, Any]) -> "EnrollmentCodeRecord": | ||||
|         record_id = str(row.get("id") or "").strip() | ||||
|         code = str(row.get("code") or "").strip() | ||||
|         if not record_id or not code: | ||||
|             raise ValueError("invalid enrollment install code record") | ||||
|  | ||||
|         used_by = row.get("used_by_guid") | ||||
|         used_by_guid = DeviceGuid(str(used_by)) if used_by else None | ||||
|  | ||||
|         return cls( | ||||
|             record_id=record_id, | ||||
|             code=code, | ||||
|             expires_at=_parse_iso8601(row.get("expires_at")) or datetime.now(tz=timezone.utc), | ||||
|             max_uses=int(row.get("max_uses") or 1), | ||||
|             use_count=int(row.get("use_count") or 0), | ||||
|             created_by_user_id=str(row.get("created_by_user_id") or "").strip() or None, | ||||
|             used_at=_parse_iso8601(row.get("used_at")), | ||||
|             used_by_guid=used_by_guid, | ||||
|             last_used_at=_parse_iso8601(row.get("last_used_at")), | ||||
|         ) | ||||
|  | ||||
|     def status(self, *, now: Optional[datetime] = None) -> str: | ||||
|         reference = now or datetime.now(tz=timezone.utc) | ||||
|         if self.use_count >= self.max_uses: | ||||
|             return "used" | ||||
|         if self.expires_at <= reference: | ||||
|             return "expired" | ||||
|         return "active" | ||||
|  | ||||
|     def to_dict(self) -> dict[str, Any]: | ||||
|         return { | ||||
|             "id": self.record_id, | ||||
|             "code": self.code, | ||||
|             "expires_at": _isoformat(self.expires_at), | ||||
|             "max_uses": self.max_uses, | ||||
|             "use_count": self.use_count, | ||||
|             "created_by_user_id": self.created_by_user_id, | ||||
|             "used_at": _isoformat(self.used_at), | ||||
|             "used_by_guid": self.used_by_guid.value if self.used_by_guid else None, | ||||
|             "last_used_at": _isoformat(self.last_used_at), | ||||
|             "status": self.status(), | ||||
|         } | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True, slots=True) | ||||
| class HostnameConflict: | ||||
|     """Existing device details colliding with a pending approval.""" | ||||
|  | ||||
|     guid: Optional[str] | ||||
|     ssl_key_fingerprint: Optional[str] | ||||
|     site_id: Optional[int] | ||||
|     site_name: str | ||||
|     fingerprint_match: bool | ||||
|     requires_prompt: bool | ||||
|  | ||||
|     def to_dict(self) -> dict[str, Any]: | ||||
|         return { | ||||
|             "guid": self.guid, | ||||
|             "ssl_key_fingerprint": self.ssl_key_fingerprint, | ||||
|             "site_id": self.site_id, | ||||
|             "site_name": self.site_name, | ||||
|             "fingerprint_match": self.fingerprint_match, | ||||
|             "requires_prompt": self.requires_prompt, | ||||
|         } | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True, slots=True) | ||||
| class DeviceApprovalRecord: | ||||
|     """Administrative projection of a device approval entry.""" | ||||
|  | ||||
|     record_id: str | ||||
|     reference: str | ||||
|     status: str | ||||
|     claimed_hostname: str | ||||
|     claimed_fingerprint: str | ||||
|     created_at: datetime | ||||
|     updated_at: datetime | ||||
|     enrollment_code_id: Optional[str] | ||||
|     guid: Optional[str] | ||||
|     approved_by_user_id: Optional[str] | ||||
|     approved_by_username: Optional[str] | ||||
|     client_nonce: str | ||||
|     server_nonce: str | ||||
|     hostname_conflict: Optional[HostnameConflict] | ||||
|     alternate_hostname: Optional[str] | ||||
|     conflict_requires_prompt: bool | ||||
|     fingerprint_match: bool | ||||
|  | ||||
|     @classmethod | ||||
|     def from_row( | ||||
|         cls, | ||||
|         row: Mapping[str, Any], | ||||
|         *, | ||||
|         conflict: Optional[HostnameConflict] = None, | ||||
|         alternate_hostname: Optional[str] = None, | ||||
|         fingerprint_match: bool = False, | ||||
|         requires_prompt: bool = False, | ||||
|     ) -> "DeviceApprovalRecord": | ||||
|         record_id = str(row.get("id") or "").strip() | ||||
|         reference = str(row.get("approval_reference") or "").strip() | ||||
|         hostname = str(row.get("hostname_claimed") or "").strip() | ||||
|         fingerprint = str(row.get("ssl_key_fingerprint_claimed") or "").strip().lower() | ||||
|         if not record_id or not reference or not hostname or not fingerprint: | ||||
|             raise ValueError("invalid device approval record") | ||||
|  | ||||
|         guid_raw = normalize_guid(row.get("guid")) or None | ||||
|  | ||||
|         return cls( | ||||
|             record_id=record_id, | ||||
|             reference=reference, | ||||
|             status=str(row.get("status") or "pending").strip().lower(), | ||||
|             claimed_hostname=hostname, | ||||
|             claimed_fingerprint=fingerprint, | ||||
|             created_at=_parse_iso8601(row.get("created_at")) or datetime.now(tz=timezone.utc), | ||||
|             updated_at=_parse_iso8601(row.get("updated_at")) or datetime.now(tz=timezone.utc), | ||||
|             enrollment_code_id=str(row.get("enrollment_code_id") or "").strip() or None, | ||||
|             guid=guid_raw, | ||||
|             approved_by_user_id=str(row.get("approved_by_user_id") or "").strip() or None, | ||||
|             approved_by_username=str(row.get("approved_by_username") or "").strip() or None, | ||||
|             client_nonce=str(row.get("client_nonce") or "").strip(), | ||||
|             server_nonce=str(row.get("server_nonce") or "").strip(), | ||||
|             hostname_conflict=conflict, | ||||
|             alternate_hostname=alternate_hostname, | ||||
|             conflict_requires_prompt=requires_prompt, | ||||
|             fingerprint_match=fingerprint_match, | ||||
|         ) | ||||
|  | ||||
|     def to_dict(self) -> dict[str, Any]: | ||||
|         payload: dict[str, Any] = { | ||||
|             "id": self.record_id, | ||||
|             "approval_reference": self.reference, | ||||
|             "status": self.status, | ||||
|             "hostname_claimed": self.claimed_hostname, | ||||
|             "ssl_key_fingerprint_claimed": self.claimed_fingerprint, | ||||
|             "created_at": _isoformat(self.created_at), | ||||
|             "updated_at": _isoformat(self.updated_at), | ||||
|             "enrollment_code_id": self.enrollment_code_id, | ||||
|             "guid": self.guid, | ||||
|             "approved_by_user_id": self.approved_by_user_id, | ||||
|             "approved_by_username": self.approved_by_username, | ||||
|             "client_nonce": self.client_nonce, | ||||
|             "server_nonce": self.server_nonce, | ||||
|             "conflict_requires_prompt": self.conflict_requires_prompt, | ||||
|             "fingerprint_match": self.fingerprint_match, | ||||
|         } | ||||
|         if self.hostname_conflict is not None: | ||||
|             payload["hostname_conflict"] = self.hostname_conflict.to_dict() | ||||
|         if self.alternate_hostname: | ||||
|             payload["alternate_hostname"] = self.alternate_hostname | ||||
|         return payload | ||||
|  | ||||
| @@ -1,8 +1,8 @@ | ||||
| """Administrative HTTP interface placeholders for the Engine.""" | ||||
| """Administrative HTTP endpoints for the Borealis Engine.""" | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| from flask import Blueprint, Flask | ||||
| from flask import Blueprint, Flask, current_app, jsonify, request, session | ||||
|  | ||||
| from Data.Engine.services.container import EngineServiceContainer | ||||
|  | ||||
| @@ -11,13 +11,106 @@ blueprint = Blueprint("engine_admin", __name__, url_prefix="/api/admin") | ||||
|  | ||||
|  | ||||
| def register(app: Flask, _services: EngineServiceContainer) -> None: | ||||
|     """Attach administrative routes to *app*. | ||||
|  | ||||
|     Concrete endpoints will be migrated in subsequent phases. | ||||
|     """ | ||||
|     """Attach administrative routes to *app*.""" | ||||
|  | ||||
|     if "engine_admin" not in app.blueprints: | ||||
|         app.register_blueprint(blueprint) | ||||
|  | ||||
|  | ||||
| def _services() -> EngineServiceContainer: | ||||
|     services = current_app.extensions.get("engine_services") | ||||
|     if services is None:  # pragma: no cover - defensive | ||||
|         raise RuntimeError("engine services not initialized") | ||||
|     return services | ||||
|  | ||||
|  | ||||
| def _admin_service(): | ||||
|     return _services().enrollment_admin_service | ||||
|  | ||||
|  | ||||
| def _require_admin(): | ||||
|     username = session.get("username") | ||||
|     role = (session.get("role") or "").strip().lower() | ||||
|     if not isinstance(username, str) or not username: | ||||
|         return jsonify({"error": "not_authenticated"}), 401 | ||||
|     if role != "admin": | ||||
|         return jsonify({"error": "forbidden"}), 403 | ||||
|     return None | ||||
|  | ||||
|  | ||||
| @blueprint.route("/enrollment-codes", methods=["GET"]) | ||||
| def list_enrollment_codes() -> object: | ||||
|     guard = _require_admin() | ||||
|     if guard: | ||||
|         return guard | ||||
|  | ||||
|     status = request.args.get("status") | ||||
|     records = _admin_service().list_install_codes(status=status) | ||||
|     return jsonify({"codes": [record.to_dict() for record in records]}) | ||||
|  | ||||
|  | ||||
| @blueprint.route("/enrollment-codes", methods=["POST"]) | ||||
| def create_enrollment_code() -> object: | ||||
|     guard = _require_admin() | ||||
|     if guard: | ||||
|         return guard | ||||
|  | ||||
|     payload = request.get_json(silent=True) or {} | ||||
|  | ||||
|     ttl_value = payload.get("ttl_hours") | ||||
|     if ttl_value is None: | ||||
|         ttl_value = payload.get("ttl") or 1 | ||||
|     try: | ||||
|         ttl_hours = int(ttl_value) | ||||
|     except (TypeError, ValueError): | ||||
|         ttl_hours = 1 | ||||
|  | ||||
|     max_uses_value = payload.get("max_uses") | ||||
|     if max_uses_value is None: | ||||
|         max_uses_value = payload.get("allowed_uses", 2) | ||||
|     try: | ||||
|         max_uses = int(max_uses_value) | ||||
|     except (TypeError, ValueError): | ||||
|         max_uses = 2 | ||||
|  | ||||
|     creator = session.get("username") if isinstance(session.get("username"), str) else None | ||||
|  | ||||
|     try: | ||||
|         record = _admin_service().create_install_code( | ||||
|             ttl_hours=ttl_hours, | ||||
|             max_uses=max_uses, | ||||
|             created_by=creator, | ||||
|         ) | ||||
|     except ValueError as exc: | ||||
|         if str(exc) == "invalid_ttl": | ||||
|             return jsonify({"error": "invalid_ttl"}), 400 | ||||
|         raise | ||||
|  | ||||
|     response = jsonify(record.to_dict()) | ||||
|     response.status_code = 201 | ||||
|     return response | ||||
|  | ||||
|  | ||||
| @blueprint.route("/enrollment-codes/<code_id>", methods=["DELETE"]) | ||||
| def delete_enrollment_code(code_id: str) -> object: | ||||
|     guard = _require_admin() | ||||
|     if guard: | ||||
|         return guard | ||||
|  | ||||
|     if not _admin_service().delete_install_code(code_id): | ||||
|         return jsonify({"error": "not_found"}), 404 | ||||
|     return jsonify({"status": "deleted"}) | ||||
|  | ||||
|  | ||||
| @blueprint.route("/device-approvals", methods=["GET"]) | ||||
| def list_device_approvals() -> object: | ||||
|     guard = _require_admin() | ||||
|     if guard: | ||||
|         return guard | ||||
|  | ||||
|     status = request.args.get("status") | ||||
|     records = _admin_service().list_device_approvals(status=status) | ||||
|     return jsonify({"approvals": [record.to_dict() for record in records]}) | ||||
|  | ||||
|  | ||||
| __all__ = ["register", "blueprint"] | ||||
|   | ||||
| @@ -5,14 +5,19 @@ from __future__ import annotations | ||||
| import logging | ||||
| from contextlib import closing | ||||
| from datetime import datetime, timezone | ||||
| from typing import Optional | ||||
| from typing import Any, List, Optional, Tuple | ||||
|  | ||||
| from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid | ||||
| from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid, normalize_guid | ||||
| from Data.Engine.domain.device_enrollment import ( | ||||
|     EnrollmentApproval, | ||||
|     EnrollmentApprovalStatus, | ||||
|     EnrollmentCode, | ||||
| ) | ||||
| from Data.Engine.domain.enrollment_admin import ( | ||||
|     DeviceApprovalRecord, | ||||
|     EnrollmentCodeRecord, | ||||
|     HostnameConflict, | ||||
| ) | ||||
| from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory | ||||
|  | ||||
| __all__ = ["SQLiteEnrollmentRepository"] | ||||
| @@ -122,6 +127,158 @@ class SQLiteEnrollmentRepository: | ||||
|             self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc) | ||||
|             return None | ||||
|  | ||||
|     def list_install_codes( | ||||
|         self, | ||||
|         *, | ||||
|         status: Optional[str] = None, | ||||
|         now: Optional[datetime] = None, | ||||
|     ) -> List[EnrollmentCodeRecord]: | ||||
|         reference = now or datetime.now(tz=timezone.utc) | ||||
|         status_filter = (status or "").strip().lower() | ||||
|         params: List[str] = [] | ||||
|  | ||||
|         sql = """ | ||||
|             SELECT id, | ||||
|                    code, | ||||
|                    expires_at, | ||||
|                    created_by_user_id, | ||||
|                    used_at, | ||||
|                    used_by_guid, | ||||
|                    max_uses, | ||||
|                    use_count, | ||||
|                    last_used_at | ||||
|               FROM enrollment_install_codes | ||||
|         """ | ||||
|  | ||||
|         if status_filter in {"active", "expired", "used"}: | ||||
|             sql += " WHERE " | ||||
|             if status_filter == "active": | ||||
|                 sql += "use_count < max_uses AND expires_at > ?" | ||||
|                 params.append(self._isoformat(reference)) | ||||
|             elif status_filter == "expired": | ||||
|                 sql += "use_count < max_uses AND expires_at <= ?" | ||||
|                 params.append(self._isoformat(reference)) | ||||
|             else:  # used | ||||
|                 sql += "use_count >= max_uses" | ||||
|  | ||||
|         sql += " ORDER BY expires_at ASC" | ||||
|  | ||||
|         rows: List[EnrollmentCodeRecord] = [] | ||||
|         with closing(self._connections()) as conn: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute(sql, params) | ||||
|             for raw in cur.fetchall(): | ||||
|                 record = { | ||||
|                     "id": raw[0], | ||||
|                     "code": raw[1], | ||||
|                     "expires_at": raw[2], | ||||
|                     "created_by_user_id": raw[3], | ||||
|                     "used_at": raw[4], | ||||
|                     "used_by_guid": raw[5], | ||||
|                     "max_uses": raw[6], | ||||
|                     "use_count": raw[7], | ||||
|                     "last_used_at": raw[8], | ||||
|                 } | ||||
|                 try: | ||||
|                     rows.append(EnrollmentCodeRecord.from_row(record)) | ||||
|                 except Exception as exc:  # pragma: no cover - defensive logging | ||||
|                     self._log.warning("invalid enrollment install code row id=%s: %s", record.get("id"), exc) | ||||
|         return rows | ||||
|  | ||||
|     def get_install_code_record(self, record_id: str) -> Optional[EnrollmentCodeRecord]: | ||||
|         identifier = (record_id or "").strip() | ||||
|         if not identifier: | ||||
|             return None | ||||
|  | ||||
|         with closing(self._connections()) as conn: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 """ | ||||
|                 SELECT id, | ||||
|                        code, | ||||
|                        expires_at, | ||||
|                        created_by_user_id, | ||||
|                        used_at, | ||||
|                        used_by_guid, | ||||
|                        max_uses, | ||||
|                        use_count, | ||||
|                        last_used_at | ||||
|                   FROM enrollment_install_codes | ||||
|                  WHERE id = ? | ||||
|                 """, | ||||
|                 (identifier,), | ||||
|             ) | ||||
|             row = cur.fetchone() | ||||
|  | ||||
|         if not row: | ||||
|             return None | ||||
|  | ||||
|         payload = { | ||||
|             "id": row[0], | ||||
|             "code": row[1], | ||||
|             "expires_at": row[2], | ||||
|             "created_by_user_id": row[3], | ||||
|             "used_at": row[4], | ||||
|             "used_by_guid": row[5], | ||||
|             "max_uses": row[6], | ||||
|             "use_count": row[7], | ||||
|             "last_used_at": row[8], | ||||
|         } | ||||
|  | ||||
|         try: | ||||
|             return EnrollmentCodeRecord.from_row(payload) | ||||
|         except Exception as exc:  # pragma: no cover - defensive logging | ||||
|             self._log.warning("invalid enrollment install code record id=%s: %s", identifier, exc) | ||||
|             return None | ||||
|  | ||||
|     def insert_install_code( | ||||
|         self, | ||||
|         *, | ||||
|         record_id: str, | ||||
|         code: str, | ||||
|         expires_at: datetime, | ||||
|         created_by: Optional[str], | ||||
|         max_uses: int, | ||||
|     ) -> EnrollmentCodeRecord: | ||||
|         expires_iso = self._isoformat(expires_at) | ||||
|  | ||||
|         with closing(self._connections()) as conn: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 """ | ||||
|                 INSERT INTO enrollment_install_codes ( | ||||
|                     id, | ||||
|                     code, | ||||
|                     expires_at, | ||||
|                     created_by_user_id, | ||||
|                     max_uses, | ||||
|                     use_count | ||||
|                 ) VALUES (?, ?, ?, ?, ?, 0) | ||||
|                 """, | ||||
|                 (record_id, code, expires_iso, created_by, max_uses), | ||||
|             ) | ||||
|             conn.commit() | ||||
|  | ||||
|         record = self.get_install_code_record(record_id) | ||||
|         if record is None: | ||||
|             raise RuntimeError("failed to load install code after insert") | ||||
|         return record | ||||
|  | ||||
|     def delete_install_code_if_unused(self, record_id: str) -> bool: | ||||
|         identifier = (record_id or "").strip() | ||||
|         if not identifier: | ||||
|             return False | ||||
|  | ||||
|         with closing(self._connections()) as conn: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 "DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0", | ||||
|                 (identifier,), | ||||
|             ) | ||||
|             deleted = cur.rowcount > 0 | ||||
|             conn.commit() | ||||
|             return deleted | ||||
|  | ||||
|     def update_install_code_usage( | ||||
|         self, | ||||
|         record_id: str, | ||||
| @@ -165,6 +322,100 @@ class SQLiteEnrollmentRepository: | ||||
|     # ------------------------------------------------------------------ | ||||
|     # Device approvals | ||||
|     # ------------------------------------------------------------------ | ||||
|     def list_device_approvals( | ||||
|         self, | ||||
|         *, | ||||
|         status: Optional[str] = None, | ||||
|     ) -> List[DeviceApprovalRecord]: | ||||
|         status_filter = (status or "").strip().lower() | ||||
|         params: List[str] = [] | ||||
|  | ||||
|         sql = """ | ||||
|             SELECT | ||||
|                 da.id, | ||||
|                 da.approval_reference, | ||||
|                 da.guid, | ||||
|                 da.hostname_claimed, | ||||
|                 da.ssl_key_fingerprint_claimed, | ||||
|                 da.enrollment_code_id, | ||||
|                 da.status, | ||||
|                 da.client_nonce, | ||||
|                 da.server_nonce, | ||||
|                 da.created_at, | ||||
|                 da.updated_at, | ||||
|                 da.approved_by_user_id, | ||||
|                 u.username AS approved_by_username | ||||
|               FROM device_approvals AS da | ||||
|          LEFT JOIN users AS u | ||||
|                 ON ( | ||||
|                     CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT) | ||||
|                     OR LOWER(da.approved_by_user_id) = LOWER(u.username) | ||||
|                 ) | ||||
|         """ | ||||
|  | ||||
|         if status_filter and status_filter not in {"all", "*"}: | ||||
|             sql += " WHERE LOWER(da.status) = ?" | ||||
|             params.append(status_filter) | ||||
|  | ||||
|         sql += " ORDER BY da.created_at ASC" | ||||
|  | ||||
|         approvals: List[DeviceApprovalRecord] = [] | ||||
|         with closing(self._connections()) as conn: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute(sql, params) | ||||
|             rows = cur.fetchall() | ||||
|  | ||||
|             for raw in rows: | ||||
|                 record = { | ||||
|                     "id": raw[0], | ||||
|                     "approval_reference": raw[1], | ||||
|                     "guid": raw[2], | ||||
|                     "hostname_claimed": raw[3], | ||||
|                     "ssl_key_fingerprint_claimed": raw[4], | ||||
|                     "enrollment_code_id": raw[5], | ||||
|                     "status": raw[6], | ||||
|                     "client_nonce": raw[7], | ||||
|                     "server_nonce": raw[8], | ||||
|                     "created_at": raw[9], | ||||
|                     "updated_at": raw[10], | ||||
|                     "approved_by_user_id": raw[11], | ||||
|                     "approved_by_username": raw[12], | ||||
|                 } | ||||
|  | ||||
|                 conflict, fingerprint_match, requires_prompt = self._compute_hostname_conflict( | ||||
|                     conn, | ||||
|                     record.get("hostname_claimed"), | ||||
|                     record.get("guid"), | ||||
|                     record.get("ssl_key_fingerprint_claimed") or "", | ||||
|                 ) | ||||
|  | ||||
|                 alternate = None | ||||
|                 if conflict and requires_prompt: | ||||
|                     alternate = self._suggest_alternate_hostname( | ||||
|                         conn, | ||||
|                         record.get("hostname_claimed"), | ||||
|                         record.get("guid"), | ||||
|                     ) | ||||
|  | ||||
|                 try: | ||||
|                     approvals.append( | ||||
|                         DeviceApprovalRecord.from_row( | ||||
|                             record, | ||||
|                             conflict=conflict, | ||||
|                             alternate_hostname=alternate, | ||||
|                             fingerprint_match=fingerprint_match, | ||||
|                             requires_prompt=requires_prompt, | ||||
|                         ) | ||||
|                     ) | ||||
|                 except Exception as exc:  # pragma: no cover - defensive logging | ||||
|                     self._log.warning( | ||||
|                         "invalid device approval record id=%s: %s", | ||||
|                         record.get("id"), | ||||
|                         exc, | ||||
|                     ) | ||||
|  | ||||
|         return approvals | ||||
|  | ||||
|     def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: | ||||
|         """Load a device approval using its operator-visible reference.""" | ||||
|  | ||||
| @@ -376,6 +627,98 @@ class SQLiteEnrollmentRepository: | ||||
|             ) | ||||
|             return None | ||||
|  | ||||
|     def _compute_hostname_conflict( | ||||
|         self, | ||||
|         conn, | ||||
|         hostname: Optional[str], | ||||
|         pending_guid: Optional[str], | ||||
|         claimed_fp: str, | ||||
|     ) -> Tuple[Optional[HostnameConflict], bool, bool]: | ||||
|         normalized_host = (hostname or "").strip() | ||||
|         if not normalized_host: | ||||
|             return None, False, False | ||||
|  | ||||
|         try: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 """ | ||||
|                 SELECT d.guid, | ||||
|                        d.ssl_key_fingerprint, | ||||
|                        ds.site_id, | ||||
|                        s.name | ||||
|                   FROM devices AS d | ||||
|              LEFT JOIN device_sites AS ds ON ds.device_hostname = d.hostname | ||||
|              LEFT JOIN sites AS s ON s.id = ds.site_id | ||||
|                  WHERE d.hostname = ? | ||||
|                 """, | ||||
|                 (normalized_host,), | ||||
|             ) | ||||
|             row = cur.fetchone() | ||||
|         except Exception as exc:  # pragma: no cover - defensive logging | ||||
|             self._log.warning("failed to inspect hostname conflict for %s: %s", normalized_host, exc) | ||||
|             return None, False, False | ||||
|  | ||||
|         if not row: | ||||
|             return None, False, False | ||||
|  | ||||
|         existing_guid = normalize_guid(row[0]) | ||||
|         pending_norm = normalize_guid(pending_guid) | ||||
|         if existing_guid and pending_norm and existing_guid == pending_norm: | ||||
|             return None, False, False | ||||
|  | ||||
|         stored_fp = (row[1] or "").strip().lower() | ||||
|         claimed_fp_normalized = (claimed_fp or "").strip().lower() | ||||
|         fingerprint_match = bool(stored_fp and claimed_fp_normalized and stored_fp == claimed_fp_normalized) | ||||
|  | ||||
|         site_id = None | ||||
|         if row[2] is not None: | ||||
|             try: | ||||
|                 site_id = int(row[2]) | ||||
|             except (TypeError, ValueError):  # pragma: no cover - defensive | ||||
|                 site_id = None | ||||
|  | ||||
|         site_name = str(row[3] or "").strip() | ||||
|         requires_prompt = not fingerprint_match | ||||
|  | ||||
|         conflict = HostnameConflict( | ||||
|             guid=existing_guid or None, | ||||
|             ssl_key_fingerprint=stored_fp or None, | ||||
|             site_id=site_id, | ||||
|             site_name=site_name, | ||||
|             fingerprint_match=fingerprint_match, | ||||
|             requires_prompt=requires_prompt, | ||||
|         ) | ||||
|  | ||||
|         return conflict, fingerprint_match, requires_prompt | ||||
|  | ||||
|     def _suggest_alternate_hostname( | ||||
|         self, | ||||
|         conn, | ||||
|         hostname: Optional[str], | ||||
|         pending_guid: Optional[str], | ||||
|     ) -> Optional[str]: | ||||
|         base = (hostname or "").strip() | ||||
|         if not base: | ||||
|             return None | ||||
|         base = base[:253] | ||||
|         candidate = base | ||||
|         pending_norm = normalize_guid(pending_guid) | ||||
|         suffix = 1 | ||||
|  | ||||
|         cur = conn.cursor() | ||||
|         while True: | ||||
|             cur.execute("SELECT guid FROM devices WHERE hostname = ?", (candidate,)) | ||||
|             row = cur.fetchone() | ||||
|             if not row: | ||||
|                 return candidate | ||||
|             existing_guid = normalize_guid(row[0]) | ||||
|             if pending_norm and existing_guid == pending_norm: | ||||
|                 return candidate | ||||
|             candidate = f"{base}-{suffix}" | ||||
|             suffix += 1 | ||||
|             if suffix > 50: | ||||
|                 return pending_norm or candidate | ||||
|  | ||||
|     @staticmethod | ||||
|     def _isoformat(value: datetime) -> str: | ||||
|         if value.tzinfo is None: | ||||
|   | ||||
| @@ -31,6 +31,9 @@ def apply_all(conn: sqlite3.Connection) -> None: | ||||
|     _ensure_refresh_token_table(conn) | ||||
|     _ensure_install_code_table(conn) | ||||
|     _ensure_device_approval_table(conn) | ||||
|     _ensure_device_list_views_table(conn) | ||||
|     _ensure_sites_tables(conn) | ||||
|     _ensure_credentials_table(conn) | ||||
|     _ensure_github_token_table(conn) | ||||
|     _ensure_scheduled_jobs_table(conn) | ||||
|     _ensure_scheduled_job_run_tables(conn) | ||||
| @@ -233,6 +236,73 @@ def _ensure_device_approval_table(conn: sqlite3.Connection) -> None: | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _ensure_device_list_views_table(conn: sqlite3.Connection) -> None: | ||||
|     cur = conn.cursor() | ||||
|     cur.execute( | ||||
|         """ | ||||
|         CREATE TABLE IF NOT EXISTS device_list_views ( | ||||
|             id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||
|             name TEXT UNIQUE NOT NULL, | ||||
|             columns_json TEXT NOT NULL, | ||||
|             filters_json TEXT, | ||||
|             created_at INTEGER, | ||||
|             updated_at INTEGER | ||||
|         ) | ||||
|         """ | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _ensure_sites_tables(conn: sqlite3.Connection) -> None: | ||||
|     cur = conn.cursor() | ||||
|     cur.execute( | ||||
|         """ | ||||
|         CREATE TABLE IF NOT EXISTS sites ( | ||||
|             id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||
|             name TEXT UNIQUE NOT NULL, | ||||
|             description TEXT, | ||||
|             created_at INTEGER | ||||
|         ) | ||||
|         """ | ||||
|     ) | ||||
|     cur.execute( | ||||
|         """ | ||||
|         CREATE TABLE IF NOT EXISTS device_sites ( | ||||
|             device_hostname TEXT UNIQUE NOT NULL, | ||||
|             site_id INTEGER NOT NULL, | ||||
|             assigned_at INTEGER, | ||||
|             FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE CASCADE | ||||
|         ) | ||||
|         """ | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _ensure_credentials_table(conn: sqlite3.Connection) -> None: | ||||
|     cur = conn.cursor() | ||||
|     cur.execute( | ||||
|         """ | ||||
|         CREATE TABLE IF NOT EXISTS credentials ( | ||||
|             id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||
|             name TEXT NOT NULL UNIQUE, | ||||
|             description TEXT, | ||||
|             site_id INTEGER, | ||||
|             credential_type TEXT NOT NULL DEFAULT 'machine', | ||||
|             connection_type TEXT NOT NULL DEFAULT 'ssh', | ||||
|             username TEXT, | ||||
|             password_encrypted BLOB, | ||||
|             private_key_encrypted BLOB, | ||||
|             private_key_passphrase_encrypted BLOB, | ||||
|             become_method TEXT, | ||||
|             become_username TEXT, | ||||
|             become_password_encrypted BLOB, | ||||
|             metadata_json TEXT, | ||||
|             created_at INTEGER NOT NULL, | ||||
|             updated_at INTEGER NOT NULL, | ||||
|             FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE SET NULL | ||||
|         ) | ||||
|         """ | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _ensure_github_token_table(conn: sqlite3.Connection) -> None: | ||||
|     cur = conn.cursor() | ||||
|     cur.execute( | ||||
|   | ||||
| @@ -71,6 +71,57 @@ class SQLiteUserRepository: | ||||
|         finally: | ||||
|             conn.close() | ||||
|  | ||||
|     def resolve_identifier(self, username: str) -> Optional[str]: | ||||
|         normalized = (username or "").strip() | ||||
|         if not normalized: | ||||
|             return None | ||||
|  | ||||
|         conn = self._connection_factory() | ||||
|         try: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 "SELECT id FROM users WHERE LOWER(username) = LOWER(?)", | ||||
|                 (normalized,), | ||||
|             ) | ||||
|             row = cur.fetchone() | ||||
|             if not row: | ||||
|                 return None | ||||
|             return str(row[0]) if row[0] is not None else None | ||||
|         except sqlite3.Error as exc:  # pragma: no cover - defensive | ||||
|             self._log.error("failed to resolve identifier for %s: %s", username, exc) | ||||
|             return None | ||||
|         finally: | ||||
|             conn.close() | ||||
|  | ||||
|     def username_for_identifier(self, identifier: str) -> Optional[str]: | ||||
|         token = (identifier or "").strip() | ||||
|         if not token: | ||||
|             return None | ||||
|  | ||||
|         conn = self._connection_factory() | ||||
|         try: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 """ | ||||
|                 SELECT username | ||||
|                   FROM users | ||||
|                  WHERE CAST(id AS TEXT) = ? | ||||
|                     OR LOWER(username) = LOWER(?) | ||||
|                  LIMIT 1 | ||||
|                 """, | ||||
|                 (token, token), | ||||
|             ) | ||||
|             row = cur.fetchone() | ||||
|             if not row: | ||||
|                 return None | ||||
|             username = str(row[0] or "").strip() | ||||
|             return username or None | ||||
|         except sqlite3.Error as exc:  # pragma: no cover - defensive | ||||
|             self._log.error("failed to resolve username for %s: %s", identifier, exc) | ||||
|             return None | ||||
|         finally: | ||||
|             conn.close() | ||||
|  | ||||
|     def list_accounts(self) -> list[OperatorAccount]: | ||||
|         conn = self._connection_factory() | ||||
|         try: | ||||
|   | ||||
| @@ -23,6 +23,7 @@ __all__ = [ | ||||
|     "SchedulerService", | ||||
|     "GitHubService", | ||||
|     "GitHubTokenPayload", | ||||
|     "EnrollmentAdminService", | ||||
| ] | ||||
|  | ||||
| _LAZY_TARGETS: Dict[str, Tuple[str, str]] = { | ||||
| @@ -43,6 +44,10 @@ _LAZY_TARGETS: Dict[str, Tuple[str, str]] = { | ||||
|     "SchedulerService": ("Data.Engine.services.jobs.scheduler_service", "SchedulerService"), | ||||
|     "GitHubService": ("Data.Engine.services.github.github_service", "GitHubService"), | ||||
|     "GitHubTokenPayload": ("Data.Engine.services.github.github_service", "GitHubTokenPayload"), | ||||
|     "EnrollmentAdminService": ( | ||||
|         "Data.Engine.services.enrollment.admin_service", | ||||
|         "EnrollmentAdminService", | ||||
|     ), | ||||
| } | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -30,6 +30,7 @@ from Data.Engine.services.auth import ( | ||||
| ) | ||||
| from Data.Engine.services.crypto.signing import ScriptSigner, load_signer | ||||
| from Data.Engine.services.enrollment import EnrollmentService | ||||
| from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService | ||||
| from Data.Engine.services.enrollment.nonce_cache import NonceCache | ||||
| from Data.Engine.services.github import GitHubService | ||||
| from Data.Engine.services.jobs import SchedulerService | ||||
| @@ -44,6 +45,7 @@ class EngineServiceContainer: | ||||
|     device_auth: DeviceAuthService | ||||
|     token_service: TokenService | ||||
|     enrollment_service: EnrollmentService | ||||
|     enrollment_admin_service: EnrollmentAdminService | ||||
|     jwt_service: JWTService | ||||
|     dpop_validator: DPoPValidator | ||||
|     agent_realtime: AgentRealtimeService | ||||
| @@ -93,6 +95,12 @@ def build_service_container( | ||||
|         logger=log.getChild("enrollment"), | ||||
|     ) | ||||
|  | ||||
|     enrollment_admin_service = EnrollmentAdminService( | ||||
|         repository=enrollment_repo, | ||||
|         user_repository=user_repo, | ||||
|         logger=log.getChild("enrollment_admin"), | ||||
|     ) | ||||
|  | ||||
|     device_auth = DeviceAuthService( | ||||
|         device_repository=device_repo, | ||||
|         jwt_service=jwt_service, | ||||
| @@ -139,6 +147,7 @@ def build_service_container( | ||||
|         device_auth=device_auth, | ||||
|         token_service=token_service, | ||||
|         enrollment_service=enrollment_service, | ||||
|         enrollment_admin_service=enrollment_admin_service, | ||||
|         jwt_service=jwt_service, | ||||
|         dpop_validator=dpop_validator, | ||||
|         agent_realtime=agent_realtime, | ||||
|   | ||||
| @@ -2,20 +2,54 @@ | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| from .enrollment_service import ( | ||||
|     EnrollmentRequestResult, | ||||
|     EnrollmentService, | ||||
|     EnrollmentStatus, | ||||
|     EnrollmentTokenBundle, | ||||
|     PollingResult, | ||||
| ) | ||||
| from Data.Engine.domain.device_enrollment import EnrollmentValidationError | ||||
| from importlib import import_module | ||||
| from typing import Any | ||||
|  | ||||
| __all__ = [ | ||||
|     "EnrollmentRequestResult", | ||||
|     "EnrollmentService", | ||||
|     "EnrollmentRequestResult", | ||||
|     "EnrollmentStatus", | ||||
|     "EnrollmentTokenBundle", | ||||
|     "EnrollmentValidationError", | ||||
|     "PollingResult", | ||||
|     "EnrollmentValidationError", | ||||
|     "EnrollmentAdminService", | ||||
| ] | ||||
|  | ||||
| _LAZY: dict[str, tuple[str, str]] = { | ||||
|     "EnrollmentService": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentService"), | ||||
|     "EnrollmentRequestResult": ( | ||||
|         "Data.Engine.services.enrollment.enrollment_service", | ||||
|         "EnrollmentRequestResult", | ||||
|     ), | ||||
|     "EnrollmentStatus": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentStatus"), | ||||
|     "EnrollmentTokenBundle": ( | ||||
|         "Data.Engine.services.enrollment.enrollment_service", | ||||
|         "EnrollmentTokenBundle", | ||||
|     ), | ||||
|     "PollingResult": ("Data.Engine.services.enrollment.enrollment_service", "PollingResult"), | ||||
|     "EnrollmentValidationError": ( | ||||
|         "Data.Engine.domain.device_enrollment", | ||||
|         "EnrollmentValidationError", | ||||
|     ), | ||||
|     "EnrollmentAdminService": ( | ||||
|         "Data.Engine.services.enrollment.admin_service", | ||||
|         "EnrollmentAdminService", | ||||
|     ), | ||||
| } | ||||
|  | ||||
|  | ||||
| def __getattr__(name: str) -> Any: | ||||
|     try: | ||||
|         module_name, attribute = _LAZY[name] | ||||
|     except KeyError as exc:  # pragma: no cover - defensive | ||||
|         raise AttributeError(name) from exc | ||||
|  | ||||
|     module = import_module(module_name) | ||||
|     value = getattr(module, attribute) | ||||
|     globals()[name] = value | ||||
|     return value | ||||
|  | ||||
|  | ||||
| def __dir__() -> list[str]:  # pragma: no cover - interactive helper | ||||
|     return sorted(set(__all__)) | ||||
|  | ||||
|   | ||||
							
								
								
									
										113
									
								
								Data/Engine/services/enrollment/admin_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								Data/Engine/services/enrollment/admin_service.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | ||||
| """Administrative helpers for enrollment workflows.""" | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| import logging | ||||
| import secrets | ||||
| import uuid | ||||
| from datetime import datetime, timedelta, timezone | ||||
| from typing import Callable, List, Optional | ||||
|  | ||||
| from Data.Engine.domain.enrollment_admin import DeviceApprovalRecord, EnrollmentCodeRecord | ||||
| from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository | ||||
| from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository | ||||
|  | ||||
| __all__ = ["EnrollmentAdminService"] | ||||
|  | ||||
|  | ||||
| class EnrollmentAdminService: | ||||
|     """Expose administrative enrollment operations.""" | ||||
|  | ||||
|     _VALID_TTL_HOURS = {1, 3, 6, 12, 24} | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         *, | ||||
|         repository: SQLiteEnrollmentRepository, | ||||
|         user_repository: SQLiteUserRepository, | ||||
|         logger: Optional[logging.Logger] = None, | ||||
|         clock: Optional[Callable[[], datetime]] = None, | ||||
|     ) -> None: | ||||
|         self._repository = repository | ||||
|         self._users = user_repository | ||||
|         self._log = logger or logging.getLogger("borealis.engine.services.enrollment_admin") | ||||
|         self._clock = clock or (lambda: datetime.now(tz=timezone.utc)) | ||||
|  | ||||
|     # ------------------------------------------------------------------ | ||||
|     # Enrollment install codes | ||||
|     # ------------------------------------------------------------------ | ||||
|     def list_install_codes(self, *, status: Optional[str] = None) -> List[EnrollmentCodeRecord]: | ||||
|         return self._repository.list_install_codes(status=status, now=self._clock()) | ||||
|  | ||||
|     def create_install_code( | ||||
|         self, | ||||
|         *, | ||||
|         ttl_hours: int, | ||||
|         max_uses: int, | ||||
|         created_by: Optional[str], | ||||
|     ) -> EnrollmentCodeRecord: | ||||
|         if ttl_hours not in self._VALID_TTL_HOURS: | ||||
|             raise ValueError("invalid_ttl") | ||||
|  | ||||
|         normalized_max = self._normalize_max_uses(max_uses) | ||||
|  | ||||
|         now = self._clock() | ||||
|         expires_at = now + timedelta(hours=ttl_hours) | ||||
|         record_id = str(uuid.uuid4()) | ||||
|         code = self._generate_install_code() | ||||
|  | ||||
|         created_by_identifier = None | ||||
|         if created_by: | ||||
|             created_by_identifier = self._users.resolve_identifier(created_by) | ||||
|             if not created_by_identifier: | ||||
|                 created_by_identifier = created_by.strip() or None | ||||
|  | ||||
|         record = self._repository.insert_install_code( | ||||
|             record_id=record_id, | ||||
|             code=code, | ||||
|             expires_at=expires_at, | ||||
|             created_by=created_by_identifier, | ||||
|             max_uses=normalized_max, | ||||
|         ) | ||||
|  | ||||
|         self._log.info( | ||||
|             "install code created id=%s ttl=%sh max_uses=%s", | ||||
|             record.record_id, | ||||
|             ttl_hours, | ||||
|             normalized_max, | ||||
|         ) | ||||
|  | ||||
|         return record | ||||
|  | ||||
|     def delete_install_code(self, record_id: str) -> bool: | ||||
|         deleted = self._repository.delete_install_code_if_unused(record_id) | ||||
|         if deleted: | ||||
|             self._log.info("install code deleted id=%s", record_id) | ||||
|         return deleted | ||||
|  | ||||
|     # ------------------------------------------------------------------ | ||||
|     # Device approvals | ||||
|     # ------------------------------------------------------------------ | ||||
|     def list_device_approvals(self, *, status: Optional[str] = None) -> List[DeviceApprovalRecord]: | ||||
|         return self._repository.list_device_approvals(status=status) | ||||
|  | ||||
|     # ------------------------------------------------------------------ | ||||
|     # Helpers | ||||
|     # ------------------------------------------------------------------ | ||||
|     @staticmethod | ||||
|     def _generate_install_code() -> str: | ||||
|         raw = secrets.token_hex(16).upper() | ||||
|         return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4)) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _normalize_max_uses(value: int) -> int: | ||||
|         try: | ||||
|             count = int(value) | ||||
|         except Exception: | ||||
|             count = 2 | ||||
|         if count < 1: | ||||
|             return 1 | ||||
|         if count > 10: | ||||
|             return 10 | ||||
|         return count | ||||
|  | ||||
							
								
								
									
										122
									
								
								Data/Engine/tests/test_enrollment_admin_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								Data/Engine/tests/test_enrollment_admin_service.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,122 @@ | ||||
| import base64 | ||||
| import sqlite3 | ||||
| from datetime import datetime, timezone | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| from Data.Engine.repositories.sqlite import connection as sqlite_connection | ||||
| from Data.Engine.repositories.sqlite import migrations as sqlite_migrations | ||||
| from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository | ||||
| from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository | ||||
| from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService | ||||
|  | ||||
|  | ||||
| def _build_service(tmp_path): | ||||
|     db_path = tmp_path / "admin.db" | ||||
|     conn = sqlite3.connect(db_path) | ||||
|     sqlite_migrations.apply_all(conn) | ||||
|     conn.close() | ||||
|  | ||||
|     factory = sqlite_connection.connection_factory(db_path) | ||||
|     enrollment_repo = SQLiteEnrollmentRepository(factory) | ||||
|     user_repo = SQLiteUserRepository(factory) | ||||
|  | ||||
|     fixed_now = datetime(2024, 1, 1, tzinfo=timezone.utc) | ||||
|     service = EnrollmentAdminService( | ||||
|         repository=enrollment_repo, | ||||
|         user_repository=user_repo, | ||||
|         clock=lambda: fixed_now, | ||||
|     ) | ||||
|     return service, factory, fixed_now | ||||
|  | ||||
|  | ||||
| def test_create_and_list_install_codes(tmp_path): | ||||
|     service, factory, fixed_now = _build_service(tmp_path) | ||||
|  | ||||
|     record = service.create_install_code(ttl_hours=3, max_uses=5, created_by="admin") | ||||
|     assert record.code | ||||
|     assert record.max_uses == 5 | ||||
|     assert record.status(now=fixed_now) == "active" | ||||
|  | ||||
|     records = service.list_install_codes() | ||||
|     assert any(r.record_id == record.record_id for r in records) | ||||
|  | ||||
|     # Invalid TTL should raise | ||||
|     with pytest.raises(ValueError): | ||||
|         service.create_install_code(ttl_hours=2, max_uses=1, created_by=None) | ||||
|  | ||||
|     # Deleting should succeed and remove the record | ||||
|     assert service.delete_install_code(record.record_id) is True | ||||
|     remaining = service.list_install_codes() | ||||
|     assert all(r.record_id != record.record_id for r in remaining) | ||||
|  | ||||
|  | ||||
| def test_list_device_approvals_includes_conflict(tmp_path): | ||||
|     service, factory, fixed_now = _build_service(tmp_path) | ||||
|  | ||||
|     conn = factory() | ||||
|     cur = conn.cursor() | ||||
|  | ||||
|     cur.execute( | ||||
|         "INSERT INTO sites (name, description, created_at) VALUES (?, ?, ?)", | ||||
|         ("HQ", "Primary site", int(fixed_now.timestamp())), | ||||
|     ) | ||||
|     site_id = cur.lastrowid | ||||
|  | ||||
|     cur.execute( | ||||
|         """ | ||||
|         INSERT INTO devices (guid, hostname, created_at, last_seen, ssl_key_fingerprint, status) | ||||
|         VALUES (?, ?, ?, ?, ?, 'active') | ||||
|         """, | ||||
|         ("11111111-1111-1111-1111-111111111111", "agent-one", int(fixed_now.timestamp()), int(fixed_now.timestamp()), "abc123",), | ||||
|     ) | ||||
|     cur.execute( | ||||
|         "INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)", | ||||
|         ("agent-one", site_id, int(fixed_now.timestamp())), | ||||
|     ) | ||||
|  | ||||
|     now_iso = fixed_now.isoformat() | ||||
|     cur.execute( | ||||
|         """ | ||||
|         INSERT INTO device_approvals ( | ||||
|             id, | ||||
|             approval_reference, | ||||
|             guid, | ||||
|             hostname_claimed, | ||||
|             ssl_key_fingerprint_claimed, | ||||
|             enrollment_code_id, | ||||
|             status, | ||||
|             client_nonce, | ||||
|             server_nonce, | ||||
|             created_at, | ||||
|             updated_at, | ||||
|             approved_by_user_id, | ||||
|             agent_pubkey_der | ||||
|         ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | ||||
|         """, | ||||
|         ( | ||||
|             "approval-1", | ||||
|             "REF123", | ||||
|             None, | ||||
|             "agent-one", | ||||
|             "abc123", | ||||
|             "code-1", | ||||
|             "pending", | ||||
|             base64.b64encode(b"client").decode(), | ||||
|             base64.b64encode(b"server").decode(), | ||||
|             now_iso, | ||||
|             now_iso, | ||||
|             None, | ||||
|             b"pubkey", | ||||
|         ), | ||||
|     ) | ||||
|     conn.commit() | ||||
|     conn.close() | ||||
|  | ||||
|     approvals = service.list_device_approvals() | ||||
|     assert len(approvals) == 1 | ||||
|     record = approvals[0] | ||||
|     assert record.hostname_conflict is not None | ||||
|     assert record.hostname_conflict.fingerprint_match is True | ||||
|     assert record.conflict_requires_prompt is False | ||||
|  | ||||
							
								
								
									
										111
									
								
								Data/Engine/tests/test_http_admin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								Data/Engine/tests/test_http_admin.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,111 @@ | ||||
| import base64 | ||||
| import sqlite3 | ||||
| from datetime import datetime, timezone | ||||
|  | ||||
| from .test_http_auth import _login, prepared_app | ||||
|  | ||||
|  | ||||
| def test_enrollment_codes_require_authentication(prepared_app): | ||||
|     client = prepared_app.test_client() | ||||
|     resp = client.get("/api/admin/enrollment-codes") | ||||
|     assert resp.status_code == 401 | ||||
|  | ||||
|  | ||||
| def test_enrollment_code_workflow(prepared_app): | ||||
|     client = prepared_app.test_client() | ||||
|     _login(client) | ||||
|  | ||||
|     payload = {"ttl_hours": 3, "max_uses": 4} | ||||
|     resp = client.post("/api/admin/enrollment-codes", json=payload) | ||||
|     assert resp.status_code == 201 | ||||
|     created = resp.get_json() | ||||
|     assert created["max_uses"] == 4 | ||||
|     assert created["status"] == "active" | ||||
|  | ||||
|     resp = client.get("/api/admin/enrollment-codes") | ||||
|     assert resp.status_code == 200 | ||||
|     codes = resp.get_json().get("codes", []) | ||||
|     assert any(code["id"] == created["id"] for code in codes) | ||||
|  | ||||
|     resp = client.delete(f"/api/admin/enrollment-codes/{created['id']}") | ||||
|     assert resp.status_code == 200 | ||||
|  | ||||
|  | ||||
| def test_device_approvals_listing(prepared_app, engine_settings): | ||||
|     client = prepared_app.test_client() | ||||
|     _login(client) | ||||
|  | ||||
|     conn = sqlite3.connect(engine_settings.database.path) | ||||
|     cur = conn.cursor() | ||||
|  | ||||
|     now = datetime.now(tz=timezone.utc) | ||||
|     cur.execute( | ||||
|         "INSERT INTO sites (name, description, created_at) VALUES (?, ?, ?)", | ||||
|         ("HQ", "Primary", int(now.timestamp())), | ||||
|     ) | ||||
|     site_id = cur.lastrowid | ||||
|  | ||||
|     cur.execute( | ||||
|         """ | ||||
|         INSERT INTO devices (guid, hostname, created_at, last_seen, ssl_key_fingerprint, status) | ||||
|         VALUES (?, ?, ?, ?, ?, 'active') | ||||
|         """, | ||||
|         ( | ||||
|             "22222222-2222-2222-2222-222222222222", | ||||
|             "approval-host", | ||||
|             int(now.timestamp()), | ||||
|             int(now.timestamp()), | ||||
|             "deadbeef", | ||||
|         ), | ||||
|     ) | ||||
|     cur.execute( | ||||
|         "INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)", | ||||
|         ("approval-host", site_id, int(now.timestamp())), | ||||
|     ) | ||||
|  | ||||
|     now_iso = now.isoformat() | ||||
|     cur.execute( | ||||
|         """ | ||||
|         INSERT INTO device_approvals ( | ||||
|             id, | ||||
|             approval_reference, | ||||
|             guid, | ||||
|             hostname_claimed, | ||||
|             ssl_key_fingerprint_claimed, | ||||
|             enrollment_code_id, | ||||
|             status, | ||||
|             client_nonce, | ||||
|             server_nonce, | ||||
|             created_at, | ||||
|             updated_at, | ||||
|             approved_by_user_id, | ||||
|             agent_pubkey_der | ||||
|         ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | ||||
|         """, | ||||
|         ( | ||||
|             "approval-http", | ||||
|             "REFHTTP", | ||||
|             None, | ||||
|             "approval-host", | ||||
|             "deadbeef", | ||||
|             "code-http", | ||||
|             "pending", | ||||
|             base64.b64encode(b"client").decode(), | ||||
|             base64.b64encode(b"server").decode(), | ||||
|             now_iso, | ||||
|             now_iso, | ||||
|             None, | ||||
|             b"pub", | ||||
|         ), | ||||
|     ) | ||||
|     conn.commit() | ||||
|     conn.close() | ||||
|  | ||||
|     resp = client.get("/api/admin/device-approvals") | ||||
|     assert resp.status_code == 200 | ||||
|     body = resp.get_json() | ||||
|     approvals = body.get("approvals", []) | ||||
|     assert any(a["id"] == "approval-http" for a in approvals) | ||||
|     record = next(a for a in approvals if a["id"] == "approval-http") | ||||
|     assert record.get("hostname_conflict", {}).get("fingerprint_match") is True | ||||
|  | ||||
		Reference in New Issue
	
	Block a user