ENGINE: Migrated Enrollment Logic

This commit is contained in:
2025-10-29 16:40:53 -06:00
parent 8fa7bd4fb0
commit 833c4b7d88
23 changed files with 1881 additions and 44 deletions

View File

@@ -38,6 +38,8 @@ Lastly, everytime that you complete a stage, you will create a pull request name
- [ ] Add migration switch in the legacy server for WebUI delegation.
- [x] Extend tests to cover critical WebUI routes.
- [ ] Port device API endpoints into Engine services (device + admin coverage in progress).
- [x] Move authentication/token stack onto Engine services without legacy fallbacks.
- [x] Port enrollment request/poll flows to Engine services and drop legacy imports.
- [ ] **Stage 7 — Plan WebSocket migration**
- [ ] Extract Socket.IO handlers into Data/Engine/services/WebSocket.
- [ ] Provide register_realtime hook for the Engine factory.
@@ -46,4 +48,4 @@ Lastly, everytime that you complete a stage, you will create a pull request name
## Current Status
- **Stage:** Stage 6 — Plan WebUI migration
- **Active Task:** Migrating device endpoints into the Engine API (legacy bridge removed).
- **Active Task:** Continue Stage 6 device/admin API migration (focus on remaining device and admin endpoints now that auth, token, and enrollment paths are Engine-native).

View File

@@ -17,7 +17,7 @@ from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from flask.testing import FlaskClient
from Modules.crypto import keys as crypto_keys
from Data.Engine.crypto import keys as crypto_keys
from .conftest import EngineTestHarness

View File

@@ -0,0 +1,27 @@
# ======================================================
# Data\Engine\auth\__init__.py
# Description: Engine-native authentication utilities and helpers decoupled from the legacy server modules.
#
# API Endpoints (if applicable): None
# ======================================================
"""Authentication utility package for the Borealis Engine."""
from .jwt_service import JWTService, load_service
from .dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
from .rate_limit import SlidingWindowRateLimiter, RateLimitDecision
from .device_auth import DeviceAuthManager, DeviceAuthError, DeviceAuthContext, require_device_auth
__all__ = [
"JWTService",
"load_service",
"DPoPValidator",
"DPoPVerificationError",
"DPoPReplayError",
"SlidingWindowRateLimiter",
"RateLimitDecision",
"DeviceAuthManager",
"DeviceAuthError",
"DeviceAuthContext",
"require_device_auth",
]

View File

