ENGINE: Migrated Enrollment Logic

This commit is contained in:
2025-10-29 16:40:53 -06:00
parent 8fa7bd4fb0
commit 833c4b7d88
23 changed files with 1881 additions and 44 deletions

View File

@@ -0,0 +1,27 @@
# ======================================================
# Data\Engine\auth\__init__.py
# Description: Engine-native authentication utilities and helpers decoupled from the legacy server modules.
#
# API Endpoints (if applicable): None
# ======================================================
"""Authentication utility package for the Borealis Engine."""
from .jwt_service import JWTService, load_service
from .dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
from .rate_limit import SlidingWindowRateLimiter, RateLimitDecision
from .device_auth import DeviceAuthManager, DeviceAuthError, DeviceAuthContext, require_device_auth
__all__ = [
"JWTService",
"load_service",
"DPoPValidator",
"DPoPVerificationError",
"DPoPReplayError",
"SlidingWindowRateLimiter",
"RateLimitDecision",
"DeviceAuthManager",
"DeviceAuthError",
"DeviceAuthContext",
"require_device_auth",
]

View File

@@ -0,0 +1,310 @@
# ======================================================
# Data\Engine\auth\device_auth.py
# Description: Engine-native device authentication manager and decorators.
#
# API Endpoints (if applicable): None
# ======================================================
"""Device authentication helpers for the Borealis Engine runtime."""
from __future__ import annotations
import functools
import sqlite3
import time
from contextlib import closing
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Optional
import jwt
from flask import g, jsonify, request
from .dpop import DPoPReplayError, DPoPValidator, DPoPVerificationError
from .guid_utils import normalize_guid
from .rate_limit import SlidingWindowRateLimiter
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
@dataclass
class DeviceAuthContext:
guid: str
ssl_key_fingerprint: str
token_version: int
access_token: str
claims: Dict[str, Any]
dpop_jkt: Optional[str]
status: str
service_mode: Optional[str]
class DeviceAuthError(Exception):
status_code = 401
error_code = "unauthorized"
def __init__(
self,
message: str = "unauthorized",
*,
status_code: Optional[int] = None,
retry_after: Optional[float] = None,
):
super().__init__(message)
if status_code is not None:
self.status_code = status_code
self.message = message
self.retry_after = retry_after
class DeviceAuthManager:
def __init__(
self,
*,
db_conn_factory: Callable[[], Any],
jwt_service,
dpop_validator: Optional[DPoPValidator],
log: Callable[[str, str, Optional[str]], None],
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
) -> None:
self._db_conn_factory = db_conn_factory
self._jwt_service = jwt_service
self._dpop_validator = dpop_validator
self._log = log
self._rate_limiter = rate_limiter
def authenticate(self) -> DeviceAuthContext:
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise DeviceAuthError("missing_authorization")
token = auth_header[len("Bearer ") :].strip()
if not token:
raise DeviceAuthError("missing_authorization")
try:
claims = self._jwt_service.decode(token)
except jwt.ExpiredSignatureError:
raise DeviceAuthError("token_expired")
except Exception:
raise DeviceAuthError("invalid_token")
raw_guid = str(claims.get("guid") or "").strip()
guid = normalize_guid(raw_guid)
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
token_version = int(claims.get("token_version") or 0)
if not guid or not fingerprint or token_version <= 0:
raise DeviceAuthError("invalid_claims")
if self._rate_limiter:
decision = self._rate_limiter.check(f"fp:{fingerprint}", 60, 60.0)
if not decision.allowed:
raise DeviceAuthError(
"rate_limited",
status_code=429,
retry_after=decision.retry_after,
)
context_label = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
with closing(self._db_conn_factory()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE UPPER(guid) = ?
""",
(guid,),
)
rows = cur.fetchall()
row = None
for candidate in rows or []:
candidate_guid = normalize_guid(candidate[0])
if candidate_guid == guid:
row = candidate
break
if row is None and rows:
row = rows[0]
if row is None:
row = self._recover_device_record(conn, guid, fingerprint, token_version, context_label)
if row is None:
raise DeviceAuthError("device_not_found", status_code=404)
stored_guid, stored_fingerprint, stored_version, status = row
stored_guid = normalize_guid(stored_guid)
if stored_guid != guid:
raise DeviceAuthError("device_mismatch", status_code=401)
if (stored_fingerprint or "").lower().strip() != fingerprint:
raise DeviceAuthError("fingerprint_mismatch", status_code=403)
if int(stored_version or 0) != token_version:
raise DeviceAuthError("token_version_mismatch", status_code=403)
status_norm = (status or "active").strip().lower()
if status_norm in {"revoked", "decommissioned"}:
raise DeviceAuthError("device_revoked", status_code=403)
dpop_proof = request.headers.get("DPoP")
jkt = None
if dpop_proof and self._dpop_validator:
try:
jkt = self._dpop_validator.verify(request.method, request.url, dpop_proof, access_token=token)
except DPoPReplayError:
raise DeviceAuthError("dpop_replayed", status_code=400)
except DPoPVerificationError:
raise DeviceAuthError("dpop_invalid", status_code=400)
ctx = DeviceAuthContext(
guid=guid,
ssl_key_fingerprint=fingerprint,
token_version=token_version,
access_token=token,
claims=claims,
dpop_jkt=jkt,
status=status_norm,
service_mode=context_label,
)
return ctx
def _recover_device_record(
self,
conn: sqlite3.Connection,
guid: str,
fingerprint: str,
token_version: int,
context_label: Optional[str],
) -> Optional[tuple]:
"""Attempt to recreate a missing device row for an authenticated token."""
guid = normalize_guid(guid)
fingerprint = (fingerprint or "").strip()
if not guid or not fingerprint:
return None
cur = conn.cursor()
now_ts = int(time.time())
try:
now_iso = datetime.now(tz=timezone.utc).isoformat()
except Exception:
now_iso = datetime.utcnow().isoformat()
base_hostname = f"RECOVERED-{guid[:12].upper()}" if guid else "RECOVERED"
for attempt in range(6):
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
try:
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
token_version,
status,
key_added_at
)
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
""",
(
guid,
hostname,
now_ts,
now_ts,
fingerprint,
max(token_version or 1, 1),
now_iso,
),
)
except sqlite3.IntegrityError as exc:
message = str(exc).lower()
if "hostname" in message and "unique" in message:
continue
self._log(
"server",
f"device auth failed to recover guid={guid} due to integrity error: {exc}",
context_label,
)
conn.rollback()
return None
except Exception as exc:
self._log(
"server",
f"device auth unexpected error recovering guid={guid}: {exc}",
context_label,
)
conn.rollback()
return None
else:
conn.commit()
break
else:
self._log(
"server",
f"device auth could not recover guid={guid}; hostname collisions persisted",
context_label,
)
conn.rollback()
return None
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE guid = ?
""",
(guid,),
)
row = cur.fetchone()
if not row:
self._log(
"server",
f"device auth recovery for guid={guid} committed but row still missing",
context_label,
)
return row
def require_device_auth(manager: DeviceAuthManager):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
ctx = manager.authenticate()
except DeviceAuthError as exc:
response = jsonify({"error": exc.message})
response.status_code = exc.status_code
retry_after = getattr(exc, "retry_after", None)
if retry_after:
try:
response.headers["Retry-After"] = str(max(1, int(retry_after)))
except Exception:
response.headers["Retry-After"] = "1"
return response
g.device_auth = ctx
return func(*args, **kwargs)
return wrapper
return decorator
__all__ = [
"AGENT_CONTEXT_HEADER",
"DeviceAuthContext",
"DeviceAuthError",
"DeviceAuthManager",
"require_device_auth",
]

118
Data/Engine/auth/dpop.py Normal file
View File

@@ -0,0 +1,118 @@
# ======================================================
# Data\Engine\auth\dpop.py
# Description: Engine-side DPoP proof validation helpers with replay protection.
#
# API Endpoints (if applicable): None
# ======================================================
"""DPoP proof verification helpers for the Engine runtime."""
from __future__ import annotations
import hashlib
import time
from threading import Lock
from typing import Dict, Optional
import jwt
_DPOP_MAX_SKEW = 300.0 # seconds
class DPoPVerificationError(Exception):
"""Raised when DPoP verification fails for structural reasons."""
class DPoPReplayError(DPoPVerificationError):
"""Raised when a DPoP proof replay is detected."""
class DPoPValidator:
"""Validate DPoP proofs and track observed JTIs to prevent replay attacks."""
def __init__(self) -> None:
self._observed_jti: Dict[str, float] = {}
self._lock = Lock()
def verify(
self,
method: str,
htu: str,
proof: str,
access_token: Optional[str] = None,
) -> str:
"""
Verify the presented DPoP proof and return the JWK thumbprint.
"""
if not proof:
raise DPoPVerificationError("DPoP proof missing")
try:
header = jwt.get_unverified_header(proof)
except Exception as exc:
raise DPoPVerificationError("invalid DPoP header") from exc
jwk = header.get("jwk")
alg = header.get("alg")
if not jwk or not isinstance(jwk, dict):
raise DPoPVerificationError("missing jwk in DPoP header")
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
try:
key = jwt.PyJWK(jwk)
public_key = key.key
except Exception as exc:
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
try:
claims = jwt.decode(
proof,
public_key,
algorithms=[alg],
options={"require": ["htm", "htu", "jti", "iat"]},
)
except Exception as exc:
raise DPoPVerificationError("invalid DPoP signature") from exc
htm = claims.get("htm")
proof_htu = claims.get("htu")
jti = claims.get("jti")
iat = claims.get("iat")
ath = claims.get("ath")
if not isinstance(htm, str) or htm.lower() != method.lower():
raise DPoPVerificationError("DPoP htm mismatch")
if not isinstance(proof_htu, str) or proof_htu != htu:
raise DPoPVerificationError("DPoP htu mismatch")
if not isinstance(jti, str):
raise DPoPVerificationError("DPoP jti missing")
if not isinstance(iat, (int, float)):
raise DPoPVerificationError("DPoP iat missing")
now = time.time()
if abs(now - float(iat)) > _DPOP_MAX_SKEW:
raise DPoPVerificationError("DPoP proof outside allowed skew")
if ath and access_token:
expected_ath = jwt.utils.base64url_encode(
hashlib.sha256(access_token.encode("utf-8")).digest()
).decode("ascii")
if expected_ath != ath:
raise DPoPVerificationError("DPoP ath mismatch")
with self._lock:
expiry = self._observed_jti.get(jti)
if expiry and expiry > now:
raise DPoPReplayError("DPoP proof replay detected")
self._observed_jti[jti] = now + _DPOP_MAX_SKEW
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
for key in stale:
self._observed_jti.pop(key, None)
thumbprint = jwt.PyJWK(jwk).thumbprint()
return thumbprint.decode("ascii")
__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"]

View File

@@ -0,0 +1,39 @@
# ======================================================
# Data\Engine\auth\guid_utils.py
# Description: GUID normalisation helpers used by Engine authentication flows.
#
# API Endpoints (if applicable): None
# ======================================================
"""GUID normalisation helpers for Engine-managed authentication."""
from __future__ import annotations
import string
import uuid
from typing import Optional
def normalize_guid(value: Optional[str]) -> str:
"""
Canonicalise GUID strings so Engine services treat different casings uniformly.
"""
candidate = (value or "").strip()
if not candidate:
return ""
candidate = candidate.strip("{}")
try:
return str(uuid.UUID(candidate)).upper()
except Exception:
cleaned = "".join(ch for ch in candidate if ch in string.hexdigits or ch == "-")
cleaned = cleaned.strip("-")
if cleaned:
try:
return str(uuid.UUID(cleaned)).upper()
except Exception:
pass
return candidate.upper()
__all__ = ["normalize_guid"]

View File

@@ -0,0 +1,206 @@
# ======================================================
# Data\Engine\auth\jwt_service.py
# Description: Engine-native JWT access-token helpers with signing key storage under Engine/Data/Auth_Tokens.
#
# API Endpoints (if applicable): None
# ======================================================
"""JWT access-token helpers backed by an Engine-managed Ed25519 key."""
from __future__ import annotations
import hashlib
import os
import time
from pathlib import Path
from typing import Any, Dict, Optional
import jwt
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from ..security.certificates import project_root_path
_TOKEN_ENV_ROOT = "BOREALIS_ENGINE_AUTH_TOKEN_ROOT"
_LEGACY_SERVER_ROOT_ENV = "BOREALIS_SERVER_ROOT"
_KEY_FILENAME = "borealis-jwt-ed25519.key"
def _env_path(name: str) -> Optional[Path]:
value = os.environ.get(name)
if not value:
return None
try:
return Path(value).expanduser().resolve()
except Exception:
try:
return Path(value).expanduser()
except Exception:
return Path(value)
def _engine_runtime_root() -> Path:
env = _env_path("BOREALIS_ENGINE_ROOT") or _env_path("BOREALIS_ENGINE_RUNTIME")
if env:
env.mkdir(parents=True, exist_ok=True)
return env
root = project_root_path() / "Engine"
root.mkdir(parents=True, exist_ok=True)
return root
def _token_root() -> Path:
env = _env_path(_TOKEN_ENV_ROOT)
if env:
env.mkdir(parents=True, exist_ok=True)
return env
root = _engine_runtime_root() / "Data" / "Auth_Tokens"
root.mkdir(parents=True, exist_ok=True)
return root
def _legacy_key_paths() -> Dict[str, Path]:
project_root = project_root_path()
server_root = _env_path(_LEGACY_SERVER_ROOT_ENV) or (project_root / "Server" / "Borealis")
candidates = {
"auth_keys": server_root / "auth_keys" / _KEY_FILENAME,
"keys": server_root / "keys" / _KEY_FILENAME,
}
return candidates
def _tighten_permissions(path: Path) -> None:
try:
if os.name != "nt":
path.chmod(0o600)
except Exception:
pass
_KEY_DIR = _token_root()
_KEY_FILE = _KEY_DIR / _KEY_FILENAME
class JWTService:
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str):
self._private_key = private_key
self._public_key = private_key.public_key()
self._key_id = key_id
@property
def key_id(self) -> str:
return self._key_id
def issue_access_token(
self,
guid: str,
ssl_key_fingerprint: str,
token_version: int,
expires_in: int = 900,
extra_claims: Optional[Dict[str, Any]] = None,
) -> str:
now = int(time.time())
payload: Dict[str, Any] = {
"sub": f"device:{guid}",
"guid": guid,
"ssl_key_fingerprint": ssl_key_fingerprint,
"token_version": int(token_version),
"iat": now,
"nbf": now,
"exp": now + int(expires_in),
}
if extra_claims:
payload.update(extra_claims)
token = jwt.encode(
payload,
self._private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
),
algorithm="EdDSA",
headers={"kid": self._key_id},
)
return token
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
options = {"require": ["exp", "iat", "sub"]}
public_pem = self._public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return jwt.decode(
token,
public_pem,
algorithms=["EdDSA"],
audience=audience,
options=options,
)
def public_jwk(self) -> Dict[str, Any]:
public_bytes = self._public_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
def load_service() -> JWTService:
private_key = _load_or_create_private_key()
public_bytes = private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
return JWTService(private_key, key_id)
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
_KEY_DIR.mkdir(parents=True, exist_ok=True)
_migrate_legacy_key_if_present()
if _KEY_FILE.exists():
with _KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
private_key = ed25519.Ed25519PrivateKey.generate()
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
with _KEY_FILE.open("wb") as fh:
fh.write(pem)
_tighten_permissions(_KEY_FILE)
return private_key
def _migrate_legacy_key_if_present() -> None:
if _KEY_FILE.exists():
return
legacy_paths = _legacy_key_paths()
for legacy_file in legacy_paths.values():
if not legacy_file.exists():
continue
try:
legacy_bytes = legacy_file.read_bytes()
except Exception:
continue
try:
_KEY_FILE.write_bytes(legacy_bytes)
_tighten_permissions(_KEY_FILE)
except Exception:
continue
try:
legacy_file.unlink()
except Exception:
pass
break
__all__ = ["JWTService", "load_service"]

View File

@@ -0,0 +1,51 @@
# ======================================================
# Data\Engine\auth\rate_limit.py
# Description: Sliding-window rate limiter for Engine authentication endpoints.
#
# API Endpoints (if applicable): None
# ======================================================
"""Engine-native in-memory rate limiter suitable for single-process development."""
from __future__ import annotations
import time
from collections import deque
from dataclasses import dataclass
from threading import Lock
from typing import Deque, Dict
@dataclass
class RateLimitDecision:
allowed: bool
retry_after: float
class SlidingWindowRateLimiter:
"""Simple sliding-window limiter to guard authentication endpoints."""
def __init__(self) -> None:
self._buckets: Dict[str, Deque[float]] = {}
self._lock = Lock()
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
now = time.monotonic()
with self._lock:
bucket = self._buckets.get(key)
if bucket is None:
bucket = deque()
self._buckets[key] = bucket
while bucket and now - bucket[0] > window_seconds:
bucket.popleft()
if len(bucket) >= limit:
retry_after = max(0.0, window_seconds - (now - bucket[0]))
return RateLimitDecision(False, retry_after)
bucket.append(now)
return RateLimitDecision(True, 0.0)
__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"]