""" JWT access-token helpers backed by an Ed25519 signing key. """ from __future__ import annotations import hashlib import time from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, Optional import jwt from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ed25519 _KEY_DIR = Path(__file__).resolve().parent.parent / "keys" _KEY_FILE = _KEY_DIR / "borealis-jwt-ed25519.key" class JWTService: def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str): self._private_key = private_key self._public_key = private_key.public_key() self._key_id = key_id @property def key_id(self) -> str: return self._key_id def issue_access_token( self, guid: str, ssl_key_fingerprint: str, token_version: int, expires_in: int = 900, extra_claims: Optional[Dict[str, Any]] = None, ) -> str: now = int(time.time()) payload: Dict[str, Any] = { "sub": f"device:{guid}", "guid": guid, "ssl_key_fingerprint": ssl_key_fingerprint, "token_version": int(token_version), "iat": now, "nbf": now, "exp": now + int(expires_in), } if extra_claims: payload.update(extra_claims) token = jwt.encode( payload, self._private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ), algorithm="EdDSA", headers={"kid": self._key_id}, ) return token def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]: options = {"require": ["exp", "iat", "sub"]} public_pem = self._public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) return jwt.decode( token, public_pem, algorithms=["EdDSA"], audience=audience, options=options, ) def public_jwk(self) -> Dict[str, Any]: public_bytes = self._public_key.public_bytes( encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw, ) # PyJWT expects base64url without padding. jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii") return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x} def load_service() -> JWTService: private_key = _load_or_create_private_key() public_bytes = private_key.public_key().public_bytes( encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) key_id = hashlib.sha256(public_bytes).hexdigest()[:16] return JWTService(private_key, key_id) def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey: _KEY_DIR.mkdir(parents=True, exist_ok=True) if _KEY_FILE.exists(): with _KEY_FILE.open("rb") as fh: return serialization.load_pem_private_key(fh.read(), password=None) private_key = ed25519.Ed25519PrivateKey.generate() pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) with _KEY_FILE.open("wb") as fh: fh.write(pem) try: if _KEY_FILE.exists() and hasattr(_KEY_FILE, "chmod"): _KEY_FILE.chmod(0o600) except Exception: pass return private_key