From 833c4b7d881426070b9cf2a70deac272d2ff925e Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Wed, 29 Oct 2025 16:40:53 -0600 Subject: [PATCH] ENGINE: Migrated Enrollment Logic --- Data/Engine/CODE_MIGRATION_TRACKER.md | 4 +- Data/Engine/Unit_Tests/test_enrollment_api.py | 2 +- Data/Engine/auth/__init__.py | 27 + Data/Engine/auth/device_auth.py | 310 ++++++++ Data/Engine/auth/dpop.py | 118 +++ Data/Engine/auth/guid_utils.py | 39 + Data/Engine/auth/jwt_service.py | 206 +++++ Data/Engine/auth/rate_limit.py | 51 ++ Data/Engine/crypto/__init__.py | 30 + Data/Engine/crypto/keys.py | 88 +++ Data/Engine/enrollment/__init__.py | 12 + Data/Engine/enrollment/nonce_store.py | 42 + Data/Engine/services/API/__init__.py | 30 +- .../services/API/access_management/login.py | 7 +- .../services/API/assemblies/management.py | 7 +- Data/Engine/services/API/authentication.py | 7 +- Data/Engine/services/API/devices/approval.py | 9 +- .../Engine/services/API/devices/enrollment.py | 8 - .../Engine/services/API/devices/management.py | 13 +- .../services/API/enrollment/__init__.py | 12 + Data/Engine/services/API/enrollment/routes.py | 744 ++++++++++++++++++ Data/Engine/services/API/tokens/__init__.py | 12 + Data/Engine/services/API/tokens/routes.py | 147 ++++ 23 files changed, 1881 insertions(+), 44 deletions(-) create mode 100644 Data/Engine/auth/__init__.py create mode 100644 Data/Engine/auth/device_auth.py create mode 100644 Data/Engine/auth/dpop.py create mode 100644 Data/Engine/auth/guid_utils.py create mode 100644 Data/Engine/auth/jwt_service.py create mode 100644 Data/Engine/auth/rate_limit.py create mode 100644 Data/Engine/crypto/__init__.py create mode 100644 Data/Engine/crypto/keys.py create mode 100644 Data/Engine/enrollment/__init__.py create mode 100644 Data/Engine/enrollment/nonce_store.py delete mode 100644 Data/Engine/services/API/devices/enrollment.py create mode 100644 Data/Engine/services/API/enrollment/__init__.py create mode 100644 Data/Engine/services/API/enrollment/routes.py create mode 100644 Data/Engine/services/API/tokens/__init__.py create mode 100644 Data/Engine/services/API/tokens/routes.py diff --git a/Data/Engine/CODE_MIGRATION_TRACKER.md b/Data/Engine/CODE_MIGRATION_TRACKER.md index c21d8cba..8a2fceb1 100644 --- a/Data/Engine/CODE_MIGRATION_TRACKER.md +++ b/Data/Engine/CODE_MIGRATION_TRACKER.md @@ -38,6 +38,8 @@ Lastly, everytime that you complete a stage, you will create a pull request name - [ ] Add migration switch in the legacy server for WebUI delegation. - [x] Extend tests to cover critical WebUI routes. - [ ] Port device API endpoints into Engine services (device + admin coverage in progress). + - [x] Move authentication/token stack onto Engine services without legacy fallbacks. + - [x] Port enrollment request/poll flows to Engine services and drop legacy imports. - [ ] **Stage 7 — Plan WebSocket migration** - [ ] Extract Socket.IO handlers into Data/Engine/services/WebSocket. - [ ] Provide register_realtime hook for the Engine factory. @@ -46,4 +48,4 @@ Lastly, everytime that you complete a stage, you will create a pull request name ## Current Status - **Stage:** Stage 6 — Plan WebUI migration -- **Active Task:** Migrating device endpoints into the Engine API (legacy bridge removed). +- **Active Task:** Continue Stage 6 device/admin API migration (focus on remaining device and admin endpoints now that auth, token, and enrollment paths are Engine-native). diff --git a/Data/Engine/Unit_Tests/test_enrollment_api.py b/Data/Engine/Unit_Tests/test_enrollment_api.py index d0ebecc7..b59fa7c2 100644 --- a/Data/Engine/Unit_Tests/test_enrollment_api.py +++ b/Data/Engine/Unit_Tests/test_enrollment_api.py @@ -17,7 +17,7 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ed25519 from flask.testing import FlaskClient -from Modules.crypto import keys as crypto_keys +from Data.Engine.crypto import keys as crypto_keys from .conftest import EngineTestHarness diff --git a/Data/Engine/auth/__init__.py b/Data/Engine/auth/__init__.py new file mode 100644 index 00000000..60a53575 --- /dev/null +++ b/Data/Engine/auth/__init__.py @@ -0,0 +1,27 @@ +# ====================================================== +# Data\Engine\auth\__init__.py +# Description: Engine-native authentication utilities and helpers decoupled from the legacy server modules. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""Authentication utility package for the Borealis Engine.""" + +from .jwt_service import JWTService, load_service +from .dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError +from .rate_limit import SlidingWindowRateLimiter, RateLimitDecision +from .device_auth import DeviceAuthManager, DeviceAuthError, DeviceAuthContext, require_device_auth + +__all__ = [ + "JWTService", + "load_service", + "DPoPValidator", + "DPoPVerificationError", + "DPoPReplayError", + "SlidingWindowRateLimiter", + "RateLimitDecision", + "DeviceAuthManager", + "DeviceAuthError", + "DeviceAuthContext", + "require_device_auth", +] diff --git a/Data/Engine/auth/device_auth.py b/Data/Engine/auth/device_auth.py new file mode 100644 index 00000000..00313406 --- /dev/null +++ b/Data/Engine/auth/device_auth.py @@ -0,0 +1,310 @@ +# ====================================================== +# Data\Engine\auth\device_auth.py +# Description: Engine-native device authentication manager and decorators. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""Device authentication helpers for the Borealis Engine runtime.""" + +from __future__ import annotations + +import functools +import sqlite3 +import time +from contextlib import closing +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Callable, Dict, Optional + +import jwt +from flask import g, jsonify, request + +from .dpop import DPoPReplayError, DPoPValidator, DPoPVerificationError +from .guid_utils import normalize_guid +from .rate_limit import SlidingWindowRateLimiter + +AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context" + + +def _canonical_context(value: Optional[str]) -> Optional[str]: + if not value: + return None + cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-")) + if not cleaned: + return None + return cleaned.upper() + + +@dataclass +class DeviceAuthContext: + guid: str + ssl_key_fingerprint: str + token_version: int + access_token: str + claims: Dict[str, Any] + dpop_jkt: Optional[str] + status: str + service_mode: Optional[str] + + +class DeviceAuthError(Exception): + status_code = 401 + error_code = "unauthorized" + + def __init__( + self, + message: str = "unauthorized", + *, + status_code: Optional[int] = None, + retry_after: Optional[float] = None, + ): + super().__init__(message) + if status_code is not None: + self.status_code = status_code + self.message = message + self.retry_after = retry_after + + +class DeviceAuthManager: + def __init__( + self, + *, + db_conn_factory: Callable[[], Any], + jwt_service, + dpop_validator: Optional[DPoPValidator], + log: Callable[[str, str, Optional[str]], None], + rate_limiter: Optional[SlidingWindowRateLimiter] = None, + ) -> None: + self._db_conn_factory = db_conn_factory + self._jwt_service = jwt_service + self._dpop_validator = dpop_validator + self._log = log + self._rate_limiter = rate_limiter + + def authenticate(self) -> DeviceAuthContext: + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + raise DeviceAuthError("missing_authorization") + token = auth_header[len("Bearer ") :].strip() + if not token: + raise DeviceAuthError("missing_authorization") + + try: + claims = self._jwt_service.decode(token) + except jwt.ExpiredSignatureError: + raise DeviceAuthError("token_expired") + except Exception: + raise DeviceAuthError("invalid_token") + + raw_guid = str(claims.get("guid") or "").strip() + guid = normalize_guid(raw_guid) + fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip() + token_version = int(claims.get("token_version") or 0) + if not guid or not fingerprint or token_version <= 0: + raise DeviceAuthError("invalid_claims") + + if self._rate_limiter: + decision = self._rate_limiter.check(f"fp:{fingerprint}", 60, 60.0) + if not decision.allowed: + raise DeviceAuthError( + "rate_limited", + status_code=429, + retry_after=decision.retry_after, + ) + + context_label = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER)) + + with closing(self._db_conn_factory()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT guid, ssl_key_fingerprint, token_version, status + FROM devices + WHERE UPPER(guid) = ? + """, + (guid,), + ) + rows = cur.fetchall() + row = None + for candidate in rows or []: + candidate_guid = normalize_guid(candidate[0]) + if candidate_guid == guid: + row = candidate + break + if row is None and rows: + row = rows[0] + + if row is None: + row = self._recover_device_record(conn, guid, fingerprint, token_version, context_label) + + if row is None: + raise DeviceAuthError("device_not_found", status_code=404) + + stored_guid, stored_fingerprint, stored_version, status = row + stored_guid = normalize_guid(stored_guid) + if stored_guid != guid: + raise DeviceAuthError("device_mismatch", status_code=401) + if (stored_fingerprint or "").lower().strip() != fingerprint: + raise DeviceAuthError("fingerprint_mismatch", status_code=403) + if int(stored_version or 0) != token_version: + raise DeviceAuthError("token_version_mismatch", status_code=403) + + status_norm = (status or "active").strip().lower() + if status_norm in {"revoked", "decommissioned"}: + raise DeviceAuthError("device_revoked", status_code=403) + + dpop_proof = request.headers.get("DPoP") + jkt = None + if dpop_proof and self._dpop_validator: + try: + jkt = self._dpop_validator.verify(request.method, request.url, dpop_proof, access_token=token) + except DPoPReplayError: + raise DeviceAuthError("dpop_replayed", status_code=400) + except DPoPVerificationError: + raise DeviceAuthError("dpop_invalid", status_code=400) + + ctx = DeviceAuthContext( + guid=guid, + ssl_key_fingerprint=fingerprint, + token_version=token_version, + access_token=token, + claims=claims, + dpop_jkt=jkt, + status=status_norm, + service_mode=context_label, + ) + return ctx + + def _recover_device_record( + self, + conn: sqlite3.Connection, + guid: str, + fingerprint: str, + token_version: int, + context_label: Optional[str], + ) -> Optional[tuple]: + """Attempt to recreate a missing device row for an authenticated token.""" + + guid = normalize_guid(guid) + fingerprint = (fingerprint or "").strip() + if not guid or not fingerprint: + return None + + cur = conn.cursor() + now_ts = int(time.time()) + try: + now_iso = datetime.now(tz=timezone.utc).isoformat() + except Exception: + now_iso = datetime.utcnow().isoformat() + + base_hostname = f"RECOVERED-{guid[:12].upper()}" if guid else "RECOVERED" + + for attempt in range(6): + hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}" + try: + cur.execute( + """ + INSERT INTO devices ( + guid, + hostname, + created_at, + last_seen, + ssl_key_fingerprint, + token_version, + status, + key_added_at + ) + VALUES (?, ?, ?, ?, ?, ?, 'active', ?) + """, + ( + guid, + hostname, + now_ts, + now_ts, + fingerprint, + max(token_version or 1, 1), + now_iso, + ), + ) + except sqlite3.IntegrityError as exc: + message = str(exc).lower() + if "hostname" in message and "unique" in message: + continue + self._log( + "server", + f"device auth failed to recover guid={guid} due to integrity error: {exc}", + context_label, + ) + conn.rollback() + return None + except Exception as exc: + self._log( + "server", + f"device auth unexpected error recovering guid={guid}: {exc}", + context_label, + ) + conn.rollback() + return None + else: + conn.commit() + break + else: + self._log( + "server", + f"device auth could not recover guid={guid}; hostname collisions persisted", + context_label, + ) + conn.rollback() + return None + + cur.execute( + """ + SELECT guid, ssl_key_fingerprint, token_version, status + FROM devices + WHERE guid = ? + """, + (guid,), + ) + row = cur.fetchone() + if not row: + self._log( + "server", + f"device auth recovery for guid={guid} committed but row still missing", + context_label, + ) + return row + + +def require_device_auth(manager: DeviceAuthManager): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + ctx = manager.authenticate() + except DeviceAuthError as exc: + response = jsonify({"error": exc.message}) + response.status_code = exc.status_code + retry_after = getattr(exc, "retry_after", None) + if retry_after: + try: + response.headers["Retry-After"] = str(max(1, int(retry_after))) + except Exception: + response.headers["Retry-After"] = "1" + return response + + g.device_auth = ctx + return func(*args, **kwargs) + + return wrapper + + return decorator + + +__all__ = [ + "AGENT_CONTEXT_HEADER", + "DeviceAuthContext", + "DeviceAuthError", + "DeviceAuthManager", + "require_device_auth", +] diff --git a/Data/Engine/auth/dpop.py b/Data/Engine/auth/dpop.py new file mode 100644 index 00000000..26d4b646 --- /dev/null +++ b/Data/Engine/auth/dpop.py @@ -0,0 +1,118 @@ +# ====================================================== +# Data\Engine\auth\dpop.py +# Description: Engine-side DPoP proof validation helpers with replay protection. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""DPoP proof verification helpers for the Engine runtime.""" + +from __future__ import annotations + +import hashlib +import time +from threading import Lock +from typing import Dict, Optional + +import jwt + +_DPOP_MAX_SKEW = 300.0 # seconds + + +class DPoPVerificationError(Exception): + """Raised when DPoP verification fails for structural reasons.""" + + +class DPoPReplayError(DPoPVerificationError): + """Raised when a DPoP proof replay is detected.""" + + +class DPoPValidator: + """Validate DPoP proofs and track observed JTIs to prevent replay attacks.""" + + def __init__(self) -> None: + self._observed_jti: Dict[str, float] = {} + self._lock = Lock() + + def verify( + self, + method: str, + htu: str, + proof: str, + access_token: Optional[str] = None, + ) -> str: + """ + Verify the presented DPoP proof and return the JWK thumbprint. + """ + + if not proof: + raise DPoPVerificationError("DPoP proof missing") + + try: + header = jwt.get_unverified_header(proof) + except Exception as exc: + raise DPoPVerificationError("invalid DPoP header") from exc + + jwk = header.get("jwk") + alg = header.get("alg") + if not jwk or not isinstance(jwk, dict): + raise DPoPVerificationError("missing jwk in DPoP header") + if alg not in ("EdDSA", "ES256", "ES384", "ES512"): + raise DPoPVerificationError(f"unsupported DPoP alg {alg}") + + try: + key = jwt.PyJWK(jwk) + public_key = key.key + except Exception as exc: + raise DPoPVerificationError("invalid jwk in DPoP header") from exc + + try: + claims = jwt.decode( + proof, + public_key, + algorithms=[alg], + options={"require": ["htm", "htu", "jti", "iat"]}, + ) + except Exception as exc: + raise DPoPVerificationError("invalid DPoP signature") from exc + + htm = claims.get("htm") + proof_htu = claims.get("htu") + jti = claims.get("jti") + iat = claims.get("iat") + ath = claims.get("ath") + + if not isinstance(htm, str) or htm.lower() != method.lower(): + raise DPoPVerificationError("DPoP htm mismatch") + if not isinstance(proof_htu, str) or proof_htu != htu: + raise DPoPVerificationError("DPoP htu mismatch") + if not isinstance(jti, str): + raise DPoPVerificationError("DPoP jti missing") + if not isinstance(iat, (int, float)): + raise DPoPVerificationError("DPoP iat missing") + + now = time.time() + if abs(now - float(iat)) > _DPOP_MAX_SKEW: + raise DPoPVerificationError("DPoP proof outside allowed skew") + + if ath and access_token: + expected_ath = jwt.utils.base64url_encode( + hashlib.sha256(access_token.encode("utf-8")).digest() + ).decode("ascii") + if expected_ath != ath: + raise DPoPVerificationError("DPoP ath mismatch") + + with self._lock: + expiry = self._observed_jti.get(jti) + if expiry and expiry > now: + raise DPoPReplayError("DPoP proof replay detected") + self._observed_jti[jti] = now + _DPOP_MAX_SKEW + stale = [key for key, exp in self._observed_jti.items() if exp <= now] + for key in stale: + self._observed_jti.pop(key, None) + + thumbprint = jwt.PyJWK(jwk).thumbprint() + return thumbprint.decode("ascii") + + +__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"] diff --git a/Data/Engine/auth/guid_utils.py b/Data/Engine/auth/guid_utils.py new file mode 100644 index 00000000..65c37437 --- /dev/null +++ b/Data/Engine/auth/guid_utils.py @@ -0,0 +1,39 @@ +# ====================================================== +# Data\Engine\auth\guid_utils.py +# Description: GUID normalisation helpers used by Engine authentication flows. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""GUID normalisation helpers for Engine-managed authentication.""" + +from __future__ import annotations + +import string +import uuid +from typing import Optional + + +def normalize_guid(value: Optional[str]) -> str: + """ + Canonicalise GUID strings so Engine services treat different casings uniformly. + """ + + candidate = (value or "").strip() + if not candidate: + return "" + candidate = candidate.strip("{}") + try: + return str(uuid.UUID(candidate)).upper() + except Exception: + cleaned = "".join(ch for ch in candidate if ch in string.hexdigits or ch == "-") + cleaned = cleaned.strip("-") + if cleaned: + try: + return str(uuid.UUID(cleaned)).upper() + except Exception: + pass + return candidate.upper() + + +__all__ = ["normalize_guid"] diff --git a/Data/Engine/auth/jwt_service.py b/Data/Engine/auth/jwt_service.py new file mode 100644 index 00000000..9098e5d0 --- /dev/null +++ b/Data/Engine/auth/jwt_service.py @@ -0,0 +1,206 @@ +# ====================================================== +# Data\Engine\auth\jwt_service.py +# Description: Engine-native JWT access-token helpers with signing key storage under Engine/Data/Auth_Tokens. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""JWT access-token helpers backed by an Engine-managed Ed25519 key.""" + +from __future__ import annotations + +import hashlib +import os +import time +from pathlib import Path +from typing import Any, Dict, Optional + +import jwt +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 + +from ..security.certificates import project_root_path + +_TOKEN_ENV_ROOT = "BOREALIS_ENGINE_AUTH_TOKEN_ROOT" +_LEGACY_SERVER_ROOT_ENV = "BOREALIS_SERVER_ROOT" +_KEY_FILENAME = "borealis-jwt-ed25519.key" + + +def _env_path(name: str) -> Optional[Path]: + value = os.environ.get(name) + if not value: + return None + try: + return Path(value).expanduser().resolve() + except Exception: + try: + return Path(value).expanduser() + except Exception: + return Path(value) + + +def _engine_runtime_root() -> Path: + env = _env_path("BOREALIS_ENGINE_ROOT") or _env_path("BOREALIS_ENGINE_RUNTIME") + if env: + env.mkdir(parents=True, exist_ok=True) + return env + root = project_root_path() / "Engine" + root.mkdir(parents=True, exist_ok=True) + return root + + +def _token_root() -> Path: + env = _env_path(_TOKEN_ENV_ROOT) + if env: + env.mkdir(parents=True, exist_ok=True) + return env + root = _engine_runtime_root() / "Data" / "Auth_Tokens" + root.mkdir(parents=True, exist_ok=True) + return root + + +def _legacy_key_paths() -> Dict[str, Path]: + project_root = project_root_path() + server_root = _env_path(_LEGACY_SERVER_ROOT_ENV) or (project_root / "Server" / "Borealis") + candidates = { + "auth_keys": server_root / "auth_keys" / _KEY_FILENAME, + "keys": server_root / "keys" / _KEY_FILENAME, + } + return candidates + + +def _tighten_permissions(path: Path) -> None: + try: + if os.name != "nt": + path.chmod(0o600) + except Exception: + pass + + +_KEY_DIR = _token_root() +_KEY_FILE = _KEY_DIR / _KEY_FILENAME + + +class JWTService: + def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str): + self._private_key = private_key + self._public_key = private_key.public_key() + self._key_id = key_id + + @property + def key_id(self) -> str: + return self._key_id + + def issue_access_token( + self, + guid: str, + ssl_key_fingerprint: str, + token_version: int, + expires_in: int = 900, + extra_claims: Optional[Dict[str, Any]] = None, + ) -> str: + now = int(time.time()) + payload: Dict[str, Any] = { + "sub": f"device:{guid}", + "guid": guid, + "ssl_key_fingerprint": ssl_key_fingerprint, + "token_version": int(token_version), + "iat": now, + "nbf": now, + "exp": now + int(expires_in), + } + if extra_claims: + payload.update(extra_claims) + + token = jwt.encode( + payload, + self._private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ), + algorithm="EdDSA", + headers={"kid": self._key_id}, + ) + return token + + def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]: + options = {"require": ["exp", "iat", "sub"]} + public_pem = self._public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return jwt.decode( + token, + public_pem, + algorithms=["EdDSA"], + audience=audience, + options=options, + ) + + def public_jwk(self) -> Dict[str, Any]: + public_bytes = self._public_key.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, + ) + jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii") + return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x} + + +def load_service() -> JWTService: + private_key = _load_or_create_private_key() + public_bytes = private_key.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + key_id = hashlib.sha256(public_bytes).hexdigest()[:16] + return JWTService(private_key, key_id) + + +def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey: + _KEY_DIR.mkdir(parents=True, exist_ok=True) + _migrate_legacy_key_if_present() + + if _KEY_FILE.exists(): + with _KEY_FILE.open("rb") as fh: + return serialization.load_pem_private_key(fh.read(), password=None) + + private_key = ed25519.Ed25519PrivateKey.generate() + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + with _KEY_FILE.open("wb") as fh: + fh.write(pem) + _tighten_permissions(_KEY_FILE) + return private_key + + +def _migrate_legacy_key_if_present() -> None: + if _KEY_FILE.exists(): + return + + legacy_paths = _legacy_key_paths() + for legacy_file in legacy_paths.values(): + if not legacy_file.exists(): + continue + try: + legacy_bytes = legacy_file.read_bytes() + except Exception: + continue + + try: + _KEY_FILE.write_bytes(legacy_bytes) + _tighten_permissions(_KEY_FILE) + except Exception: + continue + + try: + legacy_file.unlink() + except Exception: + pass + break + + +__all__ = ["JWTService", "load_service"] diff --git a/Data/Engine/auth/rate_limit.py b/Data/Engine/auth/rate_limit.py new file mode 100644 index 00000000..a872c1c9 --- /dev/null +++ b/Data/Engine/auth/rate_limit.py @@ -0,0 +1,51 @@ +# ====================================================== +# Data\Engine\auth\rate_limit.py +# Description: Sliding-window rate limiter for Engine authentication endpoints. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""Engine-native in-memory rate limiter suitable for single-process development.""" + +from __future__ import annotations + +import time +from collections import deque +from dataclasses import dataclass +from threading import Lock +from typing import Deque, Dict + + +@dataclass +class RateLimitDecision: + allowed: bool + retry_after: float + + +class SlidingWindowRateLimiter: + """Simple sliding-window limiter to guard authentication endpoints.""" + + def __init__(self) -> None: + self._buckets: Dict[str, Deque[float]] = {} + self._lock = Lock() + + def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision: + now = time.monotonic() + with self._lock: + bucket = self._buckets.get(key) + if bucket is None: + bucket = deque() + self._buckets[key] = bucket + + while bucket and now - bucket[0] > window_seconds: + bucket.popleft() + + if len(bucket) >= limit: + retry_after = max(0.0, window_seconds - (now - bucket[0])) + return RateLimitDecision(False, retry_after) + + bucket.append(now) + return RateLimitDecision(True, 0.0) + + +__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"] diff --git a/Data/Engine/crypto/__init__.py b/Data/Engine/crypto/__init__.py new file mode 100644 index 00000000..498499a0 --- /dev/null +++ b/Data/Engine/crypto/__init__.py @@ -0,0 +1,30 @@ +# ====================================================== +# Data\Engine\crypto\__init__.py +# Description: Engine cryptographic helpers and key utilities. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""Cryptographic helper utilities for the Borealis Engine runtime.""" + +from .keys import ( + generate_ed25519_keypair, + normalize_base64, + spki_der_from_base64, + base64_from_spki_der, + fingerprint_from_spki_der, + fingerprint_from_base64_spki, + private_key_to_pem, + public_key_to_pem, +) + +__all__ = [ + "generate_ed25519_keypair", + "normalize_base64", + "spki_der_from_base64", + "base64_from_spki_der", + "fingerprint_from_spki_der", + "fingerprint_from_base64_spki", + "private_key_to_pem", + "public_key_to_pem", +] diff --git a/Data/Engine/crypto/keys.py b/Data/Engine/crypto/keys.py new file mode 100644 index 00000000..69602069 --- /dev/null +++ b/Data/Engine/crypto/keys.py @@ -0,0 +1,88 @@ +# ====================================================== +# Data\Engine\crypto\keys.py +# Description: Engine-native Ed25519 key helpers and fingerprint utilities. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""Utility helpers for working with Ed25519 keys and fingerprints.""" + +from __future__ import annotations + +import base64 +import hashlib +import re +from typing import Tuple + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 +from cryptography.hazmat.primitives.serialization import load_der_public_key + + +def generate_ed25519_keypair() -> Tuple[ed25519.Ed25519PrivateKey, bytes]: + """ + Generate a new Ed25519 keypair. + + Returns the private key object and the public key encoded as SubjectPublicKeyInfo DER bytes. + """ + + private_key = ed25519.Ed25519PrivateKey.generate() + public_key = private_key.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return private_key, public_key + + +def normalize_base64(data: str) -> str: + """ + Collapse whitespace and normalise URL-safe encodings so we can reliably decode. + """ + + cleaned = re.sub(r"\s+", "", data or "") + return cleaned.replace("-", "+").replace("_", "/") + + +def spki_der_from_base64(spki_b64: str) -> bytes: + return base64.b64decode(normalize_base64(spki_b64), validate=True) + + +def base64_from_spki_der(spki_der: bytes) -> str: + return base64.b64encode(spki_der).decode("ascii") + + +def fingerprint_from_spki_der(spki_der: bytes) -> str: + digest = hashlib.sha256(spki_der).hexdigest() + return digest.lower() + + +def fingerprint_from_base64_spki(spki_b64: str) -> str: + return fingerprint_from_spki_der(spki_der_from_base64(spki_b64)) + + +def private_key_to_pem(private_key: ed25519.Ed25519PrivateKey) -> bytes: + return private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + +def public_key_to_pem(public_spki_der: bytes) -> bytes: + public_key = load_der_public_key(public_spki_der) + return public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + +__all__ = [ + "generate_ed25519_keypair", + "normalize_base64", + "spki_der_from_base64", + "base64_from_spki_der", + "fingerprint_from_spki_der", + "fingerprint_from_base64_spki", + "private_key_to_pem", + "public_key_to_pem", +] diff --git a/Data/Engine/enrollment/__init__.py b/Data/Engine/enrollment/__init__.py new file mode 100644 index 00000000..0bdc3679 --- /dev/null +++ b/Data/Engine/enrollment/__init__.py @@ -0,0 +1,12 @@ +# ====================================================== +# Data\Engine\enrollment\__init__.py +# Description: Enrollment utilities for Engine-managed device onboarding. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""Enrollment helper utilities for the Borealis Engine runtime.""" + +from .nonce_store import NonceCache + +__all__ = ["NonceCache"] diff --git a/Data/Engine/enrollment/nonce_store.py b/Data/Engine/enrollment/nonce_store.py new file mode 100644 index 00000000..5a132d5f --- /dev/null +++ b/Data/Engine/enrollment/nonce_store.py @@ -0,0 +1,42 @@ +# ====================================================== +# Data\Engine\enrollment\nonce_store.py +# Description: Short-lived nonce cache preventing replay during Engine enrollment flows. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""Short-lived nonce cache to defend against enrollment replay attacks.""" + +from __future__ import annotations + +import time +from threading import Lock +from typing import Dict + + +class NonceCache: + def __init__(self, ttl_seconds: float = 300.0) -> None: + self._ttl = ttl_seconds + self._entries: Dict[str, float] = {} + self._lock = Lock() + + def consume(self, key: str) -> bool: + """ + Attempt to consume the nonce identified by `key`. + + Returns True on first use within TTL, False if already consumed. + """ + + now = time.monotonic() + with self._lock: + expire_at = self._entries.get(key) + if expire_at and expire_at > now: + return False + self._entries[key] = now + self._ttl + stale = [nonce for nonce, expiry in self._entries.items() if expiry <= now] + for nonce in stale: + self._entries.pop(nonce, None) + return True + + +__all__ = ["NonceCache"] diff --git a/Data/Engine/services/API/__init__.py b/Data/Engine/services/API/__init__.py index 6b353dd0..3692d991 100644 --- a/Data/Engine/services/API/__init__.py +++ b/Data/Engine/services/API/__init__.py @@ -1,6 +1,6 @@ # ====================================================== # Data\Engine\services\API\__init__.py -# Description: Registers Engine API groups and bridges to legacy modules while exposing core utility routes. +# Description: Registers Engine API groups, wiring Engine-native authentication while delegating remaining legacy modules. # # API Endpoints (if applicable): # - GET /health (No Authentication) - Returns an OK status for liveness probing. @@ -20,15 +20,15 @@ from typing import Any, Callable, Iterable, Mapping, Optional, Sequence from flask import Blueprint, Flask, jsonify -from Modules.auth import jwt_service as jwt_service_module -from Modules.auth.device_auth import DeviceAuthManager -from Modules.auth.dpop import DPoPValidator -from Modules.auth.rate_limit import SlidingWindowRateLimiter +from ...auth import jwt_service as jwt_service_module +from ...auth.device_auth import DeviceAuthManager +from ...auth.dpop import DPoPValidator +from ...auth.rate_limit import SlidingWindowRateLimiter from ...database import initialise_engine_database from ...security import signing -from Modules.enrollment import routes as enrollment_routes -from Modules.enrollment.nonce_store import NonceCache -from Modules.tokens import routes as token_routes +from ...enrollment import NonceCache +from .enrollment import routes as enrollment_routes +from .tokens import routes as token_routes from ...server import EngineContext from .access_management.login import register_auth @@ -137,7 +137,7 @@ def _make_db_conn_factory(database_path: str) -> Callable[[], sqlite3.Connection @dataclass -class LegacyServiceAdapters: +class EngineServiceAdapters: context: EngineContext db_conn_factory: Callable[[], sqlite3.Connection] = field(init=False) jwt_service: Any = field(init=False) @@ -180,7 +180,7 @@ class LegacyServiceAdapters: ) -def _register_tokens(app: Flask, adapters: LegacyServiceAdapters) -> None: +def _register_tokens(app: Flask, adapters: EngineServiceAdapters) -> None: token_routes.register( app, db_conn_factory=adapters.db_conn_factory, @@ -189,7 +189,7 @@ def _register_tokens(app: Flask, adapters: LegacyServiceAdapters) -> None: ) -def _register_enrollment(app: Flask, adapters: LegacyServiceAdapters) -> None: +def _register_enrollment(app: Flask, adapters: EngineServiceAdapters) -> None: tls_bundle = adapters.context.tls_bundle_path or "" enrollment_routes.register( app, @@ -204,12 +204,12 @@ def _register_enrollment(app: Flask, adapters: LegacyServiceAdapters) -> None: ) -def _register_devices(app: Flask, adapters: LegacyServiceAdapters) -> None: +def _register_devices(app: Flask, adapters: EngineServiceAdapters) -> None: register_management(app, adapters) register_admin_endpoints(app, adapters) -_GROUP_REGISTRARS: Mapping[str, Callable[[Flask, LegacyServiceAdapters], None]] = { +_GROUP_REGISTRARS: Mapping[str, Callable[[Flask, EngineServiceAdapters], None]] = { "auth": register_auth, "tokens": _register_tokens, "enrollment": _register_enrollment, @@ -236,7 +236,7 @@ def register_api(app: Flask, context: EngineContext) -> None: enabled_groups: Iterable[str] = context.api_groups or DEFAULT_API_GROUPS normalized = [group.strip().lower() for group in enabled_groups if group] - adapters: Optional[LegacyServiceAdapters] = None + adapters: Optional[EngineServiceAdapters] = None for group in normalized: if group == "core": @@ -244,7 +244,7 @@ def register_api(app: Flask, context: EngineContext) -> None: continue if adapters is None: - adapters = LegacyServiceAdapters(context) + adapters = EngineServiceAdapters(context) registrar = _GROUP_REGISTRARS.get(group) if registrar is None: context.logger.info("Engine API group '%s' is not implemented; skipping.", group) diff --git a/Data/Engine/services/API/access_management/login.py b/Data/Engine/services/API/access_management/login.py index 65848f15..af571555 100644 --- a/Data/Engine/services/API/access_management/login.py +++ b/Data/Engine/services/API/access_management/login.py @@ -35,7 +35,7 @@ except Exception: # pragma: no cover - optional dependency qrcode = None # type: ignore if TYPE_CHECKING: # pragma: no cover - typing helper - from Data.Engine.services.API import LegacyServiceAdapters + from Data.Engine.services.API import EngineServiceAdapters def _now_ts() -> int: @@ -103,7 +103,7 @@ def _user_row_to_dict(row: Sequence[Any]) -> Mapping[str, Any]: class _AuthService: - def __init__(self, app: Flask, adapters: "LegacyServiceAdapters") -> None: + def __init__(self, app: Flask, adapters: "EngineServiceAdapters") -> None: self.app = app self.adapters = adapters self.context = adapters.context @@ -393,7 +393,7 @@ class _AuthService: ) -def register_auth(app: Flask, adapters: "LegacyServiceAdapters") -> None: +def register_auth(app: Flask, adapters: "EngineServiceAdapters") -> None: """Register authentication endpoints for the Engine.""" service = _AuthService(app, adapters) @@ -416,3 +416,4 @@ def register_auth(app: Flask, adapters: "LegacyServiceAdapters") -> None: return service.me() app.register_blueprint(blueprint) + diff --git a/Data/Engine/services/API/assemblies/management.py b/Data/Engine/services/API/assemblies/management.py index 8130b60e..4349151d 100644 --- a/Data/Engine/services/API/assemblies/management.py +++ b/Data/Engine/services/API/assemblies/management.py @@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Opti from flask import Blueprint, jsonify, request if TYPE_CHECKING: # pragma: no cover - typing aide - from .. import LegacyServiceAdapters + from .. import EngineServiceAdapters _ISLAND_DIR_MAP: Mapping[str, str] = { @@ -49,7 +49,7 @@ _BASE64_CLEANER = re.compile(r"\s+") class AssemblyManagementService: """Implements assembly CRUD helpers for Engine routes.""" - def __init__(self, adapters: "LegacyServiceAdapters") -> None: + def __init__(self, adapters: "EngineServiceAdapters") -> None: self.adapters = adapters self.logger = adapters.context.logger or logging.getLogger(__name__) self.service_log = adapters.service_log @@ -679,7 +679,7 @@ class AssemblyManagementService: return obj -def register_assemblies(app, adapters: "LegacyServiceAdapters") -> None: +def register_assemblies(app, adapters: "EngineServiceAdapters") -> None: """Register assembly CRUD endpoints on the Flask app.""" service = AssemblyManagementService(adapters) @@ -726,3 +726,4 @@ def register_assemblies(app, adapters: "LegacyServiceAdapters") -> None: return jsonify(response), status app.register_blueprint(blueprint) + diff --git a/Data/Engine/services/API/authentication.py b/Data/Engine/services/API/authentication.py index 5b173c3c..d57eb87b 100644 --- a/Data/Engine/services/API/authentication.py +++ b/Data/Engine/services/API/authentication.py @@ -35,7 +35,7 @@ except Exception: # pragma: no cover - optional dependency qrcode = None # type: ignore if TYPE_CHECKING: # pragma: no cover - typing helper - from . import LegacyServiceAdapters + from . import EngineServiceAdapters def _now_ts() -> int: @@ -103,7 +103,7 @@ def _user_row_to_dict(row: Sequence[Any]) -> Mapping[str, Any]: class _AuthService: - def __init__(self, app: Flask, adapters: "LegacyServiceAdapters") -> None: + def __init__(self, app: Flask, adapters: "EngineServiceAdapters") -> None: self.app = app self.adapters = adapters self.context = adapters.context @@ -398,7 +398,7 @@ class _AuthService: ) -def register_auth(app: Flask, adapters: "LegacyServiceAdapters") -> None: +def register_auth(app: Flask, adapters: "EngineServiceAdapters") -> None: """Register authentication endpoints for the Engine.""" service = _AuthService(app, adapters) @@ -422,3 +422,4 @@ def register_auth(app: Flask, adapters: "LegacyServiceAdapters") -> None: app.register_blueprint(blueprint) adapters.context.logger.info("Engine registered API group 'auth'.") + diff --git a/Data/Engine/services/API/devices/approval.py b/Data/Engine/services/API/devices/approval.py index fe14801a..8a61899a 100644 --- a/Data/Engine/services/API/devices/approval.py +++ b/Data/Engine/services/API/devices/approval.py @@ -24,10 +24,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from flask import Blueprint, jsonify, request, session from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer -from Modules.guid_utils import normalize_guid +from ....auth.guid_utils import normalize_guid if TYPE_CHECKING: # pragma: no cover - typing helper - from .. import LegacyServiceAdapters + from .. import EngineServiceAdapters VALID_TTL_HOURS = {1, 3, 6, 12, 24} @@ -49,7 +49,7 @@ def _generate_install_code() -> str: class AdminDeviceService: """Utility wrapper for admin device APIs.""" - def __init__(self, app, adapters: "LegacyServiceAdapters") -> None: + def __init__(self, app, adapters: "EngineServiceAdapters") -> None: self.app = app self.adapters = adapters self.db_conn_factory = adapters.db_conn_factory @@ -477,7 +477,7 @@ class AdminDeviceService: return self._set_approval_status(approval_id, "denied") -def register_admin_endpoints(app, adapters: "LegacyServiceAdapters") -> None: +def register_admin_endpoints(app, adapters: "EngineServiceAdapters") -> None: """Register admin enrollment + approval endpoints.""" service = AdminDeviceService(app, adapters) @@ -532,3 +532,4 @@ def register_admin_endpoints(app, adapters: "LegacyServiceAdapters") -> None: return jsonify(payload), status app.register_blueprint(blueprint) + diff --git a/Data/Engine/services/API/devices/enrollment.py b/Data/Engine/services/API/devices/enrollment.py deleted file mode 100644 index 687a10b6..00000000 --- a/Data/Engine/services/API/devices/enrollment.py +++ /dev/null @@ -1,8 +0,0 @@ -# ====================================================== -# Data\Engine\services\API\devices\enrollment.py -# Description: Placeholder for device enrollment API bridge (not yet implemented). -# -# API Endpoints (if applicable): None -# ====================================================== - -"Placeholder for API module devices/enrollment.py." diff --git a/Data/Engine/services/API/devices/management.py b/Data/Engine/services/API/devices/management.py index ec017a29..5a8d0bc5 100644 --- a/Data/Engine/services/API/devices/management.py +++ b/Data/Engine/services/API/devices/management.py @@ -41,8 +41,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from flask import Blueprint, jsonify, request, session, g from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer -from Modules.auth.device_auth import require_device_auth -from Modules.guid_utils import normalize_guid +from ....auth.device_auth import require_device_auth +from ....auth.guid_utils import normalize_guid try: import requests # type: ignore @@ -57,7 +57,7 @@ except ImportError: # pragma: no cover - fallback for minimal test environments requests = _RequestsStub() # type: ignore if TYPE_CHECKING: # pragma: no cover - typing aide - from .. import LegacyServiceAdapters + from .. import EngineServiceAdapters def _safe_json(raw: Optional[str], default: Any) -> Any: @@ -340,7 +340,7 @@ def _device_upsert( class RepositoryHashCache: """Lightweight GitHub head cache with on-disk persistence.""" - def __init__(self, adapters: "LegacyServiceAdapters") -> None: + def __init__(self, adapters: "EngineServiceAdapters") -> None: self._db_conn_factory = adapters.db_conn_factory self._service_log = adapters.service_log self._logger = adapters.context.logger @@ -617,7 +617,7 @@ class DeviceManagementService: "connection_endpoint", ) - def __init__(self, app, adapters: "LegacyServiceAdapters") -> None: + def __init__(self, app, adapters: "EngineServiceAdapters") -> None: self.app = app self.adapters = adapters self.db_conn_factory = adapters.db_conn_factory @@ -1513,7 +1513,7 @@ class DeviceManagementService: finally: conn.close() -def register_management(app, adapters: "LegacyServiceAdapters") -> None: +def register_management(app, adapters: "EngineServiceAdapters") -> None: """Register device management endpoints onto the Flask app.""" service = DeviceManagementService(app, adapters) @@ -1679,3 +1679,4 @@ def register_management(app, adapters: "LegacyServiceAdapters") -> None: return jsonify(payload), status app.register_blueprint(blueprint) + diff --git a/Data/Engine/services/API/enrollment/__init__.py b/Data/Engine/services/API/enrollment/__init__.py new file mode 100644 index 00000000..58e22e17 --- /dev/null +++ b/Data/Engine/services/API/enrollment/__init__.py @@ -0,0 +1,12 @@ +# ====================================================== +# Data\Engine\services\API\enrollment\__init__.py +# Description: Engine enrollment API registration helpers. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""Expose Engine-native enrollment API routes.""" + +from .routes import register + +__all__ = ["register"] diff --git a/Data/Engine/services/API/enrollment/routes.py b/Data/Engine/services/API/enrollment/routes.py new file mode 100644 index 00000000..b5b57a75 --- /dev/null +++ b/Data/Engine/services/API/enrollment/routes.py @@ -0,0 +1,744 @@ +# ====================================================== +# Data\Engine\services\API\enrollment\routes.py +# Description: Engine-native device enrollment endpoints handling install codes, approvals, and token issuance. +# +# API Endpoints (if applicable): +# - POST /api/agent/enroll/request (No Authentication) - Submits device enrollment requests. +# - POST /api/agent/enroll/poll (No Authentication) - Finalises approved enrollment requests. +# ====================================================== + +"""Device enrollment routes for the Borealis Engine runtime.""" +from __future__ import annotations + +import base64 +import secrets +import sqlite3 +import uuid +from datetime import datetime, timezone, timedelta +import time +from typing import Any, Callable, Dict, Optional, Tuple + +AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context" + + +def _canonical_context(value: Optional[str]) -> Optional[str]: + if not value: + return None + cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-")) + if not cleaned: + return None + return cleaned.upper() + +from flask import Blueprint, jsonify, request + +from ....auth.rate_limit import SlidingWindowRateLimiter +from ....crypto import keys as crypto_keys +from ....enrollment.nonce_store import NonceCache +from ....auth.guid_utils import normalize_guid +from cryptography.hazmat.primitives import serialization + + +def register( + app, + *, + db_conn_factory: Callable[[], sqlite3.Connection], + log: Callable[[str, str, Optional[str]], None], + jwt_service, + tls_bundle_path: str, + ip_rate_limiter: SlidingWindowRateLimiter, + fp_rate_limiter: SlidingWindowRateLimiter, + nonce_cache: NonceCache, + script_signer, +) -> None: + blueprint = Blueprint("enrollment", __name__) + + def _now() -> datetime: + return datetime.now(tz=timezone.utc) + + def _iso(dt: datetime) -> str: + return dt.isoformat() + + def _remote_addr() -> str: + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + addr = request.remote_addr or "unknown" + return addr.strip() + + def _signing_key_b64() -> str: + if not script_signer: + return "" + try: + return script_signer.public_base64_spki() + except Exception: + return "" + + def _rate_limited( + key: str, + limiter: SlidingWindowRateLimiter, + limit: int, + window_s: float, + context_hint: Optional[str], + ): + decision = limiter.check(key, limit, window_s) + if not decision.allowed: + log( + "server", + f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}", + context_hint, + ) + response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after}) + response.status_code = 429 + response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}" + return response + return None + + def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]: + cur.execute( + """ + SELECT id, + code, + expires_at, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at + FROM enrollment_install_codes + WHERE code = ? + """, + (code_value,), + ) + row = cur.fetchone() + if not row: + return None + keys = [ + "id", + "code", + "expires_at", + "used_at", + "used_by_guid", + "max_uses", + "use_count", + "last_used_at", + ] + record = dict(zip(keys, row)) + return record + + def _install_code_valid( + record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor + ) -> Tuple[bool, Optional[str]]: + if not record: + return False, None + expires_at = record.get("expires_at") + if not isinstance(expires_at, str): + return False, None + try: + expiry = datetime.fromisoformat(expires_at) + except Exception: + return False, None + if expiry <= _now(): + return False, None + try: + max_uses = int(record.get("max_uses") or 1) + except Exception: + max_uses = 1 + if max_uses < 1: + max_uses = 1 + try: + use_count = int(record.get("use_count") or 0) + except Exception: + use_count = 0 + if use_count < max_uses: + return True, None + + guid = normalize_guid(record.get("used_by_guid")) + if not guid: + return False, None + cur.execute( + "SELECT ssl_key_fingerprint FROM devices WHERE UPPER(guid) = ?", + (guid,), + ) + row = cur.fetchone() + if not row: + return False, None + stored_fp = (row[0] or "").strip().lower() + if not stored_fp: + return False, None + if stored_fp == (fingerprint or "").strip().lower(): + return True, guid + return False, None + + def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str: + guid_norm = normalize_guid(guid) + base = (hostname or "").strip() or guid_norm + base = base[:253] + candidate = base + suffix = 1 + while True: + cur.execute( + "SELECT guid FROM devices WHERE hostname = ?", + (candidate,), + ) + row = cur.fetchone() + if not row: + return candidate + existing_guid = normalize_guid(row[0]) + if existing_guid == guid_norm: + return candidate + candidate = f"{base}-{suffix}" + suffix += 1 + if suffix > 50: + return guid_norm + + def _store_device_key(cur: sqlite3.Cursor, guid: str, fingerprint: str) -> None: + guid_norm = normalize_guid(guid) + added_at = _iso(_now()) + cur.execute( + """ + INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at) + VALUES (?, ?, ?, ?) + """, + (str(uuid.uuid4()), guid_norm, fingerprint, added_at), + ) + cur.execute( + """ + UPDATE device_keys + SET retired_at = ? + WHERE guid = ? + AND ssl_key_fingerprint != ? + AND retired_at IS NULL + """, + (_iso(_now()), guid_norm, fingerprint), + ) + + def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]: + guid_norm = normalize_guid(guid) + cur.execute( + """ + SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at + FROM devices + WHERE UPPER(guid) = ? + """, + (guid_norm,), + ) + row = cur.fetchone() + if row: + keys = [ + "guid", + "hostname", + "token_version", + "status", + "ssl_key_fingerprint", + "key_added_at", + ] + record = dict(zip(keys, row)) + record["guid"] = normalize_guid(record.get("guid")) + stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower() + new_fp = (fingerprint or "").strip().lower() + if not stored_fp and new_fp: + cur.execute( + "UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?", + (fingerprint, _iso(_now()), record["guid"]), + ) + record["ssl_key_fingerprint"] = fingerprint + elif new_fp and stored_fp != new_fp: + now_iso = _iso(_now()) + try: + current_version = int(record.get("token_version") or 1) + except Exception: + current_version = 1 + new_version = max(current_version + 1, 1) + cur.execute( + """ + UPDATE devices + SET ssl_key_fingerprint = ?, + key_added_at = ?, + token_version = ?, + status = 'active' + WHERE guid = ? + """, + (fingerprint, now_iso, new_version, record["guid"]), + ) + cur.execute( + """ + UPDATE refresh_tokens + SET revoked_at = ? + WHERE guid = ? + AND revoked_at IS NULL + """, + (now_iso, record["guid"]), + ) + record["ssl_key_fingerprint"] = fingerprint + record["token_version"] = new_version + record["status"] = "active" + record["key_added_at"] = now_iso + return record + + resolved_hostname = _normalize_host(hostname, guid_norm, cur) + created_at = int(time.time()) + key_added_at = _iso(_now()) + cur.execute( + """ + INSERT INTO devices ( + guid, hostname, created_at, last_seen, ssl_key_fingerprint, + token_version, status, key_added_at + ) + VALUES (?, ?, ?, ?, ?, 1, 'active', ?) + """, + ( + guid_norm, + resolved_hostname, + created_at, + created_at, + fingerprint, + key_added_at, + ), + ) + return { + "guid": guid_norm, + "hostname": resolved_hostname, + "token_version": 1, + "status": "active", + "ssl_key_fingerprint": fingerprint, + "key_added_at": key_added_at, + } + + def _hash_refresh_token(token: str) -> str: + import hashlib + + return hashlib.sha256(token.encode("utf-8")).hexdigest() + + def _issue_refresh_token(cur: sqlite3.Cursor, guid: str) -> Dict[str, Any]: + token = secrets.token_urlsafe(48) + now = _now() + expires_at = now.replace(microsecond=0) + timedelta(days=30) + cur.execute( + """ + INSERT INTO refresh_tokens (id, guid, token_hash, created_at, expires_at) + VALUES (?, ?, ?, ?, ?) + """, + ( + str(uuid.uuid4()), + guid, + _hash_refresh_token(token), + _iso(now), + _iso(expires_at), + ), + ) + return {"token": token, "expires_at": expires_at} + + @blueprint.route("/api/agent/enroll/request", methods=["POST"]) + def enrollment_request(): + remote = _remote_addr() + context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER)) + + rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0, context_hint) + if rate_error: + return rate_error + + payload = request.get_json(force=True, silent=True) or {} + hostname = str(payload.get("hostname") or "").strip() + enrollment_code = str(payload.get("enrollment_code") or "").strip() + agent_pubkey_b64 = payload.get("agent_pubkey") + client_nonce_b64 = payload.get("client_nonce") + + log( + "server", + "enrollment request received " + f"ip={remote} hostname={hostname or ''} code_mask={_mask_code(enrollment_code)} " + f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}", + context_hint, + ) + + if not hostname: + log("server", f"enrollment rejected missing_hostname ip={remote}", context_hint) + return jsonify({"error": "hostname_required"}), 400 + if not enrollment_code: + log("server", f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint) + return jsonify({"error": "enrollment_code_required"}), 400 + if not isinstance(agent_pubkey_b64, str): + log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint) + return jsonify({"error": "agent_pubkey_required"}), 400 + if not isinstance(client_nonce_b64, str): + log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint) + return jsonify({"error": "client_nonce_required"}), 400 + + try: + agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64) + except Exception: + log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint) + return jsonify({"error": "invalid_agent_pubkey"}), 400 + + if len(agent_pubkey_der) < 10: + log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint) + return jsonify({"error": "invalid_agent_pubkey"}), 400 + + try: + client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True) + except Exception: + log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint) + return jsonify({"error": "invalid_client_nonce"}), 400 + if len(client_nonce_bytes) < 16: + log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint) + return jsonify({"error": "invalid_client_nonce"}), 400 + + fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der) + rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0, context_hint) + if rate_error: + return rate_error + + conn = db_conn_factory() + try: + cur = conn.cursor() + install_code = _load_install_code(cur, enrollment_code) + valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur) + if not valid_code: + log( + "server", + "enrollment request invalid_code " + f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}", + context_hint, + ) + return jsonify({"error": "invalid_enrollment_code"}), 400 + + approval_reference: str + record_id: str + server_nonce_bytes = secrets.token_bytes(32) + server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii") + now = _iso(_now()) + + cur.execute( + """ + SELECT id, approval_reference + FROM device_approvals + WHERE ssl_key_fingerprint_claimed = ? + AND status = 'pending' + """, + (fingerprint,), + ) + existing = cur.fetchone() + if existing: + record_id = existing[0] + approval_reference = existing[1] + cur.execute( + """ + UPDATE device_approvals + SET hostname_claimed = ?, + guid = ?, + enrollment_code_id = ?, + client_nonce = ?, + server_nonce = ?, + agent_pubkey_der = ?, + updated_at = ? + WHERE id = ? + """, + ( + hostname, + reuse_guid, + install_code["id"], + client_nonce_b64, + server_nonce_b64, + agent_pubkey_der, + now, + record_id, + ), + ) + else: + record_id = str(uuid.uuid4()) + approval_reference = str(uuid.uuid4()) + cur.execute( + """ + INSERT INTO device_approvals ( + id, approval_reference, guid, hostname_claimed, + ssl_key_fingerprint_claimed, enrollment_code_id, + status, client_nonce, server_nonce, agent_pubkey_der, + created_at, updated_at + ) + VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?) + """, + ( + record_id, + approval_reference, + reuse_guid, + hostname, + fingerprint, + install_code["id"], + client_nonce_b64, + server_nonce_b64, + agent_pubkey_der, + now, + now, + ), + ) + + conn.commit() + finally: + conn.close() + + response = { + "status": "pending", + "approval_reference": approval_reference, + "server_nonce": server_nonce_b64, + "poll_after_ms": 3000, + "server_certificate": _load_tls_bundle(tls_bundle_path), + "signing_key": _signing_key_b64(), + } + log( + "server", + f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}", + context_hint, + ) + return jsonify(response) + + @blueprint.route("/api/agent/enroll/poll", methods=["POST"]) + def enrollment_poll(): + payload = request.get_json(force=True, silent=True) or {} + approval_reference = payload.get("approval_reference") + client_nonce_b64 = payload.get("client_nonce") + proof_sig_b64 = payload.get("proof_sig") + context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER)) + + log( + "server", + "enrollment poll received " + f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}" + f" proof_sig_len={len(proof_sig_b64 or '')}", + context_hint, + ) + + if not isinstance(approval_reference, str) or not approval_reference: + log("server", "enrollment poll rejected missing_reference", context_hint) + return jsonify({"error": "approval_reference_required"}), 400 + if not isinstance(client_nonce_b64, str): + log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint) + return jsonify({"error": "client_nonce_required"}), 400 + if not isinstance(proof_sig_b64, str): + log("server", f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint) + return jsonify({"error": "proof_sig_required"}), 400 + + try: + client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True) + except Exception: + log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint) + return jsonify({"error": "invalid_client_nonce"}), 400 + + try: + proof_sig = base64.b64decode(proof_sig_b64, validate=True) + except Exception: + log("server", f"enrollment poll invalid_sig ref={approval_reference}", context_hint) + return jsonify({"error": "invalid_proof_sig"}), 400 + + conn = db_conn_factory() + try: + cur = conn.cursor() + cur.execute( + """ + SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed, + enrollment_code_id, status, client_nonce, server_nonce, + agent_pubkey_der, created_at, updated_at, approved_by_user_id + FROM device_approvals + WHERE approval_reference = ? + """, + (approval_reference,), + ) + row = cur.fetchone() + if not row: + log("server", f"enrollment poll unknown_reference ref={approval_reference}", context_hint) + return jsonify({"status": "unknown"}), 404 + + ( + record_id, + guid, + hostname_claimed, + fingerprint, + enrollment_code_id, + status, + client_nonce_stored, + server_nonce_b64, + agent_pubkey_der, + created_at, + updated_at, + approved_by, + ) = row + + if client_nonce_stored != client_nonce_b64: + log("server", f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint) + return jsonify({"error": "nonce_mismatch"}), 400 + + try: + server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True) + except Exception: + log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint) + return jsonify({"error": "server_nonce_invalid"}), 400 + + message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes + + try: + public_key = serialization.load_der_public_key(agent_pubkey_der) + except Exception: + log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint) + public_key = None + + if public_key is None: + log("server", f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint) + return jsonify({"error": "agent_pubkey_invalid"}), 400 + + try: + public_key.verify(proof_sig, message) + except Exception: + log("server", f"enrollment poll invalid_proof ref={approval_reference}", context_hint) + return jsonify({"error": "invalid_proof"}), 400 + + if status == "pending": + log( + "server", + f"enrollment poll pending ref={approval_reference} host={hostname_claimed}" + f" fingerprint={fingerprint[:12]}", + context_hint, + ) + return jsonify({"status": "pending", "poll_after_ms": 5000}) + if status == "denied": + log( + "server", + f"enrollment poll denied ref={approval_reference} host={hostname_claimed}", + context_hint, + ) + return jsonify({"status": "denied", "reason": "operator_denied"}) + if status == "expired": + log( + "server", + f"enrollment poll expired ref={approval_reference} host={hostname_claimed}", + context_hint, + ) + return jsonify({"status": "expired"}) + if status == "completed": + log( + "server", + f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}", + context_hint, + ) + return jsonify({"status": "approved", "detail": "finalized"}) + + if status != "approved": + log( + "server", + f"enrollment poll unexpected_status={status} ref={approval_reference}", + context_hint, + ) + return jsonify({"status": status or "unknown"}), 400 + + nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}" + if not nonce_cache.consume(nonce_key): + log( + "server", + f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}", + context_hint, + ) + return jsonify({"error": "proof_replayed"}), 409 + + # Finalize enrollment + effective_guid = normalize_guid(guid) if guid else normalize_guid(str(uuid.uuid4())) + now_iso = _iso(_now()) + + device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint) + _store_device_key(cur, effective_guid, fingerprint) + + # Mark install code used + if enrollment_code_id: + cur.execute( + "SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?", + (enrollment_code_id,), + ) + usage_row = cur.fetchone() + try: + prior_count = int(usage_row[0]) if usage_row else 0 + except Exception: + prior_count = 0 + try: + allowed_uses = int(usage_row[1]) if usage_row else 1 + except Exception: + allowed_uses = 1 + if allowed_uses < 1: + allowed_uses = 1 + new_count = prior_count + 1 + consumed = new_count >= allowed_uses + cur.execute( + """ + UPDATE enrollment_install_codes + SET use_count = ?, + used_by_guid = ?, + last_used_at = ?, + used_at = CASE WHEN ? THEN ? ELSE used_at END + WHERE id = ? + """, + ( + new_count, + effective_guid, + now_iso, + 1 if consumed else 0, + now_iso, + enrollment_code_id, + ), + ) + + # Update approval record with final state + cur.execute( + """ + UPDATE device_approvals + SET guid = ?, + status = 'completed', + updated_at = ? + WHERE id = ? + """, + (effective_guid, now_iso, record_id), + ) + + refresh_info = _issue_refresh_token(cur, effective_guid) + access_token = jwt_service.issue_access_token( + effective_guid, + fingerprint, + device_record.get("token_version") or 1, + ) + + conn.commit() + finally: + conn.close() + + log( + "server", + f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}", + context_hint, + ) + return jsonify( + { + "status": "approved", + "guid": effective_guid, + "access_token": access_token, + "expires_in": 900, + "refresh_token": refresh_info["token"], + "token_type": "Bearer", + "server_certificate": _load_tls_bundle(tls_bundle_path), + "signing_key": _signing_key_b64(), + } + ) + + app.register_blueprint(blueprint) + + +def _load_tls_bundle(path: str) -> str: + try: + with open(path, "r", encoding="utf-8") as fh: + return fh.read() + except Exception: + return "" + + +def _mask_code(code: str) -> str: + if not code: + return "" + trimmed = str(code).strip() + if len(trimmed) <= 6: + return "***" + return f"{trimmed[:3]}***{trimmed[-3:]}" + diff --git a/Data/Engine/services/API/tokens/__init__.py b/Data/Engine/services/API/tokens/__init__.py new file mode 100644 index 00000000..0000e30e --- /dev/null +++ b/Data/Engine/services/API/tokens/__init__.py @@ -0,0 +1,12 @@ +# ====================================================== +# Data\Engine\services\API\tokens\__init__.py +# Description: Token management API registration helpers for the Engine runtime. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""Expose Engine-native token management routes.""" + +from .routes import register + +__all__ = ["register"] diff --git a/Data/Engine/services/API/tokens/routes.py b/Data/Engine/services/API/tokens/routes.py new file mode 100644 index 00000000..784093b1 --- /dev/null +++ b/Data/Engine/services/API/tokens/routes.py @@ -0,0 +1,147 @@ +# ====================================================== +# Data\Engine\services\API\tokens\routes.py +# Description: Engine-native refresh token endpoints decoupled from legacy server modules. +# +# API Endpoints (if applicable): +# - POST /api/agent/token/refresh (Authenticated via refresh token) - Issues a new access token. +# ====================================================== + +"""Token management routes backed by the Engine authentication stack.""" + +from __future__ import annotations + +import hashlib +import sqlite3 +from datetime import datetime, timezone +from typing import Callable + +from flask import Blueprint, current_app, jsonify, request + +from ....auth.dpop import DPoPReplayError, DPoPValidator, DPoPVerificationError + + +def register( + app, + *, + db_conn_factory: Callable[[], sqlite3.Connection], + jwt_service, + dpop_validator: DPoPValidator, +) -> None: + blueprint = Blueprint("tokens", __name__) + + def _hash_token(token: str) -> str: + return hashlib.sha256(token.encode("utf-8")).hexdigest() + + def _iso_now() -> str: + return datetime.now(tz=timezone.utc).isoformat() + + def _parse_iso(ts: str) -> datetime: + return datetime.fromisoformat(ts) + + @blueprint.route("/api/agent/token/refresh", methods=["POST"]) + def refresh(): + payload = request.get_json(force=True, silent=True) or {} + guid = str(payload.get("guid") or "").strip() + refresh_token = str(payload.get("refresh_token") or "").strip() + + if not guid or not refresh_token: + return jsonify({"error": "invalid_request"}), 400 + + conn = db_conn_factory() + try: + cur = conn.cursor() + cur.execute( + """ + SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at + FROM refresh_tokens + WHERE guid = ? + AND token_hash = ? + """, + (guid, _hash_token(refresh_token)), + ) + row = cur.fetchone() + if not row: + return jsonify({"error": "invalid_refresh_token"}), 401 + + record_id, row_guid, _token_hash, stored_jkt, created_at, expires_at, revoked_at = row + if row_guid != guid: + return jsonify({"error": "invalid_refresh_token"}), 401 + if revoked_at: + return jsonify({"error": "refresh_token_revoked"}), 401 + if expires_at: + try: + if _parse_iso(expires_at) <= datetime.now(tz=timezone.utc): + return jsonify({"error": "refresh_token_expired"}), 401 + except Exception: + pass + + cur.execute( + """ + SELECT guid, ssl_key_fingerprint, token_version, status + FROM devices + WHERE guid = ? + """, + (guid,), + ) + device_row = cur.fetchone() + if not device_row: + return jsonify({"error": "device_not_found"}), 404 + + device_guid, fingerprint, token_version, status = device_row + status_norm = (status or "active").strip().lower() + if status_norm in {"revoked", "decommissioned"}: + return jsonify({"error": "device_revoked"}), 403 + + dpop_proof = request.headers.get("DPoP") + jkt = stored_jkt or "" + if dpop_proof: + try: + jkt = dpop_validator.verify(request.method, request.url, dpop_proof, access_token=None) + except DPoPReplayError: + return jsonify({"error": "dpop_replayed"}), 400 + except DPoPVerificationError: + return jsonify({"error": "dpop_invalid"}), 400 + elif stored_jkt: + try: + current_app.logger.warning( + "Clearing stored DPoP binding for guid=%s due to missing proof", + guid, + ) + except Exception: + pass + cur.execute( + "UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?", + (record_id,), + ) + + new_access_token = jwt_service.issue_access_token( + guid, + fingerprint or "", + token_version or 1, + ) + + cur.execute( + """ + UPDATE refresh_tokens + SET last_used_at = ?, + dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt) + WHERE id = ? + """, + (_iso_now(), jkt, record_id), + ) + conn.commit() + finally: + conn.close() + + return jsonify( + { + "access_token": new_access_token, + "expires_in": 900, + "token_type": "Bearer", + } + ) + + app.register_blueprint(blueprint) + + +__all__ = ["register"]