Add SQLite repositories for Engine services

This commit is contained in:
2025-10-22 06:58:35 -06:00
parent 0ce11eac1a
commit 7b5248dfe5
7 changed files with 614 additions and 5 deletions

View File

@@ -32,7 +32,7 @@
- 6.3 Mirror refresh token issuance into `services/auth/token_service.py`; use `builders/device_enrollment.py` for payload assembly. - 6.3 Mirror refresh token issuance into `services/auth/token_service.py`; use `builders/device_enrollment.py` for payload assembly.
- 6.4 Commit once services pass targeted unit tests and integrate with placeholder repositories. - 6.4 Commit once services pass targeted unit tests and integrate with placeholder repositories.
7. Implement SQLite repositories [COMPLETED] 7. Implement SQLite repositories
- 7.1 Introduce `repositories/sqlite/device_repository.py`, `token_repository.py`, `enrollment_repository.py` using copied SQL. - 7.1 Introduce `repositories/sqlite/device_repository.py`, `token_repository.py`, `enrollment_repository.py` using copied SQL.
- 7.2 Write integration tests exercising CRUD against a temporary SQLite file. - 7.2 Write integration tests exercising CRUD against a temporary SQLite file.
- 7.3 Commit when repositories provide the required ports used by services. - 7.3 Commit when repositories provide the required ports used by services.

View File

@@ -44,3 +44,13 @@ Step6 introduces the first real Engine services:
- `Data/Engine/services/auth/token_service.py` issues refreshed access tokens while enforcing DPoP bindings and repository lookups. - `Data/Engine/services/auth/token_service.py` issues refreshed access tokens while enforcing DPoP bindings and repository lookups.
Interfaces will begin consuming these services once the repository adapters land in the next milestone. Interfaces will begin consuming these services once the repository adapters land in the next milestone.
## SQLite repositories
Step7 ports the first persistence adapters into the Engine:
- `Data/Engine/repositories/sqlite/device_repository.py` exposes `SQLiteDeviceRepository`, mirroring the legacy device lookups and automatic record recovery used during authentication.
- `Data/Engine/repositories/sqlite/token_repository.py` provides `SQLiteRefreshTokenRepository` for refresh-token validation, DPoP binding management, and usage timestamps.
- `Data/Engine/repositories/sqlite/enrollment_repository.py` surfaces enrollment install-code counters and device approval records so future services can operate without touching raw SQL.
Each repository accepts the shared `SQLiteConnectionFactory`, keeping all SQL execution confined to the Engine layer while services depend only on protocol interfaces.

View File