@@ -0,0 +1,310 @@
# ======================================================
# Data\Engine\auth\device_auth.py
# Description: Engine-native device authentication manager and decorators.
#
# API Endpoints (if applicable): None
# ======================================================
"""Device authentication helpers for the Borealis Engine runtime."""
from __future__ import annotations
import functools
import sqlite3
import time
from contextlib import closing
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Optional
import jwt
from flask import g, jsonify, request
from .dpop import DPoPReplayError, DPoPValidator, DPoPVerificationError
from .guid_utils import normalize_guid
from .rate_limit import SlidingWindowRateLimiter
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
@dataclass
class DeviceAuthContext:
guid: str
ssl_key_fingerprint: str
token_version: int
access_token: str
claims: Dict[str, Any]
dpop_jkt: Optional[str]
status: str
service_mode: Optional[str]
class DeviceAuthError(Exception):
status_code = 401
error_code = "unauthorized"
def __init__(
self,
message: str = "unauthorized",
*,
status_code: Optional[int] = None,
retry_after: Optional[float] = None,
):
super().__init__(message)
if status_code is not None:
self.status_code = status_code
self.message = message
self.retry_after = retry_after
class DeviceAuthManager:
def __init__(
self,
*,
db_conn_factory: Callable[[], Any],
jwt_service,
dpop_validator: Optional[DPoPValidator],
log: Callable[[str, str, Optional[str]], None],
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
) -> None:
self._db_conn_factory = db_conn_factory
self._jwt_service = jwt_service
self._dpop_validator = dpop_validator
self._log = log
self._rate_limiter = rate_limiter
def authenticate(self) -> DeviceAuthContext:
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise DeviceAuthError("missing_authorization")
token = auth_header[len("Bearer ") :].strip()
if not token:
raise DeviceAuthError("missing_authorization")
try:
claims = self._jwt_service.decode(token)
except jwt.ExpiredSignatureError:
raise DeviceAuthError("token_expired")
except Exception:
raise DeviceAuthError("invalid_token")
raw_guid = str(claims.get("guid") or "").strip()
guid = normalize_guid(raw_guid)
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
token_version = int(claims.get("token_version") or 0)
if not guid or not fingerprint or token_version <= 0:
raise DeviceAuthError("invalid_claims")
if self._rate_limiter:
decision = self._rate_limiter.check(f"fp:{fingerprint}", 60, 60.0)
if not decision.allowed:
raise DeviceAuthError(
"rate_limited",
status_code=429,
retry_after=decision.retry_after,
)
context_label = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
with closing(self._db_conn_factory()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE UPPER(guid) = ?
""",
(guid,),
)
rows = cur.fetchall()
row = None
for candidate in rows or []:
candidate_guid = normalize_guid(candidate[0])
if candidate_guid == guid:
row = candidate
break
if row is None and rows:
row = rows[0]
if row is None:
row = self._recover_device_record(conn, guid, fingerprint, token_version, context_label)
if row is None:
raise DeviceAuthError("device_not_found", status_code=404)
stored_guid, stored_fingerprint, stored_version, status = row
stored_guid = normalize_guid(stored_guid)
if stored_guid != guid:
raise DeviceAuthError("device_mismatch", status_code=401)
if (stored_fingerprint or "").lower().strip() != fingerprint:
raise DeviceAuthError("fingerprint_mismatch", status_code=403)
if int(stored_version or 0) != token_version:
raise DeviceAuthError("token_version_mismatch", status_code=403)
status_norm = (status or "active").strip().lower()
if status_norm in {"revoked", "decommissioned"}:
raise DeviceAuthError("device_revoked", status_code=403)
dpop_proof = request.headers.get("DPoP")
jkt = None
if dpop_proof and self._dpop_validator:
try:
jkt = self._dpop_validator.verify(request.method, request.url, dpop_proof, access_token=token)
except DPoPReplayError:
raise DeviceAuthError("dpop_replayed", status_code=400)
except DPoPVerificationError:
raise DeviceAuthError("dpop_invalid", status_code=400)
ctx = DeviceAuthContext(
guid=guid,
ssl_key_fingerprint=fingerprint,
token_version=token_version,
access_token=token,
claims=claims,
dpop_jkt=jkt,
status=status_norm,
service_mode=context_label,
)
return ctx
def _recover_device_record(
self,
conn: sqlite3.Connection,
guid: str,
fingerprint: str,
token_version: int,
context_label: Optional[str],
) -> Optional[tuple]:
"""Attempt to recreate a missing device row for an authenticated token."""
guid = normalize_guid(guid)
fingerprint = (fingerprint or "").strip()
if not guid or not fingerprint:
return None
cur = conn.cursor()
now_ts = int(time.time())
try:
now_iso = datetime.now(tz=timezone.utc).isoformat()
except Exception:
now_iso = datetime.utcnow().isoformat()
base_hostname = f"RECOVERED-{guid[:12].upper()}" if guid else "RECOVERED"
for attempt in range(6):
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
try:
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
token_version,
status,
key_added_at
)
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
""",
(
guid,
hostname,
now_ts,
now_ts,
fingerprint,
max(token_version or 1, 1),
now_iso,
),
)
except sqlite3.IntegrityError as exc:
message = str(exc).lower()
if "hostname" in message and "unique" in message:
continue
self._log(
"server",
f"device auth failed to recover guid={guid} due to integrity error: {exc}",
context_label,
)
conn.rollback()
return None
except Exception as exc:
self._log(
"server",
f"device auth unexpected error recovering guid={guid}: {exc}",
context_label,
)
conn.rollback()
return None
else:
conn.commit()
break
else:
self._log(
"server",
f"device auth could not recover guid={guid}; hostname collisions persisted",
context_label,
)
conn.rollback()
return None
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE guid = ?
""",
(guid,),
)
row = cur.fetchone()
if not row:
self._log(
"server",
f"device auth recovery for guid={guid} committed but row still missing",
context_label,
)
return row
def require_device_auth(manager: DeviceAuthManager):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
ctx = manager.authenticate()
except DeviceAuthError as exc:
response = jsonify({"error": exc.message})
response.status_code = exc.status_code
retry_after = getattr(exc, "retry_after", None)
if retry_after:
try:
response.headers["Retry-After"] = str(max(1, int(retry_after)))
except Exception:
response.headers["Retry-After"] = "1"
return response
g.device_auth = ctx
return func(*args, **kwargs)
return wrapper
return decorator
__all__ = [
"AGENT_CONTEXT_HEADER",
"DeviceAuthContext",
"DeviceAuthError",
"DeviceAuthManager",
"require_device_auth",
]

118
Data/Engine/auth/dpop.py Normal file
View File

@@ -0,0 +1,118 @@
# ======================================================
# Data\Engine\auth\dpop.py
# Description: Engine-side DPoP proof validation helpers with replay protection.
#
# API Endpoints (if applicable): None
# ======================================================
"""DPoP proof verification helpers for the Engine runtime."""
from __future__ import annotations
import hashlib
import time
from threading import Lock
from typing import Dict, Optional
import jwt
_DPOP_MAX_SKEW = 300.0 # seconds
class DPoPVerificationError(Exception):
"""Raised when DPoP verification fails for structural reasons."""
class DPoPReplayError(DPoPVerificationError):
"""Raised when a DPoP proof replay is detected."""
class DPoPValidator:
"""Validate DPoP proofs and track observed JTIs to prevent replay attacks."""
def __init__(self) -> None:
self._observed_jti: Dict[str, float] = {}
self._lock = Lock()
def verify(
self,
method: str,
htu: str,
proof: str,
access_token: Optional[str] = None,
) -> str:
"""
Verify the presented DPoP proof and return the JWK thumbprint.
"""
if not proof:
raise DPoPVerificationError("DPoP proof missing")
try:
header = jwt.get_unverified_header(proof)
except Exception as exc:
raise DPoPVerificationError("invalid DPoP header") from exc
jwk = header.get("jwk")
alg = header.get("alg")
if not jwk or not isinstance(jwk, dict):
raise DPoPVerificationError("missing jwk in DPoP header")
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
try:
key = jwt.PyJWK(jwk)
public_key = key.key
except Exception as exc:
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
try:
claims = jwt.decode(
proof,
public_key,
algorithms=[alg],
options={"require": ["htm", "htu", "jti", "iat"]},
)
except Exception as exc:
raise DPoPVerificationError("invalid DPoP signature") from exc
htm = claims.get("htm")
proof_htu = claims.get("htu")
jti = claims.get("jti")
iat = claims.get("iat")
ath = claims.get("ath")
if not isinstance(htm, str) or htm.lower() != method.lower():
raise DPoPVerificationError("DPoP htm mismatch")
if not isinstance(proof_htu, str) or proof_htu != htu:
raise DPoPVerificationError("DPoP htu mismatch")
if not isinstance(jti, str):
raise DPoPVerificationError("DPoP jti missing")
if not isinstance(iat, (int, float)):
raise DPoPVerificationError("DPoP iat missing")
now = time.time()
if abs(now - float(iat)) > _DPOP_MAX_SKEW:
raise DPoPVerificationError("DPoP proof outside allowed skew")
if ath and access_token:
expected_ath = jwt.utils.base64url_encode(
hashlib.sha256(access_token.encode("utf-8")).digest()
).decode("ascii")
if expected_ath != ath:
raise DPoPVerificationError("DPoP ath mismatch")
with self._lock:
expiry = self._observed_jti.get(jti)
if expiry and expiry > now:
raise DPoPReplayError("DPoP proof replay detected")
self._observed_jti[jti] = now + _DPOP_MAX_SKEW
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
for key in stale:
self._observed_jti.pop(key, None)
thumbprint = jwt.PyJWK(jwk).thumbprint()
return thumbprint.decode("ascii")
__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"]

View File

@@ -0,0 +1,39 @@
# ======================================================
# Data\Engine\auth\guid_utils.py
# Description: GUID normalisation helpers used by Engine authentication flows.
#
# API Endpoints (if applicable): None
# ======================================================
"""GUID normalisation helpers for Engine-managed authentication."""
from __future__ import annotations
import string
import uuid
from typing import Optional
def normalize_guid(value: Optional[str]) -> str:
"""
Canonicalise GUID strings so Engine services treat different casings uniformly.
"""
candidate = (value or "").strip()
if not candidate:
return ""
candidate = candidate.strip("{}")
try:
return str(uuid.UUID(candidate)).upper()
except Exception:
cleaned = "".join(ch for ch in candidate if ch in string.hexdigits or ch == "-")
cleaned = cleaned.strip("-")
if cleaned:
try:
return str(uuid.UUID(cleaned)).upper()
except Exception:
pass
return candidate.upper()
__all__ = ["normalize_guid"]

View File

@@ -0,0 +1,206 @@
# ======================================================
# Data\Engine\auth\jwt_service.py
# Description: Engine-native JWT access-token helpers with signing key storage under Engine/Data/Auth_Tokens.
#
# API Endpoints (if applicable): None
# ======================================================
"""JWT access-token helpers backed by an Engine-managed Ed25519 key."""
from __future__ import annotations
import hashlib
import os
import time
from pathlib import Path
from typing import Any, Dict, Optional
import jwt
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from ..security.certificates import project_root_path
_TOKEN_ENV_ROOT = "BOREALIS_ENGINE_AUTH_TOKEN_ROOT"
_LEGACY_SERVER_ROOT_ENV = "BOREALIS_SERVER_ROOT"
_KEY_FILENAME = "borealis-jwt-ed25519.key"
def _env_path(name: str) -> Optional[Path]:
value = os.environ.get(name)
if not value:
return None
try:
return Path(value).expanduser().resolve()
except Exception:
try:
return Path(value).expanduser()
except Exception:
return Path(value)
def _engine_runtime_root() -> Path:
env = _env_path("BOREALIS_ENGINE_ROOT") or _env_path("BOREALIS_ENGINE_RUNTIME")
if env:
env.mkdir(parents=True, exist_ok=True)
return env
root = project_root_path() / "Engine"
root.mkdir(parents=True, exist_ok=True)
return root
def _token_root() -> Path:
env = _env_path(_TOKEN_ENV_ROOT)
if env:
env.mkdir(parents=True, exist_ok=True)
return env
root = _engine_runtime_root() / "Data" / "Auth_Tokens"
root.mkdir(parents=True, exist_ok=True)
return root
def _legacy_key_paths() -> Dict[str, Path]:
project_root = project_root_path()
server_root = _env_path(_LEGACY_SERVER_ROOT_ENV) or (project_root / "Server" / "Borealis")
candidates = {
"auth_keys": server_root / "auth_keys" / _KEY_FILENAME,
"keys": server_root / "keys" / _KEY_FILENAME,
}
return candidates
def _tighten_permissions(path: Path) -> None:
try:
if os.name != "nt":
path.chmod(0o600)
except Exception:
pass
_KEY_DIR = _token_root()
_KEY_FILE = _KEY_DIR / _KEY_FILENAME
class JWTService:
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str):
self._private_key = private_key
self._public_key = private_key.public_key()
self._key_id = key_id
@property
def key_id(self) -> str:
return self._key_id
def issue_access_token(
self,
guid: str,
ssl_key_fingerprint: str,
token_version: int,
expires_in: int = 900,
extra_claims: Optional[Dict[str, Any]] = None,
) -> str:
now = int(time.time())
payload: Dict[str, Any] = {
"sub": f"device:{guid}",
"guid": guid,
"ssl_key_fingerprint": ssl_key_fingerprint,
"token_version": int(token_version),
"iat": now,
"nbf": now,
"exp": now + int(expires_in),
}
if extra_claims:
payload.update(extra_claims)
token = jwt.encode(
payload,
self._private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
),
algorithm="EdDSA",
headers={"kid": self._key_id},
)
return token
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
options = {"require": ["exp", "iat", "sub"]}
public_pem = self._public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return jwt.decode(
token,
public_pem,
algorithms=["EdDSA"],
audience=audience,
options=options,
)
def public_jwk(self) -> Dict[str, Any]:
public_bytes = self._public_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
def load_service() -> JWTService:
private_key = _load_or_create_private_key()
public_bytes = private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
return JWTService(private_key, key_id)
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
_KEY_DIR.mkdir(parents=True, exist_ok=True)
_migrate_legacy_key_if_present()
if _KEY_FILE.exists():
with _KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
private_key = ed25519.Ed25519PrivateKey.generate()
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
with _KEY_FILE.open("wb") as fh:
fh.write(pem)
_tighten_permissions(_KEY_FILE)
return private_key
def _migrate_legacy_key_if_present() -> None:
if _KEY_FILE.exists():
return
legacy_paths = _legacy_key_paths()
for legacy_file in legacy_paths.values():
if not legacy_file.exists():
continue
try:
legacy_bytes = legacy_file.read_bytes()
except Exception:
continue
try:
_KEY_FILE.write_bytes(legacy_bytes)
_tighten_permissions(_KEY_FILE)
except Exception:
continue
try:
legacy_file.unlink()
except Exception:
pass
break
__all__ = ["JWTService", "load_service"]

View File

@@ -0,0 +1,51 @@
# ======================================================
# Data\Engine\auth\rate_limit.py
# Description: Sliding-window rate limiter for Engine authentication endpoints.
#
# API Endpoints (if applicable): None
# ======================================================
"""Engine-native in-memory rate limiter suitable for single-process development."""
from __future__ import annotations
import time
from collections import deque
from dataclasses import dataclass
from threading import Lock
from typing import Deque, Dict
@dataclass
class RateLimitDecision:
allowed: bool
retry_after: float
class SlidingWindowRateLimiter:
"""Simple sliding-window limiter to guard authentication endpoints."""
def __init__(self) -> None:
self._buckets: Dict[str, Deque[float]] = {}
self._lock = Lock()
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
now = time.monotonic()
with self._lock:
bucket = self._buckets.get(key)
if bucket is None:
bucket = deque()
self._buckets[key] = bucket
while bucket and now - bucket[0] > window_seconds:
bucket.popleft()
if len(bucket) >= limit:
retry_after = max(0.0, window_seconds - (now - bucket[0]))
return RateLimitDecision(False, retry_after)
bucket.append(now)
return RateLimitDecision(True, 0.0)
__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"]

View File

@@ -0,0 +1,30 @@
# ======================================================
# Data\Engine\crypto\__init__.py
# Description: Engine cryptographic helpers and key utilities.
#
# API Endpoints (if applicable): None
# ======================================================
"""Cryptographic helper utilities for the Borealis Engine runtime."""
from .keys import (
generate_ed25519_keypair,
normalize_base64,
spki_der_from_base64,
base64_from_spki_der,
fingerprint_from_spki_der,
fingerprint_from_base64_spki,
private_key_to_pem,
public_key_to_pem,
)
__all__ = [
"generate_ed25519_keypair",
"normalize_base64",
"spki_der_from_base64",
"base64_from_spki_der",
"fingerprint_from_spki_der",
"fingerprint_from_base64_spki",
"private_key_to_pem",
"public_key_to_pem",
]

View File

@@ -0,0 +1,88 @@
# ======================================================
# Data\Engine\crypto\keys.py
# Description: Engine-native Ed25519 key helpers and fingerprint utilities.
#
# API Endpoints (if applicable): None
# ======================================================
"""Utility helpers for working with Ed25519 keys and fingerprints."""
from __future__ import annotations
import base64
import hashlib
import re
from typing import Tuple
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from cryptography.hazmat.primitives.serialization import load_der_public_key
def generate_ed25519_keypair() -> Tuple[ed25519.Ed25519PrivateKey, bytes]:
"""
Generate a new Ed25519 keypair.
Returns the private key object and the public key encoded as SubjectPublicKeyInfo DER bytes.
"""
private_key = ed25519.Ed25519PrivateKey.generate()
public_key = private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return private_key, public_key
def normalize_base64(data: str) -> str:
"""
Collapse whitespace and normalise URL-safe encodings so we can reliably decode.
"""
cleaned = re.sub(r"\s+", "", data or "")
return cleaned.replace("-", "+").replace("_", "/")
def spki_der_from_base64(spki_b64: str) -> bytes:
return base64.b64decode(normalize_base64(spki_b64), validate=True)
def base64_from_spki_der(spki_der: bytes) -> str:
return base64.b64encode(spki_der).decode("ascii")
def fingerprint_from_spki_der(spki_der: bytes) -> str:
digest = hashlib.sha256(spki_der).hexdigest()
return digest.lower()
def fingerprint_from_base64_spki(spki_b64: str) -> str:
return fingerprint_from_spki_der(spki_der_from_base64(spki_b64))
def private_key_to_pem(private_key: ed25519.Ed25519PrivateKey) -> bytes:
return private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
def public_key_to_pem(public_spki_der: bytes) -> bytes:
public_key = load_der_public_key(public_spki_der)
return public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
__all__ = [
"generate_ed25519_keypair",
"normalize_base64",
"spki_der_from_base64",
"base64_from_spki_der",
"fingerprint_from_spki_der",
"fingerprint_from_base64_spki",
"private_key_to_pem",
"public_key_to_pem",
]

View File

@@ -0,0 +1,12 @@
# ======================================================
# Data\Engine\enrollment\__init__.py
# Description: Enrollment utilities for Engine-managed device onboarding.
#
# API Endpoints (if applicable): None
# ======================================================
"""Enrollment helper utilities for the Borealis Engine runtime."""
from .nonce_store import NonceCache
__all__ = ["NonceCache"]

View File

@@ -0,0 +1,42 @@
# ======================================================
# Data\Engine\enrollment\nonce_store.py
# Description: Short-lived nonce cache preventing replay during Engine enrollment flows.
#
# API Endpoints (if applicable): None
# ======================================================
"""Short-lived nonce cache to defend against enrollment replay attacks."""
from __future__ import annotations
import time
from threading import Lock
from typing import Dict
class NonceCache:
def __init__(self, ttl_seconds: float = 300.0) -> None:
self._ttl = ttl_seconds
self._entries: Dict[str, float] = {}
self._lock = Lock()
def consume(self, key: str) -> bool:
"""
Attempt to consume the nonce identified by `key`.
Returns True on first use within TTL, False if already consumed.
"""
now = time.monotonic()
with self._lock:
expire_at = self._entries.get(key)
if expire_at and expire_at > now:
return False
self._entries[key] = now + self._ttl
stale = [nonce for nonce, expiry in self._entries.items() if expiry <= now]
for nonce in stale:
self._entries.pop(nonce, None)
return True
__all__ = ["NonceCache"]

View File

@@ -1,6 +1,6 @@
# ======================================================
# Data\Engine\services\API\__init__.py
# Description: Registers Engine API groups and bridges to legacy modules while exposing core utility routes.
# Description: Registers Engine API groups, wiring Engine-native authentication while delegating remaining legacy modules.
#
# API Endpoints (if applicable):
# - GET /health (No Authentication) - Returns an OK status for liveness probing.
@@ -20,15 +20,15 @@ from typing import Any, Callable, Iterable, Mapping, Optional, Sequence
from flask import Blueprint, Flask, jsonify
from Modules.auth import jwt_service as jwt_service_module
from Modules.auth.device_auth import DeviceAuthManager
from Modules.auth.dpop import DPoPValidator
from Modules.auth.rate_limit import SlidingWindowRateLimiter
from ...auth import jwt_service as jwt_service_module
from ...auth.device_auth import DeviceAuthManager
from ...auth.dpop import DPoPValidator
from ...auth.rate_limit import SlidingWindowRateLimiter
from ...database import initialise_engine_database
from ...security import signing
from Modules.enrollment import routes as enrollment_routes
from Modules.enrollment.nonce_store import NonceCache
from Modules.tokens import routes as token_routes
from ...enrollment import NonceCache
from .enrollment import routes as enrollment_routes
from .tokens import routes as token_routes
from ...server import EngineContext
from .access_management.login import register_auth
@@ -137,7 +137,7 @@ def _make_db_conn_factory(database_path: str) -> Callable[[], sqlite3.Connection
@dataclass
class LegacyServiceAdapters:
class EngineServiceAdapters:
context: EngineContext
db_conn_factory: Callable[[], sqlite3.Connection] = field(init=False)
jwt_service: Any = field(init=False)
@@ -180,7 +180,7 @@ class LegacyServiceAdapters:
)
def _register_tokens(app: Flask, adapters: LegacyServiceAdapters) -> None:
def _register_tokens(app: Flask, adapters: EngineServiceAdapters) -> None:
token_routes.register(
app,
db_conn_factory=adapters.db_conn_factory,
@@ -189,7 +189,7 @@ def _register_tokens(app: Flask, adapters: LegacyServiceAdapters) -> None:
)
def _register_enrollment(app: Flask, adapters: LegacyServiceAdapters) -> None:
def _register_enrollment(app: Flask, adapters: EngineServiceAdapters) -> None:
tls_bundle = adapters.context.tls_bundle_path or ""
enrollment_routes.register(
app,
@@ -204,12 +204,12 @@ def _register_enrollment(app: Flask, adapters: LegacyServiceAdapters) -> None:
)
def _register_devices(app: Flask, adapters: LegacyServiceAdapters) -> None:
def _register_devices(app: Flask, adapters: EngineServiceAdapters) -> None:
register_management(app, adapters)
register_admin_endpoints(app, adapters)
_GROUP_REGISTRARS: Mapping[str, Callable[[Flask, LegacyServiceAdapters], None]] = {
_GROUP_REGISTRARS: Mapping[str, Callable[[Flask, EngineServiceAdapters], None]] = {
"auth": register_auth,
"tokens": _register_tokens,
"enrollment": _register_enrollment,
@@ -236,7 +236,7 @@ def register_api(app: Flask, context: EngineContext) -> None:
enabled_groups: Iterable[str] = context.api_groups or DEFAULT_API_GROUPS
normalized = [group.strip().lower() for group in enabled_groups if group]
adapters: Optional[LegacyServiceAdapters] = None
adapters: Optional[EngineServiceAdapters] = None
for group in normalized:
if group == "core":
@@ -244,7 +244,7 @@ def register_api(app: Flask, context: EngineContext) -> None:
continue
if adapters is None:
adapters = LegacyServiceAdapters(context)
adapters = EngineServiceAdapters(context)
registrar = _GROUP_REGISTRARS.get(group)
if registrar is None:
context.logger.info("Engine API group '%s' is not implemented; skipping.", group)

View File

@@ -35,7 +35,7 @@ except Exception: # pragma: no cover - optional dependency
qrcode = None # type: ignore
if TYPE_CHECKING: # pragma: no cover - typing helper
from Data.Engine.services.API import LegacyServiceAdapters
from Data.Engine.services.API import EngineServiceAdapters
def _now_ts() -> int:
@@ -103,7 +103,7 @@ def _user_row_to_dict(row: Sequence[Any]) -> Mapping[str, Any]:
class _AuthService:
def __init__(self, app: Flask, adapters: "LegacyServiceAdapters") -> None:
def __init__(self, app: Flask, adapters: "EngineServiceAdapters") -> None:
self.app = app
self.adapters = adapters
self.context = adapters.context
@@ -393,7 +393,7 @@ class _AuthService:
)
def register_auth(app: Flask, adapters: "LegacyServiceAdapters") -> None:
def register_auth(app: Flask, adapters: "EngineServiceAdapters") -> None:
"""Register authentication endpoints for the Engine."""
service = _AuthService(app, adapters)
@@ -416,3 +416,4 @@ def register_auth(app: Flask, adapters: "LegacyServiceAdapters") -> None:
return service.me()
app.register_blueprint(blueprint)

View File

@@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Opti
from flask import Blueprint, jsonify, request
if TYPE_CHECKING: # pragma: no cover - typing aide
from .. import LegacyServiceAdapters
from .. import EngineServiceAdapters
_ISLAND_DIR_MAP: Mapping[str, str] = {
@@ -49,7 +49,7 @@ _BASE64_CLEANER = re.compile(r"\s+")
class AssemblyManagementService:
"""Implements assembly CRUD helpers for Engine routes."""
def __init__(self, adapters: "LegacyServiceAdapters") -> None:
def __init__(self, adapters: "EngineServiceAdapters") -> None:
self.adapters = adapters
self.logger = adapters.context.logger or logging.getLogger(__name__)
self.service_log = adapters.service_log
@@ -679,7 +679,7 @@ class AssemblyManagementService:
return obj
def register_assemblies(app, adapters: "LegacyServiceAdapters") -> None:
def register_assemblies(app, adapters: "EngineServiceAdapters") -> None:
"""Register assembly CRUD endpoints on the Flask app."""
service = AssemblyManagementService(adapters)
@@ -726,3 +726,4 @@ def register_assemblies(app, adapters: "LegacyServiceAdapters") -> None:
return jsonify(response), status
app.register_blueprint(blueprint)

View File

@@ -35,7 +35,7 @@ except Exception: # pragma: no cover - optional dependency
qrcode = None # type: ignore
if TYPE_CHECKING: # pragma: no cover - typing helper
from . import LegacyServiceAdapters
from . import EngineServiceAdapters
def _now_ts() -> int:
@@ -103,7 +103,7 @@ def _user_row_to_dict(row: Sequence[Any]) -> Mapping[str, Any]:
class _AuthService:
def __init__(self, app: Flask, adapters: "LegacyServiceAdapters") -> None:
def __init__(self, app: Flask, adapters: "EngineServiceAdapters") -> None:
self.app = app
self.adapters = adapters
self.context = adapters.context
@@ -398,7 +398,7 @@ class _AuthService:
)
def register_auth(app: Flask, adapters: "LegacyServiceAdapters") -> None:
def register_auth(app: Flask, adapters: "EngineServiceAdapters") -> None:
"""Register authentication endpoints for the Engine."""
service = _AuthService(app, adapters)
@@ -422,3 +422,4 @@ def register_auth(app: Flask, adapters: "LegacyServiceAdapters") -> None:
app.register_blueprint(blueprint)
adapters.context.logger.info("Engine registered API group 'auth'.")

View File

@@ -24,10 +24,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from flask import Blueprint, jsonify, request, session
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from Modules.guid_utils import normalize_guid
from ....auth.guid_utils import normalize_guid
if TYPE_CHECKING: # pragma: no cover - typing helper
from .. import LegacyServiceAdapters
from .. import EngineServiceAdapters
VALID_TTL_HOURS = {1, 3, 6, 12, 24}
@@ -49,7 +49,7 @@ def _generate_install_code() -> str:
class AdminDeviceService:
"""Utility wrapper for admin device APIs."""
def __init__(self, app, adapters: "LegacyServiceAdapters") -> None:
def __init__(self, app, adapters: "EngineServiceAdapters") -> None:
self.app = app
self.adapters = adapters
self.db_conn_factory = adapters.db_conn_factory
@@ -477,7 +477,7 @@ class AdminDeviceService:
return self._set_approval_status(approval_id, "denied")
def register_admin_endpoints(app, adapters: "LegacyServiceAdapters") -> None:
def register_admin_endpoints(app, adapters: "EngineServiceAdapters") -> None:
"""Register admin enrollment + approval endpoints."""
service = AdminDeviceService(app, adapters)
@@ -532,3 +532,4 @@ def register_admin_endpoints(app, adapters: "LegacyServiceAdapters") -> None:
return jsonify(payload), status
app.register_blueprint(blueprint)

View File

@@ -1,8 +0,0 @@
# ======================================================
# Data\Engine\services\API\devices\enrollment.py
# Description: Placeholder for device enrollment API bridge (not yet implemented).
#
# API Endpoints (if applicable): None
# ======================================================
"Placeholder for API module devices/enrollment.py."

View File

@@ -41,8 +41,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from flask import Blueprint, jsonify, request, session, g
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from Modules.auth.device_auth import require_device_auth
from Modules.guid_utils import normalize_guid
from ....auth.device_auth import require_device_auth
from ....auth.guid_utils import normalize_guid
try:
import requests # type: ignore
@@ -57,7 +57,7 @@ except ImportError: # pragma: no cover - fallback for minimal test environments
requests = _RequestsStub() # type: ignore
if TYPE_CHECKING: # pragma: no cover - typing aide
from .. import LegacyServiceAdapters
from .. import EngineServiceAdapters
def _safe_json(raw: Optional[str], default: Any) -> Any:
@@ -340,7 +340,7 @@ def _device_upsert(
class RepositoryHashCache:
"""Lightweight GitHub head cache with on-disk persistence."""
def __init__(self, adapters: "LegacyServiceAdapters") -> None:
def __init__(self, adapters: "EngineServiceAdapters") -> None:
self._db_conn_factory = adapters.db_conn_factory
self._service_log = adapters.service_log
self._logger = adapters.context.logger
@@ -617,7 +617,7 @@ class DeviceManagementService:
"connection_endpoint",
)
def __init__(self, app, adapters: "LegacyServiceAdapters") -> None:
def __init__(self, app, adapters: "EngineServiceAdapters") -> None:
self.app = app
self.adapters = adapters
self.db_conn_factory = adapters.db_conn_factory
@@ -1513,7 +1513,7 @@ class DeviceManagementService:
finally:
conn.close()
def register_management(app, adapters: "LegacyServiceAdapters") -> None:
def register_management(app, adapters: "EngineServiceAdapters") -> None:
"""Register device management endpoints onto the Flask app."""
service = DeviceManagementService(app, adapters)
@@ -1679,3 +1679,4 @@ def register_management(app, adapters: "LegacyServiceAdapters") -> None:
return jsonify(payload), status
app.register_blueprint(blueprint)

View File

@@ -0,0 +1,12 @@
# ======================================================
# Data\Engine\services\API\enrollment\__init__.py
# Description: Engine enrollment API registration helpers.
#
# API Endpoints (if applicable): None
# ======================================================
"""Expose Engine-native enrollment API routes."""
from .routes import register
__all__ = ["register"]

View File

@@ -0,0 +1,744 @@
# ======================================================
# Data\Engine\services\API\enrollment\routes.py
# Description: Engine-native device enrollment endpoints handling install codes, approvals, and token issuance.
#
# API Endpoints (if applicable):
# - POST /api/agent/enroll/request (No Authentication) - Submits device enrollment requests.
# - POST /api/agent/enroll/poll (No Authentication) - Finalises approved enrollment requests.
# ======================================================
"""Device enrollment routes for the Borealis Engine runtime."""
from __future__ import annotations
import base64
import secrets
import sqlite3
import uuid
from datetime import datetime, timezone, timedelta
import time
from typing import Any, Callable, Dict, Optional, Tuple
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
from flask import Blueprint, jsonify, request
from ....auth.rate_limit import SlidingWindowRateLimiter
from ....crypto import keys as crypto_keys
from ....enrollment.nonce_store import NonceCache
from ....auth.guid_utils import normalize_guid
from cryptography.hazmat.primitives import serialization
def register(
app,
*,
db_conn_factory: Callable[[], sqlite3.Connection],
log: Callable[[str, str, Optional[str]], None],
jwt_service,
tls_bundle_path: str,
ip_rate_limiter: SlidingWindowRateLimiter,
fp_rate_limiter: SlidingWindowRateLimiter,
nonce_cache: NonceCache,
script_signer,
) -> None:
blueprint = Blueprint("enrollment", __name__)
def _now() -> datetime:
return datetime.now(tz=timezone.utc)
def _iso(dt: datetime) -> str:
return dt.isoformat()
def _remote_addr() -> str:
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
addr = request.remote_addr or "unknown"
return addr.strip()
def _signing_key_b64() -> str:
if not script_signer:
return ""
try:
return script_signer.public_base64_spki()
except Exception:
return ""
def _rate_limited(
key: str,
limiter: SlidingWindowRateLimiter,
limit: int,
window_s: float,
context_hint: Optional[str],
):
decision = limiter.check(key, limit, window_s)
if not decision.allowed:
log(
"server",
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
context_hint,
)
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
response.status_code = 429
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
return response
return None
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
cur.execute(
"""
SELECT id,
code,
expires_at,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
WHERE code = ?
""",
(code_value,),
)
row = cur.fetchone()
if not row:
return None
keys = [
"id",
"code",
"expires_at",
"used_at",
"used_by_guid",
"max_uses",
"use_count",
"last_used_at",
]
record = dict(zip(keys, row))
return record
def _install_code_valid(
record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor
) -> Tuple[bool, Optional[str]]:
if not record:
return False, None
expires_at = record.get("expires_at")
if not isinstance(expires_at, str):
return False, None
try:
expiry = datetime.fromisoformat(expires_at)
except Exception:
return False, None
if expiry <= _now():
return False, None
try:
max_uses = int(record.get("max_uses") or 1)
except Exception:
max_uses = 1
if max_uses < 1:
max_uses = 1
try:
use_count = int(record.get("use_count") or 0)
except Exception:
use_count = 0
if use_count < max_uses:
return True, None
guid = normalize_guid(record.get("used_by_guid"))
if not guid:
return False, None
cur.execute(
"SELECT ssl_key_fingerprint FROM devices WHERE UPPER(guid) = ?",
(guid,),
)
row = cur.fetchone()
if not row:
return False, None
stored_fp = (row[0] or "").strip().lower()
if not stored_fp:
return False, None
if stored_fp == (fingerprint or "").strip().lower():
return True, guid
return False, None
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
guid_norm = normalize_guid(guid)
base = (hostname or "").strip() or guid_norm
base = base[:253]
candidate = base
suffix = 1
while True:
cur.execute(
"SELECT guid FROM devices WHERE hostname = ?",
(candidate,),
)
row = cur.fetchone()
if not row:
return candidate
existing_guid = normalize_guid(row[0])
if existing_guid == guid_norm:
return candidate
candidate = f"{base}-{suffix}"
suffix += 1
if suffix > 50:
return guid_norm
def _store_device_key(cur: sqlite3.Cursor, guid: str, fingerprint: str) -> None:
guid_norm = normalize_guid(guid)
added_at = _iso(_now())
cur.execute(
"""
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
VALUES (?, ?, ?, ?)
""",
(str(uuid.uuid4()), guid_norm, fingerprint, added_at),
)
cur.execute(
"""
UPDATE device_keys
SET retired_at = ?
WHERE guid = ?
AND ssl_key_fingerprint != ?
AND retired_at IS NULL
""",
(_iso(_now()), guid_norm, fingerprint),
)
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
guid_norm = normalize_guid(guid)
cur.execute(
"""
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
FROM devices
WHERE UPPER(guid) = ?
""",
(guid_norm,),
)
row = cur.fetchone()
if row:
keys = [
"guid",
"hostname",
"token_version",
"status",
"ssl_key_fingerprint",
"key_added_at",
]
record = dict(zip(keys, row))
record["guid"] = normalize_guid(record.get("guid"))
stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower()
new_fp = (fingerprint or "").strip().lower()
if not stored_fp and new_fp:
cur.execute(
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
(fingerprint, _iso(_now()), record["guid"]),
)
record["ssl_key_fingerprint"] = fingerprint
elif new_fp and stored_fp != new_fp:
now_iso = _iso(_now())
try:
current_version = int(record.get("token_version") or 1)
except Exception:
current_version = 1
new_version = max(current_version + 1, 1)
cur.execute(
"""
UPDATE devices
SET ssl_key_fingerprint = ?,
key_added_at = ?,
token_version = ?,
status = 'active'
WHERE guid = ?
""",
(fingerprint, now_iso, new_version, record["guid"]),
)
cur.execute(
"""
UPDATE refresh_tokens
SET revoked_at = ?
WHERE guid = ?
AND revoked_at IS NULL
""",
(now_iso, record["guid"]),
)
record["ssl_key_fingerprint"] = fingerprint
record["token_version"] = new_version
record["status"] = "active"
record["key_added_at"] = now_iso
return record
resolved_hostname = _normalize_host(hostname, guid_norm, cur)
created_at = int(time.time())
key_added_at = _iso(_now())
cur.execute(
"""
INSERT INTO devices (
guid, hostname, created_at, last_seen, ssl_key_fingerprint,
token_version, status, key_added_at
)
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
""",
(
guid_norm,
resolved_hostname,
created_at,
created_at,
fingerprint,
key_added_at,
),
)
return {
"guid": guid_norm,
"hostname": resolved_hostname,
"token_version": 1,
"status": "active",
"ssl_key_fingerprint": fingerprint,
"key_added_at": key_added_at,
}
def _hash_refresh_token(token: str) -> str:
import hashlib
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _issue_refresh_token(cur: sqlite3.Cursor, guid: str) -> Dict[str, Any]:
token = secrets.token_urlsafe(48)
now = _now()
expires_at = now.replace(microsecond=0) + timedelta(days=30)
cur.execute(
"""
INSERT INTO refresh_tokens (id, guid, token_hash, created_at, expires_at)
VALUES (?, ?, ?, ?, ?)
""",
(
str(uuid.uuid4()),
guid,
_hash_refresh_token(token),
_iso(now),
_iso(expires_at),
),
)
return {"token": token, "expires_at": expires_at}
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
def enrollment_request():
remote = _remote_addr()
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0, context_hint)
if rate_error:
return rate_error
payload = request.get_json(force=True, silent=True) or {}
hostname = str(payload.get("hostname") or "").strip()
enrollment_code = str(payload.get("enrollment_code") or "").strip()
agent_pubkey_b64 = payload.get("agent_pubkey")
client_nonce_b64 = payload.get("client_nonce")
log(
"server",
"enrollment request received "
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
context_hint,
)
if not hostname:
log("server", f"enrollment rejected missing_hostname ip={remote}", context_hint)
return jsonify({"error": "hostname_required"}), 400
if not enrollment_code:
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
return jsonify({"error": "enrollment_code_required"}), 400
if not isinstance(agent_pubkey_b64, str):
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "agent_pubkey_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400
try:
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
except Exception:
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400
if len(agent_pubkey_der) < 10:
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
if len(client_nonce_bytes) < 16:
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0, context_hint)
if rate_error:
return rate_error
conn = db_conn_factory()
try:
cur = conn.cursor()
install_code = _load_install_code(cur, enrollment_code)
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
if not valid_code:
log(
"server",
"enrollment request invalid_code "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
)
return jsonify({"error": "invalid_enrollment_code"}), 400
approval_reference: str
record_id: str
server_nonce_bytes = secrets.token_bytes(32)
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
now = _iso(_now())
cur.execute(
"""
SELECT id, approval_reference
FROM device_approvals
WHERE ssl_key_fingerprint_claimed = ?
AND status = 'pending'
""",
(fingerprint,),
)
existing = cur.fetchone()
if existing:
record_id = existing[0]
approval_reference = existing[1]
cur.execute(
"""
UPDATE device_approvals
SET hostname_claimed = ?,
guid = ?,
enrollment_code_id = ?,
client_nonce = ?,
server_nonce = ?,
agent_pubkey_der = ?,
updated_at = ?
WHERE id = ?
""",
(
hostname,
reuse_guid,
install_code["id"],
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
now,
record_id,
),
)
else:
record_id = str(uuid.uuid4())
approval_reference = str(uuid.uuid4())
cur.execute(
"""
INSERT INTO device_approvals (
id, approval_reference, guid, hostname_claimed,
ssl_key_fingerprint_claimed, enrollment_code_id,
status, client_nonce, server_nonce, agent_pubkey_der,
created_at, updated_at
)
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
""",
(
record_id,
approval_reference,
reuse_guid,
hostname,
fingerprint,
install_code["id"],
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
now,
now,
),
)
conn.commit()
finally:
conn.close()
response = {
"status": "pending",
"approval_reference": approval_reference,
"server_nonce": server_nonce_b64,
"poll_after_ms": 3000,
"server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(),
}
log(
"server",
f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}",
context_hint,
)
return jsonify(response)
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
def enrollment_poll():
payload = request.get_json(force=True, silent=True) or {}
approval_reference = payload.get("approval_reference")
client_nonce_b64 = payload.get("client_nonce")
proof_sig_b64 = payload.get("proof_sig")
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
log(
"server",
"enrollment poll received "
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
f" proof_sig_len={len(proof_sig_b64 or '')}",
context_hint,
)
if not isinstance(approval_reference, str) or not approval_reference:
log("server", "enrollment poll rejected missing_reference", context_hint)
return jsonify({"error": "approval_reference_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400
if not isinstance(proof_sig_b64, str):
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
return jsonify({"error": "proof_sig_required"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
try:
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof_sig"}), 400
conn = db_conn_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed,
enrollment_code_id, status, client_nonce, server_nonce,
agent_pubkey_der, created_at, updated_at, approved_by_user_id
FROM device_approvals
WHERE approval_reference = ?
""",
(approval_reference,),
)
row = cur.fetchone()
if not row:
log("server", f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
return jsonify({"status": "unknown"}), 404
(
record_id,
guid,
hostname_claimed,
fingerprint,
enrollment_code_id,
status,
client_nonce_stored,
server_nonce_b64,
agent_pubkey_der,
created_at,
updated_at,
approved_by,
) = row
if client_nonce_stored != client_nonce_b64:
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
return jsonify({"error": "nonce_mismatch"}), 400
try:
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "server_nonce_invalid"}), 400
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
try:
public_key = serialization.load_der_public_key(agent_pubkey_der)
except Exception:
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
public_key = None
if public_key is None:
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
return jsonify({"error": "agent_pubkey_invalid"}), 400
try:
public_key.verify(proof_sig, message)
except Exception:
log("server", f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof"}), 400
if status == "pending":
log(
"server",
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
f" fingerprint={fingerprint[:12]}",
context_hint,
)
return jsonify({"status": "pending", "poll_after_ms": 5000})
if status == "denied":
log(
"server",
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "denied", "reason": "operator_denied"})
if status == "expired":
log(
"server",
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "expired"})
if status == "completed":
log(
"server",
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "approved", "detail": "finalized"})
if status != "approved":
log(
"server",
f"enrollment poll unexpected_status={status} ref={approval_reference}",
context_hint,
)
return jsonify({"status": status or "unknown"}), 400
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
if not nonce_cache.consume(nonce_key):
log(
"server",
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
context_hint,
)
return jsonify({"error": "proof_replayed"}), 409
# Finalize enrollment
effective_guid = normalize_guid(guid) if guid else normalize_guid(str(uuid.uuid4()))
now_iso = _iso(_now())
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
_store_device_key(cur, effective_guid, fingerprint)
# Mark install code used
if enrollment_code_id:
cur.execute(
"SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?",
(enrollment_code_id,),
)
usage_row = cur.fetchone()
try:
prior_count = int(usage_row[0]) if usage_row else 0
except Exception:
prior_count = 0
try:
allowed_uses = int(usage_row[1]) if usage_row else 1
except Exception:
allowed_uses = 1
if allowed_uses < 1:
allowed_uses = 1
new_count = prior_count + 1
consumed = new_count >= allowed_uses
cur.execute(
"""
UPDATE enrollment_install_codes
SET use_count = ?,
used_by_guid = ?,
last_used_at = ?,
used_at = CASE WHEN ? THEN ? ELSE used_at END
WHERE id = ?
""",
(
new_count,
effective_guid,
now_iso,
1 if consumed else 0,
now_iso,
enrollment_code_id,
),
)
# Update approval record with final state
cur.execute(
"""
UPDATE device_approvals
SET guid = ?,
status = 'completed',
updated_at = ?
WHERE id = ?
""",
(effective_guid, now_iso, record_id),
)
refresh_info = _issue_refresh_token(cur, effective_guid)
access_token = jwt_service.issue_access_token(
effective_guid,
fingerprint,
device_record.get("token_version") or 1,
)
conn.commit()
finally:
conn.close()
log(
"server",
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
context_hint,
)
return jsonify(
{
"status": "approved",
"guid": effective_guid,
"access_token": access_token,
"expires_in": 900,
"refresh_token": refresh_info["token"],
"token_type": "Bearer",
"server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(),
}
)
app.register_blueprint(blueprint)
def _load_tls_bundle(path: str) -> str:
try:
with open(path, "r", encoding="utf-8") as fh:
return fh.read()
except Exception:
return ""
def _mask_code(code: str) -> str:
if not code:
return "<missing>"
trimmed = str(code).strip()
if len(trimmed) <= 6:
return "***"
return f"{trimmed[:3]}***{trimmed[-3:]}"

