mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 20:01:57 -06:00
Add SQLite repositories for Engine services
This commit is contained in:
@@ -32,7 +32,7 @@
|
|||||||
- 6.3 Mirror refresh token issuance into `services/auth/token_service.py`; use `builders/device_enrollment.py` for payload assembly.
|
- 6.3 Mirror refresh token issuance into `services/auth/token_service.py`; use `builders/device_enrollment.py` for payload assembly.
|
||||||
- 6.4 Commit once services pass targeted unit tests and integrate with placeholder repositories.
|
- 6.4 Commit once services pass targeted unit tests and integrate with placeholder repositories.
|
||||||
|
|
||||||
7. Implement SQLite repositories
|
[COMPLETED] 7. Implement SQLite repositories
|
||||||
- 7.1 Introduce `repositories/sqlite/device_repository.py`, `token_repository.py`, `enrollment_repository.py` using copied SQL.
|
- 7.1 Introduce `repositories/sqlite/device_repository.py`, `token_repository.py`, `enrollment_repository.py` using copied SQL.
|
||||||
- 7.2 Write integration tests exercising CRUD against a temporary SQLite file.
|
- 7.2 Write integration tests exercising CRUD against a temporary SQLite file.
|
||||||
- 7.3 Commit when repositories provide the required ports used by services.
|
- 7.3 Commit when repositories provide the required ports used by services.
|
||||||
|
|||||||
@@ -44,3 +44,13 @@ Step 6 introduces the first real Engine services:
|
|||||||
- `Data/Engine/services/auth/token_service.py` issues refreshed access tokens while enforcing DPoP bindings and repository lookups.
|
- `Data/Engine/services/auth/token_service.py` issues refreshed access tokens while enforcing DPoP bindings and repository lookups.
|
||||||
|
|
||||||
Interfaces will begin consuming these services once the repository adapters land in the next milestone.
|
Interfaces will begin consuming these services once the repository adapters land in the next milestone.
|
||||||
|
|
||||||
|
## SQLite repositories
|
||||||
|
|
||||||
|
Step 7 ports the first persistence adapters into the Engine:
|
||||||
|
|
||||||
|
- `Data/Engine/repositories/sqlite/device_repository.py` exposes `SQLiteDeviceRepository`, mirroring the legacy device lookups and automatic record recovery used during authentication.
|
||||||
|
- `Data/Engine/repositories/sqlite/token_repository.py` provides `SQLiteRefreshTokenRepository` for refresh-token validation, DPoP binding management, and usage timestamps.
|
||||||
|
- `Data/Engine/repositories/sqlite/enrollment_repository.py` surfaces enrollment install-code counters and device approval records so future services can operate without touching raw SQL.
|
||||||
|
|
||||||
|
Each repository accepts the shared `SQLiteConnectionFactory`, keeping all SQL execution confined to the Engine layer while services depend only on protocol interfaces.
|
||||||
|
|||||||
@@ -9,7 +9,10 @@ from .connection import (
|
|||||||
connection_factory,
|
connection_factory,
|
||||||
connection_scope,
|
connection_scope,
|
||||||
)
|
)
|
||||||
|
from .device_repository import SQLiteDeviceRepository
|
||||||
|
from .enrollment_repository import SQLiteEnrollmentRepository
|
||||||
from .migrations import apply_all
|
from .migrations import apply_all
|
||||||
|
from .token_repository import SQLiteRefreshTokenRepository
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SQLiteConnectionFactory",
|
"SQLiteConnectionFactory",
|
||||||
@@ -17,5 +20,8 @@ __all__ = [
|
|||||||
"connect",
|
"connect",
|
||||||
"connection_factory",
|
"connection_factory",
|
||||||
"connection_scope",
|
"connection_scope",
|
||||||
|
"SQLiteDeviceRepository",
|
||||||
|
"SQLiteRefreshTokenRepository",
|
||||||
|
"SQLiteEnrollmentRepository",
|
||||||
"apply_all",
|
"apply_all",
|
||||||
]
|
]
|
||||||
|
|||||||
183
Data/Engine/repositories/sqlite/device_repository.py
Normal file
183
Data/Engine/repositories/sqlite/device_repository.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""SQLite-backed device repository for the Engine authentication services."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
from contextlib import closing
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from Data.Engine.domain.device_auth import (
|
||||||
|
DeviceFingerprint,
|
||||||
|
DeviceGuid,
|
||||||
|
DeviceIdentity,
|
||||||
|
DeviceStatus,
|
||||||
|
)
|
||||||
|
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||||
|
from Data.Engine.services.auth.device_auth_service import DeviceRecord
|
||||||
|
|
||||||
|
__all__ = ["SQLiteDeviceRepository"]
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteDeviceRepository:
|
||||||
|
"""Persistence adapter that reads and recovers device rows."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection_factory: SQLiteConnectionFactory,
|
||||||
|
*,
|
||||||
|
logger: Optional[logging.Logger] = None,
|
||||||
|
) -> None:
|
||||||
|
self._connections = connection_factory
|
||||||
|
self._log = logger or logging.getLogger("borealis.engine.repositories.devices")
|
||||||
|
|
||||||
|
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]:
|
||||||
|
"""Fetch a device row by GUID, normalizing legacy case variance."""
|
||||||
|
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||||
|
FROM devices
|
||||||
|
WHERE UPPER(guid) = ?
|
||||||
|
""",
|
||||||
|
(guid.value.upper(),),
|
||||||
|
)
|
||||||
|
rows = cur.fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
record = self._row_to_record(row)
|
||||||
|
if record and record.identity.guid.value == guid.value:
|
||||||
|
return record
|
||||||
|
|
||||||
|
# Fall back to the first row if normalization failed to match exactly.
|
||||||
|
return self._row_to_record(rows[0])
|
||||||
|
|
||||||
|
def recover_missing(
|
||||||
|
self,
|
||||||
|
guid: DeviceGuid,
|
||||||
|
fingerprint: DeviceFingerprint,
|
||||||
|
token_version: int,
|
||||||
|
service_context: Optional[str],
|
||||||
|
) -> Optional[DeviceRecord]:
|
||||||
|
"""Attempt to recreate a missing device row for a valid token."""
|
||||||
|
|
||||||
|
now_ts = int(time.time())
|
||||||
|
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||||
|
base_hostname = f"RECOVERED-{guid.value[:12]}"
|
||||||
|
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
for attempt in range(6):
|
||||||
|
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO devices (
|
||||||
|
guid,
|
||||||
|
hostname,
|
||||||
|
created_at,
|
||||||
|
last_seen,
|
||||||
|
ssl_key_fingerprint,
|
||||||
|
token_version,
|
||||||
|
status,
|
||||||
|
key_added_at
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
guid.value,
|
||||||
|
hostname,
|
||||||
|
now_ts,
|
||||||
|
now_ts,
|
||||||
|
fingerprint.value,
|
||||||
|
max(token_version or 1, 1),
|
||||||
|
now_iso,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except sqlite3.IntegrityError as exc:
|
||||||
|
message = str(exc).lower()
|
||||||
|
if "hostname" in message and "unique" in message:
|
||||||
|
continue
|
||||||
|
self._log.warning(
|
||||||
|
"device auth failed to recover guid=%s (context=%s): %s",
|
||||||
|
guid.value,
|
||||||
|
service_context or "none",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
conn.rollback()
|
||||||
|
return None
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
self._log.exception(
|
||||||
|
"device auth unexpected error recovering guid=%s (context=%s)",
|
||||||
|
guid.value,
|
||||||
|
service_context or "none",
|
||||||
|
exc_info=exc,
|
||||||
|
)
|
||||||
|
conn.rollback()
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
conn.commit()
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self._log.warning(
|
||||||
|
"device auth could not recover guid=%s; hostname collisions persisted",
|
||||||
|
guid.value,
|
||||||
|
)
|
||||||
|
conn.rollback()
|
||||||
|
return None
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||||
|
FROM devices
|
||||||
|
WHERE guid = ?
|
||||||
|
""",
|
||||||
|
(guid.value,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
self._log.warning(
|
||||||
|
"device auth recovery committed but row missing for guid=%s",
|
||||||
|
guid.value,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._row_to_record(row)
|
||||||
|
|
||||||
|
def _row_to_record(self, row: tuple) -> Optional[DeviceRecord]:
|
||||||
|
try:
|
||||||
|
guid = DeviceGuid(row[0])
|
||||||
|
fingerprint_value = (row[1] or "").strip()
|
||||||
|
if not fingerprint_value:
|
||||||
|
self._log.warning(
|
||||||
|
"device row %s missing TLS fingerprint; skipping",
|
||||||
|
row[0],
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
fingerprint = DeviceFingerprint(fingerprint_value)
|
||||||
|
except Exception as exc:
|
||||||
|
self._log.warning("invalid device row for guid=%s: %s", row[0], exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
token_version_raw = row[2]
|
||||||
|
try:
|
||||||
|
token_version = int(token_version_raw or 0)
|
||||||
|
except Exception:
|
||||||
|
token_version = 0
|
||||||
|
|
||||||
|
status = DeviceStatus.from_string(row[3])
|
||||||
|
identity = DeviceIdentity(guid=guid, fingerprint=fingerprint)
|
||||||
|
|
||||||
|
return DeviceRecord(
|
||||||
|
identity=identity,
|
||||||
|
token_version=max(token_version, 1),
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
286
Data/Engine/repositories/sqlite/enrollment_repository.py
Normal file
286
Data/Engine/repositories/sqlite/enrollment_repository.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
"""SQLite-backed enrollment repository for Engine services."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import closing
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid
|
||||||
|
from Data.Engine.domain.device_enrollment import (
|
||||||
|
EnrollmentApproval,
|
||||||
|
EnrollmentApprovalStatus,
|
||||||
|
EnrollmentCode,
|
||||||
|
)
|
||||||
|
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||||
|
|
||||||
|
__all__ = ["SQLiteEnrollmentRepository"]
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteEnrollmentRepository:
|
||||||
|
"""Persistence adapter that manages enrollment codes and approvals."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection_factory: SQLiteConnectionFactory,
|
||||||
|
*,
|
||||||
|
logger: Optional[logging.Logger] = None,
|
||||||
|
) -> None:
|
||||||
|
self._connections = connection_factory
|
||||||
|
self._log = logger or logging.getLogger("borealis.engine.repositories.enrollment")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Enrollment install codes
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]:
|
||||||
|
"""Load an enrollment install code by its public value."""
|
||||||
|
|
||||||
|
code_value = (code or "").strip()
|
||||||
|
if not code_value:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT id,
|
||||||
|
code,
|
||||||
|
expires_at,
|
||||||
|
used_at,
|
||||||
|
used_by_guid,
|
||||||
|
max_uses,
|
||||||
|
use_count,
|
||||||
|
last_used_at
|
||||||
|
FROM enrollment_install_codes
|
||||||
|
WHERE code = ?
|
||||||
|
""",
|
||||||
|
(code_value,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"id": row[0],
|
||||||
|
"code": row[1],
|
||||||
|
"expires_at": row[2],
|
||||||
|
"used_at": row[3],
|
||||||
|
"used_by_guid": row[4],
|
||||||
|
"max_uses": row[5],
|
||||||
|
"use_count": row[6],
|
||||||
|
"last_used_at": row[7],
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
return EnrollmentCode.from_mapping(record)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
self._log.warning("invalid enrollment code record for code=%s: %s", code_value, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update_install_code_usage(
|
||||||
|
self,
|
||||||
|
record_id: str,
|
||||||
|
*,
|
||||||
|
use_count_increment: int,
|
||||||
|
last_used_at: datetime,
|
||||||
|
used_by_guid: Optional[DeviceGuid] = None,
|
||||||
|
mark_first_use: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Increment usage counters and usage metadata for an install code."""
|
||||||
|
|
||||||
|
if use_count_increment <= 0:
|
||||||
|
raise ValueError("use_count_increment must be positive")
|
||||||
|
|
||||||
|
last_used_iso = self._isoformat(last_used_at)
|
||||||
|
guid_value = used_by_guid.value if used_by_guid else ""
|
||||||
|
mark_flag = 1 if mark_first_use else 0
|
||||||
|
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
UPDATE enrollment_install_codes
|
||||||
|
SET use_count = use_count + ?,
|
||||||
|
last_used_at = ?,
|
||||||
|
used_by_guid = COALESCE(NULLIF(?, ''), used_by_guid),
|
||||||
|
used_at = CASE WHEN ? = 1 AND used_at IS NULL THEN ? ELSE used_at END
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
use_count_increment,
|
||||||
|
last_used_iso,
|
||||||
|
guid_value,
|
||||||
|
mark_flag,
|
||||||
|
last_used_iso,
|
||||||
|
record_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Device approvals
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]:
|
||||||
|
"""Load a device approval using its operator-visible reference."""
|
||||||
|
|
||||||
|
ref_value = (reference or "").strip()
|
||||||
|
if not ref_value:
|
||||||
|
return None
|
||||||
|
return self._fetch_device_approval("approval_reference = ?", (ref_value,))
|
||||||
|
|
||||||
|
def fetch_device_approval(self, record_id: str) -> Optional[EnrollmentApproval]:
|
||||||
|
record_value = (record_id or "").strip()
|
||||||
|
if not record_value:
|
||||||
|
return None
|
||||||
|
return self._fetch_device_approval("id = ?", (record_value,))
|
||||||
|
|
||||||
|
def create_device_approval(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
record_id: str,
|
||||||
|
reference: str,
|
||||||
|
claimed_hostname: str,
|
||||||
|
claimed_fingerprint: DeviceFingerprint,
|
||||||
|
enrollment_code_id: Optional[str],
|
||||||
|
client_nonce: bytes,
|
||||||
|
server_nonce: bytes,
|
||||||
|
agent_pubkey_der: bytes,
|
||||||
|
created_at: datetime,
|
||||||
|
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
|
||||||
|
guid: Optional[DeviceGuid] = None,
|
||||||
|
) -> EnrollmentApproval:
|
||||||
|
created_iso = self._isoformat(created_at)
|
||||||
|
guid_value = guid.value if guid else None
|
||||||
|
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO device_approvals (
|
||||||
|
id,
|
||||||
|
approval_reference,
|
||||||
|
guid,
|
||||||
|
hostname_claimed,
|
||||||
|
ssl_key_fingerprint_claimed,
|
||||||
|
enrollment_code_id,
|
||||||
|
status,
|
||||||
|
created_at,
|
||||||
|
updated_at,
|
||||||
|
client_nonce,
|
||||||
|
server_nonce,
|
||||||
|
agent_pubkey_der
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
record_id,
|
||||||
|
reference,
|
||||||
|
guid_value,
|
||||||
|
claimed_hostname,
|
||||||
|
claimed_fingerprint.value,
|
||||||
|
enrollment_code_id,
|
||||||
|
status.value,
|
||||||
|
created_iso,
|
||||||
|
created_iso,
|
||||||
|
client_nonce,
|
||||||
|
server_nonce,
|
||||||
|
agent_pubkey_der,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
approval = self.fetch_device_approval(record_id)
|
||||||
|
if approval is None:
|
||||||
|
raise RuntimeError("failed to load device approval after insert")
|
||||||
|
return approval
|
||||||
|
|
||||||
|
def update_device_approval_status(
|
||||||
|
self,
|
||||||
|
record_id: str,
|
||||||
|
*,
|
||||||
|
status: EnrollmentApprovalStatus,
|
||||||
|
updated_at: datetime,
|
||||||
|
approved_by: Optional[str] = None,
|
||||||
|
guid: Optional[DeviceGuid] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Transition an approval to a new status."""
|
||||||
|
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
UPDATE device_approvals
|
||||||
|
SET status = ?,
|
||||||
|
updated_at = ?,
|
||||||
|
guid = COALESCE(?, guid),
|
||||||
|
approved_by_user_id = COALESCE(?, approved_by_user_id)
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
status.value,
|
||||||
|
self._isoformat(updated_at),
|
||||||
|
guid.value if guid else None,
|
||||||
|
approved_by,
|
||||||
|
record_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _fetch_device_approval(self, where: str, params: tuple) -> Optional[EnrollmentApproval]:
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
f"""
|
||||||
|
SELECT id,
|
||||||
|
approval_reference,
|
||||||
|
guid,
|
||||||
|
hostname_claimed,
|
||||||
|
ssl_key_fingerprint_claimed,
|
||||||
|
enrollment_code_id,
|
||||||
|
created_at,
|
||||||
|
updated_at,
|
||||||
|
status,
|
||||||
|
approved_by_user_id
|
||||||
|
FROM device_approvals
|
||||||
|
WHERE {where}
|
||||||
|
""",
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"id": row[0],
|
||||||
|
"approval_reference": row[1],
|
||||||
|
"guid": row[2],
|
||||||
|
"hostname_claimed": row[3],
|
||||||
|
"ssl_key_fingerprint_claimed": row[4],
|
||||||
|
"enrollment_code_id": row[5],
|
||||||
|
"created_at": row[6],
|
||||||
|
"updated_at": row[7],
|
||||||
|
"status": row[8],
|
||||||
|
"approved_by_user_id": row[9],
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
return EnrollmentApproval.from_mapping(record)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
self._log.warning(
|
||||||
|
"invalid device approval record id=%s reference=%s: %s",
|
||||||
|
row[0],
|
||||||
|
row[1],
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _isoformat(value: datetime) -> str:
|
||||||
|
if value.tzinfo is None:
|
||||||
|
value = value.replace(tzinfo=timezone.utc)
|
||||||
|
return value.astimezone(timezone.utc).isoformat()
|
||||||
124
Data/Engine/repositories/sqlite/token_repository.py
Normal file
124
Data/Engine/repositories/sqlite/token_repository.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""SQLite-backed refresh token repository for the Engine."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import closing
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from Data.Engine.domain.device_auth import DeviceGuid
|
||||||
|
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||||
|
from Data.Engine.services.auth.token_service import RefreshTokenRecord
|
||||||
|
|
||||||
|
__all__ = ["SQLiteRefreshTokenRepository"]
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteRefreshTokenRepository:
|
||||||
|
"""Persistence adapter for refresh token records."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection_factory: SQLiteConnectionFactory,
|
||||||
|
*,
|
||||||
|
logger: Optional[logging.Logger] = None,
|
||||||
|
) -> None:
|
||||||
|
self._connections = connection_factory
|
||||||
|
self._log = logger or logging.getLogger("borealis.engine.repositories.tokens")
|
||||||
|
|
||||||
|
def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]:
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at
|
||||||
|
FROM refresh_tokens
|
||||||
|
WHERE guid = ?
|
||||||
|
AND token_hash = ?
|
||||||
|
""",
|
||||||
|
(guid.value, token_hash),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._row_to_record(row)
|
||||||
|
|
||||||
|
def clear_dpop_binding(self, record_id: str) -> None:
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?",
|
||||||
|
(record_id,),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def touch(
|
||||||
|
self,
|
||||||
|
record_id: str,
|
||||||
|
*,
|
||||||
|
last_used_at: datetime,
|
||||||
|
dpop_jkt: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
UPDATE refresh_tokens
|
||||||
|
SET last_used_at = ?,
|
||||||
|
dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt)
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
self._isoformat(last_used_at),
|
||||||
|
(dpop_jkt or "").strip(),
|
||||||
|
record_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def _row_to_record(self, row: tuple) -> Optional[RefreshTokenRecord]:
|
||||||
|
try:
|
||||||
|
guid = DeviceGuid(row[1])
|
||||||
|
except Exception as exc:
|
||||||
|
self._log.warning("invalid refresh token row guid=%s: %s", row[1], exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
created_at = self._parse_iso(row[4])
|
||||||
|
expires_at = self._parse_iso(row[5])
|
||||||
|
revoked_at = self._parse_iso(row[6])
|
||||||
|
|
||||||
|
if created_at is None:
|
||||||
|
created_at = datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
return RefreshTokenRecord.from_row(
|
||||||
|
record_id=str(row[0]),
|
||||||
|
guid=guid,
|
||||||
|
token_hash=str(row[2]),
|
||||||
|
dpop_jkt=str(row[3]) if row[3] is not None else None,
|
||||||
|
created_at=created_at,
|
||||||
|
expires_at=expires_at,
|
||||||
|
revoked_at=revoked_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_iso(value: Optional[str]) -> Optional[datetime]:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
raw = str(value).strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = datetime.fromisoformat(raw)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if parsed.tzinfo is None:
|
||||||
|
return parsed.replace(tzinfo=timezone.utc)
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _isoformat(value: datetime) -> str:
|
||||||
|
if value.tzinfo is None:
|
||||||
|
value = value.replace(tzinfo=timezone.utc)
|
||||||
|
return value.astimezone(timezone.utc).isoformat()
|
||||||
@@ -49,7 +49,7 @@ class TokenRefreshError(Exception):
|
|||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class RefreshTokenRecord:
|
class RefreshTokenRecord:
|
||||||
record_id: int
|
record_id: str
|
||||||
guid: DeviceGuid
|
guid: DeviceGuid
|
||||||
token_hash: str
|
token_hash: str
|
||||||
dpop_jkt: Optional[str]
|
dpop_jkt: Optional[str]
|
||||||
@@ -61,7 +61,7 @@ class RefreshTokenRecord:
|
|||||||
def from_row(
|
def from_row(
|
||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
record_id: int,
|
record_id: str,
|
||||||
guid: DeviceGuid,
|
guid: DeviceGuid,
|
||||||
token_hash: str,
|
token_hash: str,
|
||||||
dpop_jkt: Optional[str],
|
dpop_jkt: Optional[str],
|
||||||
@@ -84,10 +84,10 @@ class RefreshTokenRepository(Protocol):
|
|||||||
def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]: # pragma: no cover - protocol
|
def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]: # pragma: no cover - protocol
|
||||||
...
|
...
|
||||||
|
|
||||||
def clear_dpop_binding(self, record_id: int) -> None: # pragma: no cover - protocol
|
def clear_dpop_binding(self, record_id: str) -> None: # pragma: no cover - protocol
|
||||||
...
|
...
|
||||||
|
|
||||||
def touch(self, record_id: int, *, last_used_at: datetime, dpop_jkt: Optional[str]) -> None: # pragma: no cover - protocol
|
def touch(self, record_id: str, *, last_used_at: datetime, dpop_jkt: Optional[str]) -> None: # pragma: no cover - protocol
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user