mirror of
				https://github.com/bunny-lab-io/Borealis.git
				synced 2025-10-26 17:41:58 -06:00 
			
		
		
		
	
		
			
				
	
	
		
			488 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			488 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Enrollment workflow orchestration for the Borealis Engine."""
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| import base64
 | |
| import hashlib
 | |
| import logging
 | |
| import secrets
 | |
| import uuid
 | |
| from dataclasses import dataclass
 | |
| from datetime import datetime, timedelta, timezone
 | |
| from typing import Callable, Optional, Protocol
 | |
| 
 | |
| from cryptography.hazmat.primitives import serialization
 | |
| 
 | |
| from Data.Engine.builders.device_enrollment import EnrollmentRequestInput
 | |
| from Data.Engine.domain.device_auth import (
 | |
|     DeviceFingerprint,
 | |
|     DeviceGuid,
 | |
|     sanitize_service_context,
 | |
| )
 | |
| from Data.Engine.domain.device_enrollment import (
 | |
|     EnrollmentApproval,
 | |
|     EnrollmentApprovalStatus,
 | |
|     EnrollmentCode,
 | |
| )
 | |
| from Data.Engine.services.auth.device_auth_service import DeviceRecord
 | |
| from Data.Engine.services.auth.token_service import JWTIssuer
 | |
| from Data.Engine.services.enrollment.errors import EnrollmentValidationError
 | |
| from Data.Engine.services.enrollment.nonce_cache import NonceCache
 | |
| from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
 | |
| 
 | |
| __all__ = [
 | |
|     "EnrollmentRequestResult",
 | |
|     "EnrollmentService",
 | |
|     "EnrollmentStatus",
 | |
|     "EnrollmentTokenBundle",
 | |
|     "PollingResult",
 | |
| ]
 | |
| 
 | |
| 
 | |
| class DeviceRepository(Protocol):
 | |
|     def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
|     def ensure_device_record(
 | |
|         self,
 | |
|         *,
 | |
|         guid: DeviceGuid,
 | |
|         hostname: str,
 | |
|         fingerprint: DeviceFingerprint,
 | |
|     ) -> DeviceRecord:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
|     def record_device_key(
 | |
|         self,
 | |
|         *,
 | |
|         guid: DeviceGuid,
 | |
|         fingerprint: DeviceFingerprint,
 | |
|         added_at: datetime,
 | |
|     ) -> None:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
| 
 | |
| class EnrollmentRepository(Protocol):
 | |
|     def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
|     def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
|     def update_install_code_usage(
 | |
|         self,
 | |
|         record_id: str,
 | |
|         *,
 | |
|         use_count_increment: int,
 | |
|         last_used_at: datetime,
 | |
|         used_by_guid: Optional[DeviceGuid] = None,
 | |
|         mark_first_use: bool = False,
 | |
|     ) -> None:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
|     def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
|     def fetch_pending_approval_by_fingerprint(
 | |
|         self, fingerprint: DeviceFingerprint
 | |
|     ) -> Optional[EnrollmentApproval]:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
|     def update_pending_approval(
 | |
|         self,
 | |
|         record_id: str,
 | |
|         *,
 | |
|         hostname: str,
 | |
|         guid: Optional[DeviceGuid],
 | |
|         enrollment_code_id: Optional[str],
 | |
|         client_nonce_b64: str,
 | |
|         server_nonce_b64: str,
 | |
|         agent_pubkey_der: bytes,
 | |
|         updated_at: datetime,
 | |
|     ) -> None:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
|     def create_device_approval(
 | |
|         self,
 | |
|         *,
 | |
|         record_id: str,
 | |
|         reference: str,
 | |
|         claimed_hostname: str,
 | |
|         claimed_fingerprint: DeviceFingerprint,
 | |
|         enrollment_code_id: Optional[str],
 | |
|         client_nonce_b64: str,
 | |
|         server_nonce_b64: str,
 | |
|         agent_pubkey_der: bytes,
 | |
|         created_at: datetime,
 | |
|         status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
 | |
|         guid: Optional[DeviceGuid] = None,
 | |
|     ) -> EnrollmentApproval:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
|     def update_device_approval_status(
 | |
|         self,
 | |
|         record_id: str,
 | |
|         *,
 | |
|         status: EnrollmentApprovalStatus,
 | |
|         updated_at: datetime,
 | |
|         approved_by: Optional[str] = None,
 | |
|         guid: Optional[DeviceGuid] = None,
 | |
|     ) -> None:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
| 
 | |
| class RefreshTokenRepository(Protocol):
 | |
|     def create(
 | |
|         self,
 | |
|         *,
 | |
|         record_id: str,
 | |
|         guid: DeviceGuid,
 | |
|         token_hash: str,
 | |
|         created_at: datetime,
 | |
|         expires_at: Optional[datetime],
 | |
|     ) -> None:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
| 
 | |
| class ScriptSigner(Protocol):
 | |
|     def public_base64_spki(self) -> str:  # pragma: no cover - protocol
 | |
|         ...
 | |
| 
 | |
| 
 | |
| class EnrollmentStatus(str):
 | |
|     PENDING = "pending"
 | |
|     APPROVED = "approved"
 | |
|     DENIED = "denied"
 | |
|     EXPIRED = "expired"
 | |
|     UNKNOWN = "unknown"
 | |
| 
 | |
| 
 | |
| @dataclass(frozen=True, slots=True)
 | |
| class EnrollmentTokenBundle:
 | |
|     guid: DeviceGuid
 | |
|     access_token: str
 | |
|     refresh_token: str
 | |
|     expires_in: int
 | |
|     token_type: str = "Bearer"
 | |
| 
 | |
| 
 | |
| @dataclass(frozen=True, slots=True)
 | |
| class EnrollmentRequestResult:
 | |
|     status: EnrollmentStatus
 | |
|     approval_reference: Optional[str] = None
 | |
|     server_nonce: Optional[str] = None
 | |
|     poll_after_ms: Optional[int] = None
 | |
|     server_certificate: str
 | |
|     signing_key: str
 | |
|     http_status: int = 200
 | |
|     retry_after: Optional[float] = None
 | |
| 
 | |
| 
 | |
| @dataclass(frozen=True, slots=True)
 | |
| class PollingResult:
 | |
|     status: EnrollmentStatus
 | |
|     http_status: int
 | |
|     poll_after_ms: Optional[int] = None
 | |
|     reason: Optional[str] = None
 | |
|     detail: Optional[str] = None
 | |
|     tokens: Optional[EnrollmentTokenBundle] = None
 | |
|     server_certificate: Optional[str] = None
 | |
|     signing_key: Optional[str] = None
 | |
| 
 | |
| 
 | |
| class EnrollmentService:
 | |
|     """Coordinate the Borealis device enrollment handshake."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         *,
 | |
