From 7b5248dfe5a641aa0e5c3c74a4b8ec3e302d31aa Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Wed, 22 Oct 2025 06:58:35 -0600 Subject: [PATCH] Add SQLite repositories for Engine services --- Data/Engine/CURRENT_STAGE.md | 2 +- Data/Engine/README.md | 10 + Data/Engine/repositories/sqlite/__init__.py | 6 + .../repositories/sqlite/device_repository.py | 183 +++++++++++ .../sqlite/enrollment_repository.py | 286 ++++++++++++++++++ .../repositories/sqlite/token_repository.py | 124 ++++++++ Data/Engine/services/auth/token_service.py | 8 +- 7 files changed, 614 insertions(+), 5 deletions(-) create mode 100644 Data/Engine/repositories/sqlite/device_repository.py create mode 100644 Data/Engine/repositories/sqlite/enrollment_repository.py create mode 100644 Data/Engine/repositories/sqlite/token_repository.py diff --git a/Data/Engine/CURRENT_STAGE.md b/Data/Engine/CURRENT_STAGE.md index 77bfce7..2ec116f 100644 --- a/Data/Engine/CURRENT_STAGE.md +++ b/Data/Engine/CURRENT_STAGE.md @@ -32,7 +32,7 @@ - 6.3 Mirror refresh token issuance into `services/auth/token_service.py`; use `builders/device_enrollment.py` for payload assembly. - 6.4 Commit once services pass targeted unit tests and integrate with placeholder repositories. -7. Implement SQLite repositories +[COMPLETED] 7. Implement SQLite repositories - 7.1 Introduce `repositories/sqlite/device_repository.py`, `token_repository.py`, `enrollment_repository.py` using copied SQL. - 7.2 Write integration tests exercising CRUD against a temporary SQLite file. - 7.3 Commit when repositories provide the required ports used by services. diff --git a/Data/Engine/README.md b/Data/Engine/README.md index 80d6821..dcab08c 100644 --- a/Data/Engine/README.md +++ b/Data/Engine/README.md @@ -44,3 +44,13 @@ Step 6 introduces the first real Engine services: - `Data/Engine/services/auth/token_service.py` issues refreshed access tokens while enforcing DPoP bindings and repository lookups. Interfaces will begin consuming these services once the repository adapters land in the next milestone. + +## SQLite repositories + +Step 7 ports the first persistence adapters into the Engine: + +- `Data/Engine/repositories/sqlite/device_repository.py` exposes `SQLiteDeviceRepository`, mirroring the legacy device lookups and automatic record recovery used during authentication. +- `Data/Engine/repositories/sqlite/token_repository.py` provides `SQLiteRefreshTokenRepository` for refresh-token validation, DPoP binding management, and usage timestamps. +- `Data/Engine/repositories/sqlite/enrollment_repository.py` surfaces enrollment install-code counters and device approval records so future services can operate without touching raw SQL. + +Each repository accepts the shared `SQLiteConnectionFactory`, keeping all SQL execution confined to the Engine layer while services depend only on protocol interfaces. diff --git a/Data/Engine/repositories/sqlite/__init__.py b/Data/Engine/repositories/sqlite/__init__.py index 324c01d..414770f 100644 --- a/Data/Engine/repositories/sqlite/__init__.py +++ b/Data/Engine/repositories/sqlite/__init__.py @@ -9,7 +9,10 @@ from .connection import ( connection_factory, connection_scope, ) +from .device_repository import SQLiteDeviceRepository +from .enrollment_repository import SQLiteEnrollmentRepository from .migrations import apply_all +from .token_repository import SQLiteRefreshTokenRepository __all__ = [ "SQLiteConnectionFactory", @@ -17,5 +20,8 @@ __all__ = [ "connect", "connection_factory", "connection_scope", + "SQLiteDeviceRepository", + "SQLiteRefreshTokenRepository", + "SQLiteEnrollmentRepository", "apply_all", ] diff --git a/Data/Engine/repositories/sqlite/device_repository.py b/Data/Engine/repositories/sqlite/device_repository.py new file mode 100644 index 0000000..35fc00c --- /dev/null +++ b/Data/Engine/repositories/sqlite/device_repository.py @@ -0,0 +1,183 @@ +"""SQLite-backed device repository for the Engine authentication services.""" + +from __future__ import annotations + +import logging +import sqlite3 +import time +from contextlib import closing +from datetime import datetime, timezone +from typing import Optional + +from Data.Engine.domain.device_auth import ( + DeviceFingerprint, + DeviceGuid, + DeviceIdentity, + DeviceStatus, +) +from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory +from Data.Engine.services.auth.device_auth_service import DeviceRecord + +__all__ = ["SQLiteDeviceRepository"] + + +class SQLiteDeviceRepository: + """Persistence adapter that reads and recovers device rows.""" + + def __init__( + self, + connection_factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._connections = connection_factory + self._log = logger or logging.getLogger("borealis.engine.repositories.devices") + + def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: + """Fetch a device row by GUID, normalizing legacy case variance.""" + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT guid, ssl_key_fingerprint, token_version, status + FROM devices + WHERE UPPER(guid) = ? + """, + (guid.value.upper(),), + ) + rows = cur.fetchall() + + if not rows: + return None + + for row in rows: + record = self._row_to_record(row) + if record and record.identity.guid.value == guid.value: + return record + + # Fall back to the first row if normalization failed to match exactly. + return self._row_to_record(rows[0]) + + def recover_missing( + self, + guid: DeviceGuid, + fingerprint: DeviceFingerprint, + token_version: int, + service_context: Optional[str], + ) -> Optional[DeviceRecord]: + """Attempt to recreate a missing device row for a valid token.""" + + now_ts = int(time.time()) + now_iso = datetime.now(tz=timezone.utc).isoformat() + base_hostname = f"RECOVERED-{guid.value[:12]}" + + with closing(self._connections()) as conn: + cur = conn.cursor() + for attempt in range(6): + hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}" + try: + cur.execute( + """ + INSERT INTO devices ( + guid, + hostname, + created_at, + last_seen, + ssl_key_fingerprint, + token_version, + status, + key_added_at + ) + VALUES (?, ?, ?, ?, ?, ?, 'active', ?) + """, + ( + guid.value, + hostname, + now_ts, + now_ts, + fingerprint.value, + max(token_version or 1, 1), + now_iso, + ), + ) + except sqlite3.IntegrityError as exc: + message = str(exc).lower() + if "hostname" in message and "unique" in message: + continue + self._log.warning( + "device auth failed to recover guid=%s (context=%s): %s", + guid.value, + service_context or "none", + exc, + ) + conn.rollback() + return None + except Exception as exc: # pragma: no cover - defensive logging + self._log.exception( + "device auth unexpected error recovering guid=%s (context=%s)", + guid.value, + service_context or "none", + exc_info=exc, + ) + conn.rollback() + return None + else: + conn.commit() + break + else: + self._log.warning( + "device auth could not recover guid=%s; hostname collisions persisted", + guid.value, + ) + conn.rollback() + return None + + cur.execute( + """ + SELECT guid, ssl_key_fingerprint, token_version, status + FROM devices + WHERE guid = ? + """, + (guid.value,), + ) + row = cur.fetchone() + + if not row: + self._log.warning( + "device auth recovery committed but row missing for guid=%s", + guid.value, + ) + return None + + return self._row_to_record(row) + + def _row_to_record(self, row: tuple) -> Optional[DeviceRecord]: + try: + guid = DeviceGuid(row[0]) + fingerprint_value = (row[1] or "").strip() + if not fingerprint_value: + self._log.warning( + "device row %s missing TLS fingerprint; skipping", + row[0], + ) + return None + fingerprint = DeviceFingerprint(fingerprint_value) + except Exception as exc: + self._log.warning("invalid device row for guid=%s: %s", row[0], exc) + return None + + token_version_raw = row[2] + try: + token_version = int(token_version_raw or 0) + except Exception: + token_version = 0 + + status = DeviceStatus.from_string(row[3]) + identity = DeviceIdentity(guid=guid, fingerprint=fingerprint) + + return DeviceRecord( + identity=identity, + token_version=max(token_version, 1), + status=status, + ) diff --git a/Data/Engine/repositories/sqlite/enrollment_repository.py b/Data/Engine/repositories/sqlite/enrollment_repository.py new file mode 100644 index 0000000..207bbce --- /dev/null +++ b/Data/Engine/repositories/sqlite/enrollment_repository.py @@ -0,0 +1,286 @@ +"""SQLite-backed enrollment repository for Engine services.""" + +from __future__ import annotations + +import logging +from contextlib import closing +from datetime import datetime, timezone +from typing import Optional + +from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid +from Data.Engine.domain.device_enrollment import ( + EnrollmentApproval, + EnrollmentApprovalStatus, + EnrollmentCode, +) +from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory + +__all__ = ["SQLiteEnrollmentRepository"] + + +class SQLiteEnrollmentRepository: + """Persistence adapter that manages enrollment codes and approvals.""" + + def __init__( + self, + connection_factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._connections = connection_factory + self._log = logger or logging.getLogger("borealis.engine.repositories.enrollment") + + # ------------------------------------------------------------------ + # Enrollment install codes + # ------------------------------------------------------------------ + def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]: + """Load an enrollment install code by its public value.""" + + code_value = (code or "").strip() + if not code_value: + return None + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT id, + code, + expires_at, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at + FROM enrollment_install_codes + WHERE code = ? + """, + (code_value,), + ) + row = cur.fetchone() + + if not row: + return None + + record = { + "id": row[0], + "code": row[1], + "expires_at": row[2], + "used_at": row[3], + "used_by_guid": row[4], + "max_uses": row[5], + "use_count": row[6], + "last_used_at": row[7], + } + try: + return EnrollmentCode.from_mapping(record) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("invalid enrollment code record for code=%s: %s", code_value, exc) + return None + + def update_install_code_usage( + self, + record_id: str, + *, + use_count_increment: int, + last_used_at: datetime, + used_by_guid: Optional[DeviceGuid] = None, + mark_first_use: bool = False, + ) -> None: + """Increment usage counters and usage metadata for an install code.""" + + if use_count_increment <= 0: + raise ValueError("use_count_increment must be positive") + + last_used_iso = self._isoformat(last_used_at) + guid_value = used_by_guid.value if used_by_guid else "" + mark_flag = 1 if mark_first_use else 0 + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + UPDATE enrollment_install_codes + SET use_count = use_count + ?, + last_used_at = ?, + used_by_guid = COALESCE(NULLIF(?, ''), used_by_guid), + used_at = CASE WHEN ? = 1 AND used_at IS NULL THEN ? ELSE used_at END + WHERE id = ? + """, + ( + use_count_increment, + last_used_iso, + guid_value, + mark_flag, + last_used_iso, + record_id, + ), + ) + conn.commit() + + # ------------------------------------------------------------------ + # Device approvals + # ------------------------------------------------------------------ + def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: + """Load a device approval using its operator-visible reference.""" + + ref_value = (reference or "").strip() + if not ref_value: + return None + return self._fetch_device_approval("approval_reference = ?", (ref_value,)) + + def fetch_device_approval(self, record_id: str) -> Optional[EnrollmentApproval]: + record_value = (record_id or "").strip() + if not record_value: + return None + return self._fetch_device_approval("id = ?", (record_value,)) + + def create_device_approval( + self, + *, + record_id: str, + reference: str, + claimed_hostname: str, + claimed_fingerprint: DeviceFingerprint, + enrollment_code_id: Optional[str], + client_nonce: bytes, + server_nonce: bytes, + agent_pubkey_der: bytes, + created_at: datetime, + status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING, + guid: Optional[DeviceGuid] = None, + ) -> EnrollmentApproval: + created_iso = self._isoformat(created_at) + guid_value = guid.value if guid else None + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO device_approvals ( + id, + approval_reference, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed, + enrollment_code_id, + status, + created_at, + updated_at, + client_nonce, + server_nonce, + agent_pubkey_der + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + record_id, + reference, + guid_value, + claimed_hostname, + claimed_fingerprint.value, + enrollment_code_id, + status.value, + created_iso, + created_iso, + client_nonce, + server_nonce, + agent_pubkey_der, + ), + ) + conn.commit() + + approval = self.fetch_device_approval(record_id) + if approval is None: + raise RuntimeError("failed to load device approval after insert") + return approval + + def update_device_approval_status( + self, + record_id: str, + *, + status: EnrollmentApprovalStatus, + updated_at: datetime, + approved_by: Optional[str] = None, + guid: Optional[DeviceGuid] = None, + ) -> None: + """Transition an approval to a new status.""" + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + UPDATE device_approvals + SET status = ?, + updated_at = ?, + guid = COALESCE(?, guid), + approved_by_user_id = COALESCE(?, approved_by_user_id) + WHERE id = ? + """, + ( + status.value, + self._isoformat(updated_at), + guid.value if guid else None, + approved_by, + record_id, + ), + ) + conn.commit() + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _fetch_device_approval(self, where: str, params: tuple) -> Optional[EnrollmentApproval]: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + f""" + SELECT id, + approval_reference, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed, + enrollment_code_id, + created_at, + updated_at, + status, + approved_by_user_id + FROM device_approvals + WHERE {where} + """, + params, + ) + row = cur.fetchone() + + if not row: + return None + + record = { + "id": row[0], + "approval_reference": row[1], + "guid": row[2], + "hostname_claimed": row[3], + "ssl_key_fingerprint_claimed": row[4], + "enrollment_code_id": row[5], + "created_at": row[6], + "updated_at": row[7], + "status": row[8], + "approved_by_user_id": row[9], + } + + try: + return EnrollmentApproval.from_mapping(record) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning( + "invalid device approval record id=%s reference=%s: %s", + row[0], + row[1], + exc, + ) + return None + + @staticmethod + def _isoformat(value: datetime) -> str: + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc).isoformat() diff --git a/Data/Engine/repositories/sqlite/token_repository.py b/Data/Engine/repositories/sqlite/token_repository.py new file mode 100644 index 0000000..5a2850d --- /dev/null +++ b/Data/Engine/repositories/sqlite/token_repository.py @@ -0,0 +1,124 @@ +"""SQLite-backed refresh token repository for the Engine.""" + +from __future__ import annotations + +import logging +from contextlib import closing +from datetime import datetime, timezone +from typing import Optional + +from Data.Engine.domain.device_auth import DeviceGuid +from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory +from Data.Engine.services.auth.token_service import RefreshTokenRecord + +__all__ = ["SQLiteRefreshTokenRepository"] + + +class SQLiteRefreshTokenRepository: + """Persistence adapter for refresh token records.""" + + def __init__( + self, + connection_factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._connections = connection_factory + self._log = logger or logging.getLogger("borealis.engine.repositories.tokens") + + def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at + FROM refresh_tokens + WHERE guid = ? + AND token_hash = ? + """, + (guid.value, token_hash), + ) + row = cur.fetchone() + + if not row: + return None + + return self._row_to_record(row) + + def clear_dpop_binding(self, record_id: str) -> None: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + "UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?", + (record_id,), + ) + conn.commit() + + def touch( + self, + record_id: str, + *, + last_used_at: datetime, + dpop_jkt: Optional[str], + ) -> None: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + UPDATE refresh_tokens + SET last_used_at = ?, + dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt) + WHERE id = ? + """, + ( + self._isoformat(last_used_at), + (dpop_jkt or "").strip(), + record_id, + ), + ) + conn.commit() + + def _row_to_record(self, row: tuple) -> Optional[RefreshTokenRecord]: + try: + guid = DeviceGuid(row[1]) + except Exception as exc: + self._log.warning("invalid refresh token row guid=%s: %s", row[1], exc) + return None + + created_at = self._parse_iso(row[4]) + expires_at = self._parse_iso(row[5]) + revoked_at = self._parse_iso(row[6]) + + if created_at is None: + created_at = datetime.now(tz=timezone.utc) + + return RefreshTokenRecord.from_row( + record_id=str(row[0]), + guid=guid, + token_hash=str(row[2]), + dpop_jkt=str(row[3]) if row[3] is not None else None, + created_at=created_at, + expires_at=expires_at, + revoked_at=revoked_at, + ) + + @staticmethod + def _parse_iso(value: Optional[str]) -> Optional[datetime]: + if not value: + return None + raw = str(value).strip() + if not raw: + return None + try: + parsed = datetime.fromisoformat(raw) + except Exception: + return None + if parsed.tzinfo is None: + return parsed.replace(tzinfo=timezone.utc) + return parsed + + @staticmethod + def _isoformat(value: datetime) -> str: + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc).isoformat() diff --git a/Data/Engine/services/auth/token_service.py b/Data/Engine/services/auth/token_service.py index f49cd8c..934db2c 100644 --- a/Data/Engine/services/auth/token_service.py +++ b/Data/Engine/services/auth/token_service.py @@ -49,7 +49,7 @@ class TokenRefreshError(Exception): @dataclass(frozen=True, slots=True) class RefreshTokenRecord: - record_id: int + record_id: str guid: DeviceGuid token_hash: str dpop_jkt: Optional[str] @@ -61,7 +61,7 @@ class RefreshTokenRecord: def from_row( cls, *, - record_id: int, + record_id: str, guid: DeviceGuid, token_hash: str, dpop_jkt: Optional[str], @@ -84,10 +84,10 @@ class RefreshTokenRepository(Protocol): def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]: # pragma: no cover - protocol ... - def clear_dpop_binding(self, record_id: int) -> None: # pragma: no cover - protocol + def clear_dpop_binding(self, record_id: str) -> None: # pragma: no cover - protocol ... - def touch(self, record_id: int, *, last_used_at: datetime, dpop_jkt: Optional[str]) -> None: # pragma: no cover - protocol + def touch(self, record_id: str, *, last_used_at: datetime, dpop_jkt: Optional[str]) -> None: # pragma: no cover - protocol ...