mirror of
				https://github.com/bunny-lab-io/Borealis.git
				synced 2025-10-26 15:21:57 -06:00 
			
		
		
		
	
		
			
				
	
	
		
			727 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			727 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """SQLite-backed enrollment repository for Engine services."""
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| import logging
 | |
| from contextlib import closing
 | |
| from datetime import datetime, timezone
 | |
| from typing import Any, List, Optional, Tuple
 | |
| 
 | |
| from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid, normalize_guid
 | |
| from Data.Engine.domain.device_enrollment import (
 | |
|     EnrollmentApproval,
 | |
|     EnrollmentApprovalStatus,
 | |
|     EnrollmentCode,
 | |
| )
 | |
| from Data.Engine.domain.enrollment_admin import (
 | |
|     DeviceApprovalRecord,
 | |
|     EnrollmentCodeRecord,
 | |
|     HostnameConflict,
 | |
| )
 | |
| from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
 | |
| 
 | |
| __all__ = ["SQLiteEnrollmentRepository"]
 | |
| 
 | |
| 
 | |
| class SQLiteEnrollmentRepository:
 | |
|     """Persistence adapter that manages enrollment codes and approvals."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         connection_factory: SQLiteConnectionFactory,
 | |
|         *,
 | |
|         logger: Optional[logging.Logger] = None,
 | |
|     ) -> None:
 | |
|         self._connections = connection_factory
 | |
|         self._log = logger or logging.getLogger("borealis.engine.repositories.enrollment")
 | |
| 
 | |
|     # ------------------------------------------------------------------
 | |
|     # Enrollment install codes
 | |
|     # ------------------------------------------------------------------
 | |
|     def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]:
 | |
|         """Load an enrollment install code by its public value."""
 | |
| 
 | |
|         code_value = (code or "").strip()
 | |
|         if not code_value:
 | |
|             return None
 | |
| 
 | |
|         with closing(self._connections()) as conn:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(
 | |
|                 """
 | |
|                 SELECT id,
 | |
|                        code,
 | |
|                        expires_at,
 | |
|                        used_at,
 | |
|                        used_by_guid,
 | |
|                        max_uses,
 | |
|                        use_count,
 | |
|                        last_used_at
 | |
|                   FROM enrollment_install_codes
 | |
|                  WHERE code = ?
 | |
|                 """,
 | |
|                 (code_value,),
 | |
|             )
 | |
|             row = cur.fetchone()
 | |
| 
 | |
|         if not row:
 | |
|             return None
 | |
| 
 | |
|         record = {
 | |
|             "id": row[0],
 | |
|             "code": row[1],
 | |
|             "expires_at": row[2],
 | |
|             "used_at": row[3],
 | |
|             "used_by_guid": row[4],
 | |
|             "max_uses": row[5],
 | |
|             "use_count": row[6],
 | |
|             "last_used_at": row[7],
 | |
|         }
 | |
|         try:
 | |
|             return EnrollmentCode.from_mapping(record)
 | |
|         except Exception as exc:  # pragma: no cover - defensive logging
 | |
|             self._log.warning("invalid enrollment code record for code=%s: %s", code_value, exc)
 | |
|             return None
 | |
| 
 | |
|     def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]:
 | |
|         record_value = (record_id or "").strip()
 | |
|         if not record_value:
 | |
|             return None
 | |
| 
 | |
|         with closing(self._connections()) as conn:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(
 | |
|                 """
 | |
|                 SELECT id,
 | |
|                        code,
 | |
|                        expires_at,
 | |
|                        used_at,
 | |
|                        used_by_guid,
 | |
|                        max_uses,
 | |
|                        use_count,
 | |
|                        last_used_at
 | |
|                   FROM enrollment_install_codes
 | |
|                  WHERE id = ?
 | |
