Files
Borealis-Github-Replica/Data/Engine/repositories/sqlite/device_repository.py

411 lines
13 KiB
Python

"""SQLite-backed device repository for the Engine authentication services."""
from __future__ import annotations
import logging
import sqlite3
import time
import uuid
from contextlib import closing
from datetime import datetime, timezone
from typing import Optional
from Data.Engine.domain.device_auth import (
DeviceFingerprint,
DeviceGuid,
DeviceIdentity,
DeviceStatus,
)
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
from Data.Engine.services.auth.device_auth_service import DeviceRecord
__all__ = ["SQLiteDeviceRepository"]
class SQLiteDeviceRepository:
"""Persistence adapter that reads and recovers device rows."""
def __init__(
self,
connection_factory: SQLiteConnectionFactory,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._connections = connection_factory
self._log = logger or logging.getLogger("borealis.engine.repositories.devices")
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]:
"""Fetch a device row by GUID, normalizing legacy case variance."""
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE UPPER(guid) = ?
""",
(guid.value.upper(),),
)
rows = cur.fetchall()
if not rows:
return None
for row in rows:
record = self._row_to_record(row)
if record and record.identity.guid.value == guid.value:
return record
# Fall back to the first row if normalization failed to match exactly.
return self._row_to_record(rows[0])
def recover_missing(
self,
guid: DeviceGuid,
fingerprint: DeviceFingerprint,
token_version: int,
service_context: Optional[str],
) -> Optional[DeviceRecord]:
"""Attempt to recreate a missing device row for a valid token."""
now_ts = int(time.time())
now_iso = datetime.now(tz=timezone.utc).isoformat()
base_hostname = f"RECOVERED-{guid.value[:12]}"
with closing(self._connections()) as conn:
cur = conn.cursor()
for attempt in range(6):
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
try:
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
token_version,
status,
key_added_at
)
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
""",
(
guid.value,
hostname,
now_ts,
now_ts,
fingerprint.value,
max(token_version or 1, 1),
now_iso,
),
)
except sqlite3.IntegrityError as exc:
message = str(exc).lower()
if "hostname" in message and "unique" in message:
continue
self._log.warning(
"device auth failed to recover guid=%s (context=%s): %s",
guid.value,
service_context or "none",
exc,
)
conn.rollback()
return None
except Exception as exc: # pragma: no cover - defensive logging
self._log.exception(
"device auth unexpected error recovering guid=%s (context=%s)",
guid.value,
service_context or "none",
exc_info=exc,
)
conn.rollback()
return None
else:
conn.commit()
break
else:
self._log.warning(
"device auth could not recover guid=%s; hostname collisions persisted",
guid.value,
)
conn.rollback()
return None
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE guid = ?
""",
(guid.value,),
)
row = cur.fetchone()
if not row:
self._log.warning(
"device auth recovery committed but row missing for guid=%s",
guid.value,
)
return None
return self._row_to_record(row)
def ensure_device_record(
self,
*,
guid: DeviceGuid,
hostname: str,
fingerprint: DeviceFingerprint,
) -> DeviceRecord:
now_iso = datetime.now(tz=timezone.utc).isoformat()
now_ts = int(time.time())
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
FROM devices
WHERE UPPER(guid) = ?
""",
(guid.value.upper(),),
)
row = cur.fetchone()
if row:
stored_fp = (row[4] or "").strip().lower()
new_fp = fingerprint.value
if not stored_fp:
cur.execute(
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
(new_fp, now_iso, row[0]),
)
elif stored_fp != new_fp:
token_version = self._coerce_int(row[2], default=1) + 1
cur.execute(
"""
UPDATE devices
SET ssl_key_fingerprint = ?,
key_added_at = ?,
token_version = ?,
status = 'active'
WHERE guid = ?
""",
(new_fp, now_iso, token_version, row[0]),
)
cur.execute(
"""
UPDATE refresh_tokens
SET revoked_at = ?
WHERE guid = ?
AND revoked_at IS NULL
""",
(now_iso, row[0]),
)
conn.commit()
else:
resolved_hostname = self._resolve_hostname(cur, hostname, guid)
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
token_version,
status,
key_added_at
)
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
""",
(
guid.value,
resolved_hostname,
now_ts,
now_ts,
fingerprint.value,
now_iso,
),
)
conn.commit()
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE UPPER(guid) = ?
""",
(guid.value.upper(),),
)
latest = cur.fetchone()
if not latest:
raise RuntimeError("device record could not be ensured")
record = self._row_to_record(latest)
if record is None:
raise RuntimeError("device record invalid after ensure")
return record
def record_device_key(
self,
*,
guid: DeviceGuid,
fingerprint: DeviceFingerprint,
added_at: datetime,
) -> None:
added_iso = added_at.astimezone(timezone.utc).isoformat()
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
VALUES (?, ?, ?, ?)
""",
(str(uuid.uuid4()), guid.value, fingerprint.value, added_iso),
)
cur.execute(
"""
UPDATE device_keys
SET retired_at = ?
WHERE guid = ?
AND ssl_key_fingerprint != ?
AND retired_at IS NULL
""",
(added_iso, guid.value, fingerprint.value),
)
conn.commit()
def update_device_summary(
self,
*,
hostname: Optional[str],
last_seen: Optional[int] = None,
agent_id: Optional[str] = None,
operating_system: Optional[str] = None,
last_user: Optional[str] = None,
) -> None:
if not hostname:
return
normalized_hostname = (hostname or "").strip()
if not normalized_hostname:
return
fields = []
params = []
if last_seen is not None:
try:
fields.append("last_seen = ?")
params.append(int(last_seen))
except Exception:
pass
if agent_id:
try:
candidate = agent_id.strip()
except Exception:
candidate = agent_id
if candidate:
fields.append("agent_id = ?")
params.append(candidate)
if operating_system:
try:
os_value = operating_system.strip()
except Exception:
os_value = operating_system
if os_value:
fields.append("operating_system = ?")
params.append(os_value)
if last_user:
try:
user_value = last_user.strip()
except Exception:
user_value = last_user
if user_value:
fields.append("last_user = ?")
params.append(user_value)
if not fields:
return
params.append(normalized_hostname)
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
f"UPDATE devices SET {', '.join(fields)} WHERE LOWER(hostname) = LOWER(?)",
params,
)
if cur.rowcount == 0 and agent_id:
cur.execute(
f"UPDATE devices SET {', '.join(fields)} WHERE agent_id = ?",
params[:-1] + [agent_id],
)
conn.commit()
def _row_to_record(self, row: tuple) -> Optional[DeviceRecord]:
try:
guid = DeviceGuid(row[0])
fingerprint_value = (row[1] or "").strip()
if not fingerprint_value:
self._log.warning(
"device row %s missing TLS fingerprint; skipping",
row[0],
)
return None
fingerprint = DeviceFingerprint(fingerprint_value)
except Exception as exc:
self._log.warning("invalid device row for guid=%s: %s", row[0], exc)
return None
token_version_raw = row[2]
try:
token_version = int(token_version_raw or 0)
except Exception:
token_version = 0
status = DeviceStatus.from_string(row[3])
identity = DeviceIdentity(guid=guid, fingerprint=fingerprint)
return DeviceRecord(
identity=identity,
token_version=max(token_version, 1),
status=status,
)
@staticmethod
def _coerce_int(value: object, *, default: int = 0) -> int:
try:
return int(value)
except Exception:
return default
def _resolve_hostname(self, cur: sqlite3.Cursor, hostname: str, guid: DeviceGuid) -> str:
base = (hostname or "").strip() or guid.value
base = base[:253]
candidate = base
suffix = 1
while True:
cur.execute(
"SELECT guid FROM devices WHERE hostname = ?",
(candidate,),
)
row = cur.fetchone()
if not row:
return candidate
existing = (row[0] or "").strip().upper()
if existing == guid.value:
return candidate
candidate = f"{base}-{suffix}"
suffix += 1
if suffix > 50:
return guid.value