mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 15:21:57 -06:00
Add authentication and enrollment domain primitives
This commit is contained in:
@@ -21,7 +21,7 @@
|
||||
- 4.3 Wire migrations to run during Engine bootstrap (behind a flag) and confirm tables initialize in a sandbox DB.
|
||||
- 4.4 Commit once DB connection + migrations succeed independently of legacy server.
|
||||
|
||||
5. Extract authentication/enrollment domain surface
|
||||
[COMPLETED] 5. Extract authentication/enrollment domain surface
|
||||
- 5.1 Define immutable dataclasses in `domain/device_auth.py`, `domain/device_enrollment.py` for tokens, GUIDs, approvals.
|
||||
- 5.2 Map legacy error codes/enums into domain exceptions or enums in the same modules.
|
||||
- 5.3 Commit after unit tests (or doctests) validate dataclass invariants.
|
||||
|
||||
@@ -2,4 +2,38 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__: list[str] = []
|
||||
from .device_auth import ( # noqa: F401
|
||||
AccessTokenClaims,
|
||||
DeviceAuthContext,
|
||||
DeviceAuthErrorCode,
|
||||
DeviceAuthFailure,
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
DeviceIdentity,
|
||||
DeviceStatus,
|
||||
sanitize_service_context,
|
||||
)
|
||||
from .device_enrollment import ( # noqa: F401
|
||||
EnrollmentApproval,
|
||||
EnrollmentApprovalStatus,
|
||||
EnrollmentCode,
|
||||
EnrollmentRequest,
|
||||
ProofChallenge,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AccessTokenClaims",
|
||||
"DeviceAuthContext",
|
||||
"DeviceAuthErrorCode",
|
||||
"DeviceAuthFailure",
|
||||
"DeviceFingerprint",
|
||||
"DeviceGuid",
|
||||
"DeviceIdentity",
|
||||
"DeviceStatus",
|
||||
"EnrollmentApproval",
|
||||
"EnrollmentApprovalStatus",
|
||||
"EnrollmentCode",
|
||||
"EnrollmentRequest",
|
||||
"ProofChallenge",
|
||||
"sanitize_service_context",
|
||||
]
|
||||
|
||||
242
Data/Engine/domain/device_auth.py
Normal file
242
Data/Engine/domain/device_auth.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Domain primitives for device authentication and token validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Mapping, Optional
|
||||
import string
|
||||
import uuid
|
||||
|
||||
__all__ = [
|
||||
"DeviceGuid",
|
||||
"DeviceFingerprint",
|
||||
"DeviceIdentity",
|
||||
"DeviceStatus",
|
||||
"DeviceAuthErrorCode",
|
||||
"DeviceAuthFailure",
|
||||
"AccessTokenClaims",
|
||||
"DeviceAuthContext",
|
||||
"sanitize_service_context",
|
||||
]
|
||||
|
||||
|
||||
def _normalize_guid(value: Optional[str]) -> str:
|
||||
"""Return a canonical GUID string or an empty string."""
|
||||
candidate = (value or "").strip()
|
||||
if not candidate:
|
||||
return ""
|
||||
candidate = candidate.strip("{}")
|
||||
try:
|
||||
return str(uuid.UUID(candidate)).upper()
|
||||
except Exception:
|
||||
cleaned = "".join(
|
||||
ch for ch in candidate if ch in string.hexdigits or ch == "-"
|
||||
).strip("-")
|
||||
if cleaned:
|
||||
try:
|
||||
return str(uuid.UUID(cleaned)).upper()
|
||||
except Exception:
|
||||
pass
|
||||
return candidate.upper()
|
||||
|
||||
|
||||
def _normalize_fingerprint(value: Optional[str]) -> str:
|
||||
return (value or "").strip().lower()
|
||||
|
||||
|
||||
def sanitize_service_context(value: Optional[str]) -> Optional[str]:
|
||||
"""Normalize the optional agent service context header value."""
|
||||
if not value:
|
||||
return None
|
||||
cleaned = "".join(
|
||||
ch for ch in str(value) if ch.isalnum() or ch in ("_", "-")
|
||||
)
|
||||
if not cleaned:
|
||||
return None
|
||||
return cleaned.upper()
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceGuid:
|
||||
"""Canonical GUID wrapper that enforces Borealis normalization."""
|
||||
|
||||
value: str
|
||||
|
||||
def __post_init__(self) -> None: # pragma: no cover - simple data normalization
|
||||
normalized = _normalize_guid(self.value)
|
||||
if not normalized:
|
||||
raise ValueError("device GUID is required")
|
||||
object.__setattr__(self, "value", normalized)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceFingerprint:
|
||||
"""Normalized TLS key fingerprint associated with a device."""
|
||||
|
||||
value: str
|
||||
|
||||
def __post_init__(self) -> None: # pragma: no cover - simple data normalization
|
||||
normalized = _normalize_fingerprint(self.value)
|
||||
if not normalized:
|
||||
raise ValueError("device fingerprint is required")
|
||||
object.__setattr__(self, "value", normalized)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceIdentity:
|
||||
"""Immutable pairing of device GUID and TLS key fingerprint."""
|
||||
|
||||
guid: DeviceGuid
|
||||
fingerprint: DeviceFingerprint
|
||||
|
||||
|
||||
class DeviceStatus(str, Enum):
|
||||
"""Lifecycle markers mirrored from the legacy devices table."""
|
||||
|
||||
ACTIVE = "active"
|
||||
QUARANTINED = "quarantined"
|
||||
REVOKED = "revoked"
|
||||
DECOMMISSIONED = "decommissioned"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: Optional[str]) -> "DeviceStatus":
|
||||
normalized = (value or "active").strip().lower()
|
||||
try:
|
||||
return cls(normalized)
|
||||
except ValueError:
|
||||
return cls.ACTIVE
|
||||
|
||||
@property
|
||||
def allows_access(self) -> bool:
|
||||
return self in {self.ACTIVE, self.QUARANTINED}
|
||||
|
||||
|
||||
class DeviceAuthErrorCode(str, Enum):
|
||||
"""Well-known authentication failure categories."""
|
||||
|
||||
MISSING_AUTHORIZATION = "missing_authorization"
|
||||
TOKEN_EXPIRED = "token_expired"
|
||||
INVALID_TOKEN = "invalid_token"
|
||||
INVALID_CLAIMS = "invalid_claims"
|
||||
RATE_LIMITED = "rate_limited"
|
||||
DEVICE_NOT_FOUND = "device_not_found"
|
||||
DEVICE_GUID_MISMATCH = "device_guid_mismatch"
|
||||
FINGERPRINT_MISMATCH = "fingerprint_mismatch"
|
||||
TOKEN_VERSION_REVOKED = "token_version_revoked"
|
||||
DEVICE_REVOKED = "device_revoked"
|
||||
DPOP_NOT_SUPPORTED = "dpop_not_supported"
|
||||
DPOP_REPLAYED = "dpop_replayed"
|
||||
DPOP_INVALID = "dpop_invalid"
|
||||
|
||||
|
||||
class DeviceAuthFailure(Exception):
|
||||
"""Domain-level authentication error with HTTP metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code: DeviceAuthErrorCode,
|
||||
*,
|
||||
http_status: int = 401,
|
||||
retry_after: Optional[float] = None,
|
||||
detail: Optional[str] = None,
|
||||
) -> None:
|
||||
self.code = code
|
||||
self.http_status = int(http_status)
|
||||
self.retry_after = retry_after
|
||||
self.detail = detail or code.value
|
||||
super().__init__(self.detail)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"error": self.code.value}
|
||||
if self.retry_after is not None:
|
||||
payload["retry_after"] = float(self.retry_after)
|
||||
if self.detail and self.detail != self.code.value:
|
||||
payload["detail"] = self.detail
|
||||
return payload
|
||||
|
||||
|
||||
def _coerce_int(value: Any, *, minimum: Optional[int] = None) -> int:
|
||||
try:
|
||||
result = int(value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError("expected integer value") from None
|
||||
if minimum is not None and result < minimum:
|
||||
raise ValueError("integer below minimum")
|
||||
return result
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AccessTokenClaims:
|
||||
"""Validated subset of JWT claims issued to a device."""
|
||||
|
||||
subject: str
|
||||
guid: DeviceGuid
|
||||
fingerprint: DeviceFingerprint
|
||||
token_version: int
|
||||
issued_at: int
|
||||
not_before: int
|
||||
expires_at: int
|
||||
raw: Mapping[str, Any]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.token_version <= 0:
|
||||
raise ValueError("token_version must be positive")
|
||||
if self.issued_at <= 0 or self.not_before <= 0 or self.expires_at <= 0:
|
||||
raise ValueError("temporal claims must be positive integers")
|
||||
if self.expires_at <= self.not_before:
|
||||
raise ValueError("token expiration must be after not-before")
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, claims: Mapping[str, Any]) -> "AccessTokenClaims":
|
||||
subject = str(claims.get("sub") or "").strip()
|
||||
if not subject:
|
||||
raise ValueError("missing token subject")
|
||||
guid = DeviceGuid(str(claims.get("guid") or ""))
|
||||
fingerprint = DeviceFingerprint(claims.get("ssl_key_fingerprint"))
|
||||
token_version = _coerce_int(claims.get("token_version"), minimum=1)
|
||||
issued_at = _coerce_int(claims.get("iat"), minimum=1)
|
||||
not_before = _coerce_int(claims.get("nbf"), minimum=1)
|
||||
expires_at = _coerce_int(claims.get("exp"), minimum=1)
|
||||
return cls(
|
||||
subject=subject,
|
||||
guid=guid,
|
||||
fingerprint=fingerprint,
|
||||
token_version=token_version,
|
||||
issued_at=issued_at,
|
||||
not_before=not_before,
|
||||
expires_at=expires_at,
|
||||
raw=dict(claims),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceAuthContext:
|
||||
"""Domain result emitted after successful authentication."""
|
||||
|
||||
identity: DeviceIdentity
|
||||
access_token: str
|
||||
claims: AccessTokenClaims
|
||||
status: DeviceStatus
|
||||
service_context: Optional[str]
|
||||
dpop_jkt: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.access_token:
|
||||
raise ValueError("access token is required")
|
||||
service = sanitize_service_context(self.service_context)
|
||||
object.__setattr__(self, "service_context", service)
|
||||
|
||||
@property
|
||||
def is_quarantined(self) -> bool:
|
||||
return self.status is DeviceStatus.QUARANTINED
|
||||
|
||||
@property
|
||||
def allows_access(self) -> bool:
|
||||
return self.status.allows_access
|
||||
221
Data/Engine/domain/device_enrollment.py
Normal file
221
Data/Engine/domain/device_enrollment.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Domain types describing device enrollment flows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from .device_auth import DeviceFingerprint, DeviceGuid, sanitize_service_context
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentCode",
|
||||
"EnrollmentApprovalStatus",
|
||||
"EnrollmentApproval",
|
||||
"EnrollmentRequest",
|
||||
"ProofChallenge",
|
||||
]
|
||||
|
||||
|
||||
def _parse_iso8601(value: Optional[str]) -> Optional[datetime]:
|
||||
if not value:
|
||||
return None
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
dt = datetime.fromisoformat(raw)
|
||||
except Exception as exc: # pragma: no cover - error path
|
||||
raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
def _require(value: Optional[str], field: str) -> str:
|
||||
text = (value or "").strip()
|
||||
if not text:
|
||||
raise ValueError(f"{field} is required")
|
||||
return text
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentCode:
|
||||
"""Installer code metadata loaded from the persistence layer."""
|
||||
|
||||
code: str
|
||||
expires_at: datetime
|
||||
max_uses: int
|
||||
use_count: int
|
||||
used_by_guid: Optional[DeviceGuid]
|
||||
last_used_at: Optional[datetime]
|
||||
used_at: Optional[datetime]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.code:
|
||||
raise ValueError("code is required")
|
||||
if self.max_uses < 1:
|
||||
raise ValueError("max_uses must be >= 1")
|
||||
if self.use_count < 0:
|
||||
raise ValueError("use_count cannot be negative")
|
||||
if self.use_count > self.max_uses:
|
||||
raise ValueError("use_count cannot exceed max_uses")
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, record: Mapping[str, Any]) -> "EnrollmentCode":
|
||||
used_by = record.get("used_by_guid")
|
||||
used_by_guid = DeviceGuid(used_by) if used_by else None
|
||||
return cls(
|
||||
code=_require(record.get("code"), "code"),
|
||||
expires_at=_parse_iso8601(record.get("expires_at")) or datetime.now(tz=timezone.utc),
|
||||
max_uses=int(record.get("max_uses") or 1),
|
||||
use_count=int(record.get("use_count") or 0),
|
||||
used_by_guid=used_by_guid,
|
||||
last_used_at=_parse_iso8601(record.get("last_used_at")),
|
||||
used_at=_parse_iso8601(record.get("used_at")),
|
||||
)
|
||||
|
||||
@property
|
||||
def remaining_uses(self) -> int:
|
||||
return max(self.max_uses - self.use_count, 0)
|
||||
|
||||
@property
|
||||
def is_exhausted(self) -> bool:
|
||||
return self.remaining_uses == 0
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return self.expires_at <= datetime.now(tz=timezone.utc)
|
||||
|
||||
|
||||
class EnrollmentApprovalStatus(str, Enum):
|
||||
"""Possible states for a device approval entry."""
|
||||
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
COMPLETED = "completed"
|
||||
EXPIRED = "expired"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: Optional[str]) -> "EnrollmentApprovalStatus":
|
||||
normalized = (value or "pending").strip().lower()
|
||||
try:
|
||||
return cls(normalized)
|
||||
except ValueError:
|
||||
return cls.PENDING
|
||||
|
||||
@property
|
||||
def is_terminal(self) -> bool:
|
||||
return self in {self.APPROVED, self.DENIED, self.COMPLETED, self.EXPIRED}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProofChallenge:
|
||||
"""Client/server nonce pair distributed during enrollment."""
|
||||
|
||||
client_nonce: bytes
|
||||
server_nonce: bytes
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.client_nonce or not self.server_nonce:
|
||||
raise ValueError("nonce payloads must be non-empty")
|
||||
|
||||
@classmethod
|
||||
def from_base64(cls, *, client: bytes, server: bytes) -> "ProofChallenge":
|
||||
return cls(client_nonce=client, server_nonce=server)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentRequest:
|
||||
"""Validated payload submitted by an agent during enrollment."""
|
||||
|
||||
hostname: str
|
||||
enrollment_code: str
|
||||
fingerprint: DeviceFingerprint
|
||||
proof: ProofChallenge
|
||||
service_context: Optional[str]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.hostname:
|
||||
raise ValueError("hostname is required")
|
||||
if not self.enrollment_code:
|
||||
raise ValueError("enrollment code is required")
|
||||
object.__setattr__(
|
||||
self,
|
||||
"service_context",
|
||||
sanitize_service_context(self.service_context),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_payload(
|
||||
cls,
|
||||
*,
|
||||
hostname: str,
|
||||
enrollment_code: str,
|
||||
fingerprint: str,
|
||||
client_nonce: bytes,
|
||||
server_nonce: bytes,
|
||||
service_context: Optional[str] = None,
|
||||
) -> "EnrollmentRequest":
|
||||
proof = ProofChallenge(client_nonce=client_nonce, server_nonce=server_nonce)
|
||||
return cls(
|
||||
hostname=_require(hostname, "hostname"),
|
||||
enrollment_code=_require(enrollment_code, "enrollment_code"),
|
||||
fingerprint=DeviceFingerprint(fingerprint),
|
||||
proof=proof,
|
||||
service_context=service_context,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentApproval:
|
||||
"""Pending or resolved approval tracked by operators."""
|
||||
|
||||
record_id: str
|
||||
reference: str
|
||||
status: EnrollmentApprovalStatus
|
||||
claimed_hostname: str
|
||||
claimed_fingerprint: DeviceFingerprint
|
||||
enrollment_code_id: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
guid: Optional[DeviceGuid] = None
|
||||
approved_by: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.record_id:
|
||||
raise ValueError("record identifier is required")
|
||||
if not self.reference:
|
||||
raise ValueError("approval reference is required")
|
||||
if not self.claimed_hostname:
|
||||
raise ValueError("claimed hostname is required")
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, record: Mapping[str, Any]) -> "EnrollmentApproval":
|
||||
guid_raw = record.get("guid")
|
||||
approved_raw = record.get("approved_by_user_id")
|
||||
return cls(
|
||||
record_id=_require(record.get("id"), "id"),
|
||||
reference=_require(record.get("approval_reference"), "approval_reference"),
|
||||
status=EnrollmentApprovalStatus.from_string(record.get("status")),
|
||||
claimed_hostname=_require(record.get("hostname_claimed"), "hostname_claimed"),
|
||||
claimed_fingerprint=DeviceFingerprint(record.get("ssl_key_fingerprint_claimed")),
|
||||
enrollment_code_id=record.get("enrollment_code_id"),
|
||||
created_at=_parse_iso8601(record.get("created_at")) or datetime.now(tz=timezone.utc),
|
||||
updated_at=_parse_iso8601(record.get("updated_at")) or datetime.now(tz=timezone.utc),
|
||||
guid=DeviceGuid(guid_raw) if guid_raw else None,
|
||||
approved_by=(approved_raw or None),
|
||||
)
|
||||
|
||||
@property
|
||||
def is_pending(self) -> bool:
|
||||
return self.status is EnrollmentApprovalStatus.PENDING
|
||||
|
||||
@property
|
||||
def is_completed(self) -> bool:
|
||||
return self.status in {
|
||||
EnrollmentApprovalStatus.APPROVED,
|
||||
EnrollmentApprovalStatus.COMPLETED,
|
||||
}
|
||||
Reference in New Issue
Block a user