|                 """,
 | |
|                 (record_value,),
 | |
|             )
 | |
|             row = cur.fetchone()
 | |
| 
 | |
|         if not row:
 | |
|             return None
 | |
| 
 | |
|         record = {
 | |
|             "id": row[0],
 | |
|             "code": row[1],
 | |
|             "expires_at": row[2],
 | |
|             "used_at": row[3],
 | |
|             "used_by_guid": row[4],
 | |
|             "max_uses": row[5],
 | |
|             "use_count": row[6],
 | |
|             "last_used_at": row[7],
 | |
|         }
 | |
| 
 | |
|         try:
 | |
|             return EnrollmentCode.from_mapping(record)
 | |
|         except Exception as exc:  # pragma: no cover - defensive logging
 | |
|             self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc)
 | |
|             return None
 | |
| 
 | |
|     def list_install_codes(
 | |
|         self,
 | |
|         *,
 | |
|         status: Optional[str] = None,
 | |
|         now: Optional[datetime] = None,
 | |
|     ) -> List[EnrollmentCodeRecord]:
 | |
|         reference = now or datetime.now(tz=timezone.utc)
 | |
|         status_filter = (status or "").strip().lower()
 | |
|         params: List[str] = []
 | |
| 
 | |
|         sql = """
 | |
|             SELECT id,
 | |
|                    code,
 | |
|                    expires_at,
 | |
|                    created_by_user_id,
 | |
|                    used_at,
 | |
|                    used_by_guid,
 | |
|                    max_uses,
 | |
|                    use_count,
 | |
|                    last_used_at
 | |
|               FROM enrollment_install_codes
 | |
|         """
 | |
| 
 | |
|         if status_filter in {"active", "expired", "used"}:
 | |
|             sql += " WHERE "
 | |
|             if status_filter == "active":
 | |
|                 sql += "use_count < max_uses AND expires_at > ?"
 | |
|                 params.append(self._isoformat(reference))
 | |
|             elif status_filter == "expired":
 | |
|                 sql += "use_count < max_uses AND expires_at <= ?"
 | |
|                 params.append(self._isoformat(reference))
 | |
|             else:  # used
 | |
|                 sql += "use_count >= max_uses"
 | |
| 
 | |
|         sql += " ORDER BY expires_at ASC"
 | |
| 
 | |
|         rows: List[EnrollmentCodeRecord] = []
 | |
|         with closing(self._connections()) as conn:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(sql, params)
 | |
|             for raw in cur.fetchall():
 | |
|                 record = {
 | |
|                     "id": raw[0],
 | |
|                     "code": raw[1],
 | |
|                     "expires_at": raw[2],
 | |
|                     "created_by_user_id": raw[3],
 | |
|                     "used_at": raw[4],
 | |
|                     "used_by_guid": raw[5],
 | |
|                     "max_uses": raw[6],
 | |
|                     "use_count": raw[7],
 | |
|                     "last_used_at": raw[8],
 | |
|                 }
 | |
|                 try:
 | |
|                     rows.append(EnrollmentCodeRecord.from_row(record))
 | |
|                 except Exception as exc:  # pragma: no cover - defensive logging
 | |
|                     self._log.warning("invalid enrollment install code row id=%s: %s", record.get("id"), exc)
 | |
|         return rows
 | |
| 
 | |
|     def get_install_code_record(self, record_id: str) -> Optional[EnrollmentCodeRecord]:
 | |
|         identifier = (record_id or "").strip()
 | |
|         if not identifier:
 | |
|             return None
 | |
| 
 | |
|         with closing(self._connections()) as conn:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(
 | |
|                 """
 | |
|                 SELECT id,
 | |
|                        code,
 | |
|                        expires_at,
 | |
|                        created_by_user_id,
 | |
|                        used_at,
 | |
|                        used_by_guid,
 | |
|                        max_uses,
 | |
|                        use_count,
 | |
|                        last_used_at
 | |
|                   FROM enrollment_install_codes
 | |
|                  WHERE id = ?
 | |
