Files
Borealis-Github-Replica/Data/Engine/services/auth/device_auth_service.py

238 lines
7.1 KiB
Python

"""Device authentication service copied from the legacy server stack."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Mapping, Optional, Protocol
import logging
from Data.Engine.builders.device_auth import DeviceAuthRequest
from Data.Engine.domain.device_auth import (
AccessTokenClaims,
DeviceAuthContext,
DeviceAuthErrorCode,
DeviceAuthFailure,
DeviceFingerprint,
DeviceGuid,
DeviceIdentity,
DeviceStatus,
)
__all__ = [
"DeviceAuthService",
"DeviceRecord",
"DPoPValidator",
"DPoPVerificationError",
"DPoPReplayError",
"RateLimiter",
"RateLimitDecision",
"DeviceRepository",
]
class RateLimitDecision(Protocol):
allowed: bool
retry_after: Optional[float]
class RateLimiter(Protocol):
def check(self, key: str, max_requests: int, window_seconds: float) -> RateLimitDecision: # pragma: no cover - protocol
...
class JWTDecoder(Protocol):
def decode(self, token: str) -> Mapping[str, object]: # pragma: no cover - protocol
...
class DPoPValidator(Protocol):
def verify(
self,
method: str,
htu: str,
proof: str,
access_token: Optional[str] = None,
) -> str: # pragma: no cover - protocol
...
class DPoPVerificationError(Exception):
"""Raised when a DPoP proof fails validation."""
class DPoPReplayError(DPoPVerificationError):
"""Raised when a DPoP proof is replayed."""
@dataclass(frozen=True, slots=True)
class DeviceRecord:
"""Snapshot of a device record required for authentication."""
identity: DeviceIdentity
token_version: int
status: DeviceStatus
class DeviceRepository(Protocol):
"""Port that exposes the minimal device persistence operations."""
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol
...
def recover_missing(
self,
guid: DeviceGuid,
fingerprint: DeviceFingerprint,
token_version: int,
service_context: Optional[str],
) -> Optional[DeviceRecord]: # pragma: no cover - protocol
...
class DeviceAuthService:
"""Authenticate devices using access tokens, repositories, and DPoP proofs."""
def __init__(
self,
*,
device_repository: DeviceRepository,
jwt_service: JWTDecoder,
logger: Optional[logging.Logger] = None,
rate_limiter: Optional[RateLimiter] = None,
dpop_validator: Optional[DPoPValidator] = None,
) -> None:
self._repository = device_repository
self._jwt = jwt_service
self._log = logger or logging.getLogger("borealis.engine.auth")
self._rate_limiter = rate_limiter
self._dpop_validator = dpop_validator
def authenticate(self, request: DeviceAuthRequest, *, path: str) -> DeviceAuthContext:
"""Authenticate an access token and return the resulting context."""
claims = self._decode_claims(request.access_token)
rate_limit_key = f"fp:{claims.fingerprint.value}"
if self._rate_limiter is not None:
decision = self._rate_limiter.check(rate_limit_key, 60, 60.0)
if not decision.allowed:
raise DeviceAuthFailure(
DeviceAuthErrorCode.RATE_LIMITED,
http_status=429,
retry_after=decision.retry_after,
)
record = self._repository.fetch_by_guid(claims.guid)
if record is None:
record = self._repository.recover_missing(
claims.guid,
claims.fingerprint,
claims.token_version,
request.service_context,
)
if record is None:
raise DeviceAuthFailure(
DeviceAuthErrorCode.DEVICE_NOT_FOUND,
http_status=403,
)
self._validate_identity(record, claims)
dpop_jkt = self._validate_dpop(request, record, claims)
context = DeviceAuthContext(
identity=record.identity,
access_token=request.access_token,
claims=claims,
status=record.status,
service_context=request.service_context,
dpop_jkt=dpop_jkt,
)
if context.is_quarantined:
self._log.warning(
"device %s is quarantined; limited access for %s",
record.identity.guid,
path,
)
return context
def _decode_claims(self, token: str) -> AccessTokenClaims:
try:
raw_claims = self._jwt.decode(token)
except Exception as exc: # pragma: no cover - defensive fallback
if self._is_expired_signature(exc):
raise DeviceAuthFailure(DeviceAuthErrorCode.TOKEN_EXPIRED) from exc
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN) from exc
try:
return AccessTokenClaims.from_mapping(raw_claims)
except Exception as exc:
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_CLAIMS) from exc
@staticmethod
def _is_expired_signature(exc: Exception) -> bool:
name = exc.__class__.__name__
return name == "ExpiredSignatureError"
def _validate_identity(
self,
record: DeviceRecord,
claims: AccessTokenClaims,
) -> None:
if record.identity.guid.value != claims.guid.value:
raise DeviceAuthFailure(
DeviceAuthErrorCode.DEVICE_GUID_MISMATCH,
http_status=403,
)
if record.identity.fingerprint.value:
if record.identity.fingerprint.value != claims.fingerprint.value:
raise DeviceAuthFailure(
DeviceAuthErrorCode.FINGERPRINT_MISMATCH,
http_status=403,
)
if record.token_version > claims.token_version:
raise DeviceAuthFailure(DeviceAuthErrorCode.TOKEN_VERSION_REVOKED)
if not record.status.allows_access:
raise DeviceAuthFailure(
DeviceAuthErrorCode.DEVICE_REVOKED,
http_status=403,
)
def _validate_dpop(
self,
request: DeviceAuthRequest,
record: DeviceRecord,
claims: AccessTokenClaims,
) -> Optional[str]:
if not request.dpop_proof:
return None
if self._dpop_validator is None:
raise DeviceAuthFailure(
DeviceAuthErrorCode.DPOP_NOT_SUPPORTED,
http_status=400,
)
try:
return self._dpop_validator.verify(
request.http_method,
request.htu,
request.dpop_proof,
request.access_token,
)
except DPoPReplayError as exc:
raise DeviceAuthFailure(
DeviceAuthErrorCode.DPOP_REPLAYED,
http_status=400,
) from exc
except DPoPVerificationError as exc:
raise DeviceAuthFailure(
DeviceAuthErrorCode.DPOP_INVALID,
http_status=400,
) from exc