feat: secure agent auth and heartbeat endpoints

This commit is contained in:
2025-10-17 17:15:02 -06:00
parent 78a5d3d7f9
commit a72bff5e8e
6 changed files with 420 additions and 1 deletions

View File

@@ -0,0 +1,148 @@
from __future__ import annotations
import functools
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional
import jwt
from flask import g, jsonify, request
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
@dataclass
class DeviceAuthContext:
guid: str
ssl_key_fingerprint: str
token_version: int
access_token: str
claims: Dict[str, Any]
dpop_jkt: Optional[str]
status: str
class DeviceAuthError(Exception):
status_code = 401
error_code = "unauthorized"
def __init__(self, message: str = "unauthorized", *, status_code: Optional[int] = None):
super().__init__(message)
if status_code is not None:
self.status_code = status_code
self.message = message
class DeviceAuthManager:
def __init__(
self,
*,
db_conn_factory: Callable[[], Any],
jwt_service,
dpop_validator: Optional[DPoPValidator],
log: Callable[[str, str], None],
) -> None:
self._db_conn_factory = db_conn_factory
self._jwt_service = jwt_service
self._dpop_validator = dpop_validator
self._log = log
def authenticate(self) -> DeviceAuthContext:
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise DeviceAuthError("missing_authorization")
token = auth_header[len("Bearer ") :].strip()
if not token:
raise DeviceAuthError("missing_authorization")
try:
claims = self._jwt_service.decode(token)
except jwt.ExpiredSignatureError:
raise DeviceAuthError("token_expired")
except Exception:
raise DeviceAuthError("invalid_token")
guid = str(claims.get("guid") or "").strip()
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
token_version = int(claims.get("token_version") or 0)
if not guid or not fingerprint or token_version <= 0:
raise DeviceAuthError("invalid_claims")
conn = self._db_conn_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE guid = ?
""",
(guid,),
)
row = cur.fetchone()
finally:
conn.close()
if not row:
raise DeviceAuthError("device_not_found", status_code=403)
db_guid, db_fp, db_token_version, status = row
if str(db_guid or "").lower() != guid.lower():
raise DeviceAuthError("device_guid_mismatch", status_code=403)
db_fp = (db_fp or "").lower().strip()
if db_fp and db_fp != fingerprint:
raise DeviceAuthError("fingerprint_mismatch", status_code=403)
if db_token_version and db_token_version > token_version:
raise DeviceAuthError("token_version_revoked", status_code=401)
status_normalized = (status or "active").strip().lower()
allowed_statuses = {"active", "quarantined"}
if status_normalized not in allowed_statuses:
raise DeviceAuthError("device_revoked", status_code=403)
if status_normalized == "quarantined":
self._log("server", f"device {guid} is quarantined; limited access for {request.path}")
dpop_jkt: Optional[str] = None
dpop_proof = request.headers.get("DPoP")
if dpop_proof:
if not self._dpop_validator:
raise DeviceAuthError("dpop_not_supported", status_code=400)
try:
htu = request.url
dpop_jkt = self._dpop_validator.verify(request.method, htu, dpop_proof, token)
except DPoPReplayError:
raise DeviceAuthError("dpop_replayed", status_code=400)
except DPoPVerificationError:
raise DeviceAuthError("dpop_invalid", status_code=400)
ctx = DeviceAuthContext(
guid=guid,
ssl_key_fingerprint=fingerprint,
token_version=token_version,
access_token=token,
claims=claims,
dpop_jkt=dpop_jkt,
status=status_normalized,
)
return ctx
def require_device_auth(manager: DeviceAuthManager):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
ctx = manager.authenticate()
except DeviceAuthError as exc:
response = jsonify({"error": exc.message})
response.status_code = exc.status_code
return response
g.device_auth = ctx
return func(*args, **kwargs)
return wrapper
return decorator