|                 """,
 | |
|                 (identifier,),
 | |
|             )
 | |
|             row = cur.fetchone()
 | |
| 
 | |
|         if not row:
 | |
|             return None
 | |
| 
 | |
|         payload = {
 | |
|             "id": row[0],
 | |
|             "code": row[1],
 | |
|             "expires_at": row[2],
 | |
|             "created_by_user_id": row[3],
 | |
|             "used_at": row[4],
 | |
|             "used_by_guid": row[5],
 | |
|             "max_uses": row[6],
 | |
|             "use_count": row[7],
 | |
|             "last_used_at": row[8],
 | |
|         }
 | |
| 
 | |
|         try:
 | |
|             return EnrollmentCodeRecord.from_row(payload)
 | |
|         except Exception as exc:  # pragma: no cover - defensive logging
 | |
|             self._log.warning("invalid enrollment install code record id=%s: %s", identifier, exc)
 | |
|             return None
 | |
| 
 | |
|     def insert_install_code(
 | |
|         self,
 | |
|         *,
 | |
|         record_id: str,
 | |
|         code: str,
 | |
|         expires_at: datetime,
 | |
|         created_by: Optional[str],
 | |
|         max_uses: int,
 | |
|     ) -> EnrollmentCodeRecord:
 | |
|         expires_iso = self._isoformat(expires_at)
 | |
| 
 | |
|         with closing(self._connections()) as conn:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(
 | |
|                 """
 | |
|                 INSERT INTO enrollment_install_codes (
 | |
|                     id,
 | |
|                     code,
 | |
|                     expires_at,
 | |
|                     created_by_user_id,
 | |
|                     max_uses,
 | |
|                     use_count
 | |
|                 ) VALUES (?, ?, ?, ?, ?, 0)
 | |
|                 """,
 | |
|                 (record_id, code, expires_iso, created_by, max_uses),
 | |
|             )
 | |
|             conn.commit()
 | |
| 
 | |
|         record = self.get_install_code_record(record_id)
 | |
|         if record is None:
 | |
|             raise RuntimeError("failed to load install code after insert")
 | |
|         return record
 | |
| 
 | |
|     def delete_install_code_if_unused(self, record_id: str) -> bool:
 | |
|         identifier = (record_id or "").strip()
 | |
|         if not identifier:
 | |
|             return False
 | |
| 
 | |
|         with closing(self._connections()) as conn:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(
 | |
|                 "DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
 | |
|                 (identifier,),
 | |
|             )
 | |
|             deleted = cur.rowcount > 0
 | |
|             conn.commit()
 | |
|             return deleted
 | |
| 
 | |
|     def update_install_code_usage(
 | |
|         self,
 | |
|         record_id: str,
 | |
|         *,
 | |
|         use_count_increment: int,
 | |
|         last_used_at: datetime,
 | |
|         used_by_guid: Optional[DeviceGuid] = None,
 | |
|         mark_first_use: bool = False,
 | |
|     ) -> None:
 | |
|         """Increment usage counters and usage metadata for an install code."""
 | |
| 
 | |
|         if use_count_increment <= 0:
 | |
|             raise ValueError("use_count_increment must be positive")
 | |
| 
 | |
|         last_used_iso = self._isoformat(last_used_at)
 | |
|         guid_value = used_by_guid.value if used_by_guid else ""
 | |
|         mark_flag = 1 if mark_first_use else 0
 | |
| 
 | |
|         with closing(self._connections()) as conn:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(
 | |
|                 """
 | |
|                 UPDATE enrollment_install_codes
 | |
|                    SET use_count = use_count + ?,
 | |
|                        last_used_at = ?,
 | |
|                        used_by_guid = COALESCE(NULLIF(?, ''), used_by_guid),
 | |
|                        used_at = CASE WHEN ? = 1 AND used_at IS NULL THEN ? ELSE used_at END
 | |
|                  WHERE id = ?
 | |
|                 """,
 | |
|                 (
 | |
|                     use_count_increment,
 | |
|                     last_used_iso,
 | |
|                     guid_value,
 | |
|                     mark_flag,
 | |
|                     last_used_iso,
 | |
|                     record_id,
 | |
|                 ),
 | |
|             )
 | |
|             conn.commit()
 | |
| 
 | |
|     # ------------------------------------------------------------------
 | |
|     # Device approvals
 | |
|     # ------------------------------------------------------------------
 | |
|     def list_device_approvals(
 | |
|         self,
 | |
|         *,
 | |
|         status: Optional[str] = None,
 | |
|     ) -> List[DeviceApprovalRecord]:
 | |
|         status_filter = (status or "").strip().lower()
 | |
|         params: List[str] = []
 | |
| 
 | |
|         sql = """
 | |
