Implement admin enrollment APIs

This commit is contained in:
2025-10-22 23:26:06 -06:00
parent b8e3ea2a62
commit d0fa6929b2
12 changed files with 1182 additions and 18 deletions

View File

@@ -18,6 +18,7 @@ __all__ = [
"AccessTokenClaims", "AccessTokenClaims",
"DeviceAuthContext", "DeviceAuthContext",
"sanitize_service_context", "sanitize_service_context",
"normalize_guid",
] ]
@@ -73,6 +74,12 @@ class DeviceGuid:
return self.value return self.value
def normalize_guid(value: Optional[str]) -> str:
"""Expose GUID normalization for administrative helpers."""
return _normalize_guid(value)
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
class DeviceFingerprint: class DeviceFingerprint:
"""Normalized TLS key fingerprint associated with a device.""" """Normalized TLS key fingerprint associated with a device."""

View File

@@ -0,0 +1,206 @@
"""Administrative enrollment domain models."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Mapping, Optional
from Data.Engine.domain.device_auth import DeviceGuid, normalize_guid
__all__ = [
"EnrollmentCodeRecord",
"DeviceApprovalRecord",
"HostnameConflict",
]
def _parse_iso8601(value: Optional[str]) -> Optional[datetime]:
if not value:
return None
raw = str(value).strip()
if not raw:
return None
try:
dt = datetime.fromisoformat(raw)
except Exception as exc: # pragma: no cover - defensive parsing
raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)
def _isoformat(value: Optional[datetime]) -> Optional[str]:
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc).isoformat()
@dataclass(frozen=True, slots=True)
class EnrollmentCodeRecord:
"""Installer code metadata exposed to administrative clients."""
record_id: str
code: str
expires_at: datetime
max_uses: int
use_count: int
created_by_user_id: Optional[str]
used_at: Optional[datetime]
used_by_guid: Optional[DeviceGuid]
last_used_at: Optional[datetime]
@classmethod
def from_row(cls, row: Mapping[str, Any]) -> "EnrollmentCodeRecord":
record_id = str(row.get("id") or "").strip()
code = str(row.get("code") or "").strip()
if not record_id or not code:
raise ValueError("invalid enrollment install code record")
used_by = row.get("used_by_guid")
used_by_guid = DeviceGuid(str(used_by)) if used_by else None
return cls(
record_id=record_id,
code=code,
expires_at=_parse_iso8601(row.get("expires_at")) or datetime.now(tz=timezone.utc),
max_uses=int(row.get("max_uses") or 1),
use_count=int(row.get("use_count") or 0),
created_by_user_id=str(row.get("created_by_user_id") or "").strip() or None,
used_at=_parse_iso8601(row.get("used_at")),
used_by_guid=used_by_guid,
last_used_at=_parse_iso8601(row.get("last_used_at")),
)
def status(self, *, now: Optional[datetime] = None) -> str:
reference = now or datetime.now(tz=timezone.utc)
if self.use_count >= self.max_uses:
return "used"
if self.expires_at <= reference:
return "expired"
return "active"
def to_dict(self) -> dict[str, Any]:
return {
"id": self.record_id,
"code": self.code,
"expires_at": _isoformat(self.expires_at),
"max_uses": self.max_uses,
"use_count": self.use_count,
"created_by_user_id": self.created_by_user_id,
"used_at": _isoformat(self.used_at),
"used_by_guid": self.used_by_guid.value if self.used_by_guid else None,
"last_used_at": _isoformat(self.last_used_at),
"status": self.status(),
}
@dataclass(frozen=True, slots=True)
class HostnameConflict:
"""Existing device details colliding with a pending approval."""
guid: Optional[str]
ssl_key_fingerprint: Optional[str]
site_id: Optional[int]
site_name: str
fingerprint_match: bool
requires_prompt: bool
def to_dict(self) -> dict[str, Any]:
return {
"guid": self.guid,
"ssl_key_fingerprint": self.ssl_key_fingerprint,
"site_id": self.site_id,
"site_name": self.site_name,
"fingerprint_match": self.fingerprint_match,
"requires_prompt": self.requires_prompt,
}
@dataclass(frozen=True, slots=True)
class DeviceApprovalRecord:
"""Administrative projection of a device approval entry."""
record_id: str
reference: str
status: str
claimed_hostname: str
claimed_fingerprint: str
created_at: datetime
updated_at: datetime
enrollment_code_id: Optional[str]
guid: Optional[str]
approved_by_user_id: Optional[str]
approved_by_username: Optional[str]
client_nonce: str
server_nonce: str
hostname_conflict: Optional[HostnameConflict]
alternate_hostname: Optional[str]
conflict_requires_prompt: bool
fingerprint_match: bool
@classmethod
def from_row(
cls,
row: Mapping[str, Any],
*,
conflict: Optional[HostnameConflict] = None,
alternate_hostname: Optional[str] = None,
fingerprint_match: bool = False,
requires_prompt: bool = False,
) -> "DeviceApprovalRecord":
record_id = str(row.get("id") or "").strip()
reference = str(row.get("approval_reference") or "").strip()
hostname = str(row.get("hostname_claimed") or "").strip()
fingerprint = str(row.get("ssl_key_fingerprint_claimed") or "").strip().lower()
if not record_id or not reference or not hostname or not fingerprint:
raise ValueError("invalid device approval record")
guid_raw = normalize_guid(row.get("guid")) or None
return cls(
record_id=record_id,
reference=reference,
status=str(row.get("status") or "pending").strip().lower(),
claimed_hostname=hostname,
claimed_fingerprint=fingerprint,
created_at=_parse_iso8601(row.get("created_at")) or datetime.now(tz=timezone.utc),
updated_at=_parse_iso8601(row.get("updated_at")) or datetime.now(tz=timezone.utc),
enrollment_code_id=str(row.get("enrollment_code_id") or "").strip() or None,
guid=guid_raw,
approved_by_user_id=str(row.get("approved_by_user_id") or "").strip() or None,
approved_by_username=str(row.get("approved_by_username") or "").strip() or None,
client_nonce=str(row.get("client_nonce") or "").strip(),
server_nonce=str(row.get("server_nonce") or "").strip(),
hostname_conflict=conflict,
alternate_hostname=alternate_hostname,
conflict_requires_prompt=requires_prompt,
fingerprint_match=fingerprint_match,
)
def to_dict(self) -> dict[str, Any]:
payload: dict[str, Any] = {
"id": self.record_id,
"approval_reference": self.reference,
"status": self.status,
"hostname_claimed": self.claimed_hostname,
"ssl_key_fingerprint_claimed": self.claimed_fingerprint,
"created_at": _isoformat(self.created_at),
"updated_at": _isoformat(self.updated_at),
"enrollment_code_id": self.enrollment_code_id,
"guid": self.guid,
"approved_by_user_id": self.approved_by_user_id,
"approved_by_username": self.approved_by_username,
"client_nonce": self.client_nonce,
"server_nonce": self.server_nonce,
"conflict_requires_prompt": self.conflict_requires_prompt,
"fingerprint_match": self.fingerprint_match,
}
if self.hostname_conflict is not None:
payload["hostname_conflict"] = self.hostname_conflict.to_dict()
if self.alternate_hostname:
payload["alternate_hostname"] = self.alternate_hostname
return payload

View File

@@ -1,8 +1,8 @@
"""Administrative HTTP interface placeholders for the Engine.""" """Administrative HTTP endpoints for the Borealis Engine."""
from __future__ import annotations from __future__ import annotations
from flask import Blueprint, Flask from flask import Blueprint, Flask, current_app, jsonify, request, session
from Data.Engine.services.container import EngineServiceContainer from Data.Engine.services.container import EngineServiceContainer
@@ -11,13 +11,106 @@ blueprint = Blueprint("engine_admin", __name__, url_prefix="/api/admin")
def register(app: Flask, _services: EngineServiceContainer) -> None: def register(app: Flask, _services: EngineServiceContainer) -> None:
"""Attach administrative routes to *app*. """Attach administrative routes to *app*."""
Concrete endpoints will be migrated in subsequent phases.
"""
if "engine_admin" not in app.blueprints: if "engine_admin" not in app.blueprints:
app.register_blueprint(blueprint) app.register_blueprint(blueprint)
def _services() -> EngineServiceContainer:
services = current_app.extensions.get("engine_services")
if services is None: # pragma: no cover - defensive
raise RuntimeError("engine services not initialized")
return services
def _admin_service():
return _services().enrollment_admin_service
def _require_admin():
username = session.get("username")
role = (session.get("role") or "").strip().lower()
if not isinstance(username, str) or not username:
return jsonify({"error": "not_authenticated"}), 401
if role != "admin":
return jsonify({"error": "forbidden"}), 403
return None
@blueprint.route("/enrollment-codes", methods=["GET"])
def list_enrollment_codes() -> object:
guard = _require_admin()
if guard:
return guard
status = request.args.get("status")
records = _admin_service().list_install_codes(status=status)
return jsonify({"codes": [record.to_dict() for record in records]})
@blueprint.route("/enrollment-codes", methods=["POST"])
def create_enrollment_code() -> object:
guard = _require_admin()
if guard:
return guard
payload = request.get_json(silent=True) or {}
ttl_value = payload.get("ttl_hours")
if ttl_value is None:
ttl_value = payload.get("ttl") or 1
try:
ttl_hours = int(ttl_value)
except (TypeError, ValueError):
ttl_hours = 1
max_uses_value = payload.get("max_uses")
if max_uses_value is None:
max_uses_value = payload.get("allowed_uses", 2)
try:
max_uses = int(max_uses_value)
except (TypeError, ValueError):
max_uses = 2
creator = session.get("username") if isinstance(session.get("username"), str) else None
try:
record = _admin_service().create_install_code(
ttl_hours=ttl_hours,
max_uses=max_uses,
created_by=creator,
)
except ValueError as exc:
if str(exc) == "invalid_ttl":
return jsonify({"error": "invalid_ttl"}), 400
raise
response = jsonify(record.to_dict())
response.status_code = 201
return response
@blueprint.route("/enrollment-codes/<code_id>", methods=["DELETE"])
def delete_enrollment_code(code_id: str) -> object:
guard = _require_admin()
if guard:
return guard
if not _admin_service().delete_install_code(code_id):
return jsonify({"error": "not_found"}), 404
return jsonify({"status": "deleted"})
@blueprint.route("/device-approvals", methods=["GET"])
def list_device_approvals() -> object:
guard = _require_admin()
if guard:
return guard
status = request.args.get("status")
records = _admin_service().list_device_approvals(status=status)
return jsonify({"approvals": [record.to_dict() for record in records]})
__all__ = ["register", "blueprint"] __all__ = ["register", "blueprint"]

View File

@@ -5,14 +5,19 @@ from __future__ import annotations
import logging import logging
from contextlib import closing from contextlib import closing
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional from typing import Any, List, Optional, Tuple
from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid, normalize_guid
from Data.Engine.domain.device_enrollment import ( from Data.Engine.domain.device_enrollment import (
EnrollmentApproval, EnrollmentApproval,
EnrollmentApprovalStatus, EnrollmentApprovalStatus,
EnrollmentCode, EnrollmentCode,
) )
from Data.Engine.domain.enrollment_admin import (
DeviceApprovalRecord,
EnrollmentCodeRecord,
HostnameConflict,
)
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
__all__ = ["SQLiteEnrollmentRepository"] __all__ = ["SQLiteEnrollmentRepository"]
@@ -122,6 +127,158 @@ class SQLiteEnrollmentRepository:
self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc) self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc)
return None return None
def list_install_codes(
self,
*,
status: Optional[str] = None,
now: Optional[datetime] = None,
) -> List[EnrollmentCodeRecord]:
reference = now or datetime.now(tz=timezone.utc)
status_filter = (status or "").strip().lower()
params: List[str] = []
sql = """
SELECT id,
code,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
"""
if status_filter in {"active", "expired", "used"}:
sql += " WHERE "
if status_filter == "active":
sql += "use_count < max_uses AND expires_at > ?"
params.append(self._isoformat(reference))
elif status_filter == "expired":
sql += "use_count < max_uses AND expires_at <= ?"
params.append(self._isoformat(reference))
else: # used
sql += "use_count >= max_uses"
sql += " ORDER BY expires_at ASC"
rows: List[EnrollmentCodeRecord] = []
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(sql, params)
for raw in cur.fetchall():
record = {
"id": raw[0],
"code": raw[1],
"expires_at": raw[2],
"created_by_user_id": raw[3],
"used_at": raw[4],
"used_by_guid": raw[5],
"max_uses": raw[6],
"use_count": raw[7],
"last_used_at": raw[8],
}
try:
rows.append(EnrollmentCodeRecord.from_row(record))
except Exception as exc: # pragma: no cover - defensive logging
self._log.warning("invalid enrollment install code row id=%s: %s", record.get("id"), exc)
return rows
def get_install_code_record(self, record_id: str) -> Optional[EnrollmentCodeRecord]:
identifier = (record_id or "").strip()
if not identifier:
return None
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT id,
code,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
WHERE id = ?
""",
(identifier,),
)
row = cur.fetchone()
if not row:
return None
payload = {
"id": row[0],
"code": row[1],
"expires_at": row[2],
"created_by_user_id": row[3],
"used_at": row[4],
"used_by_guid": row[5],
"max_uses": row[6],
"use_count": row[7],
"last_used_at": row[8],
}
try:
return EnrollmentCodeRecord.from_row(payload)
except Exception as exc: # pragma: no cover - defensive logging
self._log.warning("invalid enrollment install code record id=%s: %s", identifier, exc)
return None
def insert_install_code(
self,
*,
record_id: str,
code: str,
expires_at: datetime,
created_by: Optional[str],
max_uses: int,
) -> EnrollmentCodeRecord:
expires_iso = self._isoformat(expires_at)
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
INSERT INTO enrollment_install_codes (
id,
code,
expires_at,
created_by_user_id,
max_uses,
use_count
) VALUES (?, ?, ?, ?, ?, 0)
""",
(record_id, code, expires_iso, created_by, max_uses),
)
conn.commit()
record = self.get_install_code_record(record_id)
if record is None:
raise RuntimeError("failed to load install code after insert")
return record
def delete_install_code_if_unused(self, record_id: str) -> bool:
identifier = (record_id or "").strip()
if not identifier:
return False
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
(identifier,),
)
deleted = cur.rowcount > 0
conn.commit()
return deleted
def update_install_code_usage( def update_install_code_usage(
self, self,
record_id: str, record_id: str,
@@ -165,6 +322,100 @@ class SQLiteEnrollmentRepository:
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Device approvals # Device approvals
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def list_device_approvals(
self,
*,
status: Optional[str] = None,
) -> List[DeviceApprovalRecord]:
status_filter = (status or "").strip().lower()
params: List[str] = []
sql = """
SELECT
da.id,
da.approval_reference,
da.guid,
da.hostname_claimed,
da.ssl_key_fingerprint_claimed,
da.enrollment_code_id,
da.status,
da.client_nonce,
da.server_nonce,
da.created_at,
da.updated_at,
da.approved_by_user_id,
u.username AS approved_by_username
FROM device_approvals AS da
LEFT JOIN users AS u
ON (
CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT)
OR LOWER(da.approved_by_user_id) = LOWER(u.username)
)
"""
if status_filter and status_filter not in {"all", "*"}:
sql += " WHERE LOWER(da.status) = ?"
params.append(status_filter)
sql += " ORDER BY da.created_at ASC"
approvals: List[DeviceApprovalRecord] = []
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(sql, params)
rows = cur.fetchall()
for raw in rows:
record = {
"id": raw[0],
"approval_reference": raw[1],
"guid": raw[2],
"hostname_claimed": raw[3],
"ssl_key_fingerprint_claimed": raw[4],
"enrollment_code_id": raw[5],
"status": raw[6],
"client_nonce": raw[7],
"server_nonce": raw[8],
"created_at": raw[9],
"updated_at": raw[10],
"approved_by_user_id": raw[11],
"approved_by_username": raw[12],
}
conflict, fingerprint_match, requires_prompt = self._compute_hostname_conflict(
conn,
record.get("hostname_claimed"),
record.get("guid"),
record.get("ssl_key_fingerprint_claimed") or "",
)
alternate = None
if conflict and requires_prompt:
alternate = self._suggest_alternate_hostname(
conn,
record.get("hostname_claimed"),
record.get("guid"),
)
try:
approvals.append(
DeviceApprovalRecord.from_row(
record,
conflict=conflict,
alternate_hostname=alternate,
fingerprint_match=fingerprint_match,
requires_prompt=requires_prompt,
)
)
except Exception as exc: # pragma: no cover - defensive logging
self._log.warning(
"invalid device approval record id=%s: %s",
record.get("id"),
exc,
)
return approvals
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]:
"""Load a device approval using its operator-visible reference.""" """Load a device approval using its operator-visible reference."""
@@ -376,6 +627,98 @@ class SQLiteEnrollmentRepository:
) )
return None return None
def _compute_hostname_conflict(
self,
conn,
hostname: Optional[str],
pending_guid: Optional[str],
claimed_fp: str,
) -> Tuple[Optional[HostnameConflict], bool, bool]:
normalized_host = (hostname or "").strip()
if not normalized_host:
return None, False, False
try:
cur = conn.cursor()
cur.execute(
"""
SELECT d.guid,
d.ssl_key_fingerprint,
ds.site_id,
s.name
FROM devices AS d
LEFT JOIN device_sites AS ds ON ds.device_hostname = d.hostname
LEFT JOIN sites AS s ON s.id = ds.site_id
WHERE d.hostname = ?
""",
(normalized_host,),
)
row = cur.fetchone()
except Exception as exc: # pragma: no cover - defensive logging
self._log.warning("failed to inspect hostname conflict for %s: %s", normalized_host, exc)
return None, False, False
if not row:
return None, False, False
existing_guid = normalize_guid(row[0])
pending_norm = normalize_guid(pending_guid)
if existing_guid and pending_norm and existing_guid == pending_norm:
return None, False, False
stored_fp = (row[1] or "").strip().lower()
claimed_fp_normalized = (claimed_fp or "").strip().lower()
fingerprint_match = bool(stored_fp and claimed_fp_normalized and stored_fp == claimed_fp_normalized)
site_id = None
if row[2] is not None:
try:
site_id = int(row[2])
except (TypeError, ValueError): # pragma: no cover - defensive
site_id = None
site_name = str(row[3] or "").strip()
requires_prompt = not fingerprint_match
conflict = HostnameConflict(
guid=existing_guid or None,
ssl_key_fingerprint=stored_fp or None,
site_id=site_id,
site_name=site_name,
fingerprint_match=fingerprint_match,
requires_prompt=requires_prompt,
)
return conflict, fingerprint_match, requires_prompt
def _suggest_alternate_hostname(
self,
conn,
hostname: Optional[str],
pending_guid: Optional[str],
) -> Optional[str]:
base = (hostname or "").strip()
if not base:
return None
base = base[:253]
candidate = base
pending_norm = normalize_guid(pending_guid)
suffix = 1
cur = conn.cursor()
while True:
cur.execute("SELECT guid FROM devices WHERE hostname = ?", (candidate,))
row = cur.fetchone()
if not row:
return candidate
existing_guid = normalize_guid(row[0])
if pending_norm and existing_guid == pending_norm:
return candidate
candidate = f"{base}-{suffix}"
suffix += 1
if suffix > 50:
return pending_norm or candidate
@staticmethod @staticmethod
def _isoformat(value: datetime) -> str: def _isoformat(value: datetime) -> str:
if value.tzinfo is None: if value.tzinfo is None:

View File

@@ -31,6 +31,9 @@ def apply_all(conn: sqlite3.Connection) -> None:
_ensure_refresh_token_table(conn) _ensure_refresh_token_table(conn)
_ensure_install_code_table(conn) _ensure_install_code_table(conn)
_ensure_device_approval_table(conn) _ensure_device_approval_table(conn)
_ensure_device_list_views_table(conn)
_ensure_sites_tables(conn)
_ensure_credentials_table(conn)
_ensure_github_token_table(conn) _ensure_github_token_table(conn)
_ensure_scheduled_jobs_table(conn) _ensure_scheduled_jobs_table(conn)
_ensure_scheduled_job_run_tables(conn) _ensure_scheduled_job_run_tables(conn)
@@ -233,6 +236,73 @@ def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
) )
def _ensure_device_list_views_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS device_list_views (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
columns_json TEXT NOT NULL,
filters_json TEXT,
created_at INTEGER,
updated_at INTEGER
)
"""
)
def _ensure_sites_tables(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS sites (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
description TEXT,
created_at INTEGER
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS device_sites (
device_hostname TEXT UNIQUE NOT NULL,
site_id INTEGER NOT NULL,
assigned_at INTEGER,
FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE CASCADE
)
"""
)
def _ensure_credentials_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS credentials (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
description TEXT,
site_id INTEGER,
credential_type TEXT NOT NULL DEFAULT 'machine',
connection_type TEXT NOT NULL DEFAULT 'ssh',
username TEXT,
password_encrypted BLOB,
private_key_encrypted BLOB,
private_key_passphrase_encrypted BLOB,
become_method TEXT,
become_username TEXT,
become_password_encrypted BLOB,
metadata_json TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE SET NULL
)
"""
)
def _ensure_github_token_table(conn: sqlite3.Connection) -> None: def _ensure_github_token_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor() cur = conn.cursor()
cur.execute( cur.execute(

View File

@@ -71,6 +71,57 @@ class SQLiteUserRepository:
finally: finally:
conn.close() conn.close()
def resolve_identifier(self, username: str) -> Optional[str]:
normalized = (username or "").strip()
if not normalized:
return None
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute(
"SELECT id FROM users WHERE LOWER(username) = LOWER(?)",
(normalized,),
)
row = cur.fetchone()
if not row:
return None
return str(row[0]) if row[0] is not None else None
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.error("failed to resolve identifier for %s: %s", username, exc)
return None
finally:
conn.close()
def username_for_identifier(self, identifier: str) -> Optional[str]:
token = (identifier or "").strip()
if not token:
return None
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT username
FROM users
WHERE CAST(id AS TEXT) = ?
OR LOWER(username) = LOWER(?)
LIMIT 1
""",
(token, token),
)
row = cur.fetchone()
if not row:
return None
username = str(row[0] or "").strip()
return username or None
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.error("failed to resolve username for %s: %s", identifier, exc)
return None
finally:
conn.close()
def list_accounts(self) -> list[OperatorAccount]: def list_accounts(self) -> list[OperatorAccount]:
conn = self._connection_factory() conn = self._connection_factory()
try: try:

View File

@@ -23,6 +23,7 @@ __all__ = [
"SchedulerService", "SchedulerService",
"GitHubService", "GitHubService",
"GitHubTokenPayload", "GitHubTokenPayload",
"EnrollmentAdminService",
] ]
_LAZY_TARGETS: Dict[str, Tuple[str, str]] = { _LAZY_TARGETS: Dict[str, Tuple[str, str]] = {
@@ -43,6 +44,10 @@ _LAZY_TARGETS: Dict[str, Tuple[str, str]] = {
"SchedulerService": ("Data.Engine.services.jobs.scheduler_service", "SchedulerService"), "SchedulerService": ("Data.Engine.services.jobs.scheduler_service", "SchedulerService"),
"GitHubService": ("Data.Engine.services.github.github_service", "GitHubService"), "GitHubService": ("Data.Engine.services.github.github_service", "GitHubService"),
"GitHubTokenPayload": ("Data.Engine.services.github.github_service", "GitHubTokenPayload"), "GitHubTokenPayload": ("Data.Engine.services.github.github_service", "GitHubTokenPayload"),
"EnrollmentAdminService": (
"Data.Engine.services.enrollment.admin_service",
"EnrollmentAdminService",
),
} }

View File

@@ -30,6 +30,7 @@ from Data.Engine.services.auth import (
) )
from Data.Engine.services.crypto.signing import ScriptSigner, load_signer from Data.Engine.services.crypto.signing import ScriptSigner, load_signer
from Data.Engine.services.enrollment import EnrollmentService from Data.Engine.services.enrollment import EnrollmentService
from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService
from Data.Engine.services.enrollment.nonce_cache import NonceCache from Data.Engine.services.enrollment.nonce_cache import NonceCache
from Data.Engine.services.github import GitHubService from Data.Engine.services.github import GitHubService
from Data.Engine.services.jobs import SchedulerService from Data.Engine.services.jobs import SchedulerService
@@ -44,6 +45,7 @@ class EngineServiceContainer:
device_auth: DeviceAuthService device_auth: DeviceAuthService
token_service: TokenService token_service: TokenService
enrollment_service: EnrollmentService enrollment_service: EnrollmentService
enrollment_admin_service: EnrollmentAdminService
jwt_service: JWTService jwt_service: JWTService
dpop_validator: DPoPValidator dpop_validator: DPoPValidator
agent_realtime: AgentRealtimeService agent_realtime: AgentRealtimeService
@@ -93,6 +95,12 @@ def build_service_container(
logger=log.getChild("enrollment"), logger=log.getChild("enrollment"),
) )
enrollment_admin_service = EnrollmentAdminService(
repository=enrollment_repo,
user_repository=user_repo,
logger=log.getChild("enrollment_admin"),
)
device_auth = DeviceAuthService( device_auth = DeviceAuthService(
device_repository=device_repo, device_repository=device_repo,
jwt_service=jwt_service, jwt_service=jwt_service,
@@ -139,6 +147,7 @@ def build_service_container(
device_auth=device_auth, device_auth=device_auth,
token_service=token_service, token_service=token_service,
enrollment_service=enrollment_service, enrollment_service=enrollment_service,
enrollment_admin_service=enrollment_admin_service,
jwt_service=jwt_service, jwt_service=jwt_service,
dpop_validator=dpop_validator, dpop_validator=dpop_validator,
agent_realtime=agent_realtime, agent_realtime=agent_realtime,

View File

@@ -2,20 +2,54 @@
from __future__ import annotations from __future__ import annotations
from .enrollment_service import ( from importlib import import_module
EnrollmentRequestResult, from typing import Any
EnrollmentService,
EnrollmentStatus,
EnrollmentTokenBundle,
PollingResult,
)
from Data.Engine.domain.device_enrollment import EnrollmentValidationError
__all__ = [ __all__ = [
"EnrollmentRequestResult",
"EnrollmentService", "EnrollmentService",
"EnrollmentRequestResult",
"EnrollmentStatus", "EnrollmentStatus",
"EnrollmentTokenBundle", "EnrollmentTokenBundle",
"EnrollmentValidationError",
"PollingResult", "PollingResult",
"EnrollmentValidationError",
"EnrollmentAdminService",
] ]
_LAZY: dict[str, tuple[str, str]] = {
"EnrollmentService": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentService"),
"EnrollmentRequestResult": (
"Data.Engine.services.enrollment.enrollment_service",
"EnrollmentRequestResult",
),
"EnrollmentStatus": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentStatus"),
"EnrollmentTokenBundle": (
"Data.Engine.services.enrollment.enrollment_service",
"EnrollmentTokenBundle",
),
"PollingResult": ("Data.Engine.services.enrollment.enrollment_service", "PollingResult"),
"EnrollmentValidationError": (
"Data.Engine.domain.device_enrollment",
"EnrollmentValidationError",
),
"EnrollmentAdminService": (
"Data.Engine.services.enrollment.admin_service",
"EnrollmentAdminService",
),
}
def __getattr__(name: str) -> Any:
try:
module_name, attribute = _LAZY[name]
except KeyError as exc: # pragma: no cover - defensive
raise AttributeError(name) from exc
module = import_module(module_name)
value = getattr(module, attribute)
globals()[name] = value
return value
def __dir__() -> list[str]: # pragma: no cover - interactive helper
return sorted(set(__all__))

View File

@@ -0,0 +1,113 @@
"""Administrative helpers for enrollment workflows."""
from __future__ import annotations
import logging
import secrets
import uuid
from datetime import datetime, timedelta, timezone
from typing import Callable, List, Optional
from Data.Engine.domain.enrollment_admin import DeviceApprovalRecord, EnrollmentCodeRecord
from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
__all__ = ["EnrollmentAdminService"]
class EnrollmentAdminService:
"""Expose administrative enrollment operations."""
_VALID_TTL_HOURS = {1, 3, 6, 12, 24}
def __init__(
self,
*,
repository: SQLiteEnrollmentRepository,
user_repository: SQLiteUserRepository,
logger: Optional[logging.Logger] = None,
clock: Optional[Callable[[], datetime]] = None,
) -> None:
self._repository = repository
self._users = user_repository
self._log = logger or logging.getLogger("borealis.engine.services.enrollment_admin")
self._clock = clock or (lambda: datetime.now(tz=timezone.utc))
# ------------------------------------------------------------------
# Enrollment install codes
# ------------------------------------------------------------------
def list_install_codes(self, *, status: Optional[str] = None) -> List[EnrollmentCodeRecord]:
return self._repository.list_install_codes(status=status, now=self._clock())
def create_install_code(
self,
*,
ttl_hours: int,
max_uses: int,
created_by: Optional[str],
) -> EnrollmentCodeRecord:
if ttl_hours not in self._VALID_TTL_HOURS:
raise ValueError("invalid_ttl")
normalized_max = self._normalize_max_uses(max_uses)
now = self._clock()
expires_at = now + timedelta(hours=ttl_hours)
record_id = str(uuid.uuid4())
code = self._generate_install_code()
created_by_identifier = None
if created_by:
created_by_identifier = self._users.resolve_identifier(created_by)
if not created_by_identifier:
created_by_identifier = created_by.strip() or None
record = self._repository.insert_install_code(
record_id=record_id,
code=code,
expires_at=expires_at,
created_by=created_by_identifier,
max_uses=normalized_max,
)
self._log.info(
"install code created id=%s ttl=%sh max_uses=%s",
record.record_id,
ttl_hours,
normalized_max,
)
return record
def delete_install_code(self, record_id: str) -> bool:
deleted = self._repository.delete_install_code_if_unused(record_id)
if deleted:
self._log.info("install code deleted id=%s", record_id)
return deleted
# ------------------------------------------------------------------
# Device approvals
# ------------------------------------------------------------------
def list_device_approvals(self, *, status: Optional[str] = None) -> List[DeviceApprovalRecord]:
return self._repository.list_device_approvals(status=status)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _generate_install_code() -> str:
raw = secrets.token_hex(16).upper()
return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))
@staticmethod
def _normalize_max_uses(value: int) -> int:
try:
count = int(value)
except Exception:
count = 2
if count < 1:
return 1
if count > 10:
return 10
return count

View File

@@ -0,0 +1,122 @@
import base64
import sqlite3
from datetime import datetime, timezone
import pytest
from Data.Engine.repositories.sqlite import connection as sqlite_connection
from Data.Engine.repositories.sqlite import migrations as sqlite_migrations
from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService
def _build_service(tmp_path):
db_path = tmp_path / "admin.db"
conn = sqlite3.connect(db_path)
sqlite_migrations.apply_all(conn)
conn.close()
factory = sqlite_connection.connection_factory(db_path)
enrollment_repo = SQLiteEnrollmentRepository(factory)
user_repo = SQLiteUserRepository(factory)
fixed_now = datetime(2024, 1, 1, tzinfo=timezone.utc)
service = EnrollmentAdminService(
repository=enrollment_repo,
user_repository=user_repo,
clock=lambda: fixed_now,
)
return service, factory, fixed_now
def test_create_and_list_install_codes(tmp_path):
service, factory, fixed_now = _build_service(tmp_path)
record = service.create_install_code(ttl_hours=3, max_uses=5, created_by="admin")
assert record.code
assert record.max_uses == 5
assert record.status(now=fixed_now) == "active"
records = service.list_install_codes()
assert any(r.record_id == record.record_id for r in records)
# Invalid TTL should raise
with pytest.raises(ValueError):
service.create_install_code(ttl_hours=2, max_uses=1, created_by=None)
# Deleting should succeed and remove the record
assert service.delete_install_code(record.record_id) is True
remaining = service.list_install_codes()
assert all(r.record_id != record.record_id for r in remaining)
def test_list_device_approvals_includes_conflict(tmp_path):
service, factory, fixed_now = _build_service(tmp_path)
conn = factory()
cur = conn.cursor()
cur.execute(
"INSERT INTO sites (name, description, created_at) VALUES (?, ?, ?)",
("HQ", "Primary site", int(fixed_now.timestamp())),
)
site_id = cur.lastrowid
cur.execute(
"""
INSERT INTO devices (guid, hostname, created_at, last_seen, ssl_key_fingerprint, status)
VALUES (?, ?, ?, ?, ?, 'active')
""",
("11111111-1111-1111-1111-111111111111", "agent-one", int(fixed_now.timestamp()), int(fixed_now.timestamp()), "abc123",),
)
cur.execute(
"INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)",
("agent-one", site_id, int(fixed_now.timestamp())),
)
now_iso = fixed_now.isoformat()
cur.execute(
"""
INSERT INTO device_approvals (
id,
approval_reference,
guid,
hostname_claimed,
ssl_key_fingerprint_claimed,
enrollment_code_id,
status,
client_nonce,
server_nonce,
created_at,
updated_at,
approved_by_user_id,
agent_pubkey_der
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
"approval-1",
"REF123",
None,
"agent-one",
"abc123",
"code-1",
"pending",
base64.b64encode(b"client").decode(),
base64.b64encode(b"server").decode(),
now_iso,
now_iso,
None,
b"pubkey",
),
)
conn.commit()
conn.close()
approvals = service.list_device_approvals()
assert len(approvals) == 1
record = approvals[0]
assert record.hostname_conflict is not None
assert record.hostname_conflict.fingerprint_match is True
assert record.conflict_requires_prompt is False

View File

@@ -0,0 +1,111 @@
import base64
import sqlite3
from datetime import datetime, timezone
from .test_http_auth import _login, prepared_app
def test_enrollment_codes_require_authentication(prepared_app):
client = prepared_app.test_client()
resp = client.get("/api/admin/enrollment-codes")
assert resp.status_code == 401
def test_enrollment_code_workflow(prepared_app):
client = prepared_app.test_client()
_login(client)
payload = {"ttl_hours": 3, "max_uses": 4}
resp = client.post("/api/admin/enrollment-codes", json=payload)
assert resp.status_code == 201
created = resp.get_json()
assert created["max_uses"] == 4
assert created["status"] == "active"
resp = client.get("/api/admin/enrollment-codes")
assert resp.status_code == 200
codes = resp.get_json().get("codes", [])
assert any(code["id"] == created["id"] for code in codes)
resp = client.delete(f"/api/admin/enrollment-codes/{created['id']}")
assert resp.status_code == 200
def test_device_approvals_listing(prepared_app, engine_settings):
client = prepared_app.test_client()
_login(client)
conn = sqlite3.connect(engine_settings.database.path)
cur = conn.cursor()
now = datetime.now(tz=timezone.utc)
cur.execute(
"INSERT INTO sites (name, description, created_at) VALUES (?, ?, ?)",
("HQ", "Primary", int(now.timestamp())),
)
site_id = cur.lastrowid
cur.execute(
"""
INSERT INTO devices (guid, hostname, created_at, last_seen, ssl_key_fingerprint, status)
VALUES (?, ?, ?, ?, ?, 'active')
""",
(
"22222222-2222-2222-2222-222222222222",
"approval-host",
int(now.timestamp()),
int(now.timestamp()),
"deadbeef",
),
)
cur.execute(
"INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)",
("approval-host", site_id, int(now.timestamp())),
)
now_iso = now.isoformat()
cur.execute(
"""
INSERT INTO device_approvals (
id,
approval_reference,
guid,
hostname_claimed,
ssl_key_fingerprint_claimed,
enrollment_code_id,
status,
client_nonce,
server_nonce,
created_at,
updated_at,
approved_by_user_id,
agent_pubkey_der
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
"approval-http",
"REFHTTP",
None,
"approval-host",
"deadbeef",
"code-http",
"pending",
base64.b64encode(b"client").decode(),
base64.b64encode(b"server").decode(),
now_iso,
now_iso,
None,
b"pub",
),
)
conn.commit()
conn.close()
resp = client.get("/api/admin/device-approvals")
assert resp.status_code == 200
body = resp.get_json()
approvals = body.get("approvals", [])
assert any(a["id"] == "approval-http" for a in approvals)
record = next(a for a in approvals if a["id"] == "approval-http")
assert record.get("hostname_conflict", {}).get("fingerprint_match") is True