mirror of
				https://github.com/bunny-lab-io/Borealis.git
				synced 2025-10-27 03:41:57 -06:00 
			
		
		
		
	Implement admin enrollment APIs
This commit is contained in:
		| @@ -5,14 +5,19 @@ from __future__ import annotations | ||||
| import logging | ||||
| from contextlib import closing | ||||
| from datetime import datetime, timezone | ||||
| from typing import Optional | ||||
| from typing import Any, List, Optional, Tuple | ||||
|  | ||||
| from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid | ||||
| from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid, normalize_guid | ||||
| from Data.Engine.domain.device_enrollment import ( | ||||
|     EnrollmentApproval, | ||||
|     EnrollmentApprovalStatus, | ||||
|     EnrollmentCode, | ||||
| ) | ||||
| from Data.Engine.domain.enrollment_admin import ( | ||||
|     DeviceApprovalRecord, | ||||
|     EnrollmentCodeRecord, | ||||
|     HostnameConflict, | ||||
| ) | ||||
| from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory | ||||
|  | ||||
| __all__ = ["SQLiteEnrollmentRepository"] | ||||
| @@ -122,6 +127,158 @@ class SQLiteEnrollmentRepository: | ||||
|             self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc) | ||||
|             return None | ||||
|  | ||||
|     def list_install_codes( | ||||
|         self, | ||||
|         *, | ||||
|         status: Optional[str] = None, | ||||
|         now: Optional[datetime] = None, | ||||
|     ) -> List[EnrollmentCodeRecord]: | ||||
|         reference = now or datetime.now(tz=timezone.utc) | ||||
|         status_filter = (status or "").strip().lower() | ||||
|         params: List[str] = [] | ||||
|  | ||||
|         sql = """ | ||||
|             SELECT id, | ||||
|                    code, | ||||
|                    expires_at, | ||||
|                    created_by_user_id, | ||||
|                    used_at, | ||||
|                    used_by_guid, | ||||
|                    max_uses, | ||||
|                    use_count, | ||||
|                    last_used_at | ||||
|               FROM enrollment_install_codes | ||||
|         """ | ||||
|  | ||||
|         if status_filter in {"active", "expired", "used"}: | ||||
|             sql += " WHERE " | ||||
|             if status_filter == "active": | ||||
|                 sql += "use_count < max_uses AND expires_at > ?" | ||||
|                 params.append(self._isoformat(reference)) | ||||
|             elif status_filter == "expired": | ||||
|                 sql += "use_count < max_uses AND expires_at <= ?" | ||||
|                 params.append(self._isoformat(reference)) | ||||
|             else:  # used | ||||
|                 sql += "use_count >= max_uses" | ||||
|  | ||||
|         sql += " ORDER BY expires_at ASC" | ||||
|  | ||||
|         rows: List[EnrollmentCodeRecord] = [] | ||||
|         with closing(self._connections()) as conn: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute(sql, params) | ||||
|             for raw in cur.fetchall(): | ||||
|                 record = { | ||||
|                     "id": raw[0], | ||||
|                     "code": raw[1], | ||||
|                     "expires_at": raw[2], | ||||
|                     "created_by_user_id": raw[3], | ||||
|                     "used_at": raw[4], | ||||
|                     "used_by_guid": raw[5], | ||||
|                     "max_uses": raw[6], | ||||
|                     "use_count": raw[7], | ||||
|                     "last_used_at": raw[8], | ||||
|                 } | ||||
|                 try: | ||||
|                     rows.append(EnrollmentCodeRecord.from_row(record)) | ||||
|                 except Exception as exc:  # pragma: no cover - defensive logging | ||||
|                     self._log.warning("invalid enrollment install code row id=%s: %s", record.get("id"), exc) | ||||
|         return rows | ||||
|  | ||||
|     def get_install_code_record(self, record_id: str) -> Optional[EnrollmentCodeRecord]: | ||||
|         identifier = (record_id or "").strip() | ||||
|         if not identifier: | ||||
|             return None | ||||
|  | ||||
|         with closing(self._connections()) as conn: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 """ | ||||
|                 SELECT id, | ||||
|                        code, | ||||
|                        expires_at, | ||||
|                        created_by_user_id, | ||||
|                        used_at, | ||||
|                        used_by_guid, | ||||
|                        max_uses, | ||||
|                        use_count, | ||||
|                        last_used_at | ||||
|                   FROM enrollment_install_codes | ||||
|                  WHERE id = ? | ||||
|                 """, | ||||
|                 (identifier,), | ||||
|             ) | ||||
|             row = cur.fetchone() | ||||
|  | ||||
|         if not row: | ||||
|             return None | ||||
|  | ||||
|         payload = { | ||||
|             "id": row[0], | ||||
|             "code": row[1], | ||||
|             "expires_at": row[2], | ||||
|             "created_by_user_id": row[3], | ||||
|             "used_at": row[4], | ||||
|             "used_by_guid": row[5], | ||||
|             "max_uses": row[6], | ||||
|             "use_count": row[7], | ||||
|             "last_used_at": row[8], | ||||
|         } | ||||
|  | ||||
|         try: | ||||
|             return EnrollmentCodeRecord.from_row(payload) | ||||
|         except Exception as exc:  # pragma: no cover - defensive logging | ||||
|             self._log.warning("invalid enrollment install code record id=%s: %s", identifier, exc) | ||||
|             return None | ||||
|  | ||||
|     def insert_install_code( | ||||
|         self, | ||||
|         *, | ||||
|         record_id: str, | ||||
|         code: str, | ||||
|         expires_at: datetime, | ||||
|         created_by: Optional[str], | ||||
|         max_uses: int, | ||||
|     ) -> EnrollmentCodeRecord: | ||||
|         expires_iso = self._isoformat(expires_at) | ||||
|  | ||||
|         with closing(self._connections()) as conn: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 """ | ||||
|                 INSERT INTO enrollment_install_codes ( | ||||
|                     id, | ||||
|                     code, | ||||
|                     expires_at, | ||||
|                     created_by_user_id, | ||||
|                     max_uses, | ||||
|                     use_count | ||||
|                 ) VALUES (?, ?, ?, ?, ?, 0) | ||||
|                 """, | ||||
|                 (record_id, code, expires_iso, created_by, max_uses), | ||||
|             ) | ||||
|             conn.commit() | ||||
|  | ||||
|         record = self.get_install_code_record(record_id) | ||||
|         if record is None: | ||||
|             raise RuntimeError("failed to load install code after insert") | ||||
|         return record | ||||
|  | ||||
|     def delete_install_code_if_unused(self, record_id: str) -> bool: | ||||
|         identifier = (record_id or "").strip() | ||||
|         if not identifier: | ||||
|             return False | ||||
|  | ||||
|         with closing(self._connections()) as conn: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 "DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0", | ||||
|                 (identifier,), | ||||
|             ) | ||||
|             deleted = cur.rowcount > 0 | ||||
|             conn.commit() | ||||
|             return deleted | ||||
|  | ||||
|     def update_install_code_usage( | ||||
|         self, | ||||
|         record_id: str, | ||||
| @@ -165,6 +322,100 @@ class SQLiteEnrollmentRepository: | ||||
|     # ------------------------------------------------------------------ | ||||
|     # Device approvals | ||||
|     # ------------------------------------------------------------------ | ||||
|     def list_device_approvals( | ||||
|         self, | ||||
|         *, | ||||
|         status: Optional[str] = None, | ||||
|     ) -> List[DeviceApprovalRecord]: | ||||
|         status_filter = (status or "").strip().lower() | ||||
|         params: List[str] = [] | ||||
|  | ||||
|         sql = """ | ||||
|             SELECT | ||||
|                 da.id, | ||||
|                 da.approval_reference, | ||||
|                 da.guid, | ||||
|                 da.hostname_claimed, | ||||
|                 da.ssl_key_fingerprint_claimed, | ||||
|                 da.enrollment_code_id, | ||||
|                 da.status, | ||||
|                 da.client_nonce, | ||||
|                 da.server_nonce, | ||||
|                 da.created_at, | ||||
|                 da.updated_at, | ||||
|                 da.approved_by_user_id, | ||||
|                 u.username AS approved_by_username | ||||
|               FROM device_approvals AS da | ||||
|          LEFT JOIN users AS u | ||||
|                 ON ( | ||||
|                     CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT) | ||||
|                     OR LOWER(da.approved_by_user_id) = LOWER(u.username) | ||||
|                 ) | ||||
|         """ | ||||
|  | ||||
|         if status_filter and status_filter not in {"all", "*"}: | ||||
|             sql += " WHERE LOWER(da.status) = ?" | ||||
|             params.append(status_filter) | ||||
|  | ||||
|         sql += " ORDER BY da.created_at ASC" | ||||
|  | ||||
|         approvals: List[DeviceApprovalRecord] = [] | ||||
|         with closing(self._connections()) as conn: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute(sql, params) | ||||
|             rows = cur.fetchall() | ||||
|  | ||||
|             for raw in rows: | ||||
|                 record = { | ||||
|                     "id": raw[0], | ||||
|                     "approval_reference": raw[1], | ||||
|                     "guid": raw[2], | ||||
|                     "hostname_claimed": raw[3], | ||||
|                     "ssl_key_fingerprint_claimed": raw[4], | ||||
|                     "enrollment_code_id": raw[5], | ||||
|                     "status": raw[6], | ||||
|                     "client_nonce": raw[7], | ||||
|                     "server_nonce": raw[8], | ||||
|                     "created_at": raw[9], | ||||
|                     "updated_at": raw[10], | ||||
|                     "approved_by_user_id": raw[11], | ||||
|                     "approved_by_username": raw[12], | ||||
|                 } | ||||
|  | ||||
|                 conflict, fingerprint_match, requires_prompt = self._compute_hostname_conflict( | ||||
|                     conn, | ||||
|                     record.get("hostname_claimed"), | ||||
|                     record.get("guid"), | ||||
|                     record.get("ssl_key_fingerprint_claimed") or "", | ||||
|                 ) | ||||
|  | ||||
|                 alternate = None | ||||
|                 if conflict and requires_prompt: | ||||
|                     alternate = self._suggest_alternate_hostname( | ||||
|                         conn, | ||||
|                         record.get("hostname_claimed"), | ||||
|                         record.get("guid"), | ||||
|                     ) | ||||
|  | ||||
|                 try: | ||||
|                     approvals.append( | ||||
|                         DeviceApprovalRecord.from_row( | ||||
|                             record, | ||||
|                             conflict=conflict, | ||||
|                             alternate_hostname=alternate, | ||||
|                             fingerprint_match=fingerprint_match, | ||||
|                             requires_prompt=requires_prompt, | ||||
|                         ) | ||||
|                     ) | ||||
|                 except Exception as exc:  # pragma: no cover - defensive logging | ||||
|                     self._log.warning( | ||||
|                         "invalid device approval record id=%s: %s", | ||||
|                         record.get("id"), | ||||
|                         exc, | ||||
|                     ) | ||||
|  | ||||
|         return approvals | ||||
|  | ||||
|     def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: | ||||
|         """Load a device approval using its operator-visible reference.""" | ||||
|  | ||||
| @@ -376,6 +627,98 @@ class SQLiteEnrollmentRepository: | ||||
|             ) | ||||
|             return None | ||||
|  | ||||
|     def _compute_hostname_conflict( | ||||
|         self, | ||||
|         conn, | ||||
|         hostname: Optional[str], | ||||
|         pending_guid: Optional[str], | ||||
|         claimed_fp: str, | ||||
|     ) -> Tuple[Optional[HostnameConflict], bool, bool]: | ||||
|         normalized_host = (hostname or "").strip() | ||||
|         if not normalized_host: | ||||
|             return None, False, False | ||||
|  | ||||
|         try: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 """ | ||||
|                 SELECT d.guid, | ||||
|                        d.ssl_key_fingerprint, | ||||
|                        ds.site_id, | ||||
|                        s.name | ||||
|                   FROM devices AS d | ||||
|              LEFT JOIN device_sites AS ds ON ds.device_hostname = d.hostname | ||||
|              LEFT JOIN sites AS s ON s.id = ds.site_id | ||||
|                  WHERE d.hostname = ? | ||||
|                 """, | ||||
|                 (normalized_host,), | ||||
|             ) | ||||
|             row = cur.fetchone() | ||||
|         except Exception as exc:  # pragma: no cover - defensive logging | ||||
|             self._log.warning("failed to inspect hostname conflict for %s: %s", normalized_host, exc) | ||||
|             return None, False, False | ||||
|  | ||||
|         if not row: | ||||
|             return None, False, False | ||||
|  | ||||
|         existing_guid = normalize_guid(row[0]) | ||||
|         pending_norm = normalize_guid(pending_guid) | ||||
|         if existing_guid and pending_norm and existing_guid == pending_norm: | ||||
|             return None, False, False | ||||
|  | ||||
|         stored_fp = (row[1] or "").strip().lower() | ||||
|         claimed_fp_normalized = (claimed_fp or "").strip().lower() | ||||
|         fingerprint_match = bool(stored_fp and claimed_fp_normalized and stored_fp == claimed_fp_normalized) | ||||
|  | ||||
|         site_id = None | ||||
|         if row[2] is not None: | ||||
|             try: | ||||
|                 site_id = int(row[2]) | ||||
|             except (TypeError, ValueError):  # pragma: no cover - defensive | ||||
|                 site_id = None | ||||
|  | ||||
|         site_name = str(row[3] or "").strip() | ||||
|         requires_prompt = not fingerprint_match | ||||
|  | ||||
|         conflict = HostnameConflict( | ||||
|             guid=existing_guid or None, | ||||
|             ssl_key_fingerprint=stored_fp or None, | ||||
|             site_id=site_id, | ||||
|             site_name=site_name, | ||||
|             fingerprint_match=fingerprint_match, | ||||
|             requires_prompt=requires_prompt, | ||||
|         ) | ||||
|  | ||||
|         return conflict, fingerprint_match, requires_prompt | ||||
|  | ||||
|     def _suggest_alternate_hostname( | ||||
|         self, | ||||
|         conn, | ||||
|         hostname: Optional[str], | ||||
|         pending_guid: Optional[str], | ||||
|     ) -> Optional[str]: | ||||
|         base = (hostname or "").strip() | ||||
|         if not base: | ||||
|             return None | ||||
|         base = base[:253] | ||||
|         candidate = base | ||||
|         pending_norm = normalize_guid(pending_guid) | ||||
|         suffix = 1 | ||||
|  | ||||
|         cur = conn.cursor() | ||||
|         while True: | ||||
|             cur.execute("SELECT guid FROM devices WHERE hostname = ?", (candidate,)) | ||||
|             row = cur.fetchone() | ||||
|             if not row: | ||||
|                 return candidate | ||||
|             existing_guid = normalize_guid(row[0]) | ||||
|             if pending_norm and existing_guid == pending_norm: | ||||
|                 return candidate | ||||
|             candidate = f"{base}-{suffix}" | ||||
|             suffix += 1 | ||||
|             if suffix > 50: | ||||
|                 return pending_norm or candidate | ||||
|  | ||||
|     @staticmethod | ||||
|     def _isoformat(value: datetime) -> str: | ||||
|         if value.tzinfo is None: | ||||
|   | ||||
| @@ -31,6 +31,9 @@ def apply_all(conn: sqlite3.Connection) -> None: | ||||
|     _ensure_refresh_token_table(conn) | ||||
|     _ensure_install_code_table(conn) | ||||
|     _ensure_device_approval_table(conn) | ||||
|     _ensure_device_list_views_table(conn) | ||||
|     _ensure_sites_tables(conn) | ||||
|     _ensure_credentials_table(conn) | ||||
|     _ensure_github_token_table(conn) | ||||
|     _ensure_scheduled_jobs_table(conn) | ||||
|     _ensure_scheduled_job_run_tables(conn) | ||||
| @@ -233,6 +236,73 @@ def _ensure_device_approval_table(conn: sqlite3.Connection) -> None: | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _ensure_device_list_views_table(conn: sqlite3.Connection) -> None: | ||||
|     cur = conn.cursor() | ||||
|     cur.execute( | ||||
|         """ | ||||
|         CREATE TABLE IF NOT EXISTS device_list_views ( | ||||
|             id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||
|             name TEXT UNIQUE NOT NULL, | ||||
|             columns_json TEXT NOT NULL, | ||||
|             filters_json TEXT, | ||||
|             created_at INTEGER, | ||||
|             updated_at INTEGER | ||||
|         ) | ||||
|         """ | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _ensure_sites_tables(conn: sqlite3.Connection) -> None: | ||||
|     cur = conn.cursor() | ||||
|     cur.execute( | ||||
|         """ | ||||
|         CREATE TABLE IF NOT EXISTS sites ( | ||||
|             id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||
|             name TEXT UNIQUE NOT NULL, | ||||
|             description TEXT, | ||||
|             created_at INTEGER | ||||
|         ) | ||||
|         """ | ||||
|     ) | ||||
|     cur.execute( | ||||
|         """ | ||||
|         CREATE TABLE IF NOT EXISTS device_sites ( | ||||
|             device_hostname TEXT UNIQUE NOT NULL, | ||||
|             site_id INTEGER NOT NULL, | ||||
|             assigned_at INTEGER, | ||||
|             FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE CASCADE | ||||
|         ) | ||||
|         """ | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _ensure_credentials_table(conn: sqlite3.Connection) -> None: | ||||
|     cur = conn.cursor() | ||||
|     cur.execute( | ||||
|         """ | ||||
|         CREATE TABLE IF NOT EXISTS credentials ( | ||||
|             id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||
|             name TEXT NOT NULL UNIQUE, | ||||
|             description TEXT, | ||||
|             site_id INTEGER, | ||||
|             credential_type TEXT NOT NULL DEFAULT 'machine', | ||||
|             connection_type TEXT NOT NULL DEFAULT 'ssh', | ||||
|             username TEXT, | ||||
|             password_encrypted BLOB, | ||||
|             private_key_encrypted BLOB, | ||||
|             private_key_passphrase_encrypted BLOB, | ||||
|             become_method TEXT, | ||||
|             become_username TEXT, | ||||
|             become_password_encrypted BLOB, | ||||
|             metadata_json TEXT, | ||||
|             created_at INTEGER NOT NULL, | ||||
|             updated_at INTEGER NOT NULL, | ||||
|             FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE SET NULL | ||||
|         ) | ||||
|         """ | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _ensure_github_token_table(conn: sqlite3.Connection) -> None: | ||||
|     cur = conn.cursor() | ||||
|     cur.execute( | ||||
|   | ||||
| @@ -71,6 +71,57 @@ class SQLiteUserRepository: | ||||
|         finally: | ||||
|             conn.close() | ||||
|  | ||||
|     def resolve_identifier(self, username: str) -> Optional[str]: | ||||
|         normalized = (username or "").strip() | ||||
|         if not normalized: | ||||
|             return None | ||||
|  | ||||
|         conn = self._connection_factory() | ||||
|         try: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 "SELECT id FROM users WHERE LOWER(username) = LOWER(?)", | ||||
|                 (normalized,), | ||||
|             ) | ||||
|             row = cur.fetchone() | ||||
|             if not row: | ||||
|                 return None | ||||
|             return str(row[0]) if row[0] is not None else None | ||||
|         except sqlite3.Error as exc:  # pragma: no cover - defensive | ||||
|             self._log.error("failed to resolve identifier for %s: %s", username, exc) | ||||
|             return None | ||||
|         finally: | ||||
|             conn.close() | ||||
|  | ||||
|     def username_for_identifier(self, identifier: str) -> Optional[str]: | ||||
|         token = (identifier or "").strip() | ||||
|         if not token: | ||||
|             return None | ||||
|  | ||||
|         conn = self._connection_factory() | ||||
|         try: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 """ | ||||
|                 SELECT username | ||||
|                   FROM users | ||||
|                  WHERE CAST(id AS TEXT) = ? | ||||
|                     OR LOWER(username) = LOWER(?) | ||||
|                  LIMIT 1 | ||||
|                 """, | ||||
|                 (token, token), | ||||
|             ) | ||||
|             row = cur.fetchone() | ||||
|             if not row: | ||||
|                 return None | ||||
|             username = str(row[0] or "").strip() | ||||
|             return username or None | ||||
|         except sqlite3.Error as exc:  # pragma: no cover - defensive | ||||
|             self._log.error("failed to resolve username for %s: %s", identifier, exc) | ||||
|             return None | ||||
|         finally: | ||||
|             conn.close() | ||||
|  | ||||
|     def list_accounts(self) -> list[OperatorAccount]: | ||||
|         conn = self._connection_factory() | ||||
|         try: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user