|             SELECT
 | |
|                 da.id,
 | |
|                 da.approval_reference,
 | |
|                 da.guid,
 | |
|                 da.hostname_claimed,
 | |
|                 da.ssl_key_fingerprint_claimed,
 | |
|                 da.enrollment_code_id,
 | |
|                 da.status,
 | |
|                 da.client_nonce,
 | |
|                 da.server_nonce,
 | |
|                 da.created_at,
 | |
|                 da.updated_at,
 | |
|                 da.approved_by_user_id,
 | |
|                 u.username AS approved_by_username
 | |
|               FROM device_approvals AS da
 | |
|          LEFT JOIN users AS u
 | |
|                 ON (
 | |
|                     CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT)
 | |
|                     OR LOWER(da.approved_by_user_id) = LOWER(u.username)
 | |
|                 )
 | |
|         """
 | |
| 
 | |
|         if status_filter and status_filter not in {"all", "*"}:
 | |
|             sql += " WHERE LOWER(da.status) = ?"
 | |
|             params.append(status_filter)
 | |
| 
 | |
|         sql += " ORDER BY da.created_at ASC"
 | |
| 
 | |
|         approvals: List[DeviceApprovalRecord] = []
 | |
|         with closing(self._connections()) as conn:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(sql, params)
 | |
|             rows = cur.fetchall()
 | |
| 
 | |
|             for raw in rows:
 | |
|                 record = {
 | |
|                     "id": raw[0],
 | |
|                     "approval_reference": raw[1],
 | |
|                     "guid": raw[2],
 | |
|                     "hostname_claimed": raw[3],
 | |
|                     "ssl_key_fingerprint_claimed": raw[4],
 | |
|                     "enrollment_code_id": raw[5],
 | |
|                     "status": raw[6],
 | |
|                     "client_nonce": raw[7],
 | |
|                     "server_nonce": raw[8],
 | |
|                     "created_at": raw[9],
 | |
|                     "updated_at": raw[10],
 | |
|                     "approved_by_user_id": raw[11],
 | |
|                     "approved_by_username": raw[12],
 | |
|                 }
 | |
| 
 | |
|                 conflict, fingerprint_match, requires_prompt = self._compute_hostname_conflict(
 | |
|                     conn,
 | |
|                     record.get("hostname_claimed"),
 | |
|                     record.get("guid"),
 | |
|                     record.get("ssl_key_fingerprint_claimed") or "",
 | |
|                 )
 | |
| 
 | |
|                 alternate = None
 | |
|                 if conflict and requires_prompt:
 | |
|                     alternate = self._suggest_alternate_hostname(
 | |
|                         conn,
 | |
|                         record.get("hostname_claimed"),
 | |
|                         record.get("guid"),
 | |
|                     )
 | |
| 
 | |
|                 try:
 | |
|                     approvals.append(
 | |
|                         DeviceApprovalRecord.from_row(
 | |
|                             record,
 | |
|                             conflict=conflict,
 | |
|                             alternate_hostname=alternate,
 | |
|                             fingerprint_match=fingerprint_match,
 | |
|                             requires_prompt=requires_prompt,
 | |
|                         )
 | |
|                     )
 | |
|                 except Exception as exc:  # pragma: no cover - defensive logging
 | |
|                     self._log.warning(
 | |
|                         "invalid device approval record id=%s: %s",
 | |
|                         record.get("id"),
 | |
|                         exc,
 | |
|                     )
 | |
| 
 | |
|         return approvals
 | |
| 
 | |
|     def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]:
 | |
|         """Load a device approval using its operator-visible reference."""
 | |
| 
 | |
|         ref_value = (reference or "").strip()
 | |
|         if not ref_value:
 | |
|             return None
 | |
|         return self._fetch_device_approval("approval_reference = ?", (ref_value,))
 | |
| 
 | |
|     def fetch_device_approval(self, record_id: str) -> Optional[EnrollmentApproval]:
 | |
|         record_value = (record_id or "").strip()
 | |
|         if not record_value:
 | |
|             return None
 | |
|         return self._fetch_device_approval("id = ?", (record_value,))
 | |
| 
 | |
|     def fetch_pending_approval_by_fingerprint(
 | |
|         self, fingerprint: DeviceFingerprint
 | |
|     ) -> Optional[EnrollmentApproval]:
 | |
|         return self._fetch_device_approval(
 | |
|             "ssl_key_fingerprint_claimed = ? AND status = 'pending'",
 | |
|             (fingerprint.value,),
 | |
|         )
 | |
| 
 | |
|     def update_pending_approval(
 | |
|         self,
 | |
|         record_id: str,
 | |
|         *,
 | |
|         hostname: str,
 | |
|         guid: Optional[DeviceGuid],
 | |
|         enrollment_code_id: Optional[str],
 | |
|         client_nonce_b64: str,
 | |
|         server_nonce_b64: str,
 | |
|         agent_pubkey_der: bytes,
 | |
|         updated_at: datetime,
 | |
|     ) -> None:
 | |
|         with closing(self._connections()) as conn:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(
 | |
|                 """
 | |
