mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-14 22:35:47 -07:00
ENGINE: Migrated Enrollment Logic
This commit is contained in:
@@ -38,6 +38,8 @@ Lastly, everytime that you complete a stage, you will create a pull request name
|
||||
- [ ] Add migration switch in the legacy server for WebUI delegation.
|
||||
- [x] Extend tests to cover critical WebUI routes.
|
||||
- [ ] Port device API endpoints into Engine services (device + admin coverage in progress).
|
||||
- [x] Move authentication/token stack onto Engine services without legacy fallbacks.
|
||||
- [x] Port enrollment request/poll flows to Engine services and drop legacy imports.
|
||||
- [ ] **Stage 7 — Plan WebSocket migration**
|
||||
- [ ] Extract Socket.IO handlers into Data/Engine/services/WebSocket.
|
||||
- [ ] Provide register_realtime hook for the Engine factory.
|
||||
@@ -46,4 +48,4 @@ Lastly, everytime that you complete a stage, you will create a pull request name
|
||||
|
||||
## Current Status
|
||||
- **Stage:** Stage 6 — Plan WebUI migration
|
||||
- **Active Task:** Migrating device endpoints into the Engine API (legacy bridge removed).
|
||||
- **Active Task:** Continue Stage 6 device/admin API migration (focus on remaining device and admin endpoints now that auth, token, and enrollment paths are Engine-native).
|
||||
|
||||
@@ -17,7 +17,7 @@ from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from Modules.crypto import keys as crypto_keys
|
||||
from Data.Engine.crypto import keys as crypto_keys
|
||||
|
||||
from .conftest import EngineTestHarness
|
||||
|
||||
|
||||
27
Data/Engine/auth/__init__.py
Normal file
27
Data/Engine/auth/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# ======================================================
|
||||
# Data\Engine\auth\__init__.py
|
||||
# Description: Engine-native authentication utilities and helpers decoupled from the legacy server modules.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Authentication utility package for the Borealis Engine."""
|
||||
|
||||
from .jwt_service import JWTService, load_service
|
||||
from .dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
||||
from .rate_limit import SlidingWindowRateLimiter, RateLimitDecision
|
||||
from .device_auth import DeviceAuthManager, DeviceAuthError, DeviceAuthContext, require_device_auth
|
||||
|
||||
__all__ = [
|
||||
"JWTService",
|
||||
"load_service",
|
||||
"DPoPValidator",
|
||||
"DPoPVerificationError",
|
||||
"DPoPReplayError",
|
||||
"SlidingWindowRateLimiter",
|
||||
"RateLimitDecision",
|
||||
"DeviceAuthManager",
|
||||
"DeviceAuthError",
|
||||
"DeviceAuthContext",
|
||||
"require_device_auth",
|
||||
]
|
||||
310
Data/Engine/auth/device_auth.py
Normal file
310
Data/Engine/auth/device_auth.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# ======================================================
|
||||
# Data\Engine\auth\device_auth.py
|
||||
# Description: Engine-native device authentication manager and decorators.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Device authentication helpers for the Borealis Engine runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import sqlite3
|
||||
import time
|
||||
from contextlib import closing
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from flask import g, jsonify, request
|
||||
|
||||
from .dpop import DPoPReplayError, DPoPValidator, DPoPVerificationError
|
||||
from .guid_utils import normalize_guid
|
||||
from .rate_limit import SlidingWindowRateLimiter
|
||||
|
||||
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||
|
||||
|
||||
def _canonical_context(value: Optional[str]) -> Optional[str]:
|
||||
if not value:
|
||||
return None
|
||||
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
|
||||
if not cleaned:
|
||||
return None
|
||||
return cleaned.upper()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceAuthContext:
|
||||
guid: str
|
||||
ssl_key_fingerprint: str
|
||||
token_version: int
|
||||
access_token: str
|
||||
claims: Dict[str, Any]
|
||||
dpop_jkt: Optional[str]
|
||||
status: str
|
||||
service_mode: Optional[str]
|
||||
|
||||
|
||||
class DeviceAuthError(Exception):
|
||||
status_code = 401
|
||||
error_code = "unauthorized"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "unauthorized",
|
||||
*,
|
||||
status_code: Optional[int] = None,
|
||||
retry_after: Optional[float] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
if status_code is not None:
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
class DeviceAuthManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
db_conn_factory: Callable[[], Any],
|
||||
jwt_service,
|
||||
dpop_validator: Optional[DPoPValidator],
|
||||
log: Callable[[str, str, Optional[str]], None],
|
||||
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
|
||||
) -> None:
|
||||
self._db_conn_factory = db_conn_factory
|
||||
self._jwt_service = jwt_service
|
||||
self._dpop_validator = dpop_validator
|
||||
self._log = log
|
||||
self._rate_limiter = rate_limiter
|
||||
|
||||
def authenticate(self) -> DeviceAuthContext:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise DeviceAuthError("missing_authorization")
|
||||
token = auth_header[len("Bearer ") :].strip()
|
||||
if not token:
|
||||
raise DeviceAuthError("missing_authorization")
|
||||
|
||||
try:
|
||||
claims = self._jwt_service.decode(token)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise DeviceAuthError("token_expired")
|
||||
except Exception:
|
||||
raise DeviceAuthError("invalid_token")
|
||||
|
||||
raw_guid = str(claims.get("guid") or "").strip()
|
||||
guid = normalize_guid(raw_guid)
|
||||
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
|
||||
token_version = int(claims.get("token_version") or 0)
|
||||
if not guid or not fingerprint or token_version <= 0:
|
||||
raise DeviceAuthError("invalid_claims")
|
||||
|
||||
if self._rate_limiter:
|
||||
decision = self._rate_limiter.check(f"fp:{fingerprint}", 60, 60.0)
|
||||
if not decision.allowed:
|
||||
raise DeviceAuthError(
|
||||
"rate_limited",
|
||||
status_code=429,
|
||||
retry_after=decision.retry_after,
|
||||
)
|
||||
|
||||
context_label = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||
|
||||
with closing(self._db_conn_factory()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
row = None
|
||||
for candidate in rows or []:
|
||||
candidate_guid = normalize_guid(candidate[0])
|
||||
if candidate_guid == guid:
|
||||
row = candidate
|
||||
break
|
||||
if row is None and rows:
|
||||
row = rows[0]
|
||||
|
||||
if row is None:
|
||||
row = self._recover_device_record(conn, guid, fingerprint, token_version, context_label)
|
||||
|
||||
if row is None:
|
||||
raise DeviceAuthError("device_not_found", status_code=404)
|
||||
|
||||
stored_guid, stored_fingerprint, stored_version, status = row
|
||||
stored_guid = normalize_guid(stored_guid)
|
||||
if stored_guid != guid:
|
||||
raise DeviceAuthError("device_mismatch", status_code=401)
|
||||
if (stored_fingerprint or "").lower().strip() != fingerprint:
|
||||
raise DeviceAuthError("fingerprint_mismatch", status_code=403)
|
||||
if int(stored_version or 0) != token_version:
|
||||
raise DeviceAuthError("token_version_mismatch", status_code=403)
|
||||
|
||||
status_norm = (status or "active").strip().lower()
|
||||
if status_norm in {"revoked", "decommissioned"}:
|
||||
raise DeviceAuthError("device_revoked", status_code=403)
|
||||
|
||||
dpop_proof = request.headers.get("DPoP")
|
||||
jkt = None
|
||||
if dpop_proof and self._dpop_validator:
|
||||
try:
|
||||
jkt = self._dpop_validator.verify(request.method, request.url, dpop_proof, access_token=token)
|
||||
except DPoPReplayError:
|
||||
raise DeviceAuthError("dpop_replayed", status_code=400)
|
||||
except DPoPVerificationError:
|
||||
raise DeviceAuthError("dpop_invalid", status_code=400)
|
||||
|
||||
ctx = DeviceAuthContext(
|
||||
guid=guid,
|
||||
ssl_key_fingerprint=fingerprint,
|
||||
token_version=token_version,
|
||||
access_token=token,
|
||||
claims=claims,
|
||||
dpop_jkt=jkt,
|
||||
status=status_norm,
|
||||
service_mode=context_label,
|
||||
)
|
||||
return ctx
|
||||
|
||||
def _recover_device_record(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
guid: str,
|
||||
fingerprint: str,
|
||||
token_version: int,
|
||||
context_label: Optional[str],
|
||||
) -> Optional[tuple]:
|
||||
"""Attempt to recreate a missing device row for an authenticated token."""
|
||||
|
||||
guid = normalize_guid(guid)
|
||||
fingerprint = (fingerprint or "").strip()
|
||||
if not guid or not fingerprint:
|
||||
return None
|
||||
|
||||
cur = conn.cursor()
|
||||
now_ts = int(time.time())
|
||||
try:
|
||||
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||
except Exception:
|
||||
now_iso = datetime.utcnow().isoformat()
|
||||
|
||||
base_hostname = f"RECOVERED-{guid[:12].upper()}" if guid else "RECOVERED"
|
||||
|
||||
for attempt in range(6):
|
||||
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
|
||||
try:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO devices (
|
||||
guid,
|
||||
hostname,
|
||||
created_at,
|
||||
last_seen,
|
||||
ssl_key_fingerprint,
|
||||
token_version,
|
||||
status,
|
||||
key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
|
||||
""",
|
||||
(
|
||||
guid,
|
||||
hostname,
|
||||
now_ts,
|
||||
now_ts,
|
||||
fingerprint,
|
||||
max(token_version or 1, 1),
|
||||
now_iso,
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
message = str(exc).lower()
|
||||
if "hostname" in message and "unique" in message:
|
||||
continue
|
||||
self._log(
|
||||
"server",
|
||||
f"device auth failed to recover guid={guid} due to integrity error: {exc}",
|
||||
context_label,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
except Exception as exc:
|
||||
self._log(
|
||||
"server",
|
||||
f"device auth unexpected error recovering guid={guid}: {exc}",
|
||||
context_label,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
else:
|
||||
conn.commit()
|
||||
break
|
||||
else:
|
||||
self._log(
|
||||
"server",
|
||||
f"device auth could not recover guid={guid}; hostname collisions persisted",
|
||||
context_label,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(guid,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
self._log(
|
||||
"server",
|
||||
f"device auth recovery for guid={guid} committed but row still missing",
|
||||
context_label,
|
||||
)
|
||||
return row
|
||||
|
||||
|
||||
def require_device_auth(manager: DeviceAuthManager):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
ctx = manager.authenticate()
|
||||
except DeviceAuthError as exc:
|
||||
response = jsonify({"error": exc.message})
|
||||
response.status_code = exc.status_code
|
||||
retry_after = getattr(exc, "retry_after", None)
|
||||
if retry_after:
|
||||
try:
|
||||
response.headers["Retry-After"] = str(max(1, int(retry_after)))
|
||||
except Exception:
|
||||
response.headers["Retry-After"] = "1"
|
||||
return response
|
||||
|
||||
g.device_auth = ctx
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AGENT_CONTEXT_HEADER",
|
||||
"DeviceAuthContext",
|
||||
"DeviceAuthError",
|
||||
"DeviceAuthManager",
|
||||
"require_device_auth",
|
||||
]
|
||||
118
Data/Engine/auth/dpop.py
Normal file
118
Data/Engine/auth/dpop.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# ======================================================
|
||||
# Data\Engine\auth\dpop.py
|
||||
# Description: Engine-side DPoP proof validation helpers with replay protection.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""DPoP proof verification helpers for the Engine runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict, Optional
|
||||
|
||||
import jwt
|
||||
|
||||
_DPOP_MAX_SKEW = 300.0 # seconds
|
||||
|
||||
|
||||
class DPoPVerificationError(Exception):
|
||||
"""Raised when DPoP verification fails for structural reasons."""
|
||||
|
||||
|
||||
class DPoPReplayError(DPoPVerificationError):
|
||||
"""Raised when a DPoP proof replay is detected."""
|
||||
|
||||
|
||||
class DPoPValidator:
|
||||
"""Validate DPoP proofs and track observed JTIs to prevent replay attacks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._observed_jti: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def verify(
|
||||
self,
|
||||
method: str,
|
||||
htu: str,
|
||||
proof: str,
|
||||
access_token: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Verify the presented DPoP proof and return the JWK thumbprint.
|
||||
"""
|
||||
|
||||
if not proof:
|
||||
raise DPoPVerificationError("DPoP proof missing")
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(proof)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP header") from exc
|
||||
|
||||
jwk = header.get("jwk")
|
||||
alg = header.get("alg")
|
||||
if not jwk or not isinstance(jwk, dict):
|
||||
raise DPoPVerificationError("missing jwk in DPoP header")
|
||||
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
|
||||
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
|
||||
|
||||
try:
|
||||
key = jwt.PyJWK(jwk)
|
||||
public_key = key.key
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
|
||||
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
proof,
|
||||
public_key,
|
||||
algorithms=[alg],
|
||||
options={"require": ["htm", "htu", "jti", "iat"]},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP signature") from exc
|
||||
|
||||
htm = claims.get("htm")
|
||||
proof_htu = claims.get("htu")
|
||||
jti = claims.get("jti")
|
||||
iat = claims.get("iat")
|
||||
ath = claims.get("ath")
|
||||
|
||||
if not isinstance(htm, str) or htm.lower() != method.lower():
|
||||
raise DPoPVerificationError("DPoP htm mismatch")
|
||||
if not isinstance(proof_htu, str) or proof_htu != htu:
|
||||
raise DPoPVerificationError("DPoP htu mismatch")
|
||||
if not isinstance(jti, str):
|
||||
raise DPoPVerificationError("DPoP jti missing")
|
||||
if not isinstance(iat, (int, float)):
|
||||
raise DPoPVerificationError("DPoP iat missing")
|
||||
|
||||
now = time.time()
|
||||
if abs(now - float(iat)) > _DPOP_MAX_SKEW:
|
||||
raise DPoPVerificationError("DPoP proof outside allowed skew")
|
||||
|
||||
if ath and access_token:
|
||||
expected_ath = jwt.utils.base64url_encode(
|
||||
hashlib.sha256(access_token.encode("utf-8")).digest()
|
||||
).decode("ascii")
|
||||
if expected_ath != ath:
|
||||
raise DPoPVerificationError("DPoP ath mismatch")
|
||||
|
||||
with self._lock:
|
||||
expiry = self._observed_jti.get(jti)
|
||||
if expiry and expiry > now:
|
||||
raise DPoPReplayError("DPoP proof replay detected")
|
||||
self._observed_jti[jti] = now + _DPOP_MAX_SKEW
|
||||
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
|
||||
for key in stale:
|
||||
self._observed_jti.pop(key, None)
|
||||
|
||||
thumbprint = jwt.PyJWK(jwk).thumbprint()
|
||||
return thumbprint.decode("ascii")
|
||||
|
||||
|
||||
__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"]
|
||||
39
Data/Engine/auth/guid_utils.py
Normal file
39
Data/Engine/auth/guid_utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# ======================================================
|
||||
# Data\Engine\auth\guid_utils.py
|
||||
# Description: GUID normalisation helpers used by Engine authentication flows.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""GUID normalisation helpers for Engine-managed authentication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def normalize_guid(value: Optional[str]) -> str:
|
||||
"""
|
||||
Canonicalise GUID strings so Engine services treat different casings uniformly.
|
||||
"""
|
||||
|
||||
candidate = (value or "").strip()
|
||||
if not candidate:
|
||||
return ""
|
||||
candidate = candidate.strip("{}")
|
||||
try:
|
||||
return str(uuid.UUID(candidate)).upper()
|
||||
except Exception:
|
||||
cleaned = "".join(ch for ch in candidate if ch in string.hexdigits or ch == "-")
|
||||
cleaned = cleaned.strip("-")
|
||||
if cleaned:
|
||||
try:
|
||||
return str(uuid.UUID(cleaned)).upper()
|
||||
except Exception:
|
||||
pass
|
||||
return candidate.upper()
|
||||
|
||||
|
||||
__all__ = ["normalize_guid"]
|
||||
206
Data/Engine/auth/jwt_service.py
Normal file
206
Data/Engine/auth/jwt_service.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# ======================================================
|
||||
# Data\Engine\auth\jwt_service.py
|
||||
# Description: Engine-native JWT access-token helpers with signing key storage under Engine/Data/Auth_Tokens.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""JWT access-token helpers backed by an Engine-managed Ed25519 key."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from ..security.certificates import project_root_path
|
||||
|
||||
_TOKEN_ENV_ROOT = "BOREALIS_ENGINE_AUTH_TOKEN_ROOT"
|
||||
_LEGACY_SERVER_ROOT_ENV = "BOREALIS_SERVER_ROOT"
|
||||
_KEY_FILENAME = "borealis-jwt-ed25519.key"
|
||||
|
||||
|
||||
def _env_path(name: str) -> Optional[Path]:
|
||||
value = os.environ.get(name)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return Path(value).expanduser().resolve()
|
||||
except Exception:
|
||||
try:
|
||||
return Path(value).expanduser()
|
||||
except Exception:
|
||||
return Path(value)
|
||||
|
||||
|
||||
def _engine_runtime_root() -> Path:
|
||||
env = _env_path("BOREALIS_ENGINE_ROOT") or _env_path("BOREALIS_ENGINE_RUNTIME")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
root = project_root_path() / "Engine"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def _token_root() -> Path:
|
||||
env = _env_path(_TOKEN_ENV_ROOT)
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
root = _engine_runtime_root() / "Data" / "Auth_Tokens"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def _legacy_key_paths() -> Dict[str, Path]:
|
||||
project_root = project_root_path()
|
||||
server_root = _env_path(_LEGACY_SERVER_ROOT_ENV) or (project_root / "Server" / "Borealis")
|
||||
candidates = {
|
||||
"auth_keys": server_root / "auth_keys" / _KEY_FILENAME,
|
||||
"keys": server_root / "keys" / _KEY_FILENAME,
|
||||
}
|
||||
return candidates
|
||||
|
||||
|
||||
def _tighten_permissions(path: Path) -> None:
|
||||
try:
|
||||
if os.name != "nt":
|
||||
path.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
_KEY_DIR = _token_root()
|
||||
_KEY_FILE = _KEY_DIR / _KEY_FILENAME
|
||||
|
||||
|
||||
class JWTService:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str):
|
||||
self._private_key = private_key
|
||||
self._public_key = private_key.public_key()
|
||||
self._key_id = key_id
|
||||
|
||||
@property
|
||||
def key_id(self) -> str:
|
||||
return self._key_id
|
||||
|
||||
def issue_access_token(
|
||||
self,
|
||||
guid: str,
|
||||
ssl_key_fingerprint: str,
|
||||
token_version: int,
|
||||
expires_in: int = 900,
|
||||
extra_claims: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
now = int(time.time())
|
||||
payload: Dict[str, Any] = {
|
||||
"sub": f"device:{guid}",
|
||||
"guid": guid,
|
||||
"ssl_key_fingerprint": ssl_key_fingerprint,
|
||||
"token_version": int(token_version),
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": now + int(expires_in),
|
||||
}
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
self._private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
),
|
||||
algorithm="EdDSA",
|
||||
headers={"kid": self._key_id},
|
||||
)
|
||||
return token
|
||||
|
||||
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
|
||||
options = {"require": ["exp", "iat", "sub"]}
|
||||
public_pem = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=["EdDSA"],
|
||||
audience=audience,
|
||||
options=options,
|
||||
)
|
||||
|
||||
def public_jwk(self) -> Dict[str, Any]:
|
||||
public_bytes = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
|
||||
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
|
||||
|
||||
|
||||
def load_service() -> JWTService:
|
||||
private_key = _load_or_create_private_key()
|
||||
public_bytes = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
|
||||
return JWTService(private_key, key_id)
|
||||
|
||||
|
||||
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
|
||||
_KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_migrate_legacy_key_if_present()
|
||||
|
||||
if _KEY_FILE.exists():
|
||||
with _KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
with _KEY_FILE.open("wb") as fh:
|
||||
fh.write(pem)
|
||||
_tighten_permissions(_KEY_FILE)
|
||||
return private_key
|
||||
|
||||
|
||||
def _migrate_legacy_key_if_present() -> None:
|
||||
if _KEY_FILE.exists():
|
||||
return
|
||||
|
||||
legacy_paths = _legacy_key_paths()
|
||||
for legacy_file in legacy_paths.values():
|
||||
if not legacy_file.exists():
|
||||
continue
|
||||
try:
|
||||
legacy_bytes = legacy_file.read_bytes()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
try:
|
||||
_KEY_FILE.write_bytes(legacy_bytes)
|
||||
_tighten_permissions(_KEY_FILE)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
try:
|
||||
legacy_file.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
|
||||
__all__ = ["JWTService", "load_service"]
|
||||
51
Data/Engine/auth/rate_limit.py
Normal file
51
Data/Engine/auth/rate_limit.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# ======================================================
|
||||
# Data\Engine\auth\rate_limit.py
|
||||
# Description: Sliding-window rate limiter for Engine authentication endpoints.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Engine-native in-memory rate limiter suitable for single-process development."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Deque, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitDecision:
|
||||
allowed: bool
|
||||
retry_after: float
|
||||
|
||||
|
||||
class SlidingWindowRateLimiter:
|
||||
"""Simple sliding-window limiter to guard authentication endpoints."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buckets: Dict[str, Deque[float]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
bucket = self._buckets.get(key)
|
||||
if bucket is None:
|
||||
bucket = deque()
|
||||
self._buckets[key] = bucket
|
||||
|
||||
while bucket and now - bucket[0] > window_seconds:
|
||||
bucket.popleft()
|
||||
|
||||
if len(bucket) >= limit:
|
||||
retry_after = max(0.0, window_seconds - (now - bucket[0]))
|
||||
return RateLimitDecision(False, retry_after)
|
||||
|
||||
bucket.append(now)
|
||||
return RateLimitDecision(True, 0.0)
|
||||
|
||||
|
||||
__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"]
|
||||
30
Data/Engine/crypto/__init__.py
Normal file
30
Data/Engine/crypto/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# ======================================================
|
||||
# Data\Engine\crypto\__init__.py
|
||||
# Description: Engine cryptographic helpers and key utilities.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Cryptographic helper utilities for the Borealis Engine runtime."""
|
||||
|
||||
from .keys import (
|
||||
generate_ed25519_keypair,
|
||||
normalize_base64,
|
||||
spki_der_from_base64,
|
||||
base64_from_spki_der,
|
||||
fingerprint_from_spki_der,
|
||||
fingerprint_from_base64_spki,
|
||||
private_key_to_pem,
|
||||
public_key_to_pem,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"generate_ed25519_keypair",
|
||||
"normalize_base64",
|
||||
"spki_der_from_base64",
|
||||
"base64_from_spki_der",
|
||||
"fingerprint_from_spki_der",
|
||||
"fingerprint_from_base64_spki",
|
||||
"private_key_to_pem",
|
||||
"public_key_to_pem",
|
||||
]
|
||||
88
Data/Engine/crypto/keys.py
Normal file
88
Data/Engine/crypto/keys.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# ======================================================
|
||||
# Data\Engine\crypto\keys.py
|
||||
# Description: Engine-native Ed25519 key helpers and fingerprint utilities.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Utility helpers for working with Ed25519 keys and fingerprints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||
|
||||
|
||||
def generate_ed25519_keypair() -> Tuple[ed25519.Ed25519PrivateKey, bytes]:
|
||||
"""
|
||||
Generate a new Ed25519 keypair.
|
||||
|
||||
Returns the private key object and the public key encoded as SubjectPublicKeyInfo DER bytes.
|
||||
"""
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
public_key = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return private_key, public_key
|
||||
|
||||
|
||||
def normalize_base64(data: str) -> str:
|
||||
"""
|
||||
Collapse whitespace and normalise URL-safe encodings so we can reliably decode.
|
||||
"""
|
||||
|
||||
cleaned = re.sub(r"\s+", "", data or "")
|
||||
return cleaned.replace("-", "+").replace("_", "/")
|
||||
|
||||
|
||||
def spki_der_from_base64(spki_b64: str) -> bytes:
|
||||
return base64.b64decode(normalize_base64(spki_b64), validate=True)
|
||||
|
||||
|
||||
def base64_from_spki_der(spki_der: bytes) -> str:
|
||||
return base64.b64encode(spki_der).decode("ascii")
|
||||
|
||||
|
||||
def fingerprint_from_spki_der(spki_der: bytes) -> str:
|
||||
digest = hashlib.sha256(spki_der).hexdigest()
|
||||
return digest.lower()
|
||||
|
||||
|
||||
def fingerprint_from_base64_spki(spki_b64: str) -> str:
|
||||
return fingerprint_from_spki_der(spki_der_from_base64(spki_b64))
|
||||
|
||||
|
||||
def private_key_to_pem(private_key: ed25519.Ed25519PrivateKey) -> bytes:
|
||||
return private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
|
||||
def public_key_to_pem(public_spki_der: bytes) -> bytes:
|
||||
public_key = load_der_public_key(public_spki_der)
|
||||
return public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"generate_ed25519_keypair",
|
||||
"normalize_base64",
|
||||
"spki_der_from_base64",
|
||||
"base64_from_spki_der",
|
||||
"fingerprint_from_spki_der",
|
||||
"fingerprint_from_base64_spki",
|
||||
"private_key_to_pem",
|
||||
"public_key_to_pem",
|
||||
]
|
||||
12
Data/Engine/enrollment/__init__.py
Normal file
12
Data/Engine/enrollment/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# ======================================================
|
||||
# Data\Engine\enrollment\__init__.py
|
||||
# Description: Enrollment utilities for Engine-managed device onboarding.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Enrollment helper utilities for the Borealis Engine runtime."""
|
||||
|
||||
from .nonce_store import NonceCache
|
||||
|
||||
__all__ = ["NonceCache"]
|
||||
42
Data/Engine/enrollment/nonce_store.py
Normal file
42
Data/Engine/enrollment/nonce_store.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# ======================================================
|
||||
# Data\Engine\enrollment\nonce_store.py
|
||||
# Description: Short-lived nonce cache preventing replay during Engine enrollment flows.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Short-lived nonce cache to defend against enrollment replay attacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class NonceCache:
|
||||
def __init__(self, ttl_seconds: float = 300.0) -> None:
|
||||
self._ttl = ttl_seconds
|
||||
self._entries: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def consume(self, key: str) -> bool:
|
||||
"""
|
||||
Attempt to consume the nonce identified by `key`.
|
||||
|
||||
Returns True on first use within TTL, False if already consumed.
|
||||
"""
|
||||
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
expire_at = self._entries.get(key)
|
||||
if expire_at and expire_at > now:
|
||||
return False
|
||||
self._entries[key] = now + self._ttl
|
||||
stale = [nonce for nonce, expiry in self._entries.items() if expiry <= now]
|
||||
for nonce in stale:
|
||||
self._entries.pop(nonce, None)
|
||||
return True
|
||||
|
||||
|
||||
__all__ = ["NonceCache"]
|
||||
@@ -1,6 +1,6 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\API\__init__.py
|
||||
# Description: Registers Engine API groups and bridges to legacy modules while exposing core utility routes.
|
||||
# Description: Registers Engine API groups, wiring Engine-native authentication while delegating remaining legacy modules.
|
||||
#
|
||||
# API Endpoints (if applicable):
|
||||
# - GET /health (No Authentication) - Returns an OK status for liveness probing.
|
||||
@@ -20,15 +20,15 @@ from typing import Any, Callable, Iterable, Mapping, Optional, Sequence
|
||||
|
||||
from flask import Blueprint, Flask, jsonify
|
||||
|
||||
from Modules.auth import jwt_service as jwt_service_module
|
||||
from Modules.auth.device_auth import DeviceAuthManager
|
||||
from Modules.auth.dpop import DPoPValidator
|
||||
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||
from ...auth import jwt_service as jwt_service_module
|
||||
from ...auth.device_auth import DeviceAuthManager
|
||||
from ...auth.dpop import DPoPValidator
|
||||
from ...auth.rate_limit import SlidingWindowRateLimiter
|
||||
from ...database import initialise_engine_database
|
||||
from ...security import signing
|
||||
from Modules.enrollment import routes as enrollment_routes
|
||||
from Modules.enrollment.nonce_store import NonceCache
|
||||
from Modules.tokens import routes as token_routes
|
||||
from ...enrollment import NonceCache
|
||||
from .enrollment import routes as enrollment_routes
|
||||
from .tokens import routes as token_routes
|
||||
|
||||
from ...server import EngineContext
|
||||
from .access_management.login import register_auth
|
||||
@@ -137,7 +137,7 @@ def _make_db_conn_factory(database_path: str) -> Callable[[], sqlite3.Connection
|
||||
|
||||
|
||||
@dataclass
|
||||
class LegacyServiceAdapters:
|
||||
class EngineServiceAdapters:
|
||||
context: EngineContext
|
||||
db_conn_factory: Callable[[], sqlite3.Connection] = field(init=False)
|
||||
jwt_service: Any = field(init=False)
|
||||
@@ -180,7 +180,7 @@ class LegacyServiceAdapters:
|
||||
)
|
||||
|
||||
|
||||
def _register_tokens(app: Flask, adapters: LegacyServiceAdapters) -> None:
|
||||
def _register_tokens(app: Flask, adapters: EngineServiceAdapters) -> None:
|
||||
token_routes.register(
|
||||
app,
|
||||
db_conn_factory=adapters.db_conn_factory,
|
||||
@@ -189,7 +189,7 @@ def _register_tokens(app: Flask, adapters: LegacyServiceAdapters) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _register_enrollment(app: Flask, adapters: LegacyServiceAdapters) -> None:
|
||||
def _register_enrollment(app: Flask, adapters: EngineServiceAdapters) -> None:
|
||||
tls_bundle = adapters.context.tls_bundle_path or ""
|
||||
enrollment_routes.register(
|
||||
app,
|
||||
@@ -204,12 +204,12 @@ def _register_enrollment(app: Flask, adapters: LegacyServiceAdapters) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _register_devices(app: Flask, adapters: LegacyServiceAdapters) -> None:
|
||||
def _register_devices(app: Flask, adapters: EngineServiceAdapters) -> None:
|
||||
register_management(app, adapters)
|
||||
register_admin_endpoints(app, adapters)
|
||||
|
||||
|
||||
_GROUP_REGISTRARS: Mapping[str, Callable[[Flask, LegacyServiceAdapters], None]] = {
|
||||
_GROUP_REGISTRARS: Mapping[str, Callable[[Flask, EngineServiceAdapters], None]] = {
|
||||
"auth": register_auth,
|
||||
"tokens": _register_tokens,
|
||||
"enrollment": _register_enrollment,
|
||||
@@ -236,7 +236,7 @@ def register_api(app: Flask, context: EngineContext) -> None:
|
||||
|
||||
enabled_groups: Iterable[str] = context.api_groups or DEFAULT_API_GROUPS
|
||||
normalized = [group.strip().lower() for group in enabled_groups if group]
|
||||
adapters: Optional[LegacyServiceAdapters] = None
|
||||
adapters: Optional[EngineServiceAdapters] = None
|
||||
|
||||
for group in normalized:
|
||||
if group == "core":
|
||||
@@ -244,7 +244,7 @@ def register_api(app: Flask, context: EngineContext) -> None:
|
||||
continue
|
||||
|
||||
if adapters is None:
|
||||
adapters = LegacyServiceAdapters(context)
|
||||
adapters = EngineServiceAdapters(context)
|
||||
registrar = _GROUP_REGISTRARS.get(group)
|
||||
if registrar is None:
|
||||
context.logger.info("Engine API group '%s' is not implemented; skipping.", group)
|
||||
|
||||
@@ -35,7 +35,7 @@ except Exception: # pragma: no cover - optional dependency
|
||||
qrcode = None # type: ignore
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - typing helper
|
||||
from Data.Engine.services.API import LegacyServiceAdapters
|
||||
from Data.Engine.services.API import EngineServiceAdapters
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
@@ -103,7 +103,7 @@ def _user_row_to_dict(row: Sequence[Any]) -> Mapping[str, Any]:
|
||||
|
||||
|
||||
class _AuthService:
|
||||
def __init__(self, app: Flask, adapters: "LegacyServiceAdapters") -> None:
|
||||
def __init__(self, app: Flask, adapters: "EngineServiceAdapters") -> None:
|
||||
self.app = app
|
||||
self.adapters = adapters
|
||||
self.context = adapters.context
|
||||
@@ -393,7 +393,7 @@ class _AuthService:
|
||||
)
|
||||
|
||||
|
||||
def register_auth(app: Flask, adapters: "LegacyServiceAdapters") -> None:
|
||||
def register_auth(app: Flask, adapters: "EngineServiceAdapters") -> None:
|
||||
"""Register authentication endpoints for the Engine."""
|
||||
|
||||
service = _AuthService(app, adapters)
|
||||
@@ -416,3 +416,4 @@ def register_auth(app: Flask, adapters: "LegacyServiceAdapters") -> None:
|
||||
return service.me()
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Opti
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - typing aide
|
||||
from .. import LegacyServiceAdapters
|
||||
from .. import EngineServiceAdapters
|
||||
|
||||
|
||||
_ISLAND_DIR_MAP: Mapping[str, str] = {
|
||||
@@ -49,7 +49,7 @@ _BASE64_CLEANER = re.compile(r"\s+")
|
||||
class AssemblyManagementService:
|
||||
"""Implements assembly CRUD helpers for Engine routes."""
|
||||
|
||||
def __init__(self, adapters: "LegacyServiceAdapters") -> None:
|
||||
def __init__(self, adapters: "EngineServiceAdapters") -> None:
|
||||
self.adapters = adapters
|
||||
self.logger = adapters.context.logger or logging.getLogger(__name__)
|
||||
self.service_log = adapters.service_log
|
||||
@@ -679,7 +679,7 @@ class AssemblyManagementService:
|
||||
return obj
|
||||
|
||||
|
||||
def register_assemblies(app, adapters: "LegacyServiceAdapters") -> None:
|
||||
def register_assemblies(app, adapters: "EngineServiceAdapters") -> None:
|
||||
"""Register assembly CRUD endpoints on the Flask app."""
|
||||
|
||||
service = AssemblyManagementService(adapters)
|
||||
@@ -726,3 +726,4 @@ def register_assemblies(app, adapters: "LegacyServiceAdapters") -> None:
|
||||
return jsonify(response), status
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ except Exception: # pragma: no cover - optional dependency
|
||||
qrcode = None # type: ignore
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - typing helper
|
||||
from . import LegacyServiceAdapters
|
||||
from . import EngineServiceAdapters
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
@@ -103,7 +103,7 @@ def _user_row_to_dict(row: Sequence[Any]) -> Mapping[str, Any]:
|
||||
|
||||
|
||||
class _AuthService:
|
||||
def __init__(self, app: Flask, adapters: "LegacyServiceAdapters") -> None:
|
||||
def __init__(self, app: Flask, adapters: "EngineServiceAdapters") -> None:
|
||||
self.app = app
|
||||
self.adapters = adapters
|
||||
self.context = adapters.context
|
||||
@@ -398,7 +398,7 @@ class _AuthService:
|
||||
)
|
||||
|
||||
|
||||
def register_auth(app: Flask, adapters: "LegacyServiceAdapters") -> None:
|
||||
def register_auth(app: Flask, adapters: "EngineServiceAdapters") -> None:
|
||||
"""Register authentication endpoints for the Engine."""
|
||||
|
||||
service = _AuthService(app, adapters)
|
||||
@@ -422,3 +422,4 @@ def register_auth(app: Flask, adapters: "LegacyServiceAdapters") -> None:
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
adapters.context.logger.info("Engine registered API group 'auth'.")
|
||||
|
||||
|
||||
@@ -24,10 +24,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from flask import Blueprint, jsonify, request, session
|
||||
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
|
||||
|
||||
from Modules.guid_utils import normalize_guid
|
||||
from ....auth.guid_utils import normalize_guid
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - typing helper
|
||||
from .. import LegacyServiceAdapters
|
||||
from .. import EngineServiceAdapters
|
||||
|
||||
|
||||
VALID_TTL_HOURS = {1, 3, 6, 12, 24}
|
||||
@@ -49,7 +49,7 @@ def _generate_install_code() -> str:
|
||||
class AdminDeviceService:
|
||||
"""Utility wrapper for admin device APIs."""
|
||||
|
||||
def __init__(self, app, adapters: "LegacyServiceAdapters") -> None:
|
||||
def __init__(self, app, adapters: "EngineServiceAdapters") -> None:
|
||||
self.app = app
|
||||
self.adapters = adapters
|
||||
self.db_conn_factory = adapters.db_conn_factory
|
||||
@@ -477,7 +477,7 @@ class AdminDeviceService:
|
||||
return self._set_approval_status(approval_id, "denied")
|
||||
|
||||
|
||||
def register_admin_endpoints(app, adapters: "LegacyServiceAdapters") -> None:
|
||||
def register_admin_endpoints(app, adapters: "EngineServiceAdapters") -> None:
|
||||
"""Register admin enrollment + approval endpoints."""
|
||||
|
||||
service = AdminDeviceService(app, adapters)
|
||||
@@ -532,3 +532,4 @@ def register_admin_endpoints(app, adapters: "LegacyServiceAdapters") -> None:
|
||||
return jsonify(payload), status
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\API\devices\enrollment.py
|
||||
# Description: Placeholder for device enrollment API bridge (not yet implemented).
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"Placeholder for API module devices/enrollment.py."
|
||||
@@ -41,8 +41,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from flask import Blueprint, jsonify, request, session, g
|
||||
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
|
||||
|
||||
from Modules.auth.device_auth import require_device_auth
|
||||
from Modules.guid_utils import normalize_guid
|
||||
from ....auth.device_auth import require_device_auth
|
||||
from ....auth.guid_utils import normalize_guid
|
||||
|
||||
try:
|
||||
import requests # type: ignore
|
||||
@@ -57,7 +57,7 @@ except ImportError: # pragma: no cover - fallback for minimal test environments
|
||||
requests = _RequestsStub() # type: ignore
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - typing aide
|
||||
from .. import LegacyServiceAdapters
|
||||
from .. import EngineServiceAdapters
|
||||
|
||||
|
||||
def _safe_json(raw: Optional[str], default: Any) -> Any:
|
||||
@@ -340,7 +340,7 @@ def _device_upsert(
|
||||
class RepositoryHashCache:
|
||||
"""Lightweight GitHub head cache with on-disk persistence."""
|
||||
|
||||
def __init__(self, adapters: "LegacyServiceAdapters") -> None:
|
||||
def __init__(self, adapters: "EngineServiceAdapters") -> None:
|
||||
self._db_conn_factory = adapters.db_conn_factory
|
||||
self._service_log = adapters.service_log
|
||||
self._logger = adapters.context.logger
|
||||
@@ -617,7 +617,7 @@ class DeviceManagementService:
|
||||
"connection_endpoint",
|
||||
)
|
||||
|
||||
def __init__(self, app, adapters: "LegacyServiceAdapters") -> None:
|
||||
def __init__(self, app, adapters: "EngineServiceAdapters") -> None:
|
||||
self.app = app
|
||||
self.adapters = adapters
|
||||
self.db_conn_factory = adapters.db_conn_factory
|
||||
@@ -1513,7 +1513,7 @@ class DeviceManagementService:
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def register_management(app, adapters: "LegacyServiceAdapters") -> None:
|
||||
def register_management(app, adapters: "EngineServiceAdapters") -> None:
|
||||
"""Register device management endpoints onto the Flask app."""
|
||||
|
||||
service = DeviceManagementService(app, adapters)
|
||||
@@ -1679,3 +1679,4 @@ def register_management(app, adapters: "LegacyServiceAdapters") -> None:
|
||||
return jsonify(payload), status
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
12
Data/Engine/services/API/enrollment/__init__.py
Normal file
12
Data/Engine/services/API/enrollment/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\API\enrollment\__init__.py
|
||||
# Description: Engine enrollment API registration helpers.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Expose Engine-native enrollment API routes."""
|
||||
|
||||
from .routes import register
|
||||
|
||||
__all__ = ["register"]
|
||||
744
Data/Engine/services/API/enrollment/routes.py
Normal file
744
Data/Engine/services/API/enrollment/routes.py
Normal file
@@ -0,0 +1,744 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\API\enrollment\routes.py
|
||||
# Description: Engine-native device enrollment endpoints handling install codes, approvals, and token issuance.
|
||||
#
|
||||
# API Endpoints (if applicable):
|
||||
# - POST /api/agent/enroll/request (No Authentication) - Submits device enrollment requests.
|
||||
# - POST /api/agent/enroll/poll (No Authentication) - Finalises approved enrollment requests.
|
||||
# ======================================================
|
||||
|
||||
"""Device enrollment routes for the Borealis Engine runtime."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import secrets
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||
|
||||
|
||||
def _canonical_context(value: Optional[str]) -> Optional[str]:
|
||||
if not value:
|
||||
return None
|
||||
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
|
||||
if not cleaned:
|
||||
return None
|
||||
return cleaned.upper()
|
||||
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
from ....auth.rate_limit import SlidingWindowRateLimiter
|
||||
from ....crypto import keys as crypto_keys
|
||||
from ....enrollment.nonce_store import NonceCache
|
||||
from ....auth.guid_utils import normalize_guid
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
|
||||
def register(
|
||||
app,
|
||||
*,
|
||||
db_conn_factory: Callable[[], sqlite3.Connection],
|
||||
log: Callable[[str, str, Optional[str]], None],
|
||||
jwt_service,
|
||||
tls_bundle_path: str,
|
||||
ip_rate_limiter: SlidingWindowRateLimiter,
|
||||
fp_rate_limiter: SlidingWindowRateLimiter,
|
||||
nonce_cache: NonceCache,
|
||||
script_signer,
|
||||
) -> None:
|
||||
blueprint = Blueprint("enrollment", __name__)
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
def _iso(dt: datetime) -> str:
|
||||
return dt.isoformat()
|
||||
|
||||
def _remote_addr() -> str:
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
addr = request.remote_addr or "unknown"
|
||||
return addr.strip()
|
||||
|
||||
def _signing_key_b64() -> str:
|
||||
if not script_signer:
|
||||
return ""
|
||||
try:
|
||||
return script_signer.public_base64_spki()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _rate_limited(
|
||||
key: str,
|
||||
limiter: SlidingWindowRateLimiter,
|
||||
limit: int,
|
||||
window_s: float,
|
||||
context_hint: Optional[str],
|
||||
):
|
||||
decision = limiter.check(key, limit, window_s)
|
||||
if not decision.allowed:
|
||||
log(
|
||||
"server",
|
||||
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
|
||||
context_hint,
|
||||
)
|
||||
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
|
||||
response.status_code = 429
|
||||
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
|
||||
return response
|
||||
return None
|
||||
|
||||
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
WHERE code = ?
|
||||
""",
|
||||
(code_value,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
keys = [
|
||||
"id",
|
||||
"code",
|
||||
"expires_at",
|
||||
"used_at",
|
||||
"used_by_guid",
|
||||
"max_uses",
|
||||
"use_count",
|
||||
"last_used_at",
|
||||
]
|
||||
record = dict(zip(keys, row))
|
||||
return record
|
||||
|
||||
def _install_code_valid(
|
||||
record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
if not record:
|
||||
return False, None
|
||||
expires_at = record.get("expires_at")
|
||||
if not isinstance(expires_at, str):
|
||||
return False, None
|
||||
try:
|
||||
expiry = datetime.fromisoformat(expires_at)
|
||||
except Exception:
|
||||
return False, None
|
||||
if expiry <= _now():
|
||||
return False, None
|
||||
try:
|
||||
max_uses = int(record.get("max_uses") or 1)
|
||||
except Exception:
|
||||
max_uses = 1
|
||||
if max_uses < 1:
|
||||
max_uses = 1
|
||||
try:
|
||||
use_count = int(record.get("use_count") or 0)
|
||||
except Exception:
|
||||
use_count = 0
|
||||
if use_count < max_uses:
|
||||
return True, None
|
||||
|
||||
guid = normalize_guid(record.get("used_by_guid"))
|
||||
if not guid:
|
||||
return False, None
|
||||
cur.execute(
|
||||
"SELECT ssl_key_fingerprint FROM devices WHERE UPPER(guid) = ?",
|
||||
(guid,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return False, None
|
||||
stored_fp = (row[0] or "").strip().lower()
|
||||
if not stored_fp:
|
||||
return False, None
|
||||
if stored_fp == (fingerprint or "").strip().lower():
|
||||
return True, guid
|
||||
return False, None
|
||||
|
||||
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
|
||||
guid_norm = normalize_guid(guid)
|
||||
base = (hostname or "").strip() or guid_norm
|
||||
base = base[:253]
|
||||
candidate = base
|
||||
suffix = 1
|
||||
while True:
|
||||
cur.execute(
|
||||
"SELECT guid FROM devices WHERE hostname = ?",
|
||||
(candidate,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return candidate
|
||||
existing_guid = normalize_guid(row[0])
|
||||
if existing_guid == guid_norm:
|
||||
return candidate
|
||||
candidate = f"{base}-{suffix}"
|
||||
suffix += 1
|
||||
if suffix > 50:
|
||||
return guid_norm
|
||||
|
||||
def _store_device_key(cur: sqlite3.Cursor, guid: str, fingerprint: str) -> None:
|
||||
guid_norm = normalize_guid(guid)
|
||||
added_at = _iso(_now())
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(str(uuid.uuid4()), guid_norm, fingerprint, added_at),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_keys
|
||||
SET retired_at = ?
|
||||
WHERE guid = ?
|
||||
AND ssl_key_fingerprint != ?
|
||||
AND retired_at IS NULL
|
||||
""",
|
||||
(_iso(_now()), guid_norm, fingerprint),
|
||||
)
|
||||
|
||||
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
|
||||
guid_norm = normalize_guid(guid)
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid_norm,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
keys = [
|
||||
"guid",
|
||||
"hostname",
|
||||
"token_version",
|
||||
"status",
|
||||
"ssl_key_fingerprint",
|
||||
"key_added_at",
|
||||
]
|
||||
record = dict(zip(keys, row))
|
||||
record["guid"] = normalize_guid(record.get("guid"))
|
||||
stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower()
|
||||
new_fp = (fingerprint or "").strip().lower()
|
||||
if not stored_fp and new_fp:
|
||||
cur.execute(
|
||||
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
||||
(fingerprint, _iso(_now()), record["guid"]),
|
||||
)
|
||||
record["ssl_key_fingerprint"] = fingerprint
|
||||
elif new_fp and stored_fp != new_fp:
|
||||
now_iso = _iso(_now())
|
||||
try:
|
||||
current_version = int(record.get("token_version") or 1)
|
||||
except Exception:
|
||||
current_version = 1
|
||||
new_version = max(current_version + 1, 1)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET ssl_key_fingerprint = ?,
|
||||
key_added_at = ?,
|
||||
token_version = ?,
|
||||
status = 'active'
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(fingerprint, now_iso, new_version, record["guid"]),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = ?
|
||||
WHERE guid = ?
|
||||
AND revoked_at IS NULL
|
||||
""",
|
||||
(now_iso, record["guid"]),
|
||||
)
|
||||
record["ssl_key_fingerprint"] = fingerprint
|
||||
record["token_version"] = new_version
|
||||
record["status"] = "active"
|
||||
record["key_added_at"] = now_iso
|
||||
return record
|
||||
|
||||
resolved_hostname = _normalize_host(hostname, guid_norm, cur)
|
||||
created_at = int(time.time())
|
||||
key_added_at = _iso(_now())
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO devices (
|
||||
guid, hostname, created_at, last_seen, ssl_key_fingerprint,
|
||||
token_version, status, key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
|
||||
""",
|
||||
(
|
||||
guid_norm,
|
||||
resolved_hostname,
|
||||
created_at,
|
||||
created_at,
|
||||
fingerprint,
|
||||
key_added_at,
|
||||
),
|
||||
)
|
||||
return {
|
||||
"guid": guid_norm,
|
||||
"hostname": resolved_hostname,
|
||||
"token_version": 1,
|
||||
"status": "active",
|
||||
"ssl_key_fingerprint": fingerprint,
|
||||
"key_added_at": key_added_at,
|
||||
}
|
||||
|
||||
def _hash_refresh_token(token: str) -> str:
|
||||
import hashlib
|
||||
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
def _issue_refresh_token(cur: sqlite3.Cursor, guid: str) -> Dict[str, Any]:
|
||||
token = secrets.token_urlsafe(48)
|
||||
now = _now()
|
||||
expires_at = now.replace(microsecond=0) + timedelta(days=30)
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO refresh_tokens (id, guid, token_hash, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
guid,
|
||||
_hash_refresh_token(token),
|
||||
_iso(now),
|
||||
_iso(expires_at),
|
||||
),
|
||||
)
|
||||
return {"token": token, "expires_at": expires_at}
|
||||
|
||||
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
|
||||
def enrollment_request():
|
||||
remote = _remote_addr()
|
||||
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||
|
||||
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0, context_hint)
|
||||
if rate_error:
|
||||
return rate_error
|
||||
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
hostname = str(payload.get("hostname") or "").strip()
|
||||
enrollment_code = str(payload.get("enrollment_code") or "").strip()
|
||||
agent_pubkey_b64 = payload.get("agent_pubkey")
|
||||
client_nonce_b64 = payload.get("client_nonce")
|
||||
|
||||
log(
|
||||
"server",
|
||||
"enrollment request received "
|
||||
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
|
||||
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
|
||||
context_hint,
|
||||
)
|
||||
|
||||
if not hostname:
|
||||
log("server", f"enrollment rejected missing_hostname ip={remote}", context_hint)
|
||||
return jsonify({"error": "hostname_required"}), 400
|
||||
if not enrollment_code:
|
||||
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "enrollment_code_required"}), 400
|
||||
if not isinstance(agent_pubkey_b64, str):
|
||||
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "agent_pubkey_required"}), 400
|
||||
if not isinstance(client_nonce_b64, str):
|
||||
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "client_nonce_required"}), 400
|
||||
|
||||
try:
|
||||
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
|
||||
except Exception:
|
||||
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||
|
||||
if len(agent_pubkey_der) < 10:
|
||||
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
if len(client_nonce_bytes) < 16:
|
||||
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
|
||||
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
||||
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0, context_hint)
|
||||
if rate_error:
|
||||
return rate_error
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
install_code = _load_install_code(cur, enrollment_code)
|
||||
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
|
||||
if not valid_code:
|
||||
log(
|
||||
"server",
|
||||
"enrollment request invalid_code "
|
||||
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"error": "invalid_enrollment_code"}), 400
|
||||
|
||||
approval_reference: str
|
||||
record_id: str
|
||||
server_nonce_bytes = secrets.token_bytes(32)
|
||||
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
|
||||
now = _iso(_now())
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, approval_reference
|
||||
FROM device_approvals
|
||||
WHERE ssl_key_fingerprint_claimed = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(fingerprint,),
|
||||
)
|
||||
existing = cur.fetchone()
|
||||
if existing:
|
||||
record_id = existing[0]
|
||||
approval_reference = existing[1]
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET hostname_claimed = ?,
|
||||
guid = ?,
|
||||
enrollment_code_id = ?,
|
||||
client_nonce = ?,
|
||||
server_nonce = ?,
|
||||
agent_pubkey_der = ?,
|
||||
updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
hostname,
|
||||
reuse_guid,
|
||||
install_code["id"],
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
now,
|
||||
record_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
record_id = str(uuid.uuid4())
|
||||
approval_reference = str(uuid.uuid4())
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO device_approvals (
|
||||
id, approval_reference, guid, hostname_claimed,
|
||||
ssl_key_fingerprint_claimed, enrollment_code_id,
|
||||
status, client_nonce, server_nonce, agent_pubkey_der,
|
||||
created_at, updated_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
record_id,
|
||||
approval_reference,
|
||||
reuse_guid,
|
||||
hostname,
|
||||
fingerprint,
|
||||
install_code["id"],
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
now,
|
||||
now,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
response = {
|
||||
"status": "pending",
|
||||
"approval_reference": approval_reference,
|
||||
"server_nonce": server_nonce_b64,
|
||||
"poll_after_ms": 3000,
|
||||
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
||||
"signing_key": _signing_key_b64(),
|
||||
}
|
||||
log(
|
||||
"server",
|
||||
f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify(response)
|
||||
|
||||
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
|
||||
def enrollment_poll():
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
approval_reference = payload.get("approval_reference")
|
||||
client_nonce_b64 = payload.get("client_nonce")
|
||||
proof_sig_b64 = payload.get("proof_sig")
|
||||
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||
|
||||
log(
|
||||
"server",
|
||||
"enrollment poll received "
|
||||
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
|
||||
f" proof_sig_len={len(proof_sig_b64 or '')}",
|
||||
context_hint,
|
||||
)
|
||||
|
||||
if not isinstance(approval_reference, str) or not approval_reference:
|
||||
log("server", "enrollment poll rejected missing_reference", context_hint)
|
||||
return jsonify({"error": "approval_reference_required"}), 400
|
||||
if not isinstance(client_nonce_b64, str):
|
||||
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "client_nonce_required"}), 400
|
||||
if not isinstance(proof_sig_b64, str):
|
||||
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "proof_sig_required"}), 400
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
|
||||
try:
|
||||
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "invalid_proof_sig"}), 400
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed,
|
||||
enrollment_code_id, status, client_nonce, server_nonce,
|
||||
agent_pubkey_der, created_at, updated_at, approved_by_user_id
|
||||
FROM device_approvals
|
||||
WHERE approval_reference = ?
|
||||
""",
|
||||
(approval_reference,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
log("server", f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
|
||||
return jsonify({"status": "unknown"}), 404
|
||||
|
||||
(
|
||||
record_id,
|
||||
guid,
|
||||
hostname_claimed,
|
||||
fingerprint,
|
||||
enrollment_code_id,
|
||||
status,
|
||||
client_nonce_stored,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
created_at,
|
||||
updated_at,
|
||||
approved_by,
|
||||
) = row
|
||||
|
||||
if client_nonce_stored != client_nonce_b64:
|
||||
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "nonce_mismatch"}), 400
|
||||
|
||||
try:
|
||||
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "server_nonce_invalid"}), 400
|
||||
|
||||
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
|
||||
|
||||
try:
|
||||
public_key = serialization.load_der_public_key(agent_pubkey_der)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
|
||||
public_key = None
|
||||
|
||||
if public_key is None:
|
||||
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "agent_pubkey_invalid"}), 400
|
||||
|
||||
try:
|
||||
public_key.verify(proof_sig, message)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "invalid_proof"}), 400
|
||||
|
||||
if status == "pending":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
|
||||
f" fingerprint={fingerprint[:12]}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": "pending", "poll_after_ms": 5000})
|
||||
if status == "denied":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": "denied", "reason": "operator_denied"})
|
||||
if status == "expired":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": "expired"})
|
||||
if status == "completed":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": "approved", "detail": "finalized"})
|
||||
|
||||
if status != "approved":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll unexpected_status={status} ref={approval_reference}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": status or "unknown"}), 400
|
||||
|
||||
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
|
||||
if not nonce_cache.consume(nonce_key):
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"error": "proof_replayed"}), 409
|
||||
|
||||
# Finalize enrollment
|
||||
effective_guid = normalize_guid(guid) if guid else normalize_guid(str(uuid.uuid4()))
|
||||
now_iso = _iso(_now())
|
||||
|
||||
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
|
||||
_store_device_key(cur, effective_guid, fingerprint)
|
||||
|
||||
# Mark install code used
|
||||
if enrollment_code_id:
|
||||
cur.execute(
|
||||
"SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?",
|
||||
(enrollment_code_id,),
|
||||
)
|
||||
usage_row = cur.fetchone()
|
||||
try:
|
||||
prior_count = int(usage_row[0]) if usage_row else 0
|
||||
except Exception:
|
||||
prior_count = 0
|
||||
try:
|
||||
allowed_uses = int(usage_row[1]) if usage_row else 1
|
||||
except Exception:
|
||||
allowed_uses = 1
|
||||
if allowed_uses < 1:
|
||||
allowed_uses = 1
|
||||
new_count = prior_count + 1
|
||||
consumed = new_count >= allowed_uses
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE enrollment_install_codes
|
||||
SET use_count = ?,
|
||||
used_by_guid = ?,
|
||||
last_used_at = ?,
|
||||
used_at = CASE WHEN ? THEN ? ELSE used_at END
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
new_count,
|
||||
effective_guid,
|
||||
now_iso,
|
||||
1 if consumed else 0,
|
||||
now_iso,
|
||||
enrollment_code_id,
|
||||
),
|
||||
)
|
||||
|
||||
# Update approval record with final state
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET guid = ?,
|
||||
status = 'completed',
|
||||
updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(effective_guid, now_iso, record_id),
|
||||
)
|
||||
|
||||
refresh_info = _issue_refresh_token(cur, effective_guid)
|
||||
access_token = jwt_service.issue_access_token(
|
||||
effective_guid,
|
||||
fingerprint,
|
||||
device_record.get("token_version") or 1,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
log(
|
||||
"server",
|
||||
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify(
|
||||
{
|
||||
"status": "approved",
|
||||
"guid": effective_guid,
|
||||
"access_token": access_token,
|
||||
"expires_in": 900,
|
||||
"refresh_token": refresh_info["token"],
|
||||
"token_type": "Bearer",
|
||||
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
||||
"signing_key": _signing_key_b64(),
|
||||
}
|
||||
)
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
def _load_tls_bundle(path: str) -> str:
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as fh:
|
||||
return fh.read()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _mask_code(code: str) -> str:
|
||||
if not code:
|
||||
return "<missing>"
|
||||
trimmed = str(code).strip()
|
||||
if len(trimmed) <= 6:
|
||||
return "***"
|
||||
return f"{trimmed[:3]}***{trimmed[-3:]}"
|
||||
|
||||
12
Data/Engine/services/API/tokens/__init__.py
Normal file
12
Data/Engine/services/API/tokens/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\API\tokens\__init__.py
|
||||
# Description: Token management API registration helpers for the Engine runtime.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Expose Engine-native token management routes."""
|
||||
|
||||
from .routes import register
|
||||
|
||||
__all__ = ["register"]
|
||||
147
Data/Engine/services/API/tokens/routes.py
Normal file
147
Data/Engine/services/API/tokens/routes.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\API\tokens\routes.py
|
||||
# Description: Engine-native refresh token endpoints decoupled from legacy server modules.
|
||||
#
|
||||
# API Endpoints (if applicable):
|
||||
# - POST /api/agent/token/refresh (Authenticated via refresh token) - Issues a new access token.
|
||||
# ======================================================
|
||||
|
||||
"""Token management routes backed by the Engine authentication stack."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable
|
||||
|
||||
from flask import Blueprint, current_app, jsonify, request
|
||||
|
||||
from ....auth.dpop import DPoPReplayError, DPoPValidator, DPoPVerificationError
|
||||
|
||||
|
||||
def register(
|
||||
app,
|
||||
*,
|
||||
db_conn_factory: Callable[[], sqlite3.Connection],
|
||||
jwt_service,
|
||||
dpop_validator: DPoPValidator,
|
||||
) -> None:
|
||||
blueprint = Blueprint("tokens", __name__)
|
||||
|
||||
def _hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
def _iso_now() -> str:
|
||||
return datetime.now(tz=timezone.utc).isoformat()
|
||||
|
||||
def _parse_iso(ts: str) -> datetime:
|
||||
return datetime.fromisoformat(ts)
|
||||
|
||||
@blueprint.route("/api/agent/token/refresh", methods=["POST"])
|
||||
def refresh():
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
guid = str(payload.get("guid") or "").strip()
|
||||
refresh_token = str(payload.get("refresh_token") or "").strip()
|
||||
|
||||
if not guid or not refresh_token:
|
||||
return jsonify({"error": "invalid_request"}), 400
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at
|
||||
FROM refresh_tokens
|
||||
WHERE guid = ?
|
||||
AND token_hash = ?
|
||||
""",
|
||||
(guid, _hash_token(refresh_token)),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return jsonify({"error": "invalid_refresh_token"}), 401
|
||||
|
||||
record_id, row_guid, _token_hash, stored_jkt, created_at, expires_at, revoked_at = row
|
||||
if row_guid != guid:
|
||||
return jsonify({"error": "invalid_refresh_token"}), 401
|
||||
if revoked_at:
|
||||
return jsonify({"error": "refresh_token_revoked"}), 401
|
||||
if expires_at:
|
||||
try:
|
||||
if _parse_iso(expires_at) <= datetime.now(tz=timezone.utc):
|
||||
return jsonify({"error": "refresh_token_expired"}), 401
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(guid,),
|
||||
)
|
||||
device_row = cur.fetchone()
|
||||
if not device_row:
|
||||
return jsonify({"error": "device_not_found"}), 404
|
||||
|
||||
device_guid, fingerprint, token_version, status = device_row
|
||||
status_norm = (status or "active").strip().lower()
|
||||
if status_norm in {"revoked", "decommissioned"}:
|
||||
return jsonify({"error": "device_revoked"}), 403
|
||||
|
||||
dpop_proof = request.headers.get("DPoP")
|
||||
jkt = stored_jkt or ""
|
||||
if dpop_proof:
|
||||
try:
|
||||
jkt = dpop_validator.verify(request.method, request.url, dpop_proof, access_token=None)
|
||||
except DPoPReplayError:
|
||||
return jsonify({"error": "dpop_replayed"}), 400
|
||||
except DPoPVerificationError:
|
||||
return jsonify({"error": "dpop_invalid"}), 400
|
||||
elif stored_jkt:
|
||||
try:
|
||||
current_app.logger.warning(
|
||||
"Clearing stored DPoP binding for guid=%s due to missing proof",
|
||||
guid,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
cur.execute(
|
||||
"UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?",
|
||||
(record_id,),
|
||||
)
|
||||
|
||||
new_access_token = jwt_service.issue_access_token(
|
||||
guid,
|
||||
fingerprint or "",
|
||||
token_version or 1,
|
||||
)
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET last_used_at = ?,
|
||||
dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt)
|
||||
WHERE id = ?
|
||||
""",
|
||||
(_iso_now(), jkt, record_id),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"access_token": new_access_token,
|
||||
"expires_in": 900,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
)
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
__all__ = ["register"]
|
||||
Reference in New Issue
Block a user