mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 22:01:59 -06:00
Implement admin enrollment APIs
This commit is contained in:
@@ -5,14 +5,19 @@ from __future__ import annotations
|
||||
import logging
|
||||
from contextlib import closing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid
|
||||
from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid, normalize_guid
|
||||
from Data.Engine.domain.device_enrollment import (
|
||||
EnrollmentApproval,
|
||||
EnrollmentApprovalStatus,
|
||||
EnrollmentCode,
|
||||
)
|
||||
from Data.Engine.domain.enrollment_admin import (
|
||||
DeviceApprovalRecord,
|
||||
EnrollmentCodeRecord,
|
||||
HostnameConflict,
|
||||
)
|
||||
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||
|
||||
__all__ = ["SQLiteEnrollmentRepository"]
|
||||
@@ -122,6 +127,158 @@ class SQLiteEnrollmentRepository:
|
||||
self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc)
|
||||
return None
|
||||
|
||||
def list_install_codes(
|
||||
self,
|
||||
*,
|
||||
status: Optional[str] = None,
|
||||
now: Optional[datetime] = None,
|
||||
) -> List[EnrollmentCodeRecord]:
|
||||
reference = now or datetime.now(tz=timezone.utc)
|
||||
status_filter = (status or "").strip().lower()
|
||||
params: List[str] = []
|
||||
|
||||
sql = """
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
created_by_user_id,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
"""
|
||||
|
||||
if status_filter in {"active", "expired", "used"}:
|
||||
sql += " WHERE "
|
||||
if status_filter == "active":
|
||||
sql += "use_count < max_uses AND expires_at > ?"
|
||||
params.append(self._isoformat(reference))
|
||||
elif status_filter == "expired":
|
||||
sql += "use_count < max_uses AND expires_at <= ?"
|
||||
params.append(self._isoformat(reference))
|
||||
else: # used
|
||||
sql += "use_count >= max_uses"
|
||||
|
||||
sql += " ORDER BY expires_at ASC"
|
||||
|
||||
rows: List[EnrollmentCodeRecord] = []
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql, params)
|
||||
for raw in cur.fetchall():
|
||||
record = {
|
||||
"id": raw[0],
|
||||
"code": raw[1],
|
||||
"expires_at": raw[2],
|
||||
"created_by_user_id": raw[3],
|
||||
"used_at": raw[4],
|
||||
"used_by_guid": raw[5],
|
||||
"max_uses": raw[6],
|
||||
"use_count": raw[7],
|
||||
"last_used_at": raw[8],
|
||||
}
|
||||
try:
|
||||
rows.append(EnrollmentCodeRecord.from_row(record))
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("invalid enrollment install code row id=%s: %s", record.get("id"), exc)
|
||||
return rows
|
||||
|
||||
def get_install_code_record(self, record_id: str) -> Optional[EnrollmentCodeRecord]:
|
||||
identifier = (record_id or "").strip()
|
||||
if not identifier:
|
||||
return None
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
created_by_user_id,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
WHERE id = ?
|
||||
""",
|
||||
(identifier,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
payload = {
|
||||
"id": row[0],
|
||||
"code": row[1],
|
||||
"expires_at": row[2],
|
||||
"created_by_user_id": row[3],
|
||||
"used_at": row[4],
|
||||
"used_by_guid": row[5],
|
||||
"max_uses": row[6],
|
||||
"use_count": row[7],
|
||||
"last_used_at": row[8],
|
||||
}
|
||||
|
||||
try:
|
||||
return EnrollmentCodeRecord.from_row(payload)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("invalid enrollment install code record id=%s: %s", identifier, exc)
|
||||
return None
|
||||
|
||||
def insert_install_code(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
code: str,
|
||||
expires_at: datetime,
|
||||
created_by: Optional[str],
|
||||
max_uses: int,
|
||||
) -> EnrollmentCodeRecord:
|
||||
expires_iso = self._isoformat(expires_at)
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO enrollment_install_codes (
|
||||
id,
|
||||
code,
|
||||
expires_at,
|
||||
created_by_user_id,
|
||||
max_uses,
|
||||
use_count
|
||||
) VALUES (?, ?, ?, ?, ?, 0)
|
||||
""",
|
||||
(record_id, code, expires_iso, created_by, max_uses),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
record = self.get_install_code_record(record_id)
|
||||
if record is None:
|
||||
raise RuntimeError("failed to load install code after insert")
|
||||
return record
|
||||
|
||||
def delete_install_code_if_unused(self, record_id: str) -> bool:
|
||||
identifier = (record_id or "").strip()
|
||||
if not identifier:
|
||||
return False
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
|
||||
(identifier,),
|
||||
)
|
||||
deleted = cur.rowcount > 0
|
||||
conn.commit()
|
||||
return deleted
|
||||
|
||||
def update_install_code_usage(
|
||||
self,
|
||||
record_id: str,
|
||||
@@ -165,6 +322,100 @@ class SQLiteEnrollmentRepository:
|
||||
# ------------------------------------------------------------------
|
||||
# Device approvals
|
||||
# ------------------------------------------------------------------
|
||||
def list_device_approvals(
|
||||
self,
|
||||
*,
|
||||
status: Optional[str] = None,
|
||||
) -> List[DeviceApprovalRecord]:
|
||||
status_filter = (status or "").strip().lower()
|
||||
params: List[str] = []
|
||||
|
||||
sql = """
|
||||
SELECT
|
||||
da.id,
|
||||
da.approval_reference,
|
||||
da.guid,
|
||||
da.hostname_claimed,
|
||||
da.ssl_key_fingerprint_claimed,
|
||||
da.enrollment_code_id,
|
||||
da.status,
|
||||
da.client_nonce,
|
||||
da.server_nonce,
|
||||
da.created_at,
|
||||
da.updated_at,
|
||||
da.approved_by_user_id,
|
||||
u.username AS approved_by_username
|
||||
FROM device_approvals AS da
|
||||
LEFT JOIN users AS u
|
||||
ON (
|
||||
CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT)
|
||||
OR LOWER(da.approved_by_user_id) = LOWER(u.username)
|
||||
)
|
||||
"""
|
||||
|
||||
if status_filter and status_filter not in {"all", "*"}:
|
||||
sql += " WHERE LOWER(da.status) = ?"
|
||||
params.append(status_filter)
|
||||
|
||||
sql += " ORDER BY da.created_at ASC"
|
||||
|
||||
approvals: List[DeviceApprovalRecord] = []
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql, params)
|
||||
rows = cur.fetchall()
|
||||
|
||||
for raw in rows:
|
||||
record = {
|
||||
"id": raw[0],
|
||||
"approval_reference": raw[1],
|
||||
"guid": raw[2],
|
||||
"hostname_claimed": raw[3],
|
||||
"ssl_key_fingerprint_claimed": raw[4],
|
||||
"enrollment_code_id": raw[5],
|
||||
"status": raw[6],
|
||||
"client_nonce": raw[7],
|
||||
"server_nonce": raw[8],
|
||||
"created_at": raw[9],
|
||||
"updated_at": raw[10],
|
||||
"approved_by_user_id": raw[11],
|
||||
"approved_by_username": raw[12],
|
||||
}
|
||||
|
||||
conflict, fingerprint_match, requires_prompt = self._compute_hostname_conflict(
|
||||
conn,
|
||||
record.get("hostname_claimed"),
|
||||
record.get("guid"),
|
||||
record.get("ssl_key_fingerprint_claimed") or "",
|
||||
)
|
||||
|
||||
alternate = None
|
||||
if conflict and requires_prompt:
|
||||
alternate = self._suggest_alternate_hostname(
|
||||
conn,
|
||||
record.get("hostname_claimed"),
|
||||
record.get("guid"),
|
||||
)
|
||||
|
||||
try:
|
||||
approvals.append(
|
||||
DeviceApprovalRecord.from_row(
|
||||
record,
|
||||
conflict=conflict,
|
||||
alternate_hostname=alternate,
|
||||
fingerprint_match=fingerprint_match,
|
||||
requires_prompt=requires_prompt,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning(
|
||||
"invalid device approval record id=%s: %s",
|
||||
record.get("id"),
|
||||
exc,
|
||||
)
|
||||
|
||||
return approvals
|
||||
|
||||
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]:
|
||||
"""Load a device approval using its operator-visible reference."""
|
||||
|
||||
@@ -376,6 +627,98 @@ class SQLiteEnrollmentRepository:
|
||||
)
|
||||
return None
|
||||
|
||||
def _compute_hostname_conflict(
|
||||
self,
|
||||
conn,
|
||||
hostname: Optional[str],
|
||||
pending_guid: Optional[str],
|
||||
claimed_fp: str,
|
||||
) -> Tuple[Optional[HostnameConflict], bool, bool]:
|
||||
normalized_host = (hostname or "").strip()
|
||||
if not normalized_host:
|
||||
return None, False, False
|
||||
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT d.guid,
|
||||
d.ssl_key_fingerprint,
|
||||
ds.site_id,
|
||||
s.name
|
||||
FROM devices AS d
|
||||
LEFT JOIN device_sites AS ds ON ds.device_hostname = d.hostname
|
||||
LEFT JOIN sites AS s ON s.id = ds.site_id
|
||||
WHERE d.hostname = ?
|
||||
""",
|
||||
(normalized_host,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("failed to inspect hostname conflict for %s: %s", normalized_host, exc)
|
||||
return None, False, False
|
||||
|
||||
if not row:
|
||||
return None, False, False
|
||||
|
||||
existing_guid = normalize_guid(row[0])
|
||||
pending_norm = normalize_guid(pending_guid)
|
||||
if existing_guid and pending_norm and existing_guid == pending_norm:
|
||||
return None, False, False
|
||||
|
||||
stored_fp = (row[1] or "").strip().lower()
|
||||
claimed_fp_normalized = (claimed_fp or "").strip().lower()
|
||||
fingerprint_match = bool(stored_fp and claimed_fp_normalized and stored_fp == claimed_fp_normalized)
|
||||
|
||||
site_id = None
|
||||
if row[2] is not None:
|
||||
try:
|
||||
site_id = int(row[2])
|
||||
except (TypeError, ValueError): # pragma: no cover - defensive
|
||||
site_id = None
|
||||
|
||||
site_name = str(row[3] or "").strip()
|
||||
requires_prompt = not fingerprint_match
|
||||
|
||||
conflict = HostnameConflict(
|
||||
guid=existing_guid or None,
|
||||
ssl_key_fingerprint=stored_fp or None,
|
||||
site_id=site_id,
|
||||
site_name=site_name,
|
||||
fingerprint_match=fingerprint_match,
|
||||
requires_prompt=requires_prompt,
|
||||
)
|
||||
|
||||
return conflict, fingerprint_match, requires_prompt
|
||||
|
||||
def _suggest_alternate_hostname(
|
||||
self,
|
||||
conn,
|
||||
hostname: Optional[str],
|
||||
pending_guid: Optional[str],
|
||||
) -> Optional[str]:
|
||||
base = (hostname or "").strip()
|
||||
if not base:
|
||||
return None
|
||||
base = base[:253]
|
||||
candidate = base
|
||||
pending_norm = normalize_guid(pending_guid)
|
||||
suffix = 1
|
||||
|
||||
cur = conn.cursor()
|
||||
while True:
|
||||
cur.execute("SELECT guid FROM devices WHERE hostname = ?", (candidate,))
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return candidate
|
||||
existing_guid = normalize_guid(row[0])
|
||||
if pending_norm and existing_guid == pending_norm:
|
||||
return candidate
|
||||
candidate = f"{base}-{suffix}"
|
||||
suffix += 1
|
||||
if suffix > 50:
|
||||
return pending_norm or candidate
|
||||
|
||||
@staticmethod
|
||||
def _isoformat(value: datetime) -> str:
|
||||
if value.tzinfo is None:
|
||||
|
||||
@@ -31,6 +31,9 @@ def apply_all(conn: sqlite3.Connection) -> None:
|
||||
_ensure_refresh_token_table(conn)
|
||||
_ensure_install_code_table(conn)
|
||||
_ensure_device_approval_table(conn)
|
||||
_ensure_device_list_views_table(conn)
|
||||
_ensure_sites_tables(conn)
|
||||
_ensure_credentials_table(conn)
|
||||
_ensure_github_token_table(conn)
|
||||
_ensure_scheduled_jobs_table(conn)
|
||||
_ensure_scheduled_job_run_tables(conn)
|
||||
@@ -233,6 +236,73 @@ def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _ensure_device_list_views_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS device_list_views (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
columns_json TEXT NOT NULL,
|
||||
filters_json TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_sites_tables(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS sites (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
description TEXT,
|
||||
created_at INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS device_sites (
|
||||
device_hostname TEXT UNIQUE NOT NULL,
|
||||
site_id INTEGER NOT NULL,
|
||||
assigned_at INTEGER,
|
||||
FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_credentials_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS credentials (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
description TEXT,
|
||||
site_id INTEGER,
|
||||
credential_type TEXT NOT NULL DEFAULT 'machine',
|
||||
connection_type TEXT NOT NULL DEFAULT 'ssh',
|
||||
username TEXT,
|
||||
password_encrypted BLOB,
|
||||
private_key_encrypted BLOB,
|
||||
private_key_passphrase_encrypted BLOB,
|
||||
become_method TEXT,
|
||||
become_username TEXT,
|
||||
become_password_encrypted BLOB,
|
||||
metadata_json TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE SET NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_github_token_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
|
||||
@@ -71,6 +71,57 @@ class SQLiteUserRepository:
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def resolve_identifier(self, username: str) -> Optional[str]:
|
||||
normalized = (username or "").strip()
|
||||
if not normalized:
|
||||
return None
|
||||
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT id FROM users WHERE LOWER(username) = LOWER(?)",
|
||||
(normalized,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return str(row[0]) if row[0] is not None else None
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.error("failed to resolve identifier for %s: %s", username, exc)
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def username_for_identifier(self, identifier: str) -> Optional[str]:
|
||||
token = (identifier or "").strip()
|
||||
if not token:
|
||||
return None
|
||||
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT username
|
||||
FROM users
|
||||
WHERE CAST(id AS TEXT) = ?
|
||||
OR LOWER(username) = LOWER(?)
|
||||
LIMIT 1
|
||||
""",
|
||||
(token, token),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
username = str(row[0] or "").strip()
|
||||
return username or None
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.error("failed to resolve username for %s: %s", identifier, exc)
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def list_accounts(self) -> list[OperatorAccount]:
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user