|                 UPDATE device_approvals
 | |
|                    SET hostname_claimed = ?,
 | |
|                        guid = ?,
 | |
|                        enrollment_code_id = ?,
 | |
|                        client_nonce = ?,
 | |
|                        server_nonce = ?,
 | |
|                        agent_pubkey_der = ?,
 | |
|                        updated_at = ?
 | |
|                  WHERE id = ?
 | |
|                 """,
 | |
|                 (
 | |
|                     hostname,
 | |
|                     guid.value if guid else None,
 | |
|                     enrollment_code_id,
 | |
|                     client_nonce_b64,
 | |
|                     server_nonce_b64,
 | |
|                     agent_pubkey_der,
 | |
|                     self._isoformat(updated_at),
 | |
|                     record_id,
 | |
|                 ),
 | |
|             )
 | |
|             conn.commit()
 | |
| 
 | |
|     def create_device_approval(
 | |
|         self,
 | |
|         *,
 | |
|         record_id: str,
 | |
|         reference: str,
 | |
|         claimed_hostname: str,
 | |
|         claimed_fingerprint: DeviceFingerprint,
 | |
|         enrollment_code_id: Optional[str],
 | |
|         client_nonce_b64: str,
 | |
|         server_nonce_b64: str,
 | |
|         agent_pubkey_der: bytes,
 | |
|         created_at: datetime,
 | |
|         status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
 | |
|         guid: Optional[DeviceGuid] = None,
 | |
|     ) -> EnrollmentApproval:
 | |
|         created_iso = self._isoformat(created_at)
 | |
|         guid_value = guid.value if guid else None
 | |
| 
 | |
|         with closing(self._connections()) as conn:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(
 | |
|                 """
 | |
|                 INSERT INTO device_approvals (
 | |
|                     id,
 | |
|                     approval_reference,
 | |
|                     guid,
 | |
|                     hostname_claimed,
 | |
|                     ssl_key_fingerprint_claimed,
 | |
|                     enrollment_code_id,
 | |
|                     status,
 | |
|                     created_at,
 | |
|                     updated_at,
 | |
|                     client_nonce,
 | |
|                     server_nonce,
 | |
|                     agent_pubkey_der
 | |
|                 )
 | |
|                 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
 | |
