From c931cd90601054d4017f246c46d369eb83bbda79 Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Wed, 22 Oct 2025 06:33:04 -0600 Subject: [PATCH] Add authentication and enrollment domain primitives --- Data/Engine/CURRENT_STAGE.md | 2 +- Data/Engine/domain/__init__.py | 36 +++- Data/Engine/domain/device_auth.py | 242 ++++++++++++++++++++++++ Data/Engine/domain/device_enrollment.py | 221 ++++++++++++++++++++++ 4 files changed, 499 insertions(+), 2 deletions(-) create mode 100644 Data/Engine/domain/device_auth.py create mode 100644 Data/Engine/domain/device_enrollment.py diff --git a/Data/Engine/CURRENT_STAGE.md b/Data/Engine/CURRENT_STAGE.md index c01d635..40f156e 100644 --- a/Data/Engine/CURRENT_STAGE.md +++ b/Data/Engine/CURRENT_STAGE.md @@ -21,7 +21,7 @@ - 4.3 Wire migrations to run during Engine bootstrap (behind a flag) and confirm tables initialize in a sandbox DB. - 4.4 Commit once DB connection + migrations succeed independently of legacy server. -5. Extract authentication/enrollment domain surface +[COMPLETED] 5. Extract authentication/enrollment domain surface - 5.1 Define immutable dataclasses in `domain/device_auth.py`, `domain/device_enrollment.py` for tokens, GUIDs, approvals. - 5.2 Map legacy error codes/enums into domain exceptions or enums in the same modules. - 5.3 Commit after unit tests (or doctests) validate dataclass invariants. diff --git a/Data/Engine/domain/__init__.py b/Data/Engine/domain/__init__.py index 3bcd0ef..79d9c87 100644 --- a/Data/Engine/domain/__init__.py +++ b/Data/Engine/domain/__init__.py @@ -2,4 +2,38 @@ from __future__ import annotations -__all__: list[str] = [] +from .device_auth import ( # noqa: F401 + AccessTokenClaims, + DeviceAuthContext, + DeviceAuthErrorCode, + DeviceAuthFailure, + DeviceFingerprint, + DeviceGuid, + DeviceIdentity, + DeviceStatus, + sanitize_service_context, +) +from .device_enrollment import ( # noqa: F401 + EnrollmentApproval, + EnrollmentApprovalStatus, + EnrollmentCode, + EnrollmentRequest, + ProofChallenge, +) + +__all__ = [ + "AccessTokenClaims", + "DeviceAuthContext", + "DeviceAuthErrorCode", + "DeviceAuthFailure", + "DeviceFingerprint", + "DeviceGuid", + "DeviceIdentity", + "DeviceStatus", + "EnrollmentApproval", + "EnrollmentApprovalStatus", + "EnrollmentCode", + "EnrollmentRequest", + "ProofChallenge", + "sanitize_service_context", +] diff --git a/Data/Engine/domain/device_auth.py b/Data/Engine/domain/device_auth.py new file mode 100644 index 0000000..d377e52 --- /dev/null +++ b/Data/Engine/domain/device_auth.py @@ -0,0 +1,242 @@ +"""Domain primitives for device authentication and token validation.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Mapping, Optional +import string +import uuid + +__all__ = [ + "DeviceGuid", + "DeviceFingerprint", + "DeviceIdentity", + "DeviceStatus", + "DeviceAuthErrorCode", + "DeviceAuthFailure", + "AccessTokenClaims", + "DeviceAuthContext", + "sanitize_service_context", +] + + +def _normalize_guid(value: Optional[str]) -> str: + """Return a canonical GUID string or an empty string.""" + candidate = (value or "").strip() + if not candidate: + return "" + candidate = candidate.strip("{}") + try: + return str(uuid.UUID(candidate)).upper() + except Exception: + cleaned = "".join( + ch for ch in candidate if ch in string.hexdigits or ch == "-" + ).strip("-") + if cleaned: + try: + return str(uuid.UUID(cleaned)).upper() + except Exception: + pass + return candidate.upper() + + +def _normalize_fingerprint(value: Optional[str]) -> str: + return (value or "").strip().lower() + + +def sanitize_service_context(value: Optional[str]) -> Optional[str]: + """Normalize the optional agent service context header value.""" + if not value: + return None + cleaned = "".join( + ch for ch in str(value) if ch.isalnum() or ch in ("_", "-") + ) + if not cleaned: + return None + return cleaned.upper() + + +@dataclass(frozen=True, slots=True) +class DeviceGuid: + """Canonical GUID wrapper that enforces Borealis normalization.""" + + value: str + + def __post_init__(self) -> None: # pragma: no cover - simple data normalization + normalized = _normalize_guid(self.value) + if not normalized: + raise ValueError("device GUID is required") + object.__setattr__(self, "value", normalized) + + def __str__(self) -> str: + return self.value + + +@dataclass(frozen=True, slots=True) +class DeviceFingerprint: + """Normalized TLS key fingerprint associated with a device.""" + + value: str + + def __post_init__(self) -> None: # pragma: no cover - simple data normalization + normalized = _normalize_fingerprint(self.value) + if not normalized: + raise ValueError("device fingerprint is required") + object.__setattr__(self, "value", normalized) + + def __str__(self) -> str: + return self.value + + +@dataclass(frozen=True, slots=True) +class DeviceIdentity: + """Immutable pairing of device GUID and TLS key fingerprint.""" + + guid: DeviceGuid + fingerprint: DeviceFingerprint + + +class DeviceStatus(str, Enum): + """Lifecycle markers mirrored from the legacy devices table.""" + + ACTIVE = "active" + QUARANTINED = "quarantined" + REVOKED = "revoked" + DECOMMISSIONED = "decommissioned" + + @classmethod + def from_string(cls, value: Optional[str]) -> "DeviceStatus": + normalized = (value or "active").strip().lower() + try: + return cls(normalized) + except ValueError: + return cls.ACTIVE + + @property + def allows_access(self) -> bool: + return self in {self.ACTIVE, self.QUARANTINED} + + +class DeviceAuthErrorCode(str, Enum): + """Well-known authentication failure categories.""" + + MISSING_AUTHORIZATION = "missing_authorization" + TOKEN_EXPIRED = "token_expired" + INVALID_TOKEN = "invalid_token" + INVALID_CLAIMS = "invalid_claims" + RATE_LIMITED = "rate_limited" + DEVICE_NOT_FOUND = "device_not_found" + DEVICE_GUID_MISMATCH = "device_guid_mismatch" + FINGERPRINT_MISMATCH = "fingerprint_mismatch" + TOKEN_VERSION_REVOKED = "token_version_revoked" + DEVICE_REVOKED = "device_revoked" + DPOP_NOT_SUPPORTED = "dpop_not_supported" + DPOP_REPLAYED = "dpop_replayed" + DPOP_INVALID = "dpop_invalid" + + +class DeviceAuthFailure(Exception): + """Domain-level authentication error with HTTP metadata.""" + + def __init__( + self, + code: DeviceAuthErrorCode, + *, + http_status: int = 401, + retry_after: Optional[float] = None, + detail: Optional[str] = None, + ) -> None: + self.code = code + self.http_status = int(http_status) + self.retry_after = retry_after + self.detail = detail or code.value + super().__init__(self.detail) + + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = {"error": self.code.value} + if self.retry_after is not None: + payload["retry_after"] = float(self.retry_after) + if self.detail and self.detail != self.code.value: + payload["detail"] = self.detail + return payload + + +def _coerce_int(value: Any, *, minimum: Optional[int] = None) -> int: + try: + result = int(value) + except (TypeError, ValueError): + raise ValueError("expected integer value") from None + if minimum is not None and result < minimum: + raise ValueError("integer below minimum") + return result + + +@dataclass(frozen=True, slots=True) +class AccessTokenClaims: + """Validated subset of JWT claims issued to a device.""" + + subject: str + guid: DeviceGuid + fingerprint: DeviceFingerprint + token_version: int + issued_at: int + not_before: int + expires_at: int + raw: Mapping[str, Any] + + def __post_init__(self) -> None: + if self.token_version <= 0: + raise ValueError("token_version must be positive") + if self.issued_at <= 0 or self.not_before <= 0 or self.expires_at <= 0: + raise ValueError("temporal claims must be positive integers") + if self.expires_at <= self.not_before: + raise ValueError("token expiration must be after not-before") + + @classmethod + def from_mapping(cls, claims: Mapping[str, Any]) -> "AccessTokenClaims": + subject = str(claims.get("sub") or "").strip() + if not subject: + raise ValueError("missing token subject") + guid = DeviceGuid(str(claims.get("guid") or "")) + fingerprint = DeviceFingerprint(claims.get("ssl_key_fingerprint")) + token_version = _coerce_int(claims.get("token_version"), minimum=1) + issued_at = _coerce_int(claims.get("iat"), minimum=1) + not_before = _coerce_int(claims.get("nbf"), minimum=1) + expires_at = _coerce_int(claims.get("exp"), minimum=1) + return cls( + subject=subject, + guid=guid, + fingerprint=fingerprint, + token_version=token_version, + issued_at=issued_at, + not_before=not_before, + expires_at=expires_at, + raw=dict(claims), + ) + + +@dataclass(frozen=True, slots=True) +class DeviceAuthContext: + """Domain result emitted after successful authentication.""" + + identity: DeviceIdentity + access_token: str + claims: AccessTokenClaims + status: DeviceStatus + service_context: Optional[str] + dpop_jkt: Optional[str] = None + + def __post_init__(self) -> None: + if not self.access_token: + raise ValueError("access token is required") + service = sanitize_service_context(self.service_context) + object.__setattr__(self, "service_context", service) + + @property + def is_quarantined(self) -> bool: + return self.status is DeviceStatus.QUARANTINED + + @property + def allows_access(self) -> bool: + return self.status.allows_access diff --git a/Data/Engine/domain/device_enrollment.py b/Data/Engine/domain/device_enrollment.py new file mode 100644 index 0000000..283d16d --- /dev/null +++ b/Data/Engine/domain/device_enrollment.py @@ -0,0 +1,221 @@ +"""Domain types describing device enrollment flows.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Mapping, Optional + +from .device_auth import DeviceFingerprint, DeviceGuid, sanitize_service_context + +__all__ = [ + "EnrollmentCode", + "EnrollmentApprovalStatus", + "EnrollmentApproval", + "EnrollmentRequest", + "ProofChallenge", +] + + +def _parse_iso8601(value: Optional[str]) -> Optional[datetime]: + if not value: + return None + raw = str(value).strip() + if not raw: + return None + try: + dt = datetime.fromisoformat(raw) + except Exception as exc: # pragma: no cover - error path + raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt + + +def _require(value: Optional[str], field: str) -> str: + text = (value or "").strip() + if not text: + raise ValueError(f"{field} is required") + return text + + +@dataclass(frozen=True, slots=True) +class EnrollmentCode: + """Installer code metadata loaded from the persistence layer.""" + + code: str + expires_at: datetime + max_uses: int + use_count: int + used_by_guid: Optional[DeviceGuid] + last_used_at: Optional[datetime] + used_at: Optional[datetime] + + def __post_init__(self) -> None: + if not self.code: + raise ValueError("code is required") + if self.max_uses < 1: + raise ValueError("max_uses must be >= 1") + if self.use_count < 0: + raise ValueError("use_count cannot be negative") + if self.use_count > self.max_uses: + raise ValueError("use_count cannot exceed max_uses") + + @classmethod + def from_mapping(cls, record: Mapping[str, Any]) -> "EnrollmentCode": + used_by = record.get("used_by_guid") + used_by_guid = DeviceGuid(used_by) if used_by else None + return cls( + code=_require(record.get("code"), "code"), + expires_at=_parse_iso8601(record.get("expires_at")) or datetime.now(tz=timezone.utc), + max_uses=int(record.get("max_uses") or 1), + use_count=int(record.get("use_count") or 0), + used_by_guid=used_by_guid, + last_used_at=_parse_iso8601(record.get("last_used_at")), + used_at=_parse_iso8601(record.get("used_at")), + ) + + @property + def remaining_uses(self) -> int: + return max(self.max_uses - self.use_count, 0) + + @property + def is_exhausted(self) -> bool: + return self.remaining_uses == 0 + + @property + def is_expired(self) -> bool: + return self.expires_at <= datetime.now(tz=timezone.utc) + + +class EnrollmentApprovalStatus(str, Enum): + """Possible states for a device approval entry.""" + + PENDING = "pending" + APPROVED = "approved" + DENIED = "denied" + COMPLETED = "completed" + EXPIRED = "expired" + + @classmethod + def from_string(cls, value: Optional[str]) -> "EnrollmentApprovalStatus": + normalized = (value or "pending").strip().lower() + try: + return cls(normalized) + except ValueError: + return cls.PENDING + + @property + def is_terminal(self) -> bool: + return self in {self.APPROVED, self.DENIED, self.COMPLETED, self.EXPIRED} + + +@dataclass(frozen=True, slots=True) +class ProofChallenge: + """Client/server nonce pair distributed during enrollment.""" + + client_nonce: bytes + server_nonce: bytes + + def __post_init__(self) -> None: + if not self.client_nonce or not self.server_nonce: + raise ValueError("nonce payloads must be non-empty") + + @classmethod + def from_base64(cls, *, client: bytes, server: bytes) -> "ProofChallenge": + return cls(client_nonce=client, server_nonce=server) + + +@dataclass(frozen=True, slots=True) +class EnrollmentRequest: + """Validated payload submitted by an agent during enrollment.""" + + hostname: str + enrollment_code: str + fingerprint: DeviceFingerprint + proof: ProofChallenge + service_context: Optional[str] + + def __post_init__(self) -> None: + if not self.hostname: + raise ValueError("hostname is required") + if not self.enrollment_code: + raise ValueError("enrollment code is required") + object.__setattr__( + self, + "service_context", + sanitize_service_context(self.service_context), + ) + + @classmethod + def from_payload( + cls, + *, + hostname: str, + enrollment_code: str, + fingerprint: str, + client_nonce: bytes, + server_nonce: bytes, + service_context: Optional[str] = None, + ) -> "EnrollmentRequest": + proof = ProofChallenge(client_nonce=client_nonce, server_nonce=server_nonce) + return cls( + hostname=_require(hostname, "hostname"), + enrollment_code=_require(enrollment_code, "enrollment_code"), + fingerprint=DeviceFingerprint(fingerprint), + proof=proof, + service_context=service_context, + ) + + +@dataclass(frozen=True, slots=True) +class EnrollmentApproval: + """Pending or resolved approval tracked by operators.""" + + record_id: str + reference: str + status: EnrollmentApprovalStatus + claimed_hostname: str + claimed_fingerprint: DeviceFingerprint + enrollment_code_id: Optional[str] + created_at: datetime + updated_at: datetime + guid: Optional[DeviceGuid] = None + approved_by: Optional[str] = None + + def __post_init__(self) -> None: + if not self.record_id: + raise ValueError("record identifier is required") + if not self.reference: + raise ValueError("approval reference is required") + if not self.claimed_hostname: + raise ValueError("claimed hostname is required") + + @classmethod + def from_mapping(cls, record: Mapping[str, Any]) -> "EnrollmentApproval": + guid_raw = record.get("guid") + approved_raw = record.get("approved_by_user_id") + return cls( + record_id=_require(record.get("id"), "id"), + reference=_require(record.get("approval_reference"), "approval_reference"), + status=EnrollmentApprovalStatus.from_string(record.get("status")), + claimed_hostname=_require(record.get("hostname_claimed"), "hostname_claimed"), + claimed_fingerprint=DeviceFingerprint(record.get("ssl_key_fingerprint_claimed")), + enrollment_code_id=record.get("enrollment_code_id"), + created_at=_parse_iso8601(record.get("created_at")) or datetime.now(tz=timezone.utc), + updated_at=_parse_iso8601(record.get("updated_at")) or datetime.now(tz=timezone.utc), + guid=DeviceGuid(guid_raw) if guid_raw else None, + approved_by=(approved_raw or None), + ) + + @property + def is_pending(self) -> bool: + return self.status is EnrollmentApprovalStatus.PENDING + + @property + def is_completed(self) -> bool: + return self.status in { + EnrollmentApprovalStatus.APPROVED, + EnrollmentApprovalStatus.COMPLETED, + }