@@ -9,7 +9,10 @@ from .connection import (
connection_factory, connection_factory,
connection_scope, connection_scope,
) )
from .device_repository import SQLiteDeviceRepository
from .enrollment_repository import SQLiteEnrollmentRepository
from .migrations import apply_all from .migrations import apply_all
from .token_repository import SQLiteRefreshTokenRepository
__all__ = [ __all__ = [
"SQLiteConnectionFactory", "SQLiteConnectionFactory",
@@ -17,5 +20,8 @@ __all__ = [
"connect", "connect",
"connection_factory", "connection_factory",
"connection_scope", "connection_scope",
"SQLiteDeviceRepository",
"SQLiteRefreshTokenRepository",
"SQLiteEnrollmentRepository",
"apply_all", "apply_all",
] ]

View File

@@ -0,0 +1,183 @@
"""SQLite-backed device repository for the Engine authentication services."""
from __future__ import annotations
import logging
import sqlite3
import time
from contextlib import closing
from datetime import datetime, timezone
from typing import Optional
from Data.Engine.domain.device_auth import (
DeviceFingerprint,
DeviceGuid,
DeviceIdentity,
DeviceStatus,
)
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
from Data.Engine.services.auth.device_auth_service import DeviceRecord
__all__ = ["SQLiteDeviceRepository"]
class SQLiteDeviceRepository:
"""Persistence adapter that reads and recovers device rows."""
def __init__(
self,
connection_factory: SQLiteConnectionFactory,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._connections = connection_factory
self._log = logger or logging.getLogger("borealis.engine.repositories.devices")
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]:
"""Fetch a device row by GUID, normalizing legacy case variance."""
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE UPPER(guid) = ?
""",
(guid.value.upper(),),
)
rows = cur.fetchall()
if not rows:
return None
for row in rows:
record = self._row_to_record(row)
if record and record.identity.guid.value == guid.value:
return record
# Fall back to the first row if normalization failed to match exactly.
return self._row_to_record(rows[0])
def recover_missing(
self,
guid: DeviceGuid,
fingerprint: DeviceFingerprint,
token_version: int,
service_context: Optional[str],
) -> Optional[DeviceRecord]:
"""Attempt to recreate a missing device row for a valid token."""
now_ts = int(time.time())
now_iso = datetime.now(tz=timezone.utc).isoformat()
base_hostname = f"RECOVERED-{guid.value[:12]}"
with closing(self._connections()) as conn:
cur = conn.cursor()
for attempt in range(6):
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
try:
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
token_version,
status,
key_added_at
)
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
""",
(
guid.value,
hostname,
now_ts,
now_ts,
fingerprint.value,
max(token_version or 1, 1),
now_iso,
),
)
except sqlite3.IntegrityError as exc:
message = str(exc).lower()
if "hostname" in message and "unique" in message:
continue
self._log.warning(
"device auth failed to recover guid=%s (context=%s): %s",
guid.value,
service_context or "none",
exc,
)
conn.rollback()
return None
except Exception as exc: # pragma: no cover - defensive logging
self._log.exception(
"device auth unexpected error recovering guid=%s (context=%s)",
guid.value,
service_context or "none",
exc_info=exc,
)
conn.rollback()
return None
else:
conn.commit()
break
else:
self._log.warning(
"device auth could not recover guid=%s; hostname collisions persisted",
guid.value,
)
conn.rollback()
return None
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE guid = ?
""",
(guid.value,),
)
row = cur.fetchone()
if not row:
self._log.warning(
"device auth recovery committed but row missing for guid=%s",
guid.value,
)
return None
return self._row_to_record(row)
def _row_to_record(self, row: tuple) -> Optional[DeviceRecord]:
try:
guid = DeviceGuid(row[0])
fingerprint_value = (row[1] or "").strip()
if not fingerprint_value:
self._log.warning(
"device row %s missing TLS fingerprint; skipping",
row[0],
)
return None
fingerprint = DeviceFingerprint(fingerprint_value)
except Exception as exc:
self._log.warning("invalid device row for guid=%s: %s", row[0], exc)
return None
token_version_raw = row[2]
try:
token_version = int(token_version_raw or 0)
except Exception:
token_version = 0
status = DeviceStatus.from_string(row[3])
identity = DeviceIdentity(guid=guid, fingerprint=fingerprint)
return DeviceRecord(
identity=identity,
token_version=max(token_version, 1),
status=status,
)

View File

@@ -0,0 +1,286 @@
"""SQLite-backed enrollment repository for Engine services."""
from __future__ import annotations
import logging
from contextlib import closing
from datetime import datetime, timezone
from typing import Optional
from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid
from Data.Engine.domain.device_enrollment import (
EnrollmentApproval,
EnrollmentApprovalStatus,
EnrollmentCode,
)
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
__all__ = ["SQLiteEnrollmentRepository"]
class SQLiteEnrollmentRepository:
"""Persistence adapter that manages enrollment codes and approvals."""
def __init__(
self,
connection_factory: SQLiteConnectionFactory,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._connections = connection_factory
self._log = logger or logging.getLogger("borealis.engine.repositories.enrollment")
# ------------------------------------------------------------------
# Enrollment install codes
# ------------------------------------------------------------------
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]:
"""Load an enrollment install code by its public value."""
code_value = (code or "").strip()
if not code_value:
return None
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT id,
code,
expires_at,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
WHERE code = ?
""",
(code_value,),
)
row = cur.fetchone()
if not row:
return None
record = {
"id": row[0],
"code": row[1],
"expires_at": row[2],
"used_at": row[3],
"used_by_guid": row[4],
"max_uses": row[5],
"use_count": row[6],
"last_used_at": row[7],
}
try:
return EnrollmentCode.from_mapping(record)
except Exception as exc: # pragma: no cover - defensive logging
self._log.warning("invalid enrollment code record for code=%s: %s", code_value, exc)
return None
def update_install_code_usage(
self,
record_id: str,
*,
use_count_increment: int,
last_used_at: datetime,
used_by_guid: Optional[DeviceGuid] = None,
mark_first_use: bool = False,
) -> None:
"""Increment usage counters and usage metadata for an install code."""
if use_count_increment <= 0:
raise ValueError("use_count_increment must be positive")
last_used_iso = self._isoformat(last_used_at)
guid_value = used_by_guid.value if used_by_guid else ""
mark_flag = 1 if mark_first_use else 0
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
UPDATE enrollment_install_codes
SET use_count = use_count + ?,
last_used_at = ?,
used_by_guid = COALESCE(NULLIF(?, ''), used_by_guid),
used_at = CASE WHEN ? = 1 AND used_at IS NULL THEN ? ELSE used_at END
WHERE id = ?
""",
(
use_count_increment,
last_used_iso,
guid_value,
mark_flag,
last_used_iso,
record_id,
),
)
conn.commit()
# ------------------------------------------------------------------
# Device approvals
# ------------------------------------------------------------------
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]:
"""Load a device approval using its operator-visible reference."""
ref_value = (reference or "").strip()
if not ref_value:
return None
return self._fetch_device_approval("approval_reference = ?", (ref_value,))
def fetch_device_approval(self, record_id: str) -> Optional[EnrollmentApproval]:
record_value = (record_id or "").strip()
if not record_value:
return None
return self._fetch_device_approval("id = ?", (record_value,))
def create_device_approval(
self,
*,
record_id: str,
reference: str,
claimed_hostname: str,
claimed_fingerprint: DeviceFingerprint,
enrollment_code_id: Optional[str],
client_nonce: bytes,
server_nonce: bytes,
agent_pubkey_der: bytes,
created_at: datetime,
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
guid: Optional[DeviceGuid] = None,
) -> EnrollmentApproval:
created_iso = self._isoformat(created_at)
guid_value = guid.value if guid else None
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
INSERT INTO device_approvals (
id,
approval_reference,
guid,
hostname_claimed,
ssl_key_fingerprint_claimed,
enrollment_code_id,
status,
created_at,
updated_at,
client_nonce,
server_nonce,
agent_pubkey_der
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
record_id,
reference,
guid_value,
claimed_hostname,
claimed_fingerprint.value,
enrollment_code_id,
status.value,
created_iso,
created_iso,
client_nonce,
server_nonce,
agent_pubkey_der,
),
)
conn.commit()
approval = self.fetch_device_approval(record_id)
if approval is None:
raise RuntimeError("failed to load device approval after insert")
return approval
def update_device_approval_status(
self,
record_id: str,
*,
status: EnrollmentApprovalStatus,
updated_at: datetime,
approved_by: Optional[str] = None,
guid: Optional[DeviceGuid] = None,
) -> None:
"""Transition an approval to a new status."""
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
UPDATE device_approvals
SET status = ?,
updated_at = ?,
guid = COALESCE(?, guid),
approved_by_user_id = COALESCE(?, approved_by_user_id)
WHERE id = ?
""",
(
status.value,
self._isoformat(updated_at),
guid.value if guid else None,
approved_by,
record_id,
),
)
conn.commit()
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _fetch_device_approval(self, where: str, params: tuple) -> Optional[EnrollmentApproval]:
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
f"""
SELECT id,
approval_reference,
guid,
hostname_claimed,
ssl_key_fingerprint_claimed,
enrollment_code_id,
created_at,
updated_at,
status,
approved_by_user_id
FROM device_approvals
WHERE {where}
""",
params,
)
row = cur.fetchone()
if not row:
return None
record = {
"id": row[0],
"approval_reference": row[1],
"guid": row[2],
"hostname_claimed": row[3],
"ssl_key_fingerprint_claimed": row[4],
"enrollment_code_id": row[5],
"created_at": row[6],
"updated_at": row[7],
"status": row[8],
"approved_by_user_id": row[9],
}
try:
return EnrollmentApproval.from_mapping(record)
except Exception as exc: # pragma: no cover - defensive logging
self._log.warning(
"invalid device approval record id=%s reference=%s: %s",
row[0],
row[1],
exc,
)
return None
@staticmethod
def _isoformat(value: datetime) -> str:
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc).isoformat()

