Files
Borealis-Github-Replica/Data/Engine/services/enrollment/enrollment_service.py

488 lines
16 KiB
Python

"""Enrollment workflow orchestration for the Borealis Engine."""
from __future__ import annotations
import base64
import hashlib
import logging
import secrets
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Callable, Optional, Protocol
from cryptography.hazmat.primitives import serialization
from Data.Engine.builders.device_enrollment import EnrollmentRequestInput
from Data.Engine.domain.device_auth import (
DeviceFingerprint,
DeviceGuid,
sanitize_service_context,
)
from Data.Engine.domain.device_enrollment import (
EnrollmentApproval,
EnrollmentApprovalStatus,
EnrollmentCode,
EnrollmentValidationError,
)
from Data.Engine.services.auth.device_auth_service import DeviceRecord
from Data.Engine.services.auth.token_service import JWTIssuer
from Data.Engine.services.enrollment.nonce_cache import NonceCache
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
__all__ = [
"EnrollmentRequestResult",
"EnrollmentService",
"EnrollmentStatus",
"EnrollmentTokenBundle",
"PollingResult",
]
class DeviceRepository(Protocol):
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol
...
def ensure_device_record(
self,
*,
guid: DeviceGuid,
hostname: str,
fingerprint: DeviceFingerprint,
) -> DeviceRecord: # pragma: no cover - protocol
...
def record_device_key(
self,
*,
guid: DeviceGuid,
fingerprint: DeviceFingerprint,
added_at: datetime,
) -> None: # pragma: no cover - protocol
...
class EnrollmentRepository(Protocol):
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
...
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
...
def update_install_code_usage(
self,
record_id: str,
*,
use_count_increment: int,
last_used_at: datetime,
used_by_guid: Optional[DeviceGuid] = None,
mark_first_use: bool = False,
) -> None: # pragma: no cover - protocol
...
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
...
def fetch_pending_approval_by_fingerprint(
self, fingerprint: DeviceFingerprint
) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
...
def update_pending_approval(
self,
record_id: str,
*,
hostname: str,
guid: Optional[DeviceGuid],
enrollment_code_id: Optional[str],
client_nonce_b64: str,
server_nonce_b64: str,
agent_pubkey_der: bytes,
updated_at: datetime,
) -> None: # pragma: no cover - protocol
...
def create_device_approval(
self,
*,
record_id: str,
reference: str,
claimed_hostname: str,
claimed_fingerprint: DeviceFingerprint,
enrollment_code_id: Optional[str],
client_nonce_b64: str,
server_nonce_b64: str,
agent_pubkey_der: bytes,
created_at: datetime,
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
guid: Optional[DeviceGuid] = None,
) -> EnrollmentApproval: # pragma: no cover - protocol
...
def update_device_approval_status(
self,
record_id: str,
*,
status: EnrollmentApprovalStatus,
updated_at: datetime,
approved_by: Optional[str] = None,
guid: Optional[DeviceGuid] = None,
) -> None: # pragma: no cover - protocol
...
class RefreshTokenRepository(Protocol):
def create(
self,
*,
record_id: str,
guid: DeviceGuid,
token_hash: str,
created_at: datetime,
expires_at: Optional[datetime],
) -> None: # pragma: no cover - protocol
...
class ScriptSigner(Protocol):
def public_base64_spki(self) -> str: # pragma: no cover - protocol
...
class EnrollmentStatus(str):
PENDING = "pending"
APPROVED = "approved"
DENIED = "denied"
EXPIRED = "expired"
UNKNOWN = "unknown"
@dataclass(frozen=True, slots=True)
class EnrollmentTokenBundle:
guid: DeviceGuid
access_token: str
refresh_token: str
expires_in: int
token_type: str = "Bearer"
@dataclass(frozen=True, slots=True)
class EnrollmentRequestResult:
status: EnrollmentStatus
server_certificate: str
signing_key: str
approval_reference: Optional[str] = None
server_nonce: Optional[str] = None
poll_after_ms: Optional[int] = None
http_status: int = 200
retry_after: Optional[float] = None
@dataclass(frozen=True, slots=True)
class PollingResult:
status: EnrollmentStatus
http_status: int
poll_after_ms: Optional[int] = None
reason: Optional[str] = None
detail: Optional[str] = None
tokens: Optional[EnrollmentTokenBundle] = None
server_certificate: Optional[str] = None
signing_key: Optional[str] = None
class EnrollmentService:
"""Coordinate the Borealis device enrollment handshake."""
def __init__(
self,
*,
device_repository: DeviceRepository,
enrollment_repository: EnrollmentRepository,
token_repository: RefreshTokenRepository,
jwt_service: JWTIssuer,
tls_bundle_loader: Callable[[], str],
ip_rate_limiter: SlidingWindowRateLimiter,
fingerprint_rate_limiter: SlidingWindowRateLimiter,
nonce_cache: NonceCache,
script_signer: Optional[ScriptSigner] = None,
logger: Optional[logging.Logger] = None,
) -> None:
self._devices = device_repository
self._enrollment = enrollment_repository
self._tokens = token_repository
self._jwt = jwt_service
self._load_tls_bundle = tls_bundle_loader
self._ip_rate_limiter = ip_rate_limiter
self._fp_rate_limiter = fingerprint_rate_limiter
self._nonce_cache = nonce_cache
self._signer = script_signer
self._log = logger or logging.getLogger("borealis.engine.enrollment")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def request_enrollment(
self,
payload: EnrollmentRequestInput,
*,
remote_addr: str,
) -> EnrollmentRequestResult:
context_hint = sanitize_service_context(payload.service_context)
self._log.info(
"enrollment-request ip=%s host=%s code_mask=%s", remote_addr, payload.hostname, self._mask_code(payload.enrollment_code)
)
self._enforce_rate_limit(self._ip_rate_limiter, f"ip:{remote_addr}")
self._enforce_rate_limit(self._fp_rate_limiter, f"fp:{payload.fingerprint.value}")
install_code = self._enrollment.fetch_install_code(payload.enrollment_code)
reuse_guid = self._determine_reuse_guid(install_code, payload.fingerprint)
server_nonce_bytes = secrets.token_bytes(32)
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
now = self._now()
approval = self._enrollment.fetch_pending_approval_by_fingerprint(payload.fingerprint)
if approval:
self._enrollment.update_pending_approval(
approval.record_id,
hostname=payload.hostname,
guid=reuse_guid,
enrollment_code_id=install_code.identifier if install_code else None,
client_nonce_b64=payload.client_nonce_b64,
server_nonce_b64=server_nonce_b64,
agent_pubkey_der=payload.agent_public_key_der,
updated_at=now,
)
approval_reference = approval.reference
else:
record_id = str(uuid.uuid4())
approval_reference = str(uuid.uuid4())
approval = self._enrollment.create_device_approval(
record_id=record_id,
reference=approval_reference,
claimed_hostname=payload.hostname,
claimed_fingerprint=payload.fingerprint,
enrollment_code_id=install_code.identifier if install_code else None,
client_nonce_b64=payload.client_nonce_b64,
server_nonce_b64=server_nonce_b64,
agent_pubkey_der=payload.agent_public_key_der,
created_at=now,
guid=reuse_guid,
)
signing_key = self._signer.public_base64_spki() if self._signer else ""
certificate = self._load_tls_bundle()
return EnrollmentRequestResult(
status=EnrollmentStatus.PENDING,
approval_reference=approval.reference,
server_nonce=server_nonce_b64,
poll_after_ms=3000,
server_certificate=certificate,
signing_key=signing_key,
)
def poll_enrollment(
self,
*,
approval_reference: str,
client_nonce_b64: str,
proof_signature_b64: str,
) -> PollingResult:
if not approval_reference:
raise EnrollmentValidationError("approval_reference_required")
if not client_nonce_b64:
raise EnrollmentValidationError("client_nonce_required")
if not proof_signature_b64:
raise EnrollmentValidationError("proof_sig_required")
approval = self._enrollment.fetch_device_approval_by_reference(approval_reference)
if approval is None:
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=404)
client_nonce = self._decode_base64(client_nonce_b64, "invalid_client_nonce")
server_nonce = self._decode_base64(approval.server_nonce_b64, "server_nonce_invalid")
proof_sig = self._decode_base64(proof_signature_b64, "invalid_proof_sig")
if approval.client_nonce_b64 != client_nonce_b64:
raise EnrollmentValidationError("nonce_mismatch")
self._verify_proof_signature(
approval=approval,
client_nonce=client_nonce,
server_nonce=server_nonce,
signature=proof_sig,
)
status = approval.status
if status is EnrollmentApprovalStatus.PENDING:
return PollingResult(
status=EnrollmentStatus.PENDING,
http_status=200,
poll_after_ms=5000,
)
if status is EnrollmentApprovalStatus.DENIED:
return PollingResult(
status=EnrollmentStatus.DENIED,
http_status=200,
reason="operator_denied",
)
if status is EnrollmentApprovalStatus.EXPIRED:
return PollingResult(status=EnrollmentStatus.EXPIRED, http_status=200)
if status is EnrollmentApprovalStatus.COMPLETED:
return PollingResult(
status=EnrollmentStatus.APPROVED,
http_status=200,
detail="finalized",
)
if status is not EnrollmentApprovalStatus.APPROVED:
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=400)
nonce_key = f"{approval.reference}:{proof_signature_b64}"
if not self._nonce_cache.consume(nonce_key):
raise EnrollmentValidationError("proof_replayed", http_status=409)
token_bundle = self._finalize_approval(approval)
signing_key = self._signer.public_base64_spki() if self._signer else ""
certificate = self._load_tls_bundle()
return PollingResult(
status=EnrollmentStatus.APPROVED,
http_status=200,
tokens=token_bundle,
server_certificate=certificate,
signing_key=signing_key,
)
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
def _enforce_rate_limit(
self,
limiter: SlidingWindowRateLimiter,
key: str,
*,
limit: int = 60,
window_seconds: float = 60.0,
) -> None:
decision = limiter.check(key, limit, window_seconds)
if not decision.allowed:
raise EnrollmentValidationError(
"rate_limited", http_status=429, retry_after=max(decision.retry_after, 1.0)
)
def _determine_reuse_guid(
self,
install_code: Optional[EnrollmentCode],
fingerprint: DeviceFingerprint,
) -> Optional[DeviceGuid]:
if install_code is None:
raise EnrollmentValidationError("invalid_enrollment_code")
if install_code.is_expired:
raise EnrollmentValidationError("invalid_enrollment_code")
if not install_code.is_exhausted:
return None
if not install_code.used_by_guid:
raise EnrollmentValidationError("invalid_enrollment_code")
existing = self._devices.fetch_by_guid(install_code.used_by_guid)
if existing and existing.identity.fingerprint.value == fingerprint.value:
return install_code.used_by_guid
raise EnrollmentValidationError("invalid_enrollment_code")
def _finalize_approval(self, approval: EnrollmentApproval) -> EnrollmentTokenBundle:
now = self._now()
effective_guid = approval.guid or DeviceGuid(str(uuid.uuid4()))
device_record = self._devices.ensure_device_record(
guid=effective_guid,
hostname=approval.claimed_hostname,
fingerprint=approval.claimed_fingerprint,
)
self._devices.record_device_key(
guid=effective_guid,
fingerprint=approval.claimed_fingerprint,
added_at=now,
)
if approval.enrollment_code_id:
code = self._enrollment.fetch_install_code_by_id(approval.enrollment_code_id)
if code is not None:
mark_first = code.used_at is None
self._enrollment.update_install_code_usage(
approval.enrollment_code_id,
use_count_increment=1,
last_used_at=now,
used_by_guid=effective_guid,
mark_first_use=mark_first,
)
refresh_token = secrets.token_urlsafe(48)
refresh_id = str(uuid.uuid4())
expires_at = now + timedelta(days=30)
token_hash = hashlib.sha256(refresh_token.encode("utf-8")).hexdigest()
self._tokens.create(
record_id=refresh_id,
guid=effective_guid,
token_hash=token_hash,
created_at=now,
expires_at=expires_at,
)
access_token = self._jwt.issue_access_token(
effective_guid.value,
device_record.identity.fingerprint.value,
max(device_record.token_version, 1),
)
self._enrollment.update_device_approval_status(
approval.record_id,
status=EnrollmentApprovalStatus.COMPLETED,
updated_at=now,
guid=effective_guid,
)
return EnrollmentTokenBundle(
guid=effective_guid,
access_token=access_token,
refresh_token=refresh_token,
expires_in=900,
)
def _verify_proof_signature(
self,
*,
approval: EnrollmentApproval,
client_nonce: bytes,
server_nonce: bytes,
signature: bytes,
) -> None:
message = server_nonce + approval.reference.encode("utf-8") + client_nonce
try:
public_key = serialization.load_der_public_key(approval.agent_pubkey_der)
except Exception as exc:
raise EnrollmentValidationError("agent_pubkey_invalid") from exc
try:
public_key.verify(signature, message)
except Exception as exc:
raise EnrollmentValidationError("invalid_proof") from exc
@staticmethod
def _decode_base64(value: str, error_code: str) -> bytes:
try:
return base64.b64decode(value, validate=True)
except Exception as exc:
raise EnrollmentValidationError(error_code) from exc
@staticmethod
def _mask_code(code: str) -> str:
trimmed = (code or "").strip()
if len(trimmed) <= 6:
return "***"
return f"{trimmed[:3]}***{trimmed[-3:]}"
@staticmethod
def _now() -> datetime:
return datetime.now(tz=timezone.utc)