View File

@@ -0,0 +1,12 @@
# ======================================================
# Data\Engine\services\API\tokens\__init__.py
# Description: Token management API registration helpers for the Engine runtime.
#
# API Endpoints (if applicable): None
# ======================================================
"""Expose Engine-native token management routes."""
from .routes import register
__all__ = ["register"]

View File

@@ -0,0 +1,147 @@
# ======================================================
# Data\Engine\services\API\tokens\routes.py
# Description: Engine-native refresh token endpoints decoupled from legacy server modules.
#
# API Endpoints (if applicable):
# - POST /api/agent/token/refresh (Authenticated via refresh token) - Issues a new access token.
# ======================================================
"""Token management routes backed by the Engine authentication stack."""
from __future__ import annotations
import hashlib
import sqlite3
from datetime import datetime, timezone
from typing import Callable
from flask import Blueprint, current_app, jsonify, request
from ....auth.dpop import DPoPReplayError, DPoPValidator, DPoPVerificationError
def register(
app,
*,
db_conn_factory: Callable[[], sqlite3.Connection],
jwt_service,
dpop_validator: DPoPValidator,
) -> None:
blueprint = Blueprint("tokens", __name__)
def _hash_token(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _iso_now() -> str:
return datetime.now(tz=timezone.utc).isoformat()
def _parse_iso(ts: str) -> datetime:
return datetime.fromisoformat(ts)
@blueprint.route("/api/agent/token/refresh", methods=["POST"])
def refresh():
payload = request.get_json(force=True, silent=True) or {}
guid = str(payload.get("guid") or "").strip()
refresh_token = str(payload.get("refresh_token") or "").strip()
if not guid or not refresh_token:
return jsonify({"error": "invalid_request"}), 400
conn = db_conn_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at
FROM refresh_tokens
WHERE guid = ?
AND token_hash = ?
""",
(guid, _hash_token(refresh_token)),
)
row = cur.fetchone()
if not row:
return jsonify({"error": "invalid_refresh_token"}), 401
record_id, row_guid, _token_hash, stored_jkt, created_at, expires_at, revoked_at = row
if row_guid != guid:
return jsonify({"error": "invalid_refresh_token"}), 401
if revoked_at:
return jsonify({"error": "refresh_token_revoked"}), 401
if expires_at:
try:
if _parse_iso(expires_at) <= datetime.now(tz=timezone.utc):
return jsonify({"error": "refresh_token_expired"}), 401
except Exception:
pass
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE guid = ?
""",
(guid,),
)
device_row = cur.fetchone()
if not device_row:
return jsonify({"error": "device_not_found"}), 404
device_guid, fingerprint, token_version, status = device_row
status_norm = (status or "active").strip().lower()
if status_norm in {"revoked", "decommissioned"}:
return jsonify({"error": "device_revoked"}), 403
dpop_proof = request.headers.get("DPoP")
jkt = stored_jkt or ""
if dpop_proof:
try:
jkt = dpop_validator.verify(request.method, request.url, dpop_proof, access_token=None)
except DPoPReplayError:
return jsonify({"error": "dpop_replayed"}), 400
except DPoPVerificationError:
return jsonify({"error": "dpop_invalid"}), 400
elif stored_jkt:
try:
current_app.logger.warning(
"Clearing stored DPoP binding for guid=%s due to missing proof",
guid,
)
except Exception:
pass
cur.execute(
"UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?",
(record_id,),
)
new_access_token = jwt_service.issue_access_token(
guid,
fingerprint or "",
token_version or 1,
)
cur.execute(
"""
UPDATE refresh_tokens
SET last_used_at = ?,
dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt)
WHERE id = ?
""",
(_iso_now(), jkt, record_id),
)
conn.commit()
finally:
conn.close()
return jsonify(
{
"access_token": new_access_token,
"expires_in": 900,
"token_type": "Bearer",
}
)
app.register_blueprint(blueprint)
__all__ = ["register"]