View File

@@ -0,0 +1,124 @@
"""SQLite-backed refresh token repository for the Engine."""
from __future__ import annotations
import logging
from contextlib import closing
from datetime import datetime, timezone
from typing import Optional
from Data.Engine.domain.device_auth import DeviceGuid
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
from Data.Engine.services.auth.token_service import RefreshTokenRecord
__all__ = ["SQLiteRefreshTokenRepository"]
class SQLiteRefreshTokenRepository:
"""Persistence adapter for refresh token records."""
def __init__(
self,
connection_factory: SQLiteConnectionFactory,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._connections = connection_factory
self._log = logger or logging.getLogger("borealis.engine.repositories.tokens")
def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]:
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at
FROM refresh_tokens
WHERE guid = ?
AND token_hash = ?
""",
(guid.value, token_hash),
)
row = cur.fetchone()
if not row:
return None
return self._row_to_record(row)
def clear_dpop_binding(self, record_id: str) -> None:
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?",
(record_id,),
)
conn.commit()
def touch(
self,
record_id: str,
*,
last_used_at: datetime,
dpop_jkt: Optional[str],
) -> None:
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
UPDATE refresh_tokens
SET last_used_at = ?,
dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt)
WHERE id = ?
""",
(
self._isoformat(last_used_at),
(dpop_jkt or "").strip(),
record_id,
),
)
conn.commit()
def _row_to_record(self, row: tuple) -> Optional[RefreshTokenRecord]:
try:
guid = DeviceGuid(row[1])
except Exception as exc:
self._log.warning("invalid refresh token row guid=%s: %s", row[1], exc)
return None
created_at = self._parse_iso(row[4])
expires_at = self._parse_iso(row[5])
revoked_at = self._parse_iso(row[6])
if created_at is None:
created_at = datetime.now(tz=timezone.utc)
return RefreshTokenRecord.from_row(
record_id=str(row[0]),
guid=guid,
token_hash=str(row[2]),
dpop_jkt=str(row[3]) if row[3] is not None else None,
created_at=created_at,
expires_at=expires_at,
revoked_at=revoked_at,
)
@staticmethod
def _parse_iso(value: Optional[str]) -> Optional[datetime]:
if not value:
return None
raw = str(value).strip()
if not raw:
return None
try:
parsed = datetime.fromisoformat(raw)
except Exception:
return None
if parsed.tzinfo is None:
return parsed.replace(tzinfo=timezone.utc)
return parsed
@staticmethod
def _isoformat(value: datetime) -> str:
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc).isoformat()

