mirror of
				https://github.com/bunny-lab-io/Borealis.git
				synced 2025-10-26 15:41:58 -06:00 
			
		
		
		
	
		
			
				
	
	
		
			106 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			106 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """DPoP proof validation for Engine services."""
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| import hashlib
 | |
| import time
 | |
| from threading import Lock
 | |
| from typing import Dict, Optional
 | |
| 
 | |
| import jwt
 | |
| 
 | |
| __all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"]
 | |
| 
 | |
| 
 | |
| _DP0P_MAX_SKEW = 300.0
 | |
| 
 | |
| 
 | |
| class DPoPVerificationError(Exception):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class DPoPReplayError(DPoPVerificationError):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class DPoPValidator:
 | |
|     def __init__(self) -> None:
 | |
|         self._observed_jti: Dict[str, float] = {}
 | |
|         self._lock = Lock()
 | |
| 
 | |
|     def verify(
 | |
|         self,
 | |
|         method: str,
 | |
|         htu: str,
 | |
|         proof: str,
 | |
|         access_token: Optional[str] = None,
 | |
|     ) -> str:
 | |
|         if not proof:
 | |
|             raise DPoPVerificationError("DPoP proof missing")
 | |
| 
 | |
|         try:
 | |
|             header = jwt.get_unverified_header(proof)
 | |
|         except Exception as exc:
 | |
|             raise DPoPVerificationError("invalid DPoP header") from exc
 | |
| 
 | |
|         jwk = header.get("jwk")
 | |
|         alg = header.get("alg")
 | |
|         if not jwk or not isinstance(jwk, dict):
 | |
|             raise DPoPVerificationError("missing jwk in DPoP header")
 | |
|         if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
 | |
|             raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
 | |
| 
 | |
|         try:
 | |
|             key = jwt.PyJWK(jwk)
 | |
|             public_key = key.key
 | |
|         except Exception as exc:
 | |
|             raise DPoPVerificationError("invalid jwk in DPoP header") from exc
 | |
| 
 | |
|         try:
 | |
|             claims = jwt.decode(
 | |
|                 proof,
 | |
|                 public_key,
 | |
|                 algorithms=[alg],
 | |
|                 options={"require": ["htm", "htu", "jti", "iat"]},
 | |
|             )
 | |
|         except Exception as exc:
 | |
|             raise DPoPVerificationError("invalid DPoP signature") from exc
 | |
| 
 | |
|         htm = claims.get("htm")
 | |
|         proof_htu = claims.get("htu")
 | |
|         jti = claims.get("jti")
 | |
|         iat = claims.get("iat")
 | |
|         ath = claims.get("ath")
 | |
| 
 | |
|         if not isinstance(htm, str) or htm.lower() != method.lower():
 | |
|             raise DPoPVerificationError("DPoP htm mismatch")
 | |
|         if not isinstance(proof_htu, str) or proof_htu != htu:
 | |
|             raise DPoPVerificationError("DPoP htu mismatch")
 | |
|         if not isinstance(jti, str):
 | |
|             raise DPoPVerificationError("DPoP jti missing")
 | |
|         if not isinstance(iat, (int, float)):
 | |
|             raise DPoPVerificationError("DPoP iat missing")
 | |
| 
 | |
|         now = time.time()
 | |
|         if abs(now - float(iat)) > _DP0P_MAX_SKEW:
 | |
|             raise DPoPVerificationError("DPoP proof outside allowed skew")
 | |
| 
 | |
|         if ath and access_token:
 | |
|             expected_ath = jwt.utils.base64url_encode(
 | |
|                 hashlib.sha256(access_token.encode("utf-8")).digest()
 | |
|             ).decode("ascii")
 | |
|             if expected_ath != ath:
 | |
|                 raise DPoPVerificationError("DPoP ath mismatch")
 | |
| 
 | |
|         with self._lock:
 | |
|             expiry = self._observed_jti.get(jti)
 | |
|             if expiry and expiry > now:
 | |
|                 raise DPoPReplayError("DPoP proof replay detected")
 | |
|             self._observed_jti[jti] = now + _DP0P_MAX_SKEW
 | |
|             stale = [key for key, exp in self._observed_jti.items() if exp <= now]
 | |
|             for key in stale:
 | |
|                 self._observed_jti.pop(key, None)
 | |
| 
 | |
|         thumbprint = jwt.PyJWK(jwk).thumbprint()
 | |
|         return thumbprint.decode("ascii")
 |