mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 15:21:57 -06:00
Implement admin enrollment APIs
This commit is contained in:
@@ -18,6 +18,7 @@ __all__ = [
|
|||||||
"AccessTokenClaims",
|
"AccessTokenClaims",
|
||||||
"DeviceAuthContext",
|
"DeviceAuthContext",
|
||||||
"sanitize_service_context",
|
"sanitize_service_context",
|
||||||
|
"normalize_guid",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -73,6 +74,12 @@ class DeviceGuid:
|
|||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_guid(value: Optional[str]) -> str:
|
||||||
|
"""Expose GUID normalization for administrative helpers."""
|
||||||
|
|
||||||
|
return _normalize_guid(value)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class DeviceFingerprint:
|
class DeviceFingerprint:
|
||||||
"""Normalized TLS key fingerprint associated with a device."""
|
"""Normalized TLS key fingerprint associated with a device."""
|
||||||
|
|||||||
206
Data/Engine/domain/enrollment_admin.py
Normal file
206
Data/Engine/domain/enrollment_admin.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""Administrative enrollment domain models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
|
from Data.Engine.domain.device_auth import DeviceGuid, normalize_guid
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EnrollmentCodeRecord",
|
||||||
|
"DeviceApprovalRecord",
|
||||||
|
"HostnameConflict",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_iso8601(value: Optional[str]) -> Optional[datetime]:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
raw = str(value).strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
dt = datetime.fromisoformat(raw)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive parsing
|
||||||
|
raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
return dt.replace(tzinfo=timezone.utc)
|
||||||
|
return dt.astimezone(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def _isoformat(value: Optional[datetime]) -> Optional[str]:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if value.tzinfo is None:
|
||||||
|
value = value.replace(tzinfo=timezone.utc)
|
||||||
|
return value.astimezone(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class EnrollmentCodeRecord:
|
||||||
|
"""Installer code metadata exposed to administrative clients."""
|
||||||
|
|
||||||
|
record_id: str
|
||||||
|
code: str
|
||||||
|
expires_at: datetime
|
||||||
|
max_uses: int
|
||||||
|
use_count: int
|
||||||
|
created_by_user_id: Optional[str]
|
||||||
|
used_at: Optional[datetime]
|
||||||
|
used_by_guid: Optional[DeviceGuid]
|
||||||
|
last_used_at: Optional[datetime]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_row(cls, row: Mapping[str, Any]) -> "EnrollmentCodeRecord":
|
||||||
|
record_id = str(row.get("id") or "").strip()
|
||||||
|
code = str(row.get("code") or "").strip()
|
||||||
|
if not record_id or not code:
|
||||||
|
raise ValueError("invalid enrollment install code record")
|
||||||
|
|
||||||
|
used_by = row.get("used_by_guid")
|
||||||
|
used_by_guid = DeviceGuid(str(used_by)) if used_by else None
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
record_id=record_id,
|
||||||
|
code=code,
|
||||||
|
expires_at=_parse_iso8601(row.get("expires_at")) or datetime.now(tz=timezone.utc),
|
||||||
|
max_uses=int(row.get("max_uses") or 1),
|
||||||
|
use_count=int(row.get("use_count") or 0),
|
||||||
|
created_by_user_id=str(row.get("created_by_user_id") or "").strip() or None,
|
||||||
|
used_at=_parse_iso8601(row.get("used_at")),
|
||||||
|
used_by_guid=used_by_guid,
|
||||||
|
last_used_at=_parse_iso8601(row.get("last_used_at")),
|
||||||
|
)
|
||||||
|
|
||||||
|
def status(self, *, now: Optional[datetime] = None) -> str:
|
||||||
|
reference = now or datetime.now(tz=timezone.utc)
|
||||||
|
if self.use_count >= self.max_uses:
|
||||||
|
return "used"
|
||||||
|
if self.expires_at <= reference:
|
||||||
|
return "expired"
|
||||||
|
return "active"
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": self.record_id,
|
||||||
|
"code": self.code,
|
||||||
|
"expires_at": _isoformat(self.expires_at),
|
||||||
|
"max_uses": self.max_uses,
|
||||||
|
"use_count": self.use_count,
|
||||||
|
"created_by_user_id": self.created_by_user_id,
|
||||||
|
"used_at": _isoformat(self.used_at),
|
||||||
|
"used_by_guid": self.used_by_guid.value if self.used_by_guid else None,
|
||||||
|
"last_used_at": _isoformat(self.last_used_at),
|
||||||
|
"status": self.status(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class HostnameConflict:
|
||||||
|
"""Existing device details colliding with a pending approval."""
|
||||||
|
|
||||||
|
guid: Optional[str]
|
||||||
|
ssl_key_fingerprint: Optional[str]
|
||||||
|
site_id: Optional[int]
|
||||||
|
site_name: str
|
||||||
|
fingerprint_match: bool
|
||||||
|
requires_prompt: bool
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"guid": self.guid,
|
||||||
|
"ssl_key_fingerprint": self.ssl_key_fingerprint,
|
||||||
|
"site_id": self.site_id,
|
||||||
|
"site_name": self.site_name,
|
||||||
|
"fingerprint_match": self.fingerprint_match,
|
||||||
|
"requires_prompt": self.requires_prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class DeviceApprovalRecord:
|
||||||
|
"""Administrative projection of a device approval entry."""
|
||||||
|
|
||||||
|
record_id: str
|
||||||
|
reference: str
|
||||||
|
status: str
|
||||||
|
claimed_hostname: str
|
||||||
|
claimed_fingerprint: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
enrollment_code_id: Optional[str]
|
||||||
|
guid: Optional[str]
|
||||||
|
approved_by_user_id: Optional[str]
|
||||||
|
approved_by_username: Optional[str]
|
||||||
|
client_nonce: str
|
||||||
|
server_nonce: str
|
||||||
|
hostname_conflict: Optional[HostnameConflict]
|
||||||
|
alternate_hostname: Optional[str]
|
||||||
|
conflict_requires_prompt: bool
|
||||||
|
fingerprint_match: bool
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_row(
|
||||||
|
cls,
|
||||||
|
row: Mapping[str, Any],
|
||||||
|
*,
|
||||||
|
conflict: Optional[HostnameConflict] = None,
|
||||||
|
alternate_hostname: Optional[str] = None,
|
||||||
|
fingerprint_match: bool = False,
|
||||||
|
requires_prompt: bool = False,
|
||||||
|
) -> "DeviceApprovalRecord":
|
||||||
|
record_id = str(row.get("id") or "").strip()
|
||||||
|
reference = str(row.get("approval_reference") or "").strip()
|
||||||
|
hostname = str(row.get("hostname_claimed") or "").strip()
|
||||||
|
fingerprint = str(row.get("ssl_key_fingerprint_claimed") or "").strip().lower()
|
||||||
|
if not record_id or not reference or not hostname or not fingerprint:
|
||||||
|
raise ValueError("invalid device approval record")
|
||||||
|
|
||||||
|
guid_raw = normalize_guid(row.get("guid")) or None
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
record_id=record_id,
|
||||||
|
reference=reference,
|
||||||
|
status=str(row.get("status") or "pending").strip().lower(),
|
||||||
|
claimed_hostname=hostname,
|
||||||
|
claimed_fingerprint=fingerprint,
|
||||||
|
created_at=_parse_iso8601(row.get("created_at")) or datetime.now(tz=timezone.utc),
|
||||||
|
updated_at=_parse_iso8601(row.get("updated_at")) or datetime.now(tz=timezone.utc),
|
||||||
|
enrollment_code_id=str(row.get("enrollment_code_id") or "").strip() or None,
|
||||||
|
guid=guid_raw,
|
||||||
|
approved_by_user_id=str(row.get("approved_by_user_id") or "").strip() or None,
|
||||||
|
approved_by_username=str(row.get("approved_by_username") or "").strip() or None,
|
||||||
|
client_nonce=str(row.get("client_nonce") or "").strip(),
|
||||||
|
server_nonce=str(row.get("server_nonce") or "").strip(),
|
||||||
|
hostname_conflict=conflict,
|
||||||
|
alternate_hostname=alternate_hostname,
|
||||||
|
conflict_requires_prompt=requires_prompt,
|
||||||
|
fingerprint_match=fingerprint_match,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"id": self.record_id,
|
||||||
|
"approval_reference": self.reference,
|
||||||
|
"status": self.status,
|
||||||
|
"hostname_claimed": self.claimed_hostname,
|
||||||
|
"ssl_key_fingerprint_claimed": self.claimed_fingerprint,
|
||||||
|
"created_at": _isoformat(self.created_at),
|
||||||
|
"updated_at": _isoformat(self.updated_at),
|
||||||
|
"enrollment_code_id": self.enrollment_code_id,
|
||||||
|
"guid": self.guid,
|
||||||
|
"approved_by_user_id": self.approved_by_user_id,
|
||||||
|
"approved_by_username": self.approved_by_username,
|
||||||
|
"client_nonce": self.client_nonce,
|
||||||
|
"server_nonce": self.server_nonce,
|
||||||
|
"conflict_requires_prompt": self.conflict_requires_prompt,
|
||||||
|
"fingerprint_match": self.fingerprint_match,
|
||||||
|
}
|
||||||
|
if self.hostname_conflict is not None:
|
||||||
|
payload["hostname_conflict"] = self.hostname_conflict.to_dict()
|
||||||
|
if self.alternate_hostname:
|
||||||
|
payload["alternate_hostname"] = self.alternate_hostname
|
||||||
|
return payload
|
||||||
|
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Administrative HTTP interface placeholders for the Engine."""
|
"""Administrative HTTP endpoints for the Borealis Engine."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from flask import Blueprint, Flask
|
from flask import Blueprint, Flask, current_app, jsonify, request, session
|
||||||
|
|
||||||
from Data.Engine.services.container import EngineServiceContainer
|
from Data.Engine.services.container import EngineServiceContainer
|
||||||
|
|
||||||
@@ -11,13 +11,106 @@ blueprint = Blueprint("engine_admin", __name__, url_prefix="/api/admin")
|
|||||||
|
|
||||||
|
|
||||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||||
"""Attach administrative routes to *app*.
|
"""Attach administrative routes to *app*."""
|
||||||
|
|
||||||
Concrete endpoints will be migrated in subsequent phases.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if "engine_admin" not in app.blueprints:
|
if "engine_admin" not in app.blueprints:
|
||||||
app.register_blueprint(blueprint)
|
app.register_blueprint(blueprint)
|
||||||
|
|
||||||
|
|
||||||
|
def _services() -> EngineServiceContainer:
|
||||||
|
services = current_app.extensions.get("engine_services")
|
||||||
|
if services is None: # pragma: no cover - defensive
|
||||||
|
raise RuntimeError("engine services not initialized")
|
||||||
|
return services
|
||||||
|
|
||||||
|
|
||||||
|
def _admin_service():
|
||||||
|
return _services().enrollment_admin_service
|
||||||
|
|
||||||
|
|
||||||
|
def _require_admin():
|
||||||
|
username = session.get("username")
|
||||||
|
role = (session.get("role") or "").strip().lower()
|
||||||
|
if not isinstance(username, str) or not username:
|
||||||
|
return jsonify({"error": "not_authenticated"}), 401
|
||||||
|
if role != "admin":
|
||||||
|
return jsonify({"error": "forbidden"}), 403
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@blueprint.route("/enrollment-codes", methods=["GET"])
|
||||||
|
def list_enrollment_codes() -> object:
|
||||||
|
guard = _require_admin()
|
||||||
|
if guard:
|
||||||
|
return guard
|
||||||
|
|
||||||
|
status = request.args.get("status")
|
||||||
|
records = _admin_service().list_install_codes(status=status)
|
||||||
|
return jsonify({"codes": [record.to_dict() for record in records]})
|
||||||
|
|
||||||
|
|
||||||
|
@blueprint.route("/enrollment-codes", methods=["POST"])
|
||||||
|
def create_enrollment_code() -> object:
|
||||||
|
guard = _require_admin()
|
||||||
|
if guard:
|
||||||
|
return guard
|
||||||
|
|
||||||
|
payload = request.get_json(silent=True) or {}
|
||||||
|
|
||||||
|
ttl_value = payload.get("ttl_hours")
|
||||||
|
if ttl_value is None:
|
||||||
|
ttl_value = payload.get("ttl") or 1
|
||||||
|
try:
|
||||||
|
ttl_hours = int(ttl_value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
ttl_hours = 1
|
||||||
|
|
||||||
|
max_uses_value = payload.get("max_uses")
|
||||||
|
if max_uses_value is None:
|
||||||
|
max_uses_value = payload.get("allowed_uses", 2)
|
||||||
|
try:
|
||||||
|
max_uses = int(max_uses_value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
max_uses = 2
|
||||||
|
|
||||||
|
creator = session.get("username") if isinstance(session.get("username"), str) else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
record = _admin_service().create_install_code(
|
||||||
|
ttl_hours=ttl_hours,
|
||||||
|
max_uses=max_uses,
|
||||||
|
created_by=creator,
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
if str(exc) == "invalid_ttl":
|
||||||
|
return jsonify({"error": "invalid_ttl"}), 400
|
||||||
|
raise
|
||||||
|
|
||||||
|
response = jsonify(record.to_dict())
|
||||||
|
response.status_code = 201
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@blueprint.route("/enrollment-codes/<code_id>", methods=["DELETE"])
|
||||||
|
def delete_enrollment_code(code_id: str) -> object:
|
||||||
|
guard = _require_admin()
|
||||||
|
if guard:
|
||||||
|
return guard
|
||||||
|
|
||||||
|
if not _admin_service().delete_install_code(code_id):
|
||||||
|
return jsonify({"error": "not_found"}), 404
|
||||||
|
return jsonify({"status": "deleted"})
|
||||||
|
|
||||||
|
|
||||||
|
@blueprint.route("/device-approvals", methods=["GET"])
|
||||||
|
def list_device_approvals() -> object:
|
||||||
|
guard = _require_admin()
|
||||||
|
if guard:
|
||||||
|
return guard
|
||||||
|
|
||||||
|
status = request.args.get("status")
|
||||||
|
records = _admin_service().list_device_approvals(status=status)
|
||||||
|
return jsonify({"approvals": [record.to_dict() for record in records]})
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["register", "blueprint"]
|
__all__ = ["register", "blueprint"]
|
||||||
|
|||||||
@@ -5,14 +5,19 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid
|
from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid, normalize_guid
|
||||||
from Data.Engine.domain.device_enrollment import (
|
from Data.Engine.domain.device_enrollment import (
|
||||||
EnrollmentApproval,
|
EnrollmentApproval,
|
||||||
EnrollmentApprovalStatus,
|
EnrollmentApprovalStatus,
|
||||||
EnrollmentCode,
|
EnrollmentCode,
|
||||||
)
|
)
|
||||||
|
from Data.Engine.domain.enrollment_admin import (
|
||||||
|
DeviceApprovalRecord,
|
||||||
|
EnrollmentCodeRecord,
|
||||||
|
HostnameConflict,
|
||||||
|
)
|
||||||
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||||
|
|
||||||
__all__ = ["SQLiteEnrollmentRepository"]
|
__all__ = ["SQLiteEnrollmentRepository"]
|
||||||
@@ -122,6 +127,158 @@ class SQLiteEnrollmentRepository:
|
|||||||
self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc)
|
self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def list_install_codes(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
now: Optional[datetime] = None,
|
||||||
|
) -> List[EnrollmentCodeRecord]:
|
||||||
|
reference = now or datetime.now(tz=timezone.utc)
|
||||||
|
status_filter = (status or "").strip().lower()
|
||||||
|
params: List[str] = []
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT id,
|
||||||
|
code,
|
||||||
|
expires_at,
|
||||||
|
created_by_user_id,
|
||||||
|
used_at,
|
||||||
|
used_by_guid,
|
||||||
|
max_uses,
|
||||||
|
use_count,
|
||||||
|
last_used_at
|
||||||
|
FROM enrollment_install_codes
|
||||||
|
"""
|
||||||
|
|
||||||
|
if status_filter in {"active", "expired", "used"}:
|
||||||
|
sql += " WHERE "
|
||||||
|
if status_filter == "active":
|
||||||
|
sql += "use_count < max_uses AND expires_at > ?"
|
||||||
|
params.append(self._isoformat(reference))
|
||||||
|
elif status_filter == "expired":
|
||||||
|
sql += "use_count < max_uses AND expires_at <= ?"
|
||||||
|
params.append(self._isoformat(reference))
|
||||||
|
else: # used
|
||||||
|
sql += "use_count >= max_uses"
|
||||||
|
|
||||||
|
sql += " ORDER BY expires_at ASC"
|
||||||
|
|
||||||
|
rows: List[EnrollmentCodeRecord] = []
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(sql, params)
|
||||||
|
for raw in cur.fetchall():
|
||||||
|
record = {
|
||||||
|
"id": raw[0],
|
||||||
|
"code": raw[1],
|
||||||
|
"expires_at": raw[2],
|
||||||
|
"created_by_user_id": raw[3],
|
||||||
|
"used_at": raw[4],
|
||||||
|
"used_by_guid": raw[5],
|
||||||
|
"max_uses": raw[6],
|
||||||
|
"use_count": raw[7],
|
||||||
|
"last_used_at": raw[8],
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
rows.append(EnrollmentCodeRecord.from_row(record))
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
self._log.warning("invalid enrollment install code row id=%s: %s", record.get("id"), exc)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
def get_install_code_record(self, record_id: str) -> Optional[EnrollmentCodeRecord]:
|
||||||
|
identifier = (record_id or "").strip()
|
||||||
|
if not identifier:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT id,
|
||||||
|
code,
|
||||||
|
expires_at,
|
||||||
|
created_by_user_id,
|
||||||
|
used_at,
|
||||||
|
used_by_guid,
|
||||||
|
max_uses,
|
||||||
|
use_count,
|
||||||
|
last_used_at
|
||||||
|
FROM enrollment_install_codes
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(identifier,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"id": row[0],
|
||||||
|
"code": row[1],
|
||||||
|
"expires_at": row[2],
|
||||||
|
"created_by_user_id": row[3],
|
||||||
|
"used_at": row[4],
|
||||||
|
"used_by_guid": row[5],
|
||||||
|
"max_uses": row[6],
|
||||||
|
"use_count": row[7],
|
||||||
|
"last_used_at": row[8],
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
return EnrollmentCodeRecord.from_row(payload)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
self._log.warning("invalid enrollment install code record id=%s: %s", identifier, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def insert_install_code(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
record_id: str,
|
||||||
|
code: str,
|
||||||
|
expires_at: datetime,
|
||||||
|
created_by: Optional[str],
|
||||||
|
max_uses: int,
|
||||||
|
) -> EnrollmentCodeRecord:
|
||||||
|
expires_iso = self._isoformat(expires_at)
|
||||||
|
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO enrollment_install_codes (
|
||||||
|
id,
|
||||||
|
code,
|
||||||
|
expires_at,
|
||||||
|
created_by_user_id,
|
||||||
|
max_uses,
|
||||||
|
use_count
|
||||||
|
) VALUES (?, ?, ?, ?, ?, 0)
|
||||||
|
""",
|
||||||
|
(record_id, code, expires_iso, created_by, max_uses),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
record = self.get_install_code_record(record_id)
|
||||||
|
if record is None:
|
||||||
|
raise RuntimeError("failed to load install code after insert")
|
||||||
|
return record
|
||||||
|
|
||||||
|
def delete_install_code_if_unused(self, record_id: str) -> bool:
|
||||||
|
identifier = (record_id or "").strip()
|
||||||
|
if not identifier:
|
||||||
|
return False
|
||||||
|
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
|
||||||
|
(identifier,),
|
||||||
|
)
|
||||||
|
deleted = cur.rowcount > 0
|
||||||
|
conn.commit()
|
||||||
|
return deleted
|
||||||
|
|
||||||
def update_install_code_usage(
|
def update_install_code_usage(
|
||||||
self,
|
self,
|
||||||
record_id: str,
|
record_id: str,
|
||||||
@@ -165,6 +322,100 @@ class SQLiteEnrollmentRepository:
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Device approvals
|
# Device approvals
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
def list_device_approvals(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
) -> List[DeviceApprovalRecord]:
|
||||||
|
status_filter = (status or "").strip().lower()
|
||||||
|
params: List[str] = []
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT
|
||||||
|
da.id,
|
||||||
|
da.approval_reference,
|
||||||
|
da.guid,
|
||||||
|
da.hostname_claimed,
|
||||||
|
da.ssl_key_fingerprint_claimed,
|
||||||
|
da.enrollment_code_id,
|
||||||
|
da.status,
|
||||||
|
da.client_nonce,
|
||||||
|
da.server_nonce,
|
||||||
|
da.created_at,
|
||||||
|
da.updated_at,
|
||||||
|
da.approved_by_user_id,
|
||||||
|
u.username AS approved_by_username
|
||||||
|
FROM device_approvals AS da
|
||||||
|
LEFT JOIN users AS u
|
||||||
|
ON (
|
||||||
|
CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT)
|
||||||
|
OR LOWER(da.approved_by_user_id) = LOWER(u.username)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if status_filter and status_filter not in {"all", "*"}:
|
||||||
|
sql += " WHERE LOWER(da.status) = ?"
|
||||||
|
params.append(status_filter)
|
||||||
|
|
||||||
|
sql += " ORDER BY da.created_at ASC"
|
||||||
|
|
||||||
|
approvals: List[DeviceApprovalRecord] = []
|
||||||
|
with closing(self._connections()) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(sql, params)
|
||||||
|
rows = cur.fetchall()
|
||||||
|
|
||||||
|
for raw in rows:
|
||||||
|
record = {
|
||||||
|
"id": raw[0],
|
||||||
|
"approval_reference": raw[1],
|
||||||
|
"guid": raw[2],
|
||||||
|
"hostname_claimed": raw[3],
|
||||||
|
"ssl_key_fingerprint_claimed": raw[4],
|
||||||
|
"enrollment_code_id": raw[5],
|
||||||
|
"status": raw[6],
|
||||||
|
"client_nonce": raw[7],
|
||||||
|
"server_nonce": raw[8],
|
||||||
|
"created_at": raw[9],
|
||||||
|
"updated_at": raw[10],
|
||||||
|
"approved_by_user_id": raw[11],
|
||||||
|
"approved_by_username": raw[12],
|
||||||
|
}
|
||||||
|
|
||||||
|
conflict, fingerprint_match, requires_prompt = self._compute_hostname_conflict(
|
||||||
|
conn,
|
||||||
|
record.get("hostname_claimed"),
|
||||||
|
record.get("guid"),
|
||||||
|
record.get("ssl_key_fingerprint_claimed") or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
alternate = None
|
||||||
|
if conflict and requires_prompt:
|
||||||
|
alternate = self._suggest_alternate_hostname(
|
||||||
|
conn,
|
||||||
|
record.get("hostname_claimed"),
|
||||||
|
record.get("guid"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
approvals.append(
|
||||||
|
DeviceApprovalRecord.from_row(
|
||||||
|
record,
|
||||||
|
conflict=conflict,
|
||||||
|
alternate_hostname=alternate,
|
||||||
|
fingerprint_match=fingerprint_match,
|
||||||
|
requires_prompt=requires_prompt,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
self._log.warning(
|
||||||
|
"invalid device approval record id=%s: %s",
|
||||||
|
record.get("id"),
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
return approvals
|
||||||
|
|
||||||
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]:
|
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]:
|
||||||
"""Load a device approval using its operator-visible reference."""
|
"""Load a device approval using its operator-visible reference."""
|
||||||
|
|
||||||
@@ -376,6 +627,98 @@ class SQLiteEnrollmentRepository:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _compute_hostname_conflict(
|
||||||
|
self,
|
||||||
|
conn,
|
||||||
|
hostname: Optional[str],
|
||||||
|
pending_guid: Optional[str],
|
||||||
|
claimed_fp: str,
|
||||||
|
) -> Tuple[Optional[HostnameConflict], bool, bool]:
|
||||||
|
normalized_host = (hostname or "").strip()
|
||||||
|
if not normalized_host:
|
||||||
|
return None, False, False
|
||||||
|
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT d.guid,
|
||||||
|
d.ssl_key_fingerprint,
|
||||||
|
ds.site_id,
|
||||||
|
s.name
|
||||||
|
FROM devices AS d
|
||||||
|
LEFT JOIN device_sites AS ds ON ds.device_hostname = d.hostname
|
||||||
|
LEFT JOIN sites AS s ON s.id = ds.site_id
|
||||||
|
WHERE d.hostname = ?
|
||||||
|
""",
|
||||||
|
(normalized_host,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
self._log.warning("failed to inspect hostname conflict for %s: %s", normalized_host, exc)
|
||||||
|
return None, False, False
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
return None, False, False
|
||||||
|
|
||||||
|
existing_guid = normalize_guid(row[0])
|
||||||
|
pending_norm = normalize_guid(pending_guid)
|
||||||
|
if existing_guid and pending_norm and existing_guid == pending_norm:
|
||||||
|
return None, False, False
|
||||||
|
|
||||||
|
stored_fp = (row[1] or "").strip().lower()
|
||||||
|
claimed_fp_normalized = (claimed_fp or "").strip().lower()
|
||||||
|
fingerprint_match = bool(stored_fp and claimed_fp_normalized and stored_fp == claimed_fp_normalized)
|
||||||
|
|
||||||
|
site_id = None
|
||||||
|
if row[2] is not None:
|
||||||
|
try:
|
||||||
|
site_id = int(row[2])
|
||||||
|
except (TypeError, ValueError): # pragma: no cover - defensive
|
||||||
|
site_id = None
|
||||||
|
|
||||||
|
site_name = str(row[3] or "").strip()
|
||||||
|
requires_prompt = not fingerprint_match
|
||||||
|
|
||||||
|
conflict = HostnameConflict(
|
||||||
|
guid=existing_guid or None,
|
||||||
|
ssl_key_fingerprint=stored_fp or None,
|
||||||
|
site_id=site_id,
|
||||||
|
site_name=site_name,
|
||||||
|
fingerprint_match=fingerprint_match,
|
||||||
|
requires_prompt=requires_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
return conflict, fingerprint_match, requires_prompt
|
||||||
|
|
||||||
|
def _suggest_alternate_hostname(
|
||||||
|
self,
|
||||||
|
conn,
|
||||||
|
hostname: Optional[str],
|
||||||
|
pending_guid: Optional[str],
|
||||||
|
) -> Optional[str]:
|
||||||
|
base = (hostname or "").strip()
|
||||||
|
if not base:
|
||||||
|
return None
|
||||||
|
base = base[:253]
|
||||||
|
candidate = base
|
||||||
|
pending_norm = normalize_guid(pending_guid)
|
||||||
|
suffix = 1
|
||||||
|
|
||||||
|
cur = conn.cursor()
|
||||||
|
while True:
|
||||||
|
cur.execute("SELECT guid FROM devices WHERE hostname = ?", (candidate,))
|
||||||
|
row = cur.fetchone()
|
||||||
|
if not row:
|
||||||
|
return candidate
|
||||||
|
existing_guid = normalize_guid(row[0])
|
||||||
|
if pending_norm and existing_guid == pending_norm:
|
||||||
|
return candidate
|
||||||
|
candidate = f"{base}-{suffix}"
|
||||||
|
suffix += 1
|
||||||
|
if suffix > 50:
|
||||||
|
return pending_norm or candidate
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _isoformat(value: datetime) -> str:
|
def _isoformat(value: datetime) -> str:
|
||||||
if value.tzinfo is None:
|
if value.tzinfo is None:
|
||||||
|
|||||||
@@ -31,6 +31,9 @@ def apply_all(conn: sqlite3.Connection) -> None:
|
|||||||
_ensure_refresh_token_table(conn)
|
_ensure_refresh_token_table(conn)
|
||||||
_ensure_install_code_table(conn)
|
_ensure_install_code_table(conn)
|
||||||
_ensure_device_approval_table(conn)
|
_ensure_device_approval_table(conn)
|
||||||
|
_ensure_device_list_views_table(conn)
|
||||||
|
_ensure_sites_tables(conn)
|
||||||
|
_ensure_credentials_table(conn)
|
||||||
_ensure_github_token_table(conn)
|
_ensure_github_token_table(conn)
|
||||||
_ensure_scheduled_jobs_table(conn)
|
_ensure_scheduled_jobs_table(conn)
|
||||||
_ensure_scheduled_job_run_tables(conn)
|
_ensure_scheduled_job_run_tables(conn)
|
||||||
@@ -233,6 +236,73 @@ def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_device_list_views_table(conn: sqlite3.Connection) -> None:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS device_list_views (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT UNIQUE NOT NULL,
|
||||||
|
columns_json TEXT NOT NULL,
|
||||||
|
filters_json TEXT,
|
||||||
|
created_at INTEGER,
|
||||||
|
updated_at INTEGER
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_sites_tables(conn: sqlite3.Connection) -> None:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS sites (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT UNIQUE NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
created_at INTEGER
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS device_sites (
|
||||||
|
device_hostname TEXT UNIQUE NOT NULL,
|
||||||
|
site_id INTEGER NOT NULL,
|
||||||
|
assigned_at INTEGER,
|
||||||
|
FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_credentials_table(conn: sqlite3.Connection) -> None:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS credentials (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL UNIQUE,
|
||||||
|
description TEXT,
|
||||||
|
site_id INTEGER,
|
||||||
|
credential_type TEXT NOT NULL DEFAULT 'machine',
|
||||||
|
connection_type TEXT NOT NULL DEFAULT 'ssh',
|
||||||
|
username TEXT,
|
||||||
|
password_encrypted BLOB,
|
||||||
|
private_key_encrypted BLOB,
|
||||||
|
private_key_passphrase_encrypted BLOB,
|
||||||
|
become_method TEXT,
|
||||||
|
become_username TEXT,
|
||||||
|
become_password_encrypted BLOB,
|
||||||
|
metadata_json TEXT,
|
||||||
|
created_at INTEGER NOT NULL,
|
||||||
|
updated_at INTEGER NOT NULL,
|
||||||
|
FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE SET NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _ensure_github_token_table(conn: sqlite3.Connection) -> None:
|
def _ensure_github_token_table(conn: sqlite3.Connection) -> None:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute(
|
cur.execute(
|
||||||
|
|||||||
@@ -71,6 +71,57 @@ class SQLiteUserRepository:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def resolve_identifier(self, username: str) -> Optional[str]:
|
||||||
|
normalized = (username or "").strip()
|
||||||
|
if not normalized:
|
||||||
|
return None
|
||||||
|
|
||||||
|
conn = self._connection_factory()
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"SELECT id FROM users WHERE LOWER(username) = LOWER(?)",
|
||||||
|
(normalized,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
return str(row[0]) if row[0] is not None else None
|
||||||
|
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||||
|
self._log.error("failed to resolve identifier for %s: %s", username, exc)
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def username_for_identifier(self, identifier: str) -> Optional[str]:
|
||||||
|
token = (identifier or "").strip()
|
||||||
|
if not token:
|
||||||
|
return None
|
||||||
|
|
||||||
|
conn = self._connection_factory()
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT username
|
||||||
|
FROM users
|
||||||
|
WHERE CAST(id AS TEXT) = ?
|
||||||
|
OR LOWER(username) = LOWER(?)
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(token, token),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
username = str(row[0] or "").strip()
|
||||||
|
return username or None
|
||||||
|
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||||
|
self._log.error("failed to resolve username for %s: %s", identifier, exc)
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
def list_accounts(self) -> list[OperatorAccount]:
|
def list_accounts(self) -> list[OperatorAccount]:
|
||||||
conn = self._connection_factory()
|
conn = self._connection_factory()
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ __all__ = [
|
|||||||
"SchedulerService",
|
"SchedulerService",
|
||||||
"GitHubService",
|
"GitHubService",
|
||||||
"GitHubTokenPayload",
|
"GitHubTokenPayload",
|
||||||
|
"EnrollmentAdminService",
|
||||||
]
|
]
|
||||||
|
|
||||||
_LAZY_TARGETS: Dict[str, Tuple[str, str]] = {
|
_LAZY_TARGETS: Dict[str, Tuple[str, str]] = {
|
||||||
@@ -43,6 +44,10 @@ _LAZY_TARGETS: Dict[str, Tuple[str, str]] = {
|
|||||||
"SchedulerService": ("Data.Engine.services.jobs.scheduler_service", "SchedulerService"),
|
"SchedulerService": ("Data.Engine.services.jobs.scheduler_service", "SchedulerService"),
|
||||||
"GitHubService": ("Data.Engine.services.github.github_service", "GitHubService"),
|
"GitHubService": ("Data.Engine.services.github.github_service", "GitHubService"),
|
||||||
"GitHubTokenPayload": ("Data.Engine.services.github.github_service", "GitHubTokenPayload"),
|
"GitHubTokenPayload": ("Data.Engine.services.github.github_service", "GitHubTokenPayload"),
|
||||||
|
"EnrollmentAdminService": (
|
||||||
|
"Data.Engine.services.enrollment.admin_service",
|
||||||
|
"EnrollmentAdminService",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from Data.Engine.services.auth import (
|
|||||||
)
|
)
|
||||||
from Data.Engine.services.crypto.signing import ScriptSigner, load_signer
|
from Data.Engine.services.crypto.signing import ScriptSigner, load_signer
|
||||||
from Data.Engine.services.enrollment import EnrollmentService
|
from Data.Engine.services.enrollment import EnrollmentService
|
||||||
|
from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService
|
||||||
from Data.Engine.services.enrollment.nonce_cache import NonceCache
|
from Data.Engine.services.enrollment.nonce_cache import NonceCache
|
||||||
from Data.Engine.services.github import GitHubService
|
from Data.Engine.services.github import GitHubService
|
||||||
from Data.Engine.services.jobs import SchedulerService
|
from Data.Engine.services.jobs import SchedulerService
|
||||||
@@ -44,6 +45,7 @@ class EngineServiceContainer:
|
|||||||
device_auth: DeviceAuthService
|
device_auth: DeviceAuthService
|
||||||
token_service: TokenService
|
token_service: TokenService
|
||||||
enrollment_service: EnrollmentService
|
enrollment_service: EnrollmentService
|
||||||
|
enrollment_admin_service: EnrollmentAdminService
|
||||||
jwt_service: JWTService
|
jwt_service: JWTService
|
||||||
dpop_validator: DPoPValidator
|
dpop_validator: DPoPValidator
|
||||||
agent_realtime: AgentRealtimeService
|
agent_realtime: AgentRealtimeService
|
||||||
@@ -93,6 +95,12 @@ def build_service_container(
|
|||||||
logger=log.getChild("enrollment"),
|
logger=log.getChild("enrollment"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
enrollment_admin_service = EnrollmentAdminService(
|
||||||
|
repository=enrollment_repo,
|
||||||
|
user_repository=user_repo,
|
||||||
|
logger=log.getChild("enrollment_admin"),
|
||||||
|
)
|
||||||
|
|
||||||
device_auth = DeviceAuthService(
|
device_auth = DeviceAuthService(
|
||||||
device_repository=device_repo,
|
device_repository=device_repo,
|
||||||
jwt_service=jwt_service,
|
jwt_service=jwt_service,
|
||||||
@@ -139,6 +147,7 @@ def build_service_container(
|
|||||||
device_auth=device_auth,
|
device_auth=device_auth,
|
||||||
token_service=token_service,
|
token_service=token_service,
|
||||||
enrollment_service=enrollment_service,
|
enrollment_service=enrollment_service,
|
||||||
|
enrollment_admin_service=enrollment_admin_service,
|
||||||
jwt_service=jwt_service,
|
jwt_service=jwt_service,
|
||||||
dpop_validator=dpop_validator,
|
dpop_validator=dpop_validator,
|
||||||
agent_realtime=agent_realtime,
|
agent_realtime=agent_realtime,
|
||||||
|
|||||||
@@ -2,20 +2,54 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .enrollment_service import (
|
from importlib import import_module
|
||||||
EnrollmentRequestResult,
|
from typing import Any
|
||||||
EnrollmentService,
|
|
||||||
EnrollmentStatus,
|
|
||||||
EnrollmentTokenBundle,
|
|
||||||
PollingResult,
|
|
||||||
)
|
|
||||||
from Data.Engine.domain.device_enrollment import EnrollmentValidationError
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EnrollmentRequestResult",
|
|
||||||
"EnrollmentService",
|
"EnrollmentService",
|
||||||
|
"EnrollmentRequestResult",
|
||||||
"EnrollmentStatus",
|
"EnrollmentStatus",
|
||||||
"EnrollmentTokenBundle",
|
"EnrollmentTokenBundle",
|
||||||
"EnrollmentValidationError",
|
|
||||||
"PollingResult",
|
"PollingResult",
|
||||||
|
"EnrollmentValidationError",
|
||||||
|
"EnrollmentAdminService",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_LAZY: dict[str, tuple[str, str]] = {
|
||||||
|
"EnrollmentService": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentService"),
|
||||||
|
"EnrollmentRequestResult": (
|
||||||
|
"Data.Engine.services.enrollment.enrollment_service",
|
||||||
|
"EnrollmentRequestResult",
|
||||||
|
),
|
||||||
|
"EnrollmentStatus": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentStatus"),
|
||||||
|
"EnrollmentTokenBundle": (
|
||||||
|
"Data.Engine.services.enrollment.enrollment_service",
|
||||||
|
"EnrollmentTokenBundle",
|
||||||
|
),
|
||||||
|
"PollingResult": ("Data.Engine.services.enrollment.enrollment_service", "PollingResult"),
|
||||||
|
"EnrollmentValidationError": (
|
||||||
|
"Data.Engine.domain.device_enrollment",
|
||||||
|
"EnrollmentValidationError",
|
||||||
|
),
|
||||||
|
"EnrollmentAdminService": (
|
||||||
|
"Data.Engine.services.enrollment.admin_service",
|
||||||
|
"EnrollmentAdminService",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
try:
|
||||||
|
module_name, attribute = _LAZY[name]
|
||||||
|
except KeyError as exc: # pragma: no cover - defensive
|
||||||
|
raise AttributeError(name) from exc
|
||||||
|
|
||||||
|
module = import_module(module_name)
|
||||||
|
value = getattr(module, attribute)
|
||||||
|
globals()[name] = value
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def __dir__() -> list[str]: # pragma: no cover - interactive helper
|
||||||
|
return sorted(set(__all__))
|
||||||
|
|
||||||
|
|||||||
113
Data/Engine/services/enrollment/admin_service.py
Normal file
113
Data/Engine/services/enrollment/admin_service.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""Administrative helpers for enrollment workflows."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
from Data.Engine.domain.enrollment_admin import DeviceApprovalRecord, EnrollmentCodeRecord
|
||||||
|
from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository
|
||||||
|
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
|
||||||
|
|
||||||
|
__all__ = ["EnrollmentAdminService"]
|
||||||
|
|
||||||
|
|
||||||
|
class EnrollmentAdminService:
|
||||||
|
"""Expose administrative enrollment operations."""
|
||||||
|
|
||||||
|
_VALID_TTL_HOURS = {1, 3, 6, 12, 24}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
repository: SQLiteEnrollmentRepository,
|
||||||
|
user_repository: SQLiteUserRepository,
|
||||||
|
logger: Optional[logging.Logger] = None,
|
||||||
|
clock: Optional[Callable[[], datetime]] = None,
|
||||||
|
) -> None:
|
||||||
|
self._repository = repository
|
||||||
|
self._users = user_repository
|
||||||
|
self._log = logger or logging.getLogger("borealis.engine.services.enrollment_admin")
|
||||||
|
self._clock = clock or (lambda: datetime.now(tz=timezone.utc))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Enrollment install codes
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def list_install_codes(self, *, status: Optional[str] = None) -> List[EnrollmentCodeRecord]:
|
||||||
|
return self._repository.list_install_codes(status=status, now=self._clock())
|
||||||
|
|
||||||
|
def create_install_code(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ttl_hours: int,
|
||||||
|
max_uses: int,
|
||||||
|
created_by: Optional[str],
|
||||||
|
) -> EnrollmentCodeRecord:
|
||||||
|
if ttl_hours not in self._VALID_TTL_HOURS:
|
||||||
|
raise ValueError("invalid_ttl")
|
||||||
|
|
||||||
|
normalized_max = self._normalize_max_uses(max_uses)
|
||||||
|
|
||||||
|
now = self._clock()
|
||||||
|
expires_at = now + timedelta(hours=ttl_hours)
|
||||||
|
record_id = str(uuid.uuid4())
|
||||||
|
code = self._generate_install_code()
|
||||||
|
|
||||||
|
created_by_identifier = None
|
||||||
|
if created_by:
|
||||||
|
created_by_identifier = self._users.resolve_identifier(created_by)
|
||||||
|
if not created_by_identifier:
|
||||||
|
created_by_identifier = created_by.strip() or None
|
||||||
|
|
||||||
|
record = self._repository.insert_install_code(
|
||||||
|
record_id=record_id,
|
||||||
|
code=code,
|
||||||
|
expires_at=expires_at,
|
||||||
|
created_by=created_by_identifier,
|
||||||
|
max_uses=normalized_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._log.info(
|
||||||
|
"install code created id=%s ttl=%sh max_uses=%s",
|
||||||
|
record.record_id,
|
||||||
|
ttl_hours,
|
||||||
|
normalized_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
return record
|
||||||
|
|
||||||
|
def delete_install_code(self, record_id: str) -> bool:
|
||||||
|
deleted = self._repository.delete_install_code_if_unused(record_id)
|
||||||
|
if deleted:
|
||||||
|
self._log.info("install code deleted id=%s", record_id)
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Device approvals
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def list_device_approvals(self, *, status: Optional[str] = None) -> List[DeviceApprovalRecord]:
|
||||||
|
return self._repository.list_device_approvals(status=status)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
@staticmethod
|
||||||
|
def _generate_install_code() -> str:
|
||||||
|
raw = secrets.token_hex(16).upper()
|
||||||
|
return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_max_uses(value: int) -> int:
|
||||||
|
try:
|
||||||
|
count = int(value)
|
||||||
|
except Exception:
|
||||||
|
count = 2
|
||||||
|
if count < 1:
|
||||||
|
return 1
|
||||||
|
if count > 10:
|
||||||
|
return 10
|
||||||
|
return count
|
||||||
|
|
||||||
122
Data/Engine/tests/test_enrollment_admin_service.py
Normal file
122
Data/Engine/tests/test_enrollment_admin_service.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import base64
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from Data.Engine.repositories.sqlite import connection as sqlite_connection
|
||||||
|
from Data.Engine.repositories.sqlite import migrations as sqlite_migrations
|
||||||
|
from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository
|
||||||
|
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
|
||||||
|
from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService
|
||||||
|
|
||||||
|
|
||||||
|
def _build_service(tmp_path):
|
||||||
|
db_path = tmp_path / "admin.db"
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
sqlite_migrations.apply_all(conn)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
factory = sqlite_connection.connection_factory(db_path)
|
||||||
|
enrollment_repo = SQLiteEnrollmentRepository(factory)
|
||||||
|
user_repo = SQLiteUserRepository(factory)
|
||||||
|
|
||||||
|
fixed_now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
service = EnrollmentAdminService(
|
||||||
|
repository=enrollment_repo,
|
||||||
|
user_repository=user_repo,
|
||||||
|
clock=lambda: fixed_now,
|
||||||
|
)
|
||||||
|
return service, factory, fixed_now
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_and_list_install_codes(tmp_path):
|
||||||
|
service, factory, fixed_now = _build_service(tmp_path)
|
||||||
|
|
||||||
|
record = service.create_install_code(ttl_hours=3, max_uses=5, created_by="admin")
|
||||||
|
assert record.code
|
||||||
|
assert record.max_uses == 5
|
||||||
|
assert record.status(now=fixed_now) == "active"
|
||||||
|
|
||||||
|
records = service.list_install_codes()
|
||||||
|
assert any(r.record_id == record.record_id for r in records)
|
||||||
|
|
||||||
|
# Invalid TTL should raise
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
service.create_install_code(ttl_hours=2, max_uses=1, created_by=None)
|
||||||
|
|
||||||
|
# Deleting should succeed and remove the record
|
||||||
|
assert service.delete_install_code(record.record_id) is True
|
||||||
|
remaining = service.list_install_codes()
|
||||||
|
assert all(r.record_id != record.record_id for r in remaining)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_device_approvals_includes_conflict(tmp_path):
|
||||||
|
service, factory, fixed_now = _build_service(tmp_path)
|
||||||
|
|
||||||
|
conn = factory()
|
||||||
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"INSERT INTO sites (name, description, created_at) VALUES (?, ?, ?)",
|
||||||
|
("HQ", "Primary site", int(fixed_now.timestamp())),
|
||||||
|
)
|
||||||
|
site_id = cur.lastrowid
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO devices (guid, hostname, created_at, last_seen, ssl_key_fingerprint, status)
|
||||||
|
VALUES (?, ?, ?, ?, ?, 'active')
|
||||||
|
""",
|
||||||
|
("11111111-1111-1111-1111-111111111111", "agent-one", int(fixed_now.timestamp()), int(fixed_now.timestamp()), "abc123",),
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)",
|
||||||
|
("agent-one", site_id, int(fixed_now.timestamp())),
|
||||||
|
)
|
||||||
|
|
||||||
|
now_iso = fixed_now.isoformat()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO device_approvals (
|
||||||
|
id,
|
||||||
|
approval_reference,
|
||||||
|
guid,
|
||||||
|
hostname_claimed,
|
||||||
|
ssl_key_fingerprint_claimed,
|
||||||
|
enrollment_code_id,
|
||||||
|
status,
|
||||||
|
client_nonce,
|
||||||
|
server_nonce,
|
||||||
|
created_at,
|
||||||
|
updated_at,
|
||||||
|
approved_by_user_id,
|
||||||
|
agent_pubkey_der
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
"approval-1",
|
||||||
|
"REF123",
|
||||||
|
None,
|
||||||
|
"agent-one",
|
||||||
|
"abc123",
|
||||||
|
"code-1",
|
||||||
|
"pending",
|
||||||
|
base64.b64encode(b"client").decode(),
|
||||||
|
base64.b64encode(b"server").decode(),
|
||||||
|
now_iso,
|
||||||
|
now_iso,
|
||||||
|
None,
|
||||||
|
b"pubkey",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
approvals = service.list_device_approvals()
|
||||||
|
assert len(approvals) == 1
|
||||||
|
record = approvals[0]
|
||||||
|
assert record.hostname_conflict is not None
|
||||||
|
assert record.hostname_conflict.fingerprint_match is True
|
||||||
|
assert record.conflict_requires_prompt is False
|
||||||
|
|
||||||
111
Data/Engine/tests/test_http_admin.py
Normal file
111
Data/Engine/tests/test_http_admin.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
import base64
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from .test_http_auth import _login, prepared_app
|
||||||
|
|
||||||
|
|
||||||
|
def test_enrollment_codes_require_authentication(prepared_app):
|
||||||
|
client = prepared_app.test_client()
|
||||||
|
resp = client.get("/api/admin/enrollment-codes")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_enrollment_code_workflow(prepared_app):
|
||||||
|
client = prepared_app.test_client()
|
||||||
|
_login(client)
|
||||||
|
|
||||||
|
payload = {"ttl_hours": 3, "max_uses": 4}
|
||||||
|
resp = client.post("/api/admin/enrollment-codes", json=payload)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
created = resp.get_json()
|
||||||
|
assert created["max_uses"] == 4
|
||||||
|
assert created["status"] == "active"
|
||||||
|
|
||||||
|
resp = client.get("/api/admin/enrollment-codes")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
codes = resp.get_json().get("codes", [])
|
||||||
|
assert any(code["id"] == created["id"] for code in codes)
|
||||||
|
|
||||||
|
resp = client.delete(f"/api/admin/enrollment-codes/{created['id']}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_device_approvals_listing(prepared_app, engine_settings):
|
||||||
|
client = prepared_app.test_client()
|
||||||
|
_login(client)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(engine_settings.database.path)
|
||||||
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
cur.execute(
|
||||||
|
"INSERT INTO sites (name, description, created_at) VALUES (?, ?, ?)",
|
||||||
|
("HQ", "Primary", int(now.timestamp())),
|
||||||
|
)
|
||||||
|
site_id = cur.lastrowid
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO devices (guid, hostname, created_at, last_seen, ssl_key_fingerprint, status)
|
||||||
|
VALUES (?, ?, ?, ?, ?, 'active')
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
"22222222-2222-2222-2222-222222222222",
|
||||||
|
"approval-host",
|
||||||
|
int(now.timestamp()),
|
||||||
|
int(now.timestamp()),
|
||||||
|
"deadbeef",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)",
|
||||||
|
("approval-host", site_id, int(now.timestamp())),
|
||||||
|
)
|
||||||
|
|
||||||
|
now_iso = now.isoformat()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO device_approvals (
|
||||||
|
id,
|
||||||
|
approval_reference,
|
||||||
|
guid,
|
||||||
|
hostname_claimed,
|
||||||
|
ssl_key_fingerprint_claimed,
|
||||||
|
enrollment_code_id,
|
||||||
|
status,
|
||||||
|
client_nonce,
|
||||||
|
server_nonce,
|
||||||
|
created_at,
|
||||||
|
updated_at,
|
||||||
|
approved_by_user_id,
|
||||||
|
agent_pubkey_der
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
"approval-http",
|
||||||
|
"REFHTTP",
|
||||||
|
None,
|
||||||
|
"approval-host",
|
||||||
|
"deadbeef",
|
||||||
|
"code-http",
|
||||||
|
"pending",
|
||||||
|
base64.b64encode(b"client").decode(),
|
||||||
|
base64.b64encode(b"server").decode(),
|
||||||
|
now_iso,
|
||||||
|
now_iso,
|
||||||
|
None,
|
||||||
|
b"pub",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
resp = client.get("/api/admin/device-approvals")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.get_json()
|
||||||
|
approvals = body.get("approvals", [])
|
||||||
|
assert any(a["id"] == "approval-http" for a in approvals)
|
||||||
|
record = next(a for a in approvals if a["id"] == "approval-http")
|
||||||
|
assert record.get("hostname_conflict", {}).get("fingerprint_match") is True
|
||||||
|
|
||||||
Reference in New Issue
Block a user