|                 """,
 | |
|                 (
 | |
|                     record_id,
 | |
|                     reference,
 | |
|                     guid_value,
 | |
|                     claimed_hostname,
 | |
|                     claimed_fingerprint.value,
 | |
|                     enrollment_code_id,
 | |
|                     status.value,
 | |
|                     created_iso,
 | |
|                     created_iso,
 | |
|                     client_nonce_b64,
 | |
|                     server_nonce_b64,
 | |
|                     agent_pubkey_der,
 | |
|                 ),
 | |
|             )
 | |
|             conn.commit()
 | |
| 
 | |
|         approval = self.fetch_device_approval(record_id)
 | |
|         if approval is None:
 | |
|             raise RuntimeError("failed to load device approval after insert")
 | |
|         return approval
 | |
| 
 | |
|     def update_device_approval_status(
 | |
|         self,
 | |
|         record_id: str,
 | |
|         *,
 | |
|         status: EnrollmentApprovalStatus,
 | |
|         updated_at: datetime,
 | |
|         approved_by: Optional[str] = None,
 | |
|         guid: Optional[DeviceGuid] = None,
 | |
|     ) -> None:
 | |
|         """Transition an approval to a new status."""
 | |
| 
 | |
|         with closing(self._connections()) as conn:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(
 | |
|                 """
 | |
|                 UPDATE device_approvals
 | |
|                    SET status = ?,
 | |
|                        updated_at = ?,
 | |
|                        guid = COALESCE(?, guid),
 | |
|                        approved_by_user_id = COALESCE(?, approved_by_user_id)
 | |
|                  WHERE id = ?
 | |
|                 """,
 | |
|                 (
 | |
|                     status.value,
 | |
|                     self._isoformat(updated_at),
 | |
|                     guid.value if guid else None,
 | |
|                     approved_by,
 | |
|                     record_id,
 | |
|                 ),
 | |
|             )
 | |
|             conn.commit()
 | |
| 
 | |
|     # ------------------------------------------------------------------
 | |
|     # Helpers
 | |
|     # ------------------------------------------------------------------
 | |
|     def _fetch_device_approval(self, where: str, params: tuple) -> Optional[EnrollmentApproval]:
 | |
|         with closing(self._connections()) as conn:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(
 | |
|                 f"""
 | |
|                 SELECT id,
 | |
|                        approval_reference,
 | |
|                        guid,
 | |
|                        hostname_claimed,
 | |
|                        ssl_key_fingerprint_claimed,
 | |
|                        enrollment_code_id,
 | |
|                        created_at,
 | |
|                        updated_at,
 | |
|                        status,
 | |
|                        approved_by_user_id,
 | |
|                        client_nonce,
 | |
|                        server_nonce,
 | |
|                        agent_pubkey_der
 | |
|                   FROM device_approvals
 | |
|                  WHERE {where}
 | |
|                 """,
 | |
|                 params,
 | |
|             )
 | |
|             row = cur.fetchone()
 | |
| 
 | |
|         if not row:
 | |
|             return None
 | |
| 
 | |
|         record = {
 | |
|             "id": row[0],
 | |
|             "approval_reference": row[1],
 | |
|             "guid": row[2],
 | |
|             "hostname_claimed": row[3],
 | |
|             "ssl_key_fingerprint_claimed": row[4],
 | |
|             "enrollment_code_id": row[5],
 | |
|             "created_at": row[6],
 | |
|             "updated_at": row[7],
 | |
|             "status": row[8],
 | |
|             "approved_by_user_id": row[9],
 | |
|             "client_nonce": row[10],
 | |
|             "server_nonce": row[11],
 | |
|             "agent_pubkey_der": row[12],
 | |
|         }
 | |
| 
 | |
|         try:
 | |
|             return EnrollmentApproval.from_mapping(record)
 | |
|         except Exception as exc:  # pragma: no cover - defensive logging
 | |
|             self._log.warning(
 | |
|                 "invalid device approval record id=%s reference=%s: %s",
 | |
|                 row[0],
 | |
|                 row[1],
 | |
|                 exc,
 | |
|             )
 | |
|             return None
 | |
| 
 | |
|     def _compute_hostname_conflict(
 | |
|         self,
 | |
|         conn,
 | |
|         hostname: Optional[str],
 | |
|         pending_guid: Optional[str],
 | |
|         claimed_fp: str,
 | |
|     ) -> Tuple[Optional[HostnameConflict], bool, bool]:
 | |
|         normalized_host = (hostname or "").strip()
 | |
|         if not normalized_host:
 | |
|             return None, False, False
 | |
| 
 | |
|         try:
 | |
|             cur = conn.cursor()
 | |
|             cur.execute(
 | |
|                 """
 | |