View File

@@ -49,7 +49,7 @@ class TokenRefreshError(Exception):
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
class RefreshTokenRecord: class RefreshTokenRecord:
record_id: int record_id: str
guid: DeviceGuid guid: DeviceGuid
token_hash: str token_hash: str
dpop_jkt: Optional[str] dpop_jkt: Optional[str]
@@ -61,7 +61,7 @@ class RefreshTokenRecord:
def from_row( def from_row(
cls, cls,
*, *,
record_id: int, record_id: str,
guid: DeviceGuid, guid: DeviceGuid,
token_hash: str, token_hash: str,
dpop_jkt: Optional[str], dpop_jkt: Optional[str],
@@ -84,10 +84,10 @@ class RefreshTokenRepository(Protocol):
def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]: # pragma: no cover - protocol def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]: # pragma: no cover - protocol
... ...
def clear_dpop_binding(self, record_id: int) -> None: # pragma: no cover - protocol def clear_dpop_binding(self, record_id: str) -> None: # pragma: no cover - protocol
... ...
def touch(self, record_id: int, *, last_used_at: datetime, dpop_jkt: Optional[str]) -> None: # pragma: no cover - protocol def touch(self, record_id: str, *, last_used_at: datetime, dpop_jkt: Optional[str]) -> None: # pragma: no cover - protocol
... ...