mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-15 18:55:48 -07:00
311 lines
10 KiB
Python
311 lines
10 KiB
Python
# ======================================================
|
|
# Data\Engine\auth\device_auth.py
|
|
# Description: Engine-native device authentication manager and decorators.
|
|
#
|
|
# API Endpoints (if applicable): None
|
|
# ======================================================
|
|
|
|
"""Device authentication helpers for the Borealis Engine runtime."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import sqlite3
|
|
import time
|
|
from contextlib import closing
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
import jwt
|
|
from flask import g, jsonify, request
|
|
|
|
from .dpop import DPoPReplayError, DPoPValidator, DPoPVerificationError
|
|
from .guid_utils import normalize_guid
|
|
from .rate_limit import SlidingWindowRateLimiter
|
|
|
|
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
|
|
|
|
|
def _canonical_context(value: Optional[str]) -> Optional[str]:
|
|
if not value:
|
|
return None
|
|
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
|
|
if not cleaned:
|
|
return None
|
|
return cleaned.upper()
|
|
|
|
|
|
@dataclass
|
|
class DeviceAuthContext:
|
|
guid: str
|
|
ssl_key_fingerprint: str
|
|
token_version: int
|
|
access_token: str
|
|
claims: Dict[str, Any]
|
|
dpop_jkt: Optional[str]
|
|
status: str
|
|
service_mode: Optional[str]
|
|
|
|
|
|
class DeviceAuthError(Exception):
|
|
status_code = 401
|
|
error_code = "unauthorized"
|
|
|
|
def __init__(
|
|
self,
|
|
message: str = "unauthorized",
|
|
*,
|
|
status_code: Optional[int] = None,
|
|
retry_after: Optional[float] = None,
|
|
):
|
|
super().__init__(message)
|
|
if status_code is not None:
|
|
self.status_code = status_code
|
|
self.message = message
|
|
self.retry_after = retry_after
|
|
|
|
|
|
class DeviceAuthManager:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
db_conn_factory: Callable[[], Any],
|
|
jwt_service,
|
|
dpop_validator: Optional[DPoPValidator],
|
|
log: Callable[[str, str, Optional[str]], None],
|
|
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
|
|
) -> None:
|
|
self._db_conn_factory = db_conn_factory
|
|
self._jwt_service = jwt_service
|
|
self._dpop_validator = dpop_validator
|
|
self._log = log
|
|
self._rate_limiter = rate_limiter
|
|
|
|
def authenticate(self) -> DeviceAuthContext:
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if not auth_header.startswith("Bearer "):
|
|
raise DeviceAuthError("missing_authorization")
|
|
token = auth_header[len("Bearer ") :].strip()
|
|
if not token:
|
|
raise DeviceAuthError("missing_authorization")
|
|
|
|
try:
|
|
claims = self._jwt_service.decode(token)
|
|
except jwt.ExpiredSignatureError:
|
|
raise DeviceAuthError("token_expired")
|
|
except Exception:
|
|
raise DeviceAuthError("invalid_token")
|
|
|
|
raw_guid = str(claims.get("guid") or "").strip()
|
|
guid = normalize_guid(raw_guid)
|
|
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
|
|
token_version = int(claims.get("token_version") or 0)
|
|
if not guid or not fingerprint or token_version <= 0:
|
|
raise DeviceAuthError("invalid_claims")
|
|
|
|
if self._rate_limiter:
|
|
decision = self._rate_limiter.check(f"fp:{fingerprint}", 60, 60.0)
|
|
if not decision.allowed:
|
|
raise DeviceAuthError(
|
|
"rate_limited",
|
|
status_code=429,
|
|
retry_after=decision.retry_after,
|
|
)
|
|
|
|
context_label = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
|
|
|
with closing(self._db_conn_factory()) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
SELECT guid, ssl_key_fingerprint, token_version, status
|
|
FROM devices
|
|
WHERE UPPER(guid) = ?
|
|
""",
|
|
(guid,),
|
|
)
|
|
rows = cur.fetchall()
|
|
row = None
|
|
for candidate in rows or []:
|
|
candidate_guid = normalize_guid(candidate[0])
|
|
if candidate_guid == guid:
|
|
row = candidate
|
|
break
|
|
if row is None and rows:
|
|
row = rows[0]
|
|
|
|
if row is None:
|
|
row = self._recover_device_record(conn, guid, fingerprint, token_version, context_label)
|
|
|
|
if row is None:
|
|
raise DeviceAuthError("device_not_found", status_code=404)
|
|
|
|
stored_guid, stored_fingerprint, stored_version, status = row
|
|
stored_guid = normalize_guid(stored_guid)
|
|
if stored_guid != guid:
|
|
raise DeviceAuthError("device_mismatch", status_code=401)
|
|
if (stored_fingerprint or "").lower().strip() != fingerprint:
|
|
raise DeviceAuthError("fingerprint_mismatch", status_code=403)
|
|
if int(stored_version or 0) != token_version:
|
|
raise DeviceAuthError("token_version_mismatch", status_code=403)
|
|
|
|
status_norm = (status or "active").strip().lower()
|
|
if status_norm in {"revoked", "decommissioned"}:
|
|
raise DeviceAuthError("device_revoked", status_code=403)
|
|
|
|
dpop_proof = request.headers.get("DPoP")
|
|
jkt = None
|
|
if dpop_proof and self._dpop_validator:
|
|
try:
|
|
jkt = self._dpop_validator.verify(request.method, request.url, dpop_proof, access_token=token)
|
|
except DPoPReplayError:
|
|
raise DeviceAuthError("dpop_replayed", status_code=400)
|
|
except DPoPVerificationError:
|
|
raise DeviceAuthError("dpop_invalid", status_code=400)
|
|
|
|
ctx = DeviceAuthContext(
|
|
guid=guid,
|
|
ssl_key_fingerprint=fingerprint,
|
|
token_version=token_version,
|
|
access_token=token,
|
|
claims=claims,
|
|
dpop_jkt=jkt,
|
|
status=status_norm,
|
|
service_mode=context_label,
|
|
)
|
|
return ctx
|
|
|
|
def _recover_device_record(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
guid: str,
|
|
fingerprint: str,
|
|
token_version: int,
|
|
context_label: Optional[str],
|
|
) -> Optional[tuple]:
|
|
"""Attempt to recreate a missing device row for an authenticated token."""
|
|
|
|
guid = normalize_guid(guid)
|
|
fingerprint = (fingerprint or "").strip()
|
|
if not guid or not fingerprint:
|
|
return None
|
|
|
|
cur = conn.cursor()
|
|
now_ts = int(time.time())
|
|
try:
|
|
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
|
except Exception:
|
|
now_iso = datetime.utcnow().isoformat()
|
|
|
|
base_hostname = f"RECOVERED-{guid[:12].upper()}" if guid else "RECOVERED"
|
|
|
|
for attempt in range(6):
|
|
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
|
|
try:
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO devices (
|
|
guid,
|
|
hostname,
|
|
created_at,
|
|
last_seen,
|
|
ssl_key_fingerprint,
|
|
token_version,
|
|
status,
|
|
key_added_at
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
|
|
""",
|
|
(
|
|
guid,
|
|
hostname,
|
|
now_ts,
|
|
now_ts,
|
|
fingerprint,
|
|
max(token_version or 1, 1),
|
|
now_iso,
|
|
),
|
|
)
|
|
except sqlite3.IntegrityError as exc:
|
|
message = str(exc).lower()
|
|
if "hostname" in message and "unique" in message:
|
|
continue
|
|
self._log(
|
|
"server",
|
|
f"device auth failed to recover guid={guid} due to integrity error: {exc}",
|
|
context_label,
|
|
)
|
|
conn.rollback()
|
|
return None
|
|
except Exception as exc:
|
|
self._log(
|
|
"server",
|
|
f"device auth unexpected error recovering guid={guid}: {exc}",
|
|
context_label,
|
|
)
|
|
conn.rollback()
|
|
return None
|
|
else:
|
|
conn.commit()
|
|
break
|
|
else:
|
|
self._log(
|
|
"server",
|
|
f"device auth could not recover guid={guid}; hostname collisions persisted",
|
|
context_label,
|
|
)
|
|
conn.rollback()
|
|
return None
|
|
|
|
cur.execute(
|
|
"""
|
|
SELECT guid, ssl_key_fingerprint, token_version, status
|
|
FROM devices
|
|
WHERE guid = ?
|
|
""",
|
|
(guid,),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
self._log(
|
|
"server",
|
|
f"device auth recovery for guid={guid} committed but row still missing",
|
|
context_label,
|
|
)
|
|
return row
|
|
|
|
|
|
def require_device_auth(manager: DeviceAuthManager):
|
|
def decorator(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
try:
|
|
ctx = manager.authenticate()
|
|
except DeviceAuthError as exc:
|
|
response = jsonify({"error": exc.message})
|
|
response.status_code = exc.status_code
|
|
retry_after = getattr(exc, "retry_after", None)
|
|
if retry_after:
|
|
try:
|
|
response.headers["Retry-After"] = str(max(1, int(retry_after)))
|
|
except Exception:
|
|
response.headers["Retry-After"] = "1"
|
|
return response
|
|
|
|
g.device_auth = ctx
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
__all__ = [
|
|
"AGENT_CONTEXT_HEADER",
|
|
"DeviceAuthContext",
|
|
"DeviceAuthError",
|
|
"DeviceAuthManager",
|
|
"require_device_auth",
|
|
]
|