|                 SELECT d.guid,
 | |
|                        d.ssl_key_fingerprint,
 | |
|                        ds.site_id,
 | |
|                        s.name
 | |
|                   FROM devices AS d
 | |
|              LEFT JOIN device_sites AS ds ON ds.device_hostname = d.hostname
 | |
|              LEFT JOIN sites AS s ON s.id = ds.site_id
 | |
|                  WHERE d.hostname = ?
 | |
|                 """,
 | |
|                 (normalized_host,),
 | |
|             )
 | |
|             row = cur.fetchone()
 | |
|         except Exception as exc:  # pragma: no cover - defensive logging
 | |
|             self._log.warning("failed to inspect hostname conflict for %s: %s", normalized_host, exc)
 | |
|             return None, False, False
 | |
| 
 | |
|         if not row:
 | |
|             return None, False, False
 | |
| 
 | |
|         existing_guid = normalize_guid(row[0])
 | |
|         pending_norm = normalize_guid(pending_guid)
 | |
|         if existing_guid and pending_norm and existing_guid == pending_norm:
 | |
|             return None, False, False
 | |
| 
 | |
|         stored_fp = (row[1] or "").strip().lower()
 | |
|         claimed_fp_normalized = (claimed_fp or "").strip().lower()
 | |
|         fingerprint_match = bool(stored_fp and claimed_fp_normalized and stored_fp == claimed_fp_normalized)
 | |
| 
 | |
|         site_id = None
 | |
|         if row[2] is not None:
 | |
|             try:
 | |
|                 site_id = int(row[2])
 | |
|             except (TypeError, ValueError):  # pragma: no cover - defensive
 | |
|                 site_id = None
 | |
| 
 | |
|         site_name = str(row[3] or "").strip()
 | |
|         requires_prompt = not fingerprint_match
 | |
| 
 | |
|         conflict = HostnameConflict(
 | |
|             guid=existing_guid or None,
 | |
|             ssl_key_fingerprint=stored_fp or None,
 | |
|             site_id=site_id,
 | |
|             site_name=site_name,
 | |
|             fingerprint_match=fingerprint_match,
 | |
|             requires_prompt=requires_prompt,
 | |
|         )
 | |
| 
 | |
|         return conflict, fingerprint_match, requires_prompt
 | |
| 
 | |
|     def _suggest_alternate_hostname(
 | |
|         self,
 | |
|         conn,
 | |
|         hostname: Optional[str],
 | |
|         pending_guid: Optional[str],
 | |
|     ) -> Optional[str]:
 | |
|         base = (hostname or "").strip()
 | |
|         if not base:
 | |
|             return None
 | |
|         base = base[:253]
 | |
|         candidate = base
 | |
|         pending_norm = normalize_guid(pending_guid)
 | |
|         suffix = 1
 | |
| 
 | |
|         cur = conn.cursor()
 | |
|         while True:
 | |
|             cur.execute("SELECT guid FROM devices WHERE hostname = ?", (candidate,))
 | |
|             row = cur.fetchone()
 | |
|             if not row:
 | |
|                 return candidate
 | |
|             existing_guid = normalize_guid(row[0])
 | |
|             if pending_norm and existing_guid == pending_norm:
 | |
|                 return candidate
 | |
|             candidate = f"{base}-{suffix}"
 | |
|             suffix += 1
 | |
|             if suffix > 50:
 | |
|                 return pending_norm or candidate
 | |
| 
 | |
|     @staticmethod
 | |
|     def _isoformat(value: datetime) -> str:
 | |
|         if value.tzinfo is None:
 | |
|             value = value.replace(tzinfo=timezone.utc)
 | |
|         return value.astimezone(timezone.utc).isoformat()
 |