Files
Borealis-Github-Replica/Data/Server/Modules/auth/device_auth.py
2025-10-17 19:11:23 -06:00

174 lines
5.7 KiB
Python

from __future__ import annotations
import functools
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional
import jwt
from flask import g, jsonify, request
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
from Modules.auth.rate_limit import SlidingWindowRateLimiter
@dataclass
class DeviceAuthContext:
guid: str
ssl_key_fingerprint: str
token_version: int
access_token: str
claims: Dict[str, Any]
dpop_jkt: Optional[str]
status: str
class DeviceAuthError(Exception):
status_code = 401
error_code = "unauthorized"
def __init__(
self,
message: str = "unauthorized",
*,
status_code: Optional[int] = None,
retry_after: Optional[float] = None,
):
super().__init__(message)
if status_code is not None:
self.status_code = status_code
self.message = message
self.retry_after = retry_after
class DeviceAuthManager:
def __init__(
self,
*,
db_conn_factory: Callable[[], Any],
jwt_service,
dpop_validator: Optional[DPoPValidator],
log: Callable[[str, str], None],
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
) -> None:
self._db_conn_factory = db_conn_factory
self._jwt_service = jwt_service
self._dpop_validator = dpop_validator
self._log = log
self._rate_limiter = rate_limiter
def authenticate(self) -> DeviceAuthContext:
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise DeviceAuthError("missing_authorization")
token = auth_header[len("Bearer ") :].strip()
if not token:
raise DeviceAuthError("missing_authorization")
try:
claims = self._jwt_service.decode(token)
except jwt.ExpiredSignatureError:
raise DeviceAuthError("token_expired")
except Exception:
raise DeviceAuthError("invalid_token")
guid = str(claims.get("guid") or "").strip()
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
token_version = int(claims.get("token_version") or 0)
if not guid or not fingerprint or token_version <= 0:
raise DeviceAuthError("invalid_claims")
if self._rate_limiter:
decision = self._rate_limiter.check(f"fp:{fingerprint}", 60, 60.0)
if not decision.allowed:
raise DeviceAuthError(
"rate_limited",
status_code=429,
retry_after=decision.retry_after,
)
conn = self._db_conn_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE guid = ?
""",
(guid,),
)
row = cur.fetchone()
finally:
conn.close()
if not row:
raise DeviceAuthError("device_not_found", status_code=403)
db_guid, db_fp, db_token_version, status = row
if str(db_guid or "").lower() != guid.lower():
raise DeviceAuthError("device_guid_mismatch", status_code=403)
db_fp = (db_fp or "").lower().strip()
if db_fp and db_fp != fingerprint:
raise DeviceAuthError("fingerprint_mismatch", status_code=403)
if db_token_version and db_token_version > token_version:
raise DeviceAuthError("token_version_revoked", status_code=401)
status_normalized = (status or "active").strip().lower()
allowed_statuses = {"active", "quarantined"}
if status_normalized not in allowed_statuses:
raise DeviceAuthError("device_revoked", status_code=403)
if status_normalized == "quarantined":
self._log("server", f"device {guid} is quarantined; limited access for {request.path}")
dpop_jkt: Optional[str] = None
dpop_proof = request.headers.get("DPoP")
if dpop_proof:
if not self._dpop_validator:
raise DeviceAuthError("dpop_not_supported", status_code=400)
try:
htu = request.url
dpop_jkt = self._dpop_validator.verify(request.method, htu, dpop_proof, token)
except DPoPReplayError:
raise DeviceAuthError("dpop_replayed", status_code=400)
except DPoPVerificationError:
raise DeviceAuthError("dpop_invalid", status_code=400)
ctx = DeviceAuthContext(
guid=guid,
ssl_key_fingerprint=fingerprint,
token_version=token_version,
access_token=token,
claims=claims,
dpop_jkt=dpop_jkt,
status=status_normalized,
)
return ctx
def require_device_auth(manager: DeviceAuthManager):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
ctx = manager.authenticate()
except DeviceAuthError as exc:
response = jsonify({"error": exc.message})
response.status_code = exc.status_code
retry_after = getattr(exc, "retry_after", None)
if retry_after:
try:
response.headers["Retry-After"] = str(max(1, int(retry_after)))
except Exception:
response.headers["Retry-After"] = "1"
return response
g.device_auth = ctx
return func(*args, **kwargs)
return wrapper
return decorator