mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 15:41:58 -06:00
411 lines
13 KiB
Python
411 lines
13 KiB
Python
"""SQLite-backed device repository for the Engine authentication services."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import sqlite3
|
|
import time
|
|
import uuid
|
|
from contextlib import closing
|
|
from datetime import datetime, timezone
|
|
from typing import Optional
|
|
|
|
from Data.Engine.domain.device_auth import (
|
|
DeviceFingerprint,
|
|
DeviceGuid,
|
|
DeviceIdentity,
|
|
DeviceStatus,
|
|
)
|
|
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
|
from Data.Engine.services.auth.device_auth_service import DeviceRecord
|
|
|
|
__all__ = ["SQLiteDeviceRepository"]
|
|
|
|
|
|
class SQLiteDeviceRepository:
|
|
"""Persistence adapter that reads and recovers device rows."""
|
|
|
|
def __init__(
|
|
self,
|
|
connection_factory: SQLiteConnectionFactory,
|
|
*,
|
|
logger: Optional[logging.Logger] = None,
|
|
) -> None:
|
|
self._connections = connection_factory
|
|
self._log = logger or logging.getLogger("borealis.engine.repositories.devices")
|
|
|
|
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]:
|
|
"""Fetch a device row by GUID, normalizing legacy case variance."""
|
|
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
SELECT guid, ssl_key_fingerprint, token_version, status
|
|
FROM devices
|
|
WHERE UPPER(guid) = ?
|
|
""",
|
|
(guid.value.upper(),),
|
|
)
|
|
rows = cur.fetchall()
|
|
|
|
if not rows:
|
|
return None
|
|
|
|
for row in rows:
|
|
record = self._row_to_record(row)
|
|
if record and record.identity.guid.value == guid.value:
|
|
return record
|
|
|
|
# Fall back to the first row if normalization failed to match exactly.
|
|
return self._row_to_record(rows[0])
|
|
|
|
def recover_missing(
|
|
self,
|
|
guid: DeviceGuid,
|
|
fingerprint: DeviceFingerprint,
|
|
token_version: int,
|
|
service_context: Optional[str],
|
|
) -> Optional[DeviceRecord]:
|
|
"""Attempt to recreate a missing device row for a valid token."""
|
|
|
|
now_ts = int(time.time())
|
|
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
|
base_hostname = f"RECOVERED-{guid.value[:12]}"
|
|
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
for attempt in range(6):
|
|
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
|
|
try:
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO devices (
|
|
guid,
|
|
hostname,
|
|
created_at,
|
|
last_seen,
|
|
ssl_key_fingerprint,
|
|
token_version,
|
|
status,
|
|
key_added_at
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
|
|
""",
|
|
(
|
|
guid.value,
|
|
hostname,
|
|
now_ts,
|
|
now_ts,
|
|
fingerprint.value,
|
|
max(token_version or 1, 1),
|
|
now_iso,
|
|
),
|
|
)
|
|
except sqlite3.IntegrityError as exc:
|
|
message = str(exc).lower()
|
|
if "hostname" in message and "unique" in message:
|
|
continue
|
|
self._log.warning(
|
|
"device auth failed to recover guid=%s (context=%s): %s",
|
|
guid.value,
|
|
service_context or "none",
|
|
exc,
|
|
)
|
|
conn.rollback()
|
|
return None
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
self._log.exception(
|
|
"device auth unexpected error recovering guid=%s (context=%s)",
|
|
guid.value,
|
|
service_context or "none",
|
|
exc_info=exc,
|
|
)
|
|
conn.rollback()
|
|
return None
|
|
else:
|
|
conn.commit()
|
|
break
|
|
else:
|
|
self._log.warning(
|
|
"device auth could not recover guid=%s; hostname collisions persisted",
|
|
guid.value,
|
|
)
|
|
conn.rollback()
|
|
return None
|
|
|
|
cur.execute(
|
|
"""
|
|
SELECT guid, ssl_key_fingerprint, token_version, status
|
|
FROM devices
|
|
WHERE guid = ?
|
|
""",
|
|
(guid.value,),
|
|
)
|
|
row = cur.fetchone()
|
|
|
|
if not row:
|
|
self._log.warning(
|
|
"device auth recovery committed but row missing for guid=%s",
|
|
guid.value,
|
|
)
|
|
return None
|
|
|
|
return self._row_to_record(row)
|
|
|
|
def ensure_device_record(
|
|
self,
|
|
*,
|
|
guid: DeviceGuid,
|
|
hostname: str,
|
|
fingerprint: DeviceFingerprint,
|
|
) -> DeviceRecord:
|
|
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
|
now_ts = int(time.time())
|
|
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
|
|
FROM devices
|
|
WHERE UPPER(guid) = ?
|
|
""",
|
|
(guid.value.upper(),),
|
|
)
|
|
row = cur.fetchone()
|
|
|
|
if row:
|
|
stored_fp = (row[4] or "").strip().lower()
|
|
new_fp = fingerprint.value
|
|
if not stored_fp:
|
|
cur.execute(
|
|
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
|
(new_fp, now_iso, row[0]),
|
|
)
|
|
elif stored_fp != new_fp:
|
|
token_version = self._coerce_int(row[2], default=1) + 1
|
|
cur.execute(
|
|
"""
|
|
UPDATE devices
|
|
SET ssl_key_fingerprint = ?,
|
|
key_added_at = ?,
|
|
token_version = ?,
|
|
status = 'active'
|
|
WHERE guid = ?
|
|
""",
|
|
(new_fp, now_iso, token_version, row[0]),
|
|
)
|
|
cur.execute(
|
|
"""
|
|
UPDATE refresh_tokens
|
|
SET revoked_at = ?
|
|
WHERE guid = ?
|
|
AND revoked_at IS NULL
|
|
""",
|
|
(now_iso, row[0]),
|
|
)
|
|
conn.commit()
|
|
else:
|
|
resolved_hostname = self._resolve_hostname(cur, hostname, guid)
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO devices (
|
|
guid,
|
|
hostname,
|
|
created_at,
|
|
last_seen,
|
|
ssl_key_fingerprint,
|
|
token_version,
|
|
status,
|
|
key_added_at
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
|
|
""",
|
|
(
|
|
guid.value,
|
|
resolved_hostname,
|
|
now_ts,
|
|
now_ts,
|
|
fingerprint.value,
|
|
now_iso,
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
cur.execute(
|
|
"""
|
|
SELECT guid, ssl_key_fingerprint, token_version, status
|
|
FROM devices
|
|
WHERE UPPER(guid) = ?
|
|
""",
|
|
(guid.value.upper(),),
|
|
)
|
|
latest = cur.fetchone()
|
|
|
|
if not latest:
|
|
raise RuntimeError("device record could not be ensured")
|
|
|
|
record = self._row_to_record(latest)
|
|
if record is None:
|
|
raise RuntimeError("device record invalid after ensure")
|
|
return record
|
|
|
|
def record_device_key(
|
|
self,
|
|
*,
|
|
guid: DeviceGuid,
|
|
fingerprint: DeviceFingerprint,
|
|
added_at: datetime,
|
|
) -> None:
|
|
added_iso = added_at.astimezone(timezone.utc).isoformat()
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
|
|
VALUES (?, ?, ?, ?)
|
|
""",
|
|
(str(uuid.uuid4()), guid.value, fingerprint.value, added_iso),
|
|
)
|
|
cur.execute(
|
|
"""
|
|
UPDATE device_keys
|
|
SET retired_at = ?
|
|
WHERE guid = ?
|
|
AND ssl_key_fingerprint != ?
|
|
AND retired_at IS NULL
|
|
""",
|
|
(added_iso, guid.value, fingerprint.value),
|
|
)
|
|
conn.commit()
|
|
|
|
def update_device_summary(
|
|
self,
|
|
*,
|
|
hostname: Optional[str],
|
|
last_seen: Optional[int] = None,
|
|
agent_id: Optional[str] = None,
|
|
operating_system: Optional[str] = None,
|
|
last_user: Optional[str] = None,
|
|
) -> None:
|
|
if not hostname:
|
|
return
|
|
|
|
normalized_hostname = (hostname or "").strip()
|
|
if not normalized_hostname:
|
|
return
|
|
|
|
fields = []
|
|
params = []
|
|
|
|
if last_seen is not None:
|
|
try:
|
|
fields.append("last_seen = ?")
|
|
params.append(int(last_seen))
|
|
except Exception:
|
|
pass
|
|
|
|
if agent_id:
|
|
try:
|
|
candidate = agent_id.strip()
|
|
except Exception:
|
|
candidate = agent_id
|
|
if candidate:
|
|
fields.append("agent_id = ?")
|
|
params.append(candidate)
|
|
|
|
if operating_system:
|
|
try:
|
|
os_value = operating_system.strip()
|
|
except Exception:
|
|
os_value = operating_system
|
|
if os_value:
|
|
fields.append("operating_system = ?")
|
|
params.append(os_value)
|
|
|
|
if last_user:
|
|
try:
|
|
user_value = last_user.strip()
|
|
except Exception:
|
|
user_value = last_user
|
|
if user_value:
|
|
fields.append("last_user = ?")
|
|
params.append(user_value)
|
|
|
|
if not fields:
|
|
return
|
|
|
|
params.append(normalized_hostname)
|
|
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
f"UPDATE devices SET {', '.join(fields)} WHERE LOWER(hostname) = LOWER(?)",
|
|
params,
|
|
)
|
|
if cur.rowcount == 0 and agent_id:
|
|
cur.execute(
|
|
f"UPDATE devices SET {', '.join(fields)} WHERE agent_id = ?",
|
|
params[:-1] + [agent_id],
|
|
)
|
|
conn.commit()
|
|
|
|
def _row_to_record(self, row: tuple) -> Optional[DeviceRecord]:
|
|
try:
|
|
guid = DeviceGuid(row[0])
|
|
fingerprint_value = (row[1] or "").strip()
|
|
if not fingerprint_value:
|
|
self._log.warning(
|
|
"device row %s missing TLS fingerprint; skipping",
|
|
row[0],
|
|
)
|
|
return None
|
|
fingerprint = DeviceFingerprint(fingerprint_value)
|
|
except Exception as exc:
|
|
self._log.warning("invalid device row for guid=%s: %s", row[0], exc)
|
|
return None
|
|
|
|
token_version_raw = row[2]
|
|
try:
|
|
token_version = int(token_version_raw or 0)
|
|
except Exception:
|
|
token_version = 0
|
|
|
|
status = DeviceStatus.from_string(row[3])
|
|
identity = DeviceIdentity(guid=guid, fingerprint=fingerprint)
|
|
|
|
return DeviceRecord(
|
|
identity=identity,
|
|
token_version=max(token_version, 1),
|
|
status=status,
|
|
)
|
|
|
|
@staticmethod
|
|
def _coerce_int(value: object, *, default: int = 0) -> int:
|
|
try:
|
|
return int(value)
|
|
except Exception:
|
|
return default
|
|
|
|
def _resolve_hostname(self, cur: sqlite3.Cursor, hostname: str, guid: DeviceGuid) -> str:
|
|
base = (hostname or "").strip() or guid.value
|
|
base = base[:253]
|
|
candidate = base
|
|
suffix = 1
|
|
while True:
|
|
cur.execute(
|
|
"SELECT guid FROM devices WHERE hostname = ?",
|
|
(candidate,),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
return candidate
|
|
existing = (row[0] or "").strip().upper()
|
|
if existing == guid.value:
|
|
return candidate
|
|
candidate = f"{base}-{suffix}"
|
|
suffix += 1
|
|
if suffix > 50:
|
|
return guid.value
|