mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 15:21:57 -06:00
727 lines
24 KiB
Python
727 lines
24 KiB
Python
"""SQLite-backed enrollment repository for Engine services."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from contextlib import closing
|
|
from datetime import datetime, timezone
|
|
from typing import Any, List, Optional, Tuple
|
|
|
|
from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid, normalize_guid
|
|
from Data.Engine.domain.device_enrollment import (
|
|
EnrollmentApproval,
|
|
EnrollmentApprovalStatus,
|
|
EnrollmentCode,
|
|
)
|
|
from Data.Engine.domain.enrollment_admin import (
|
|
DeviceApprovalRecord,
|
|
EnrollmentCodeRecord,
|
|
HostnameConflict,
|
|
)
|
|
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
|
|
|
__all__ = ["SQLiteEnrollmentRepository"]
|
|
|
|
|
|
class SQLiteEnrollmentRepository:
|
|
"""Persistence adapter that manages enrollment codes and approvals."""
|
|
|
|
def __init__(
|
|
self,
|
|
connection_factory: SQLiteConnectionFactory,
|
|
*,
|
|
logger: Optional[logging.Logger] = None,
|
|
) -> None:
|
|
self._connections = connection_factory
|
|
self._log = logger or logging.getLogger("borealis.engine.repositories.enrollment")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Enrollment install codes
|
|
# ------------------------------------------------------------------
|
|
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]:
|
|
"""Load an enrollment install code by its public value."""
|
|
|
|
code_value = (code or "").strip()
|
|
if not code_value:
|
|
return None
|
|
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
SELECT id,
|
|
code,
|
|
expires_at,
|
|
used_at,
|
|
used_by_guid,
|
|
max_uses,
|
|
use_count,
|
|
last_used_at
|
|
FROM enrollment_install_codes
|
|
WHERE code = ?
|
|
""",
|
|
(code_value,),
|
|
)
|
|
row = cur.fetchone()
|
|
|
|
if not row:
|
|
return None
|
|
|
|
record = {
|
|
"id": row[0],
|
|
"code": row[1],
|
|
"expires_at": row[2],
|
|
"used_at": row[3],
|
|
"used_by_guid": row[4],
|
|
"max_uses": row[5],
|
|
"use_count": row[6],
|
|
"last_used_at": row[7],
|
|
}
|
|
try:
|
|
return EnrollmentCode.from_mapping(record)
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
self._log.warning("invalid enrollment code record for code=%s: %s", code_value, exc)
|
|
return None
|
|
|
|
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]:
|
|
record_value = (record_id or "").strip()
|
|
if not record_value:
|
|
return None
|
|
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
SELECT id,
|
|
code,
|
|
expires_at,
|
|
used_at,
|
|
used_by_guid,
|
|
max_uses,
|
|
use_count,
|
|
last_used_at
|
|
FROM enrollment_install_codes
|
|
WHERE id = ?
|
|
""",
|
|
(record_value,),
|
|
)
|
|
row = cur.fetchone()
|
|
|
|
if not row:
|
|
return None
|
|
|
|
record = {
|
|
"id": row[0],
|
|
"code": row[1],
|
|
"expires_at": row[2],
|
|
"used_at": row[3],
|
|
"used_by_guid": row[4],
|
|
"max_uses": row[5],
|
|
"use_count": row[6],
|
|
"last_used_at": row[7],
|
|
}
|
|
|
|
try:
|
|
return EnrollmentCode.from_mapping(record)
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc)
|
|
return None
|
|
|
|
def list_install_codes(
|
|
self,
|
|
*,
|
|
status: Optional[str] = None,
|
|
now: Optional[datetime] = None,
|
|
) -> List[EnrollmentCodeRecord]:
|
|
reference = now or datetime.now(tz=timezone.utc)
|
|
status_filter = (status or "").strip().lower()
|
|
params: List[str] = []
|
|
|
|
sql = """
|
|
SELECT id,
|
|
code,
|
|
expires_at,
|
|
created_by_user_id,
|
|
used_at,
|
|
used_by_guid,
|
|
max_uses,
|
|
use_count,
|
|
last_used_at
|
|
FROM enrollment_install_codes
|
|
"""
|
|
|
|
if status_filter in {"active", "expired", "used"}:
|
|
sql += " WHERE "
|
|
if status_filter == "active":
|
|
sql += "use_count < max_uses AND expires_at > ?"
|
|
params.append(self._isoformat(reference))
|
|
elif status_filter == "expired":
|
|
sql += "use_count < max_uses AND expires_at <= ?"
|
|
params.append(self._isoformat(reference))
|
|
else: # used
|
|
sql += "use_count >= max_uses"
|
|
|
|
sql += " ORDER BY expires_at ASC"
|
|
|
|
rows: List[EnrollmentCodeRecord] = []
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(sql, params)
|
|
for raw in cur.fetchall():
|
|
record = {
|
|
"id": raw[0],
|
|
"code": raw[1],
|
|
"expires_at": raw[2],
|
|
"created_by_user_id": raw[3],
|
|
"used_at": raw[4],
|
|
"used_by_guid": raw[5],
|
|
"max_uses": raw[6],
|
|
"use_count": raw[7],
|
|
"last_used_at": raw[8],
|
|
}
|
|
try:
|
|
rows.append(EnrollmentCodeRecord.from_row(record))
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
self._log.warning("invalid enrollment install code row id=%s: %s", record.get("id"), exc)
|
|
return rows
|
|
|
|
def get_install_code_record(self, record_id: str) -> Optional[EnrollmentCodeRecord]:
|
|
identifier = (record_id or "").strip()
|
|
if not identifier:
|
|
return None
|
|
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
SELECT id,
|
|
code,
|
|
expires_at,
|
|
created_by_user_id,
|
|
used_at,
|
|
used_by_guid,
|
|
max_uses,
|
|
use_count,
|
|
last_used_at
|
|
FROM enrollment_install_codes
|
|
WHERE id = ?
|
|
""",
|
|
(identifier,),
|
|
)
|
|
row = cur.fetchone()
|
|
|
|
if not row:
|
|
return None
|
|
|
|
payload = {
|
|
"id": row[0],
|
|
"code": row[1],
|
|
"expires_at": row[2],
|
|
"created_by_user_id": row[3],
|
|
"used_at": row[4],
|
|
"used_by_guid": row[5],
|
|
"max_uses": row[6],
|
|
"use_count": row[7],
|
|
"last_used_at": row[8],
|
|
}
|
|
|
|
try:
|
|
return EnrollmentCodeRecord.from_row(payload)
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
self._log.warning("invalid enrollment install code record id=%s: %s", identifier, exc)
|
|
return None
|
|
|
|
def insert_install_code(
|
|
self,
|
|
*,
|
|
record_id: str,
|
|
code: str,
|
|
expires_at: datetime,
|
|
created_by: Optional[str],
|
|
max_uses: int,
|
|
) -> EnrollmentCodeRecord:
|
|
expires_iso = self._isoformat(expires_at)
|
|
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO enrollment_install_codes (
|
|
id,
|
|
code,
|
|
expires_at,
|
|
created_by_user_id,
|
|
max_uses,
|
|
use_count
|
|
) VALUES (?, ?, ?, ?, ?, 0)
|
|
""",
|
|
(record_id, code, expires_iso, created_by, max_uses),
|
|
)
|
|
conn.commit()
|
|
|
|
record = self.get_install_code_record(record_id)
|
|
if record is None:
|
|
raise RuntimeError("failed to load install code after insert")
|
|
return record
|
|
|
|
def delete_install_code_if_unused(self, record_id: str) -> bool:
|
|
identifier = (record_id or "").strip()
|
|
if not identifier:
|
|
return False
|
|
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
|
|
(identifier,),
|
|
)
|
|
deleted = cur.rowcount > 0
|
|
conn.commit()
|
|
return deleted
|
|
|
|
def update_install_code_usage(
|
|
self,
|
|
record_id: str,
|
|
*,
|
|
use_count_increment: int,
|
|
last_used_at: datetime,
|
|
used_by_guid: Optional[DeviceGuid] = None,
|
|
mark_first_use: bool = False,
|
|
) -> None:
|
|
"""Increment usage counters and usage metadata for an install code."""
|
|
|
|
if use_count_increment <= 0:
|
|
raise ValueError("use_count_increment must be positive")
|
|
|
|
last_used_iso = self._isoformat(last_used_at)
|
|
guid_value = used_by_guid.value if used_by_guid else ""
|
|
mark_flag = 1 if mark_first_use else 0
|
|
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
UPDATE enrollment_install_codes
|
|
SET use_count = use_count + ?,
|
|
last_used_at = ?,
|
|
used_by_guid = COALESCE(NULLIF(?, ''), used_by_guid),
|
|
used_at = CASE WHEN ? = 1 AND used_at IS NULL THEN ? ELSE used_at END
|
|
WHERE id = ?
|
|
""",
|
|
(
|
|
use_count_increment,
|
|
last_used_iso,
|
|
guid_value,
|
|
mark_flag,
|
|
last_used_iso,
|
|
record_id,
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Device approvals
|
|
# ------------------------------------------------------------------
|
|
def list_device_approvals(
|
|
self,
|
|
*,
|
|
status: Optional[str] = None,
|
|
) -> List[DeviceApprovalRecord]:
|
|
status_filter = (status or "").strip().lower()
|
|
params: List[str] = []
|
|
|
|
sql = """
|
|
SELECT
|
|
da.id,
|
|
da.approval_reference,
|
|
da.guid,
|
|
da.hostname_claimed,
|
|
da.ssl_key_fingerprint_claimed,
|
|
da.enrollment_code_id,
|
|
da.status,
|
|
da.client_nonce,
|
|
da.server_nonce,
|
|
da.created_at,
|
|
da.updated_at,
|
|
da.approved_by_user_id,
|
|
u.username AS approved_by_username
|
|
FROM device_approvals AS da
|
|
LEFT JOIN users AS u
|
|
ON (
|
|
CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT)
|
|
OR LOWER(da.approved_by_user_id) = LOWER(u.username)
|
|
)
|
|
"""
|
|
|
|
if status_filter and status_filter not in {"all", "*"}:
|
|
sql += " WHERE LOWER(da.status) = ?"
|
|
params.append(status_filter)
|
|
|
|
sql += " ORDER BY da.created_at ASC"
|
|
|
|
approvals: List[DeviceApprovalRecord] = []
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(sql, params)
|
|
rows = cur.fetchall()
|
|
|
|
for raw in rows:
|
|
record = {
|
|
"id": raw[0],
|
|
"approval_reference": raw[1],
|
|
"guid": raw[2],
|
|
"hostname_claimed": raw[3],
|
|
"ssl_key_fingerprint_claimed": raw[4],
|
|
"enrollment_code_id": raw[5],
|
|
"status": raw[6],
|
|
"client_nonce": raw[7],
|
|
"server_nonce": raw[8],
|
|
"created_at": raw[9],
|
|
"updated_at": raw[10],
|
|
"approved_by_user_id": raw[11],
|
|
"approved_by_username": raw[12],
|
|
}
|
|
|
|
conflict, fingerprint_match, requires_prompt = self._compute_hostname_conflict(
|
|
conn,
|
|
record.get("hostname_claimed"),
|
|
record.get("guid"),
|
|
record.get("ssl_key_fingerprint_claimed") or "",
|
|
)
|
|
|
|
alternate = None
|
|
if conflict and requires_prompt:
|
|
alternate = self._suggest_alternate_hostname(
|
|
conn,
|
|
record.get("hostname_claimed"),
|
|
record.get("guid"),
|
|
)
|
|
|
|
try:
|
|
approvals.append(
|
|
DeviceApprovalRecord.from_row(
|
|
record,
|
|
conflict=conflict,
|
|
alternate_hostname=alternate,
|
|
fingerprint_match=fingerprint_match,
|
|
requires_prompt=requires_prompt,
|
|
)
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
self._log.warning(
|
|
"invalid device approval record id=%s: %s",
|
|
record.get("id"),
|
|
exc,
|
|
)
|
|
|
|
return approvals
|
|
|
|
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]:
|
|
"""Load a device approval using its operator-visible reference."""
|
|
|
|
ref_value = (reference or "").strip()
|
|
if not ref_value:
|
|
return None
|
|
return self._fetch_device_approval("approval_reference = ?", (ref_value,))
|
|
|
|
def fetch_device_approval(self, record_id: str) -> Optional[EnrollmentApproval]:
|
|
record_value = (record_id or "").strip()
|
|
if not record_value:
|
|
return None
|
|
return self._fetch_device_approval("id = ?", (record_value,))
|
|
|
|
def fetch_pending_approval_by_fingerprint(
|
|
self, fingerprint: DeviceFingerprint
|
|
) -> Optional[EnrollmentApproval]:
|
|
return self._fetch_device_approval(
|
|
"ssl_key_fingerprint_claimed = ? AND status = 'pending'",
|
|
(fingerprint.value,),
|
|
)
|
|
|
|
def update_pending_approval(
|
|
self,
|
|
record_id: str,
|
|
*,
|
|
hostname: str,
|
|
guid: Optional[DeviceGuid],
|
|
enrollment_code_id: Optional[str],
|
|
client_nonce_b64: str,
|
|
server_nonce_b64: str,
|
|
agent_pubkey_der: bytes,
|
|
updated_at: datetime,
|
|
) -> None:
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
UPDATE device_approvals
|
|
SET hostname_claimed = ?,
|
|
guid = ?,
|
|
enrollment_code_id = ?,
|
|
client_nonce = ?,
|
|
server_nonce = ?,
|
|
agent_pubkey_der = ?,
|
|
updated_at = ?
|
|
WHERE id = ?
|
|
""",
|
|
(
|
|
hostname,
|
|
guid.value if guid else None,
|
|
enrollment_code_id,
|
|
client_nonce_b64,
|
|
server_nonce_b64,
|
|
agent_pubkey_der,
|
|
self._isoformat(updated_at),
|
|
record_id,
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
def create_device_approval(
|
|
self,
|
|
*,
|
|
record_id: str,
|
|
reference: str,
|
|
claimed_hostname: str,
|
|
claimed_fingerprint: DeviceFingerprint,
|
|
enrollment_code_id: Optional[str],
|
|
client_nonce_b64: str,
|
|
server_nonce_b64: str,
|
|
agent_pubkey_der: bytes,
|
|
created_at: datetime,
|
|
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
|
|
guid: Optional[DeviceGuid] = None,
|
|
) -> EnrollmentApproval:
|
|
created_iso = self._isoformat(created_at)
|
|
guid_value = guid.value if guid else None
|
|
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO device_approvals (
|
|
id,
|
|
approval_reference,
|
|
guid,
|
|
hostname_claimed,
|
|
ssl_key_fingerprint_claimed,
|
|
enrollment_code_id,
|
|
status,
|
|
created_at,
|
|
updated_at,
|
|
client_nonce,
|
|
server_nonce,
|
|
agent_pubkey_der
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
record_id,
|
|
reference,
|
|
guid_value,
|
|
claimed_hostname,
|
|
claimed_fingerprint.value,
|
|
enrollment_code_id,
|
|
status.value,
|
|
created_iso,
|
|
created_iso,
|
|
client_nonce_b64,
|
|
server_nonce_b64,
|
|
agent_pubkey_der,
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
approval = self.fetch_device_approval(record_id)
|
|
if approval is None:
|
|
raise RuntimeError("failed to load device approval after insert")
|
|
return approval
|
|
|
|
def update_device_approval_status(
|
|
self,
|
|
record_id: str,
|
|
*,
|
|
status: EnrollmentApprovalStatus,
|
|
updated_at: datetime,
|
|
approved_by: Optional[str] = None,
|
|
guid: Optional[DeviceGuid] = None,
|
|
) -> None:
|
|
"""Transition an approval to a new status."""
|
|
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
UPDATE device_approvals
|
|
SET status = ?,
|
|
updated_at = ?,
|
|
guid = COALESCE(?, guid),
|
|
approved_by_user_id = COALESCE(?, approved_by_user_id)
|
|
WHERE id = ?
|
|
""",
|
|
(
|
|
status.value,
|
|
self._isoformat(updated_at),
|
|
guid.value if guid else None,
|
|
approved_by,
|
|
record_id,
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Helpers
|
|
# ------------------------------------------------------------------
|
|
def _fetch_device_approval(self, where: str, params: tuple) -> Optional[EnrollmentApproval]:
|
|
with closing(self._connections()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
f"""
|
|
SELECT id,
|
|
approval_reference,
|
|
guid,
|
|
hostname_claimed,
|
|
ssl_key_fingerprint_claimed,
|
|
enrollment_code_id,
|
|
created_at,
|
|
updated_at,
|
|
status,
|
|
approved_by_user_id,
|
|
client_nonce,
|
|
server_nonce,
|
|
agent_pubkey_der
|
|
FROM device_approvals
|
|
WHERE {where}
|
|
""",
|
|
params,
|
|
)
|
|
row = cur.fetchone()
|
|
|
|
if not row:
|
|
return None
|
|
|
|
record = {
|
|
"id": row[0],
|
|
"approval_reference": row[1],
|
|
"guid": row[2],
|
|
"hostname_claimed": row[3],
|
|
"ssl_key_fingerprint_claimed": row[4],
|
|
"enrollment_code_id": row[5],
|
|
"created_at": row[6],
|
|
"updated_at": row[7],
|
|
"status": row[8],
|
|
"approved_by_user_id": row[9],
|
|
"client_nonce": row[10],
|
|
"server_nonce": row[11],
|
|
"agent_pubkey_der": row[12],
|
|
}
|
|
|
|
try:
|
|
return EnrollmentApproval.from_mapping(record)
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
self._log.warning(
|
|
"invalid device approval record id=%s reference=%s: %s",
|
|
row[0],
|
|
row[1],
|
|
exc,
|
|
)
|
|
return None
|
|
|
|
def _compute_hostname_conflict(
|
|
self,
|
|
conn,
|
|
hostname: Optional[str],
|
|
pending_guid: Optional[str],
|
|
claimed_fp: str,
|
|
) -> Tuple[Optional[HostnameConflict], bool, bool]:
|
|
normalized_host = (hostname or "").strip()
|
|
if not normalized_host:
|
|
return None, False, False
|
|
|
|
try:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
SELECT d.guid,
|
|
d.ssl_key_fingerprint,
|
|
ds.site_id,
|
|
s.name
|
|
FROM devices AS d
|
|
LEFT JOIN device_sites AS ds ON ds.device_hostname = d.hostname
|
|
LEFT JOIN sites AS s ON s.id = ds.site_id
|
|
WHERE d.hostname = ?
|
|
""",
|
|
(normalized_host,),
|
|
)
|
|
row = cur.fetchone()
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
self._log.warning("failed to inspect hostname conflict for %s: %s", normalized_host, exc)
|
|
return None, False, False
|
|
|
|
if not row:
|
|
return None, False, False
|
|
|
|
existing_guid = normalize_guid(row[0])
|
|
pending_norm = normalize_guid(pending_guid)
|
|
if existing_guid and pending_norm and existing_guid == pending_norm:
|
|
return None, False, False
|
|
|
|
stored_fp = (row[1] or "").strip().lower()
|
|
claimed_fp_normalized = (claimed_fp or "").strip().lower()
|
|
fingerprint_match = bool(stored_fp and claimed_fp_normalized and stored_fp == claimed_fp_normalized)
|
|
|
|
site_id = None
|
|
if row[2] is not None:
|
|
try:
|
|
site_id = int(row[2])
|
|
except (TypeError, ValueError): # pragma: no cover - defensive
|
|
site_id = None
|
|
|
|
site_name = str(row[3] or "").strip()
|
|
requires_prompt = not fingerprint_match
|
|
|
|
conflict = HostnameConflict(
|
|
guid=existing_guid or None,
|
|
ssl_key_fingerprint=stored_fp or None,
|
|
site_id=site_id,
|
|
site_name=site_name,
|
|
fingerprint_match=fingerprint_match,
|
|
requires_prompt=requires_prompt,
|
|
)
|
|
|
|
return conflict, fingerprint_match, requires_prompt
|
|
|
|
def _suggest_alternate_hostname(
|
|
self,
|
|
conn,
|
|
hostname: Optional[str],
|
|
pending_guid: Optional[str],
|
|
) -> Optional[str]:
|
|
base = (hostname or "").strip()
|
|
if not base:
|
|
return None
|
|
base = base[:253]
|
|
candidate = base
|
|
pending_norm = normalize_guid(pending_guid)
|
|
suffix = 1
|
|
|
|
cur = conn.cursor()
|
|
while True:
|
|
cur.execute("SELECT guid FROM devices WHERE hostname = ?", (candidate,))
|
|
row = cur.fetchone()
|
|
if not row:
|
|
return candidate
|
|
existing_guid = normalize_guid(row[0])
|
|
if pending_norm and existing_guid == pending_norm:
|
|
return candidate
|
|
candidate = f"{base}-{suffix}"
|
|
suffix += 1
|
|
if suffix > 50:
|
|
return pending_norm or candidate
|
|
|
|
@staticmethod
|
|
def _isoformat(value: datetime) -> str:
|
|
if value.tzinfo is None:
|
|
value = value.replace(tzinfo=timezone.utc)
|
|
return value.astimezone(timezone.utc).isoformat()
|