|         device_repository: DeviceRepository,
 | |
|         enrollment_repository: EnrollmentRepository,
 | |
|         token_repository: RefreshTokenRepository,
 | |
|         jwt_service: JWTIssuer,
 | |
|         tls_bundle_loader: Callable[[], str],
 | |
|         ip_rate_limiter: SlidingWindowRateLimiter,
 | |
|         fingerprint_rate_limiter: SlidingWindowRateLimiter,
 | |
|         nonce_cache: NonceCache,
 | |
|         script_signer: Optional[ScriptSigner] = None,
 | |
|         logger: Optional[logging.Logger] = None,
 | |
|     ) -> None:
 | |
|         self._devices = device_repository
 | |
|         self._enrollment = enrollment_repository
 | |
|         self._tokens = token_repository
 | |
|         self._jwt = jwt_service
 | |
|         self._load_tls_bundle = tls_bundle_loader
 | |
|         self._ip_rate_limiter = ip_rate_limiter
 | |
|         self._fp_rate_limiter = fingerprint_rate_limiter
 | |
|         self._nonce_cache = nonce_cache
 | |
|         self._signer = script_signer
 | |
|         self._log = logger or logging.getLogger("borealis.engine.enrollment")
 | |
| 
 | |
|     # ------------------------------------------------------------------
 | |
|     # Public API
 | |
|     # ------------------------------------------------------------------
 | |
|     def request_enrollment(
 | |
|         self,
 | |
|         payload: EnrollmentRequestInput,
 | |
|         *,
 | |
|         remote_addr: str,
 | |
|     ) -> EnrollmentRequestResult:
 | |
|         context_hint = sanitize_service_context(payload.service_context)
 | |
|         self._log.info(
 | |
|             "enrollment-request ip=%s host=%s code_mask=%s", remote_addr, payload.hostname, self._mask_code(payload.enrollment_code)
 | |
|         )
 | |
| 
 | |
|         self._enforce_rate_limit(self._ip_rate_limiter, f"ip:{remote_addr}")
 | |
|         self._enforce_rate_limit(self._fp_rate_limiter, f"fp:{payload.fingerprint.value}")
 | |
| 
 | |
|         install_code = self._enrollment.fetch_install_code(payload.enrollment_code)
 | |
|         reuse_guid = self._determine_reuse_guid(install_code, payload.fingerprint)
 | |
| 
 | |
|         server_nonce_bytes = secrets.token_bytes(32)
 | |
|         server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
 | |
| 
 | |
|         now = self._now()
 | |
|         approval = self._enrollment.fetch_pending_approval_by_fingerprint(payload.fingerprint)
 | |
|         if approval:
 | |
|             self._enrollment.update_pending_approval(
 | |
|                 approval.record_id,
 | |
|                 hostname=payload.hostname,
 | |
|                 guid=reuse_guid,
 | |
|                 enrollment_code_id=install_code.identifier if install_code else None,
 | |
|                 client_nonce_b64=payload.client_nonce_b64,
 | |
|                 server_nonce_b64=server_nonce_b64,
 | |
|                 agent_pubkey_der=payload.agent_public_key_der,
 | |
|                 updated_at=now,
 | |
|             )
 | |
|             approval_reference = approval.reference
 | |
|         else:
 | |
|             record_id = str(uuid.uuid4())
 | |
|             approval_reference = str(uuid.uuid4())
 | |
|             approval = self._enrollment.create_device_approval(
 | |
|                 record_id=record_id,
 | |
|                 reference=approval_reference,
 | |
|                 claimed_hostname=payload.hostname,
 | |
|                 claimed_fingerprint=payload.fingerprint,
 | |
|                 enrollment_code_id=install_code.identifier if install_code else None,
 | |
|                 client_nonce_b64=payload.client_nonce_b64,
 | |
|                 server_nonce_b64=server_nonce_b64,
 | |
|                 agent_pubkey_der=payload.agent_public_key_der,
 | |
|                 created_at=now,
 | |
|                 guid=reuse_guid,
 | |
|             )
 | |
| 
 | |
|         signing_key = self._signer.public_base64_spki() if self._signer else ""
 | |
|         certificate = self._load_tls_bundle()
 | |
| 
 | |
|         return EnrollmentRequestResult(
 | |
|             status=EnrollmentStatus.PENDING,
 | |
|             approval_reference=approval.reference,
 | |
|             server_nonce=server_nonce_b64,
 | |
|             poll_after_ms=3000,
 | |
|             server_certificate=certificate,
 | |
|             signing_key=signing_key,
 | |
|         )
 | |
| 
 | |
|     def poll_enrollment(
 | |
|         self,
 | |
|         *,
 | |
|         approval_reference: str,
 | |
|         client_nonce_b64: str,
 | |
|         proof_signature_b64: str,
 | |
|     ) -> PollingResult:
 | |
|         if not approval_reference:
 | |
|             raise EnrollmentValidationError("approval_reference_required")
 | |
|         if not client_nonce_b64:
 | |
|             raise EnrollmentValidationError("client_nonce_required")
 | |
|         if not proof_signature_b64:
 | |
|             raise EnrollmentValidationError("proof_sig_required")
 | |
| 
 | |
|         approval = self._enrollment.fetch_device_approval_by_reference(approval_reference)
 | |
|         if approval is None:
 | |
|             return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=404)
 | |
| 
 | |
|         client_nonce = self._decode_base64(client_nonce_b64, "invalid_client_nonce")
 | |
|         server_nonce = self._decode_base64(approval.server_nonce_b64, "server_nonce_invalid")
 | |
|         proof_sig = self._decode_base64(proof_signature_b64, "invalid_proof_sig")
 | |
| 
 | |
|         if approval.client_nonce_b64 != client_nonce_b64:
 | |
|             raise EnrollmentValidationError("nonce_mismatch")
 | |
| 
 | |
|         self._verify_proof_signature(
 | |
|             approval=approval,
 | |
|             client_nonce=client_nonce,
 | |
|             server_nonce=server_nonce,
 | |
|             signature=proof_sig,
 | |
|         )
 | |
| 
 | |
|         status = approval.status
 | |
|         if status is EnrollmentApprovalStatus.PENDING:
 | |
|             return PollingResult(
 | |
|                 status=EnrollmentStatus.PENDING,
 | |
|                 http_status=200,
 | |
|                 poll_after_ms=5000,
 | |
|             )
 | |
|         if status is EnrollmentApprovalStatus.DENIED:
 | |
|             return PollingResult(
 | |
|                 status=EnrollmentStatus.DENIED,
 | |
|                 http_status=200,
 | |
|                 reason="operator_denied",
 | |
|             )
 | |
|         if status is EnrollmentApprovalStatus.EXPIRED:
 | |
|             return PollingResult(status=EnrollmentStatus.EXPIRED, http_status=200)
 | |
|         if status is EnrollmentApprovalStatus.COMPLETED:
 | |
|             return PollingResult(
 | |
|                 status=EnrollmentStatus.APPROVED,
 | |
|                 http_status=200,
 | |
|                 detail="finalized",
 | |
|             )
 | |
|         if status is not EnrollmentApprovalStatus.APPROVED:
 | |
|             return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=400)
 | |
| 
 | |
|         nonce_key = f"{approval.reference}:{proof_signature_b64}"
 | |
|         if not self._nonce_cache.consume(nonce_key):
 | |
|             raise EnrollmentValidationError("proof_replayed", http_status=409)
 | |
| 
 | |
|         token_bundle = self._finalize_approval(approval)
 | |
|         signing_key = self._signer.public_base64_spki() if self._signer else ""
 | |
|         certificate = self._load_tls_bundle()
 | |
| 
 | |
|         return PollingResult(
 | |
|             status=EnrollmentStatus.APPROVED,
 | |
|             http_status=200,
 | |
|             tokens=token_bundle,
 | |
|             server_certificate=certificate,
 | |
|             signing_key=signing_key,
 | |
|         )
 | |
| 
 | |
|     # ------------------------------------------------------------------
 | |
|     # Internals
 | |
|     # ------------------------------------------------------------------
 | |
|     def _enforce_rate_limit(
 | |
|         self,
 | |
|         limiter: SlidingWindowRateLimiter,
 | |
|         key: str,
 | |
|         *,
 | |
|         limit: int = 60,
 | |
|         window_seconds: float = 60.0,
 | |
|     ) -> None:
 | |
|         decision = limiter.check(key, limit, window_seconds)
 | |
|         if not decision.allowed:
 | |
|             raise EnrollmentValidationError(
 | |
|                 "rate_limited", http_status=429, retry_after=max(decision.retry_after, 1.0)
 | |
|             )
 | |
| 
 | |
|     def _determine_reuse_guid(
 | |
|         self,
 | |
|         install_code: Optional[EnrollmentCode],
 | |
|         fingerprint: DeviceFingerprint,
 | |
|     ) -> Optional[DeviceGuid]:
 | |
|         if install_code is None:
 | |
|             raise EnrollmentValidationError("invalid_enrollment_code")
 | |
|         if install_code.is_expired:
 | |
|             raise EnrollmentValidationError("invalid_enrollment_code")
 | |
|         if not install_code.is_exhausted:
 | |
|             return None
 | |
|         if not install_code.used_by_guid:
 | |
|             raise EnrollmentValidationError("invalid_enrollment_code")
 | |
| 
 | |
|         existing = self._devices.fetch_by_guid(install_code.used_by_guid)
 | |
|         if existing and existing.identity.fingerprint.value == fingerprint.value:
 | |
|             return install_code.used_by_guid
 | |
|         raise EnrollmentValidationError("invalid_enrollment_code")
 | |
| 
 | |
