mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 15:41:58 -06:00
250 lines
7.4 KiB
Python
250 lines
7.4 KiB
Python
"""Domain primitives for device authentication and token validation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Any, Mapping, Optional
|
|
import string
|
|
import uuid
|
|
|
|
__all__ = [
|
|
"DeviceGuid",
|
|
"DeviceFingerprint",
|
|
"DeviceIdentity",
|
|
"DeviceStatus",
|
|
"DeviceAuthErrorCode",
|
|
"DeviceAuthFailure",
|
|
"AccessTokenClaims",
|
|
"DeviceAuthContext",
|
|
"sanitize_service_context",
|
|
"normalize_guid",
|
|
]
|
|
|
|
|
|
def _normalize_guid(value: Optional[str]) -> str:
|
|
"""Return a canonical GUID string or an empty string."""
|
|
candidate = (value or "").strip()
|
|
if not candidate:
|
|
return ""
|
|
candidate = candidate.strip("{}")
|
|
try:
|
|
return str(uuid.UUID(candidate)).upper()
|
|
except Exception:
|
|
cleaned = "".join(
|
|
ch for ch in candidate if ch in string.hexdigits or ch == "-"
|
|
).strip("-")
|
|
if cleaned:
|
|
try:
|
|
return str(uuid.UUID(cleaned)).upper()
|
|
except Exception:
|
|
pass
|
|
return candidate.upper()
|
|
|
|
|
|
def _normalize_fingerprint(value: Optional[str]) -> str:
|
|
return (value or "").strip().lower()
|
|
|
|
|
|
def sanitize_service_context(value: Optional[str]) -> Optional[str]:
|
|
"""Normalize the optional agent service context header value."""
|
|
if not value:
|
|
return None
|
|
cleaned = "".join(
|
|
ch for ch in str(value) if ch.isalnum() or ch in ("_", "-")
|
|
)
|
|
if not cleaned:
|
|
return None
|
|
return cleaned.upper()
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class DeviceGuid:
|
|
"""Canonical GUID wrapper that enforces Borealis normalization."""
|
|
|
|
value: str
|
|
|
|
def __post_init__(self) -> None: # pragma: no cover - simple data normalization
|
|
normalized = _normalize_guid(self.value)
|
|
if not normalized:
|
|
raise ValueError("device GUID is required")
|
|
object.__setattr__(self, "value", normalized)
|
|
|
|
def __str__(self) -> str:
|
|
return self.value
|
|
|
|
|
|
def normalize_guid(value: Optional[str]) -> str:
|
|
"""Expose GUID normalization for administrative helpers."""
|
|
|
|
return _normalize_guid(value)
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class DeviceFingerprint:
|
|
"""Normalized TLS key fingerprint associated with a device."""
|
|
|
|
value: str
|
|
|
|
def __post_init__(self) -> None: # pragma: no cover - simple data normalization
|
|
normalized = _normalize_fingerprint(self.value)
|
|
if not normalized:
|
|
raise ValueError("device fingerprint is required")
|
|
object.__setattr__(self, "value", normalized)
|
|
|
|
def __str__(self) -> str:
|
|
return self.value
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class DeviceIdentity:
|
|
"""Immutable pairing of device GUID and TLS key fingerprint."""
|
|
|
|
guid: DeviceGuid
|
|
fingerprint: DeviceFingerprint
|
|
|
|
|
|
class DeviceStatus(str, Enum):
|
|
"""Lifecycle markers mirrored from the legacy devices table."""
|
|
|
|
ACTIVE = "active"
|
|
QUARANTINED = "quarantined"
|
|
REVOKED = "revoked"
|
|
DECOMMISSIONED = "decommissioned"
|
|
|
|
@classmethod
|
|
def from_string(cls, value: Optional[str]) -> "DeviceStatus":
|
|
normalized = (value or "active").strip().lower()
|
|
try:
|
|
return cls(normalized)
|
|
except ValueError:
|
|
return cls.ACTIVE
|
|
|
|
@property
|
|
def allows_access(self) -> bool:
|
|
return self in {self.ACTIVE, self.QUARANTINED}
|
|
|
|
|
|
class DeviceAuthErrorCode(str, Enum):
|
|
"""Well-known authentication failure categories."""
|
|
|
|
MISSING_AUTHORIZATION = "missing_authorization"
|
|
TOKEN_EXPIRED = "token_expired"
|
|
INVALID_TOKEN = "invalid_token"
|
|
INVALID_CLAIMS = "invalid_claims"
|
|
RATE_LIMITED = "rate_limited"
|
|
DEVICE_NOT_FOUND = "device_not_found"
|
|
DEVICE_GUID_MISMATCH = "device_guid_mismatch"
|
|
FINGERPRINT_MISMATCH = "fingerprint_mismatch"
|
|
TOKEN_VERSION_REVOKED = "token_version_revoked"
|
|
DEVICE_REVOKED = "device_revoked"
|
|
DPOP_NOT_SUPPORTED = "dpop_not_supported"
|
|
DPOP_REPLAYED = "dpop_replayed"
|
|
DPOP_INVALID = "dpop_invalid"
|
|
|
|
|
|
class DeviceAuthFailure(Exception):
|
|
"""Domain-level authentication error with HTTP metadata."""
|
|
|
|
def __init__(
|
|
self,
|
|
code: DeviceAuthErrorCode,
|
|
*,
|
|
http_status: int = 401,
|
|
retry_after: Optional[float] = None,
|
|
detail: Optional[str] = None,
|
|
) -> None:
|
|
self.code = code
|
|
self.http_status = int(http_status)
|
|
self.retry_after = retry_after
|
|
self.detail = detail or code.value
|
|
super().__init__(self.detail)
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
payload: dict[str, Any] = {"error": self.code.value}
|
|
if self.retry_after is not None:
|
|
payload["retry_after"] = float(self.retry_after)
|
|
if self.detail and self.detail != self.code.value:
|
|
payload["detail"] = self.detail
|
|
return payload
|
|
|
|
|
|
def _coerce_int(value: Any, *, minimum: Optional[int] = None) -> int:
|
|
try:
|
|
result = int(value)
|
|
except (TypeError, ValueError):
|
|
raise ValueError("expected integer value") from None
|
|
if minimum is not None and result < minimum:
|
|
raise ValueError("integer below minimum")
|
|
return result
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class AccessTokenClaims:
|
|
"""Validated subset of JWT claims issued to a device."""
|
|
|
|
subject: str
|
|
guid: DeviceGuid
|
|
fingerprint: DeviceFingerprint
|
|
token_version: int
|
|
issued_at: int
|
|
not_before: int
|
|
expires_at: int
|
|
raw: Mapping[str, Any]
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.token_version <= 0:
|
|
raise ValueError("token_version must be positive")
|
|
if self.issued_at <= 0 or self.not_before <= 0 or self.expires_at <= 0:
|
|
raise ValueError("temporal claims must be positive integers")
|
|
if self.expires_at <= self.not_before:
|
|
raise ValueError("token expiration must be after not-before")
|
|
|
|
@classmethod
|
|
def from_mapping(cls, claims: Mapping[str, Any]) -> "AccessTokenClaims":
|
|
subject = str(claims.get("sub") or "").strip()
|
|
if not subject:
|
|
raise ValueError("missing token subject")
|
|
guid = DeviceGuid(str(claims.get("guid") or ""))
|
|
fingerprint = DeviceFingerprint(claims.get("ssl_key_fingerprint"))
|
|
token_version = _coerce_int(claims.get("token_version"), minimum=1)
|
|
issued_at = _coerce_int(claims.get("iat"), minimum=1)
|
|
not_before = _coerce_int(claims.get("nbf"), minimum=1)
|
|
expires_at = _coerce_int(claims.get("exp"), minimum=1)
|
|
return cls(
|
|
subject=subject,
|
|
guid=guid,
|
|
fingerprint=fingerprint,
|
|
token_version=token_version,
|
|
issued_at=issued_at,
|
|
not_before=not_before,
|
|
expires_at=expires_at,
|
|
raw=dict(claims),
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class DeviceAuthContext:
|
|
"""Domain result emitted after successful authentication."""
|
|
|
|
identity: DeviceIdentity
|
|
access_token: str
|
|
claims: AccessTokenClaims
|
|
status: DeviceStatus
|
|
service_context: Optional[str]
|
|
dpop_jkt: Optional[str] = None
|
|
|
|
def __post_init__(self) -> None:
|
|
if not self.access_token:
|
|
raise ValueError("access token is required")
|
|
service = sanitize_service_context(self.service_context)
|
|
object.__setattr__(self, "service_context", service)
|
|
|
|
@property
|
|
def is_quarantined(self) -> bool:
|
|
return self.status is DeviceStatus.QUARANTINED
|
|
|
|
@property
|
|
def allows_access(self) -> bool:
|
|
return self.status.allows_access
|