mirror of
				https://github.com/bunny-lab-io/Borealis.git
				synced 2025-10-26 15:41:58 -06:00 
			
		
		
		
	
		
			
				
	
	
		
			262 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			262 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Domain types describing device enrollment flows."""
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| from dataclasses import dataclass
 | |
| from datetime import datetime, timezone
 | |
| import base64
 | |
| from enum import Enum
 | |
| from typing import Any, Mapping, Optional
 | |
| 
 | |
| from .device_auth import DeviceFingerprint, DeviceGuid, sanitize_service_context
 | |
| 
 | |
| __all__ = [
 | |
|     "EnrollmentCode",
 | |
|     "EnrollmentApprovalStatus",
 | |
|     "EnrollmentApproval",
 | |
|     "EnrollmentRequest",
 | |
|     "EnrollmentValidationError",
 | |
|     "ProofChallenge",
 | |
| ]
 | |
| 
 | |
| 
 | |
| def _parse_iso8601(value: Optional[str]) -> Optional[datetime]:
 | |
|     if not value:
 | |
|         return None
 | |
|     raw = str(value).strip()
 | |
|     if not raw:
 | |
|         return None
 | |
|     try:
 | |
|         dt = datetime.fromisoformat(raw)
 | |
|     except Exception as exc:  # pragma: no cover - error path
 | |
|         raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc
 | |
|     if dt.tzinfo is None:
 | |
|         return dt.replace(tzinfo=timezone.utc)
 | |
|     return dt
 | |
| 
 | |
| 
 | |
| def _require(value: Optional[str], field: str) -> str:
 | |
|     text = (value or "").strip()
 | |
|     if not text:
 | |
|         raise ValueError(f"{field} is required")
 | |
|     return text
 | |
| 
 | |
| 
 | |
| @dataclass(frozen=True, slots=True)
 | |
| class EnrollmentValidationError(Exception):
 | |
|     """Raised when enrollment input fails validation."""
 | |
| 
 | |
|     code: str
 | |
|     http_status: int = 400
 | |
|     retry_after: Optional[float] = None
 | |
| 
 | |
|     def to_response(self) -> dict[str, object]:
 | |
|         payload: dict[str, object] = {"error": self.code}
 | |
|         if self.retry_after is not None:
 | |
|             payload["retry_after"] = self.retry_after
 | |
|         return payload
 | |
| 
 | |
|     def __str__(self) -> str:  # pragma: no cover - debug helper
 | |
|         return f"{self.code} (status={self.http_status})"
 | |
| 
 | |
| 
 | |
| @dataclass(frozen=True, slots=True)
 | |
| class EnrollmentCode:
 | |
|     """Installer code metadata loaded from the persistence layer."""
 | |
| 
 | |
|     code: str
 | |
|     expires_at: datetime
 | |
|     max_uses: int
 | |
|     use_count: int
 | |
|     used_by_guid: Optional[DeviceGuid]
 | |
|     last_used_at: Optional[datetime]
 | |
|     used_at: Optional[datetime]
 | |
|     record_id: Optional[str] = None
 | |
| 
 | |
|     def __post_init__(self) -> None:
 | |
|         if not self.code:
 | |
|             raise ValueError("code is required")
 | |
|         if self.max_uses < 1:
 | |
|             raise ValueError("max_uses must be >= 1")
 | |
|         if self.use_count < 0:
 | |
|             raise ValueError("use_count cannot be negative")
 | |
|         if self.use_count > self.max_uses:
 | |
|             raise ValueError("use_count cannot exceed max_uses")
 | |
| 
 | |
|     @classmethod
 | |
|     def from_mapping(cls, record: Mapping[str, Any]) -> "EnrollmentCode":
 | |
|         used_by = record.get("used_by_guid")
 | |
|         used_by_guid = DeviceGuid(used_by) if used_by else None
 | |
|         return cls(
 | |
|             code=_require(record.get("code"), "code"),
 | |
|             expires_at=_parse_iso8601(record.get("expires_at")) or datetime.now(tz=timezone.utc),
 | |
|             max_uses=int(record.get("max_uses") or 1),
 | |
|             use_count=int(record.get("use_count") or 0),
 | |
|             used_by_guid=used_by_guid,
 | |
|             last_used_at=_parse_iso8601(record.get("last_used_at")),
 | |
|             used_at=_parse_iso8601(record.get("used_at")),
 | |
|             record_id=str(record.get("id") or "") or None,
 | |
|         )
 | |
| 
 | |
|     @property
 | |
|     def remaining_uses(self) -> int:
 | |
|         return max(self.max_uses - self.use_count, 0)
 | |
| 
 | |
|     @property
 | |
|     def is_exhausted(self) -> bool:
 | |
|         return self.remaining_uses == 0
 | |
| 
 | |
|     @property
 | |
|     def is_expired(self) -> bool:
 | |
|         return self.expires_at <= datetime.now(tz=timezone.utc)
 | |
| 
 | |
|     @property
 | |
|     def identifier(self) -> Optional[str]:
 | |
|         return self.record_id
 | |
| 
 | |
| 
 | |
| class EnrollmentApprovalStatus(str, Enum):
 | |
|     """Possible states for a device approval entry."""
 | |
| 
 | |
|     PENDING = "pending"
 | |
|     APPROVED = "approved"
 | |
|     DENIED = "denied"
 | |
|     COMPLETED = "completed"
 | |
|     EXPIRED = "expired"
 | |
| 
 | |
|     @classmethod
 | |
|     def from_string(cls, value: Optional[str]) -> "EnrollmentApprovalStatus":
 | |
|         normalized = (value or "pending").strip().lower()
 | |
|         try:
 | |
|             return cls(normalized)
 | |
|         except ValueError:
 | |
|             return cls.PENDING
 | |
| 
 | |
|     @property
 | |
|     def is_terminal(self) -> bool:
 | |
|         return self in {self.APPROVED, self.DENIED, self.COMPLETED, self.EXPIRED}
 | |
| 
 | |
| 
 | |
| @dataclass(frozen=True, slots=True)
 | |
| class ProofChallenge:
 | |
|     """Client/server nonce pair distributed during enrollment."""
 | |
| 
 | |
|     client_nonce: bytes
 | |
|     server_nonce: bytes
 | |
| 
 | |
|     def __post_init__(self) -> None:
 | |
|         if not self.client_nonce or not self.server_nonce:
 | |
|             raise ValueError("nonce payloads must be non-empty")
 | |
| 
 | |
|     @classmethod
 | |
|     def from_base64(cls, *, client: bytes, server: bytes) -> "ProofChallenge":
 | |
|         return cls(client_nonce=client, server_nonce=server)
 | |
| 
 | |
| 
 | |
| @dataclass(frozen=True, slots=True)
 | |
| class EnrollmentRequest:
 | |
|     """Validated payload submitted by an agent during enrollment."""
 | |
| 
 | |
|     hostname: str
 | |
|     enrollment_code: str
 | |
|     fingerprint: DeviceFingerprint
 | |
|     proof: ProofChallenge
 | |
|     service_context: Optional[str]
 | |
| 
 | |
|     def __post_init__(self) -> None:
 | |
|         if not self.hostname:
 | |
|             raise ValueError("hostname is required")
 | |
|         if not self.enrollment_code:
 | |
|             raise ValueError("enrollment code is required")
 | |
|         object.__setattr__(
 | |
|             self,
 | |
|             "service_context",
 | |
|             sanitize_service_context(self.service_context),
 | |
|         )
 | |
| 
 | |
|     @classmethod
 | |
|     def from_payload(
 | |
|         cls,
 | |
|         *,
 | |
|         hostname: str,
 | |
|         enrollment_code: str,
 | |
|         fingerprint: str,
 | |
|         client_nonce: bytes,
 | |
|         server_nonce: bytes,
 | |
|         service_context: Optional[str] = None,
 | |
|     ) -> "EnrollmentRequest":
 | |
|         proof = ProofChallenge(client_nonce=client_nonce, server_nonce=server_nonce)
 | |
|         return cls(
 | |
|             hostname=_require(hostname, "hostname"),
 | |
|             enrollment_code=_require(enrollment_code, "enrollment_code"),
 | |
|             fingerprint=DeviceFingerprint(fingerprint),
 | |
|             proof=proof,
 | |
|             service_context=service_context,
 | |
|         )
 | |
| 
 | |
| 
 | |
| @dataclass(frozen=True, slots=True)
 | |
| class EnrollmentApproval:
 | |
|     """Pending or resolved approval tracked by operators."""
 | |
| 
 | |
|     record_id: str
 | |
|     reference: str
 | |
|     status: EnrollmentApprovalStatus
 | |
|     claimed_hostname: str
 | |
|     claimed_fingerprint: DeviceFingerprint
 | |
|     enrollment_code_id: Optional[str]
 | |
|     created_at: datetime
 | |
|     updated_at: datetime
 | |
|     client_nonce_b64: str
 | |
|     server_nonce_b64: str
 | |
|     agent_pubkey_der: bytes
 | |
|     guid: Optional[DeviceGuid] = None
 | |
|     approved_by: Optional[str] = None
 | |
| 
 | |
|     def __post_init__(self) -> None:
 | |
|         if not self.record_id:
 | |
|             raise ValueError("record identifier is required")
 | |
|         if not self.reference:
 | |
|             raise ValueError("approval reference is required")
 | |
|         if not self.claimed_hostname:
 | |
|             raise ValueError("claimed hostname is required")
 | |
| 
 | |
|     @classmethod
 | |
|     def from_mapping(cls, record: Mapping[str, Any]) -> "EnrollmentApproval":
 | |
|         guid_raw = record.get("guid")
 | |
|         approved_raw = record.get("approved_by_user_id")
 | |
|         return cls(
 | |
|             record_id=_require(record.get("id"), "id"),
 | |
|             reference=_require(record.get("approval_reference"), "approval_reference"),
 | |
|             status=EnrollmentApprovalStatus.from_string(record.get("status")),
 | |
|             claimed_hostname=_require(record.get("hostname_claimed"), "hostname_claimed"),
 | |
|             claimed_fingerprint=DeviceFingerprint(record.get("ssl_key_fingerprint_claimed")),
 | |
|             enrollment_code_id=record.get("enrollment_code_id"),
 | |
|             created_at=_parse_iso8601(record.get("created_at")) or datetime.now(tz=timezone.utc),
 | |
|             updated_at=_parse_iso8601(record.get("updated_at")) or datetime.now(tz=timezone.utc),
 | |
|             guid=DeviceGuid(guid_raw) if guid_raw else None,
 | |
|             approved_by=(approved_raw or None),
 | |
|             client_nonce_b64=_require(record.get("client_nonce"), "client_nonce"),
 | |
|             server_nonce_b64=_require(record.get("server_nonce"), "server_nonce"),
 | |
|             agent_pubkey_der=bytes(record.get("agent_pubkey_der") or b""),
 | |
|         )
 | |
| 
 | |
|     @property
 | |
|     def is_pending(self) -> bool:
 | |
|         return self.status is EnrollmentApprovalStatus.PENDING
 | |
| 
 | |
|     @property
 | |
|     def is_completed(self) -> bool:
 | |
|         return self.status in {
 | |
|             EnrollmentApprovalStatus.APPROVED,
 | |
|             EnrollmentApprovalStatus.COMPLETED,
 | |
|         }
 | |
| 
 | |
|     @property
 | |
|     def client_nonce_bytes(self) -> bytes:
 | |
|         return base64.b64decode(self.client_nonce_b64.encode("ascii"), validate=True)
 | |
| 
 | |
|     @property
 | |
|     def server_nonce_bytes(self) -> bytes:
 | |
|         return base64.b64decode(self.server_nonce_b64.encode("ascii"), validate=True)
 |