mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-16 04:05:48 -07:00
ENGINE: Migrated Enrollment Logic
This commit is contained in:
27
Data/Engine/auth/__init__.py
Normal file
27
Data/Engine/auth/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# ======================================================
|
||||
# Data\Engine\auth\__init__.py
|
||||
# Description: Engine-native authentication utilities and helpers decoupled from the legacy server modules.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Authentication utility package for the Borealis Engine."""
|
||||
|
||||
from .jwt_service import JWTService, load_service
|
||||
from .dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
||||
from .rate_limit import SlidingWindowRateLimiter, RateLimitDecision
|
||||
from .device_auth import DeviceAuthManager, DeviceAuthError, DeviceAuthContext, require_device_auth
|
||||
|
||||
__all__ = [
|
||||
"JWTService",
|
||||
"load_service",
|
||||
"DPoPValidator",
|
||||
"DPoPVerificationError",
|
||||
"DPoPReplayError",
|
||||
"SlidingWindowRateLimiter",
|
||||
"RateLimitDecision",
|
||||
"DeviceAuthManager",
|
||||
"DeviceAuthError",
|
||||
"DeviceAuthContext",
|
||||
"require_device_auth",
|
||||
]
|
||||
310
Data/Engine/auth/device_auth.py
Normal file
310
Data/Engine/auth/device_auth.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# ======================================================
|
||||
# Data\Engine\auth\device_auth.py
|
||||
# Description: Engine-native device authentication manager and decorators.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Device authentication helpers for the Borealis Engine runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import sqlite3
|
||||
import time
|
||||
from contextlib import closing
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from flask import g, jsonify, request
|
||||
|
||||
from .dpop import DPoPReplayError, DPoPValidator, DPoPVerificationError
|
||||
from .guid_utils import normalize_guid
|
||||
from .rate_limit import SlidingWindowRateLimiter
|
||||
|
||||
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||
|
||||
|
||||
def _canonical_context(value: Optional[str]) -> Optional[str]:
|
||||
if not value:
|
||||
return None
|
||||
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
|
||||
if not cleaned:
|
||||
return None
|
||||
return cleaned.upper()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceAuthContext:
|
||||
guid: str
|
||||
ssl_key_fingerprint: str
|
||||
token_version: int
|
||||
access_token: str
|
||||
claims: Dict[str, Any]
|
||||
dpop_jkt: Optional[str]
|
||||
status: str
|
||||
service_mode: Optional[str]
|
||||
|
||||
|
||||
class DeviceAuthError(Exception):
|
||||
status_code = 401
|
||||
error_code = "unauthorized"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "unauthorized",
|
||||
*,
|
||||
status_code: Optional[int] = None,
|
||||
retry_after: Optional[float] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
if status_code is not None:
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
class DeviceAuthManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
db_conn_factory: Callable[[], Any],
|
||||
jwt_service,
|
||||
dpop_validator: Optional[DPoPValidator],
|
||||
log: Callable[[str, str, Optional[str]], None],
|
||||
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
|
||||
) -> None:
|
||||
self._db_conn_factory = db_conn_factory
|
||||
self._jwt_service = jwt_service
|
||||
self._dpop_validator = dpop_validator
|
||||
self._log = log
|
||||
self._rate_limiter = rate_limiter
|
||||
|
||||
def authenticate(self) -> DeviceAuthContext:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise DeviceAuthError("missing_authorization")
|
||||
token = auth_header[len("Bearer ") :].strip()
|
||||
if not token:
|
||||
raise DeviceAuthError("missing_authorization")
|
||||
|
||||
try:
|
||||
claims = self._jwt_service.decode(token)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise DeviceAuthError("token_expired")
|
||||
except Exception:
|
||||
raise DeviceAuthError("invalid_token")
|
||||
|
||||
raw_guid = str(claims.get("guid") or "").strip()
|
||||
guid = normalize_guid(raw_guid)
|
||||
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
|
||||
token_version = int(claims.get("token_version") or 0)
|
||||
if not guid or not fingerprint or token_version <= 0:
|
||||
raise DeviceAuthError("invalid_claims")
|
||||
|
||||
if self._rate_limiter:
|
||||
decision = self._rate_limiter.check(f"fp:{fingerprint}", 60, 60.0)
|
||||
if not decision.allowed:
|
||||
raise DeviceAuthError(
|
||||
"rate_limited",
|
||||
status_code=429,
|
||||
retry_after=decision.retry_after,
|
||||
)
|
||||
|
||||
context_label = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||
|
||||
with closing(self._db_conn_factory()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
row = None
|
||||
for candidate in rows or []:
|
||||
candidate_guid = normalize_guid(candidate[0])
|
||||
if candidate_guid == guid:
|
||||
row = candidate
|
||||
break
|
||||
if row is None and rows:
|
||||
row = rows[0]
|
||||
|
||||
if row is None:
|
||||
row = self._recover_device_record(conn, guid, fingerprint, token_version, context_label)
|
||||
|
||||
if row is None:
|
||||
raise DeviceAuthError("device_not_found", status_code=404)
|
||||
|
||||
stored_guid, stored_fingerprint, stored_version, status = row
|
||||
stored_guid = normalize_guid(stored_guid)
|
||||
if stored_guid != guid:
|
||||
raise DeviceAuthError("device_mismatch", status_code=401)
|
||||
if (stored_fingerprint or "").lower().strip() != fingerprint:
|
||||
raise DeviceAuthError("fingerprint_mismatch", status_code=403)
|
||||
if int(stored_version or 0) != token_version:
|
||||
raise DeviceAuthError("token_version_mismatch", status_code=403)
|
||||
|
||||
status_norm = (status or "active").strip().lower()
|
||||
if status_norm in {"revoked", "decommissioned"}:
|
||||
raise DeviceAuthError("device_revoked", status_code=403)
|
||||
|
||||
dpop_proof = request.headers.get("DPoP")
|
||||
jkt = None
|
||||
if dpop_proof and self._dpop_validator:
|
||||
try:
|
||||
jkt = self._dpop_validator.verify(request.method, request.url, dpop_proof, access_token=token)
|
||||
except DPoPReplayError:
|
||||
raise DeviceAuthError("dpop_replayed", status_code=400)
|
||||
except DPoPVerificationError:
|
||||
raise DeviceAuthError("dpop_invalid", status_code=400)
|
||||
|
||||
ctx = DeviceAuthContext(
|
||||
guid=guid,
|
||||
ssl_key_fingerprint=fingerprint,
|
||||
token_version=token_version,
|
||||
access_token=token,
|
||||
claims=claims,
|
||||
dpop_jkt=jkt,
|
||||
status=status_norm,
|
||||
service_mode=context_label,
|
||||
)
|
||||
return ctx
|
||||
|
||||
def _recover_device_record(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
guid: str,
|
||||
fingerprint: str,
|
||||
token_version: int,
|
||||
context_label: Optional[str],
|
||||
) -> Optional[tuple]:
|
||||
"""Attempt to recreate a missing device row for an authenticated token."""
|
||||
|
||||
guid = normalize_guid(guid)
|
||||
fingerprint = (fingerprint or "").strip()
|
||||
if not guid or not fingerprint:
|
||||
return None
|
||||
|
||||
cur = conn.cursor()
|
||||
now_ts = int(time.time())
|
||||
try:
|
||||
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||
except Exception:
|
||||
now_iso = datetime.utcnow().isoformat()
|
||||
|
||||
base_hostname = f"RECOVERED-{guid[:12].upper()}" if guid else "RECOVERED"
|
||||
|
||||
for attempt in range(6):
|
||||
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
|
||||
try:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO devices (
|
||||
guid,
|
||||
hostname,
|
||||
created_at,
|
||||
last_seen,
|
||||
ssl_key_fingerprint,
|
||||
token_version,
|
||||
status,
|
||||
key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
|
||||
""",
|
||||
(
|
||||
guid,
|
||||
hostname,
|
||||
now_ts,
|
||||
now_ts,
|
||||
fingerprint,
|
||||
max(token_version or 1, 1),
|
||||
now_iso,
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
message = str(exc).lower()
|
||||
if "hostname" in message and "unique" in message:
|
||||
continue
|
||||
self._log(
|
||||
"server",
|
||||
f"device auth failed to recover guid={guid} due to integrity error: {exc}",
|
||||
context_label,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
except Exception as exc:
|
||||
self._log(
|
||||
"server",
|
||||
f"device auth unexpected error recovering guid={guid}: {exc}",
|
||||
context_label,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
else:
|
||||
conn.commit()
|
||||
break
|
||||
else:
|
||||
self._log(
|
||||
"server",
|
||||
f"device auth could not recover guid={guid}; hostname collisions persisted",
|
||||
context_label,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(guid,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
self._log(
|
||||
"server",
|
||||
f"device auth recovery for guid={guid} committed but row still missing",
|
||||
context_label,
|
||||
)
|
||||
return row
|
||||
|
||||
|
||||
def require_device_auth(manager: DeviceAuthManager):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
ctx = manager.authenticate()
|
||||
except DeviceAuthError as exc:
|
||||
response = jsonify({"error": exc.message})
|
||||
response.status_code = exc.status_code
|
||||
retry_after = getattr(exc, "retry_after", None)
|
||||
if retry_after:
|
||||
try:
|
||||
response.headers["Retry-After"] = str(max(1, int(retry_after)))
|
||||
except Exception:
|
||||
response.headers["Retry-After"] = "1"
|
||||
return response
|
||||
|
||||
g.device_auth = ctx
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AGENT_CONTEXT_HEADER",
|
||||
"DeviceAuthContext",
|
||||
"DeviceAuthError",
|
||||
"DeviceAuthManager",
|
||||
"require_device_auth",
|
||||
]
|
||||
118
Data/Engine/auth/dpop.py
Normal file
118
Data/Engine/auth/dpop.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# ======================================================
|
||||
# Data\Engine\auth\dpop.py
|
||||
# Description: Engine-side DPoP proof validation helpers with replay protection.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""DPoP proof verification helpers for the Engine runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict, Optional
|
||||
|
||||
import jwt
|
||||
|
||||
_DPOP_MAX_SKEW = 300.0 # seconds
|
||||
|
||||
|
||||
class DPoPVerificationError(Exception):
|
||||
"""Raised when DPoP verification fails for structural reasons."""
|
||||
|
||||
|
||||
class DPoPReplayError(DPoPVerificationError):
|
||||
"""Raised when a DPoP proof replay is detected."""
|
||||
|
||||
|
||||
class DPoPValidator:
|
||||
"""Validate DPoP proofs and track observed JTIs to prevent replay attacks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._observed_jti: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def verify(
|
||||
self,
|
||||
method: str,
|
||||
htu: str,
|
||||
proof: str,
|
||||
access_token: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Verify the presented DPoP proof and return the JWK thumbprint.
|
||||
"""
|
||||
|
||||
if not proof:
|
||||
raise DPoPVerificationError("DPoP proof missing")
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(proof)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP header") from exc
|
||||
|
||||
jwk = header.get("jwk")
|
||||
alg = header.get("alg")
|
||||
if not jwk or not isinstance(jwk, dict):
|
||||
raise DPoPVerificationError("missing jwk in DPoP header")
|
||||
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
|
||||
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
|
||||
|
||||
try:
|
||||
key = jwt.PyJWK(jwk)
|
||||
public_key = key.key
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
|
||||
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
proof,
|
||||
public_key,
|
||||
algorithms=[alg],
|
||||
options={"require": ["htm", "htu", "jti", "iat"]},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP signature") from exc
|
||||
|
||||
htm = claims.get("htm")
|
||||
proof_htu = claims.get("htu")
|
||||
jti = claims.get("jti")
|
||||
iat = claims.get("iat")
|
||||
ath = claims.get("ath")
|
||||
|
||||
if not isinstance(htm, str) or htm.lower() != method.lower():
|
||||
raise DPoPVerificationError("DPoP htm mismatch")
|
||||
if not isinstance(proof_htu, str) or proof_htu != htu:
|
||||
raise DPoPVerificationError("DPoP htu mismatch")
|
||||
if not isinstance(jti, str):
|
||||
raise DPoPVerificationError("DPoP jti missing")
|
||||
if not isinstance(iat, (int, float)):
|
||||
raise DPoPVerificationError("DPoP iat missing")
|
||||
|
||||
now = time.time()
|
||||
if abs(now - float(iat)) > _DPOP_MAX_SKEW:
|
||||
raise DPoPVerificationError("DPoP proof outside allowed skew")
|
||||
|
||||
if ath and access_token:
|
||||
expected_ath = jwt.utils.base64url_encode(
|
||||
hashlib.sha256(access_token.encode("utf-8")).digest()
|
||||
).decode("ascii")
|
||||
if expected_ath != ath:
|
||||
raise DPoPVerificationError("DPoP ath mismatch")
|
||||
|
||||
with self._lock:
|
||||
expiry = self._observed_jti.get(jti)
|
||||
if expiry and expiry > now:
|
||||
raise DPoPReplayError("DPoP proof replay detected")
|
||||
self._observed_jti[jti] = now + _DPOP_MAX_SKEW
|
||||
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
|
||||
for key in stale:
|
||||
self._observed_jti.pop(key, None)
|
||||
|
||||
thumbprint = jwt.PyJWK(jwk).thumbprint()
|
||||
return thumbprint.decode("ascii")
|
||||
|
||||
|
||||
__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"]
|
||||
39
Data/Engine/auth/guid_utils.py
Normal file
39
Data/Engine/auth/guid_utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# ======================================================
|
||||
# Data\Engine\auth\guid_utils.py
|
||||
# Description: GUID normalisation helpers used by Engine authentication flows.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""GUID normalisation helpers for Engine-managed authentication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def normalize_guid(value: Optional[str]) -> str:
|
||||
"""
|
||||
Canonicalise GUID strings so Engine services treat different casings uniformly.
|
||||
"""
|
||||
|
||||
candidate = (value or "").strip()
|
||||
if not candidate:
|
||||
return ""
|
||||
candidate = candidate.strip("{}")
|
||||
try:
|
||||
return str(uuid.UUID(candidate)).upper()
|
||||
except Exception:
|
||||
cleaned = "".join(ch for ch in candidate if ch in string.hexdigits or ch == "-")
|
||||
cleaned = cleaned.strip("-")
|
||||
if cleaned:
|
||||
try:
|
||||
return str(uuid.UUID(cleaned)).upper()
|
||||
except Exception:
|
||||
pass
|
||||
return candidate.upper()
|
||||
|
||||
|
||||
__all__ = ["normalize_guid"]
|
||||
206
Data/Engine/auth/jwt_service.py
Normal file
206
Data/Engine/auth/jwt_service.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# ======================================================
|
||||
# Data\Engine\auth\jwt_service.py
|
||||
# Description: Engine-native JWT access-token helpers with signing key storage under Engine/Data/Auth_Tokens.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""JWT access-token helpers backed by an Engine-managed Ed25519 key."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from ..security.certificates import project_root_path
|
||||
|
||||
_TOKEN_ENV_ROOT = "BOREALIS_ENGINE_AUTH_TOKEN_ROOT"
|
||||
_LEGACY_SERVER_ROOT_ENV = "BOREALIS_SERVER_ROOT"
|
||||
_KEY_FILENAME = "borealis-jwt-ed25519.key"
|
||||
|
||||
|
||||
def _env_path(name: str) -> Optional[Path]:
|
||||
value = os.environ.get(name)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return Path(value).expanduser().resolve()
|
||||
except Exception:
|
||||
try:
|
||||
return Path(value).expanduser()
|
||||
except Exception:
|
||||
return Path(value)
|
||||
|
||||
|
||||
def _engine_runtime_root() -> Path:
|
||||
env = _env_path("BOREALIS_ENGINE_ROOT") or _env_path("BOREALIS_ENGINE_RUNTIME")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
root = project_root_path() / "Engine"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def _token_root() -> Path:
|
||||
env = _env_path(_TOKEN_ENV_ROOT)
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
root = _engine_runtime_root() / "Data" / "Auth_Tokens"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def _legacy_key_paths() -> Dict[str, Path]:
|
||||
project_root = project_root_path()
|
||||
server_root = _env_path(_LEGACY_SERVER_ROOT_ENV) or (project_root / "Server" / "Borealis")
|
||||
candidates = {
|
||||
"auth_keys": server_root / "auth_keys" / _KEY_FILENAME,
|
||||
"keys": server_root / "keys" / _KEY_FILENAME,
|
||||
}
|
||||
return candidates
|
||||
|
||||
|
||||
def _tighten_permissions(path: Path) -> None:
|
||||
try:
|
||||
if os.name != "nt":
|
||||
path.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
_KEY_DIR = _token_root()
|
||||
_KEY_FILE = _KEY_DIR / _KEY_FILENAME
|
||||
|
||||
|
||||
class JWTService:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str):
|
||||
self._private_key = private_key
|
||||
self._public_key = private_key.public_key()
|
||||
self._key_id = key_id
|
||||
|
||||
@property
|
||||
def key_id(self) -> str:
|
||||
return self._key_id
|
||||
|
||||
def issue_access_token(
|
||||
self,
|
||||
guid: str,
|
||||
ssl_key_fingerprint: str,
|
||||
token_version: int,
|
||||
expires_in: int = 900,
|
||||
extra_claims: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
now = int(time.time())
|
||||
payload: Dict[str, Any] = {
|
||||
"sub": f"device:{guid}",
|
||||
"guid": guid,
|
||||
"ssl_key_fingerprint": ssl_key_fingerprint,
|
||||
"token_version": int(token_version),
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": now + int(expires_in),
|
||||
}
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
self._private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
),
|
||||
algorithm="EdDSA",
|
||||
headers={"kid": self._key_id},
|
||||
)
|
||||
return token
|
||||
|
||||
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
|
||||
options = {"require": ["exp", "iat", "sub"]}
|
||||
public_pem = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=["EdDSA"],
|
||||
audience=audience,
|
||||
options=options,
|
||||
)
|
||||
|
||||
def public_jwk(self) -> Dict[str, Any]:
|
||||
public_bytes = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
|
||||
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
|
||||
|
||||
|
||||
def load_service() -> JWTService:
|
||||
private_key = _load_or_create_private_key()
|
||||
public_bytes = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
|
||||
return JWTService(private_key, key_id)
|
||||
|
||||
|
||||
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
|
||||
_KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_migrate_legacy_key_if_present()
|
||||
|
||||
if _KEY_FILE.exists():
|
||||
with _KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
with _KEY_FILE.open("wb") as fh:
|
||||
fh.write(pem)
|
||||
_tighten_permissions(_KEY_FILE)
|
||||
return private_key
|
||||
|
||||
|
||||
def _migrate_legacy_key_if_present() -> None:
|
||||
if _KEY_FILE.exists():
|
||||
return
|
||||
|
||||
legacy_paths = _legacy_key_paths()
|
||||
for legacy_file in legacy_paths.values():
|
||||
if not legacy_file.exists():
|
||||
continue
|
||||
try:
|
||||
legacy_bytes = legacy_file.read_bytes()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
try:
|
||||
_KEY_FILE.write_bytes(legacy_bytes)
|
||||
_tighten_permissions(_KEY_FILE)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
try:
|
||||
legacy_file.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
|
||||
__all__ = ["JWTService", "load_service"]
|
||||
51
Data/Engine/auth/rate_limit.py
Normal file
51
Data/Engine/auth/rate_limit.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# ======================================================
|
||||
# Data\Engine\auth\rate_limit.py
|
||||
# Description: Sliding-window rate limiter for Engine authentication endpoints.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Engine-native in-memory rate limiter suitable for single-process development."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Deque, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitDecision:
|
||||
allowed: bool
|
||||
retry_after: float
|
||||
|
||||
|
||||
class SlidingWindowRateLimiter:
|
||||
"""Simple sliding-window limiter to guard authentication endpoints."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buckets: Dict[str, Deque[float]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
bucket = self._buckets.get(key)
|
||||
if bucket is None:
|
||||
bucket = deque()
|
||||
self._buckets[key] = bucket
|
||||
|
||||
while bucket and now - bucket[0] > window_seconds:
|
||||
bucket.popleft()
|
||||
|
||||
if len(bucket) >= limit:
|
||||
retry_after = max(0.0, window_seconds - (now - bucket[0]))
|
||||
return RateLimitDecision(False, retry_after)
|
||||
|
||||
bucket.append(now)
|
||||
return RateLimitDecision(True, 0.0)
|
||||
|
||||
|
||||
__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"]
|
||||
Reference in New Issue
Block a user