|     def _finalize_approval(self, approval: EnrollmentApproval) -> EnrollmentTokenBundle:
 | |
|         now = self._now()
 | |
|         effective_guid = approval.guid or DeviceGuid(str(uuid.uuid4()))
 | |
|         device_record = self._devices.ensure_device_record(
 | |
|             guid=effective_guid,
 | |
|             hostname=approval.claimed_hostname,
 | |
|             fingerprint=approval.claimed_fingerprint,
 | |
|         )
 | |
|         self._devices.record_device_key(
 | |
|             guid=effective_guid,
 | |
|             fingerprint=approval.claimed_fingerprint,
 | |
|             added_at=now,
 | |
|         )
 | |
| 
 | |
|         if approval.enrollment_code_id:
 | |
|             code = self._enrollment.fetch_install_code_by_id(approval.enrollment_code_id)
 | |
|             if code is not None:
 | |
|                 mark_first = code.used_at is None
 | |
|                 self._enrollment.update_install_code_usage(
 | |
|                     approval.enrollment_code_id,
 | |
|                     use_count_increment=1,
 | |
|                     last_used_at=now,
 | |
|                     used_by_guid=effective_guid,
 | |
|                     mark_first_use=mark_first,
 | |
|                 )
 | |
| 
 | |
|         refresh_token = secrets.token_urlsafe(48)
 | |
|         refresh_id = str(uuid.uuid4())
 | |
|         expires_at = now + timedelta(days=30)
 | |
|         token_hash = hashlib.sha256(refresh_token.encode("utf-8")).hexdigest()
 | |
|         self._tokens.create(
 | |
|             record_id=refresh_id,
 | |
|             guid=effective_guid,
 | |
|             token_hash=token_hash,
 | |
|             created_at=now,
 | |
|             expires_at=expires_at,
 | |
|         )
 | |
| 
 | |
|         access_token = self._jwt.issue_access_token(
 | |
|             effective_guid.value,
 | |
|             device_record.identity.fingerprint.value,
 | |
|             max(device_record.token_version, 1),
 | |
|         )
 | |
| 
 | |
|         self._enrollment.update_device_approval_status(
 | |
|             approval.record_id,
 | |
|             status=EnrollmentApprovalStatus.COMPLETED,
 | |
|             updated_at=now,
 | |
|             guid=effective_guid,
 | |
|         )
 | |
| 
 | |
|         return EnrollmentTokenBundle(
 | |
|             guid=effective_guid,
 | |
|             access_token=access_token,
 | |
|             refresh_token=refresh_token,
 | |
|             expires_in=900,
 | |
|         )
 | |
| 
 | |
|     def _verify_proof_signature(
 | |
|         self,
 | |
|         *,
 | |
|         approval: EnrollmentApproval,
 | |
|         client_nonce: bytes,
 | |
|         server_nonce: bytes,
 | |
|         signature: bytes,
 | |
|     ) -> None:
 | |
|         message = server_nonce + approval.reference.encode("utf-8") + client_nonce
 | |
|         try:
 | |
|             public_key = serialization.load_der_public_key(approval.agent_pubkey_der)
 | |
|         except Exception as exc:
 | |
|             raise EnrollmentValidationError("agent_pubkey_invalid") from exc
 | |
| 
 | |
|         try:
 | |
|             public_key.verify(signature, message)
 | |
|         except Exception as exc:
 | |
|             raise EnrollmentValidationError("invalid_proof") from exc
 | |
| 
 | |
|     @staticmethod
 | |
|     def _decode_base64(value: str, error_code: str) -> bytes:
 | |
|         try:
 | |
|             return base64.b64decode(value, validate=True)
 | |
|         except Exception as exc:
 | |
|             raise EnrollmentValidationError(error_code) from exc
 | |
| 
 | |
|     @staticmethod
 | |
|     def _mask_code(code: str) -> str:
 | |
|         trimmed = (code or "").strip()
 | |
|         if len(trimmed) <= 6:
 | |
|             return "***"
 | |
|         return f"{trimmed[:3]}***{trimmed[-3:]}"
 | |
| 
 | |
|     @staticmethod
 | |
|     def _now() -> datetime:
 | |
|         return datetime.now(tz=timezone.utc)
 |