Implement Engine HTTP interfaces for health, enrollment, and tokens

This commit is contained in:
2025-10-22 13:33:15 -06:00
parent 7b5248dfe5
commit 9292cfb280
28 changed files with 1840 additions and 77 deletions

View File

@@ -10,6 +10,14 @@ from .auth import (
TokenRefreshErrorCode,
TokenService,
)
from .enrollment import (
EnrollmentRequestResult,
EnrollmentService,
EnrollmentStatus,
EnrollmentTokenBundle,
EnrollmentValidationError,
PollingResult,
)
__all__ = [
"DeviceAuthService",
@@ -18,4 +26,10 @@ __all__ = [
"TokenRefreshError",
"TokenRefreshErrorCode",
"TokenService",
"EnrollmentService",
"EnrollmentRequestResult",
"EnrollmentStatus",
"EnrollmentTokenBundle",
"EnrollmentValidationError",
"PollingResult",
]

View File

@@ -3,6 +3,8 @@
from __future__ import annotations
from .device_auth_service import DeviceAuthService, DeviceRecord
from .dpop import DPoPReplayError, DPoPVerificationError, DPoPValidator
from .jwt_service import JWTService, load_service as load_jwt_service
from .token_service import (
RefreshTokenRecord,
TokenRefreshError,
@@ -13,6 +15,11 @@ from .token_service import (
__all__ = [
"DeviceAuthService",
"DeviceRecord",
"DPoPReplayError",
"DPoPVerificationError",
"DPoPValidator",
"JWTService",
"load_jwt_service",
"RefreshTokenRecord",
"TokenRefreshError",
"TokenRefreshErrorCode",

View File

@@ -0,0 +1,105 @@
"""DPoP proof validation for Engine services."""
from __future__ import annotations
import hashlib
import time
from threading import Lock
from typing import Dict, Optional
import jwt
__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"]
_DP0P_MAX_SKEW = 300.0
class DPoPVerificationError(Exception):
pass
class DPoPReplayError(DPoPVerificationError):
pass
class DPoPValidator:
def __init__(self) -> None:
self._observed_jti: Dict[str, float] = {}
self._lock = Lock()
def verify(
self,
method: str,
htu: str,
proof: str,
access_token: Optional[str] = None,
) -> str:
if not proof:
raise DPoPVerificationError("DPoP proof missing")
try:
header = jwt.get_unverified_header(proof)
except Exception as exc:
raise DPoPVerificationError("invalid DPoP header") from exc
jwk = header.get("jwk")
alg = header.get("alg")
if not jwk or not isinstance(jwk, dict):
raise DPoPVerificationError("missing jwk in DPoP header")
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
try:
key = jwt.PyJWK(jwk)
public_key = key.key
except Exception as exc:
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
try:
claims = jwt.decode(
proof,
public_key,
algorithms=[alg],
options={"require": ["htm", "htu", "jti", "iat"]},
)
except Exception as exc:
raise DPoPVerificationError("invalid DPoP signature") from exc
htm = claims.get("htm")
proof_htu = claims.get("htu")
jti = claims.get("jti")
iat = claims.get("iat")
ath = claims.get("ath")
if not isinstance(htm, str) or htm.lower() != method.lower():
raise DPoPVerificationError("DPoP htm mismatch")
if not isinstance(proof_htu, str) or proof_htu != htu:
raise DPoPVerificationError("DPoP htu mismatch")
if not isinstance(jti, str):
raise DPoPVerificationError("DPoP jti missing")
if not isinstance(iat, (int, float)):
raise DPoPVerificationError("DPoP iat missing")
now = time.time()
if abs(now - float(iat)) > _DP0P_MAX_SKEW:
raise DPoPVerificationError("DPoP proof outside allowed skew")
if ath and access_token:
expected_ath = jwt.utils.base64url_encode(
hashlib.sha256(access_token.encode("utf-8")).digest()
).decode("ascii")
if expected_ath != ath:
raise DPoPVerificationError("DPoP ath mismatch")
with self._lock:
expiry = self._observed_jti.get(jti)
if expiry and expiry > now:
raise DPoPReplayError("DPoP proof replay detected")
self._observed_jti[jti] = now + _DP0P_MAX_SKEW
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
for key in stale:
self._observed_jti.pop(key, None)
thumbprint = jwt.PyJWK(jwk).thumbprint()
return thumbprint.decode("ascii")

View File

@@ -0,0 +1,124 @@
"""JWT issuance utilities for the Engine."""
from __future__ import annotations
import hashlib
import time
from typing import Any, Dict, Optional
import jwt
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from Data.Engine.runtime import ensure_runtime_dir, runtime_path
__all__ = ["JWTService", "load_service"]
_KEY_DIR = runtime_path("auth_keys")
_KEY_FILE = _KEY_DIR / "engine-jwt-ed25519.key"
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-jwt-ed25519.key"
class JWTService:
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str) -> None:
self._private_key = private_key
self._public_key = private_key.public_key()
self._key_id = key_id
@property
def key_id(self) -> str:
return self._key_id
def issue_access_token(
self,
guid: str,
ssl_key_fingerprint: str,
token_version: int,
expires_in: int = 900,
extra_claims: Optional[Dict[str, Any]] = None,
) -> str:
now = int(time.time())
payload: Dict[str, Any] = {
"sub": f"device:{guid}",
"guid": guid,
"ssl_key_fingerprint": ssl_key_fingerprint,
"token_version": int(token_version),
"iat": now,
"nbf": now,
"exp": now + int(expires_in),
}
if extra_claims:
payload.update(extra_claims)
token = jwt.encode(
payload,
self._private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
),
algorithm="EdDSA",
headers={"kid": self._key_id},
)
return token
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
options = {"require": ["exp", "iat", "sub"]}
public_pem = self._public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return jwt.decode(
token,
public_pem,
algorithms=["EdDSA"],
audience=audience,
options=options,
)
def public_jwk(self) -> Dict[str, Any]:
public_bytes = self._public_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
def load_service() -> JWTService:
private_key = _load_or_create_private_key()
public_bytes = private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
return JWTService(private_key, key_id)
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
ensure_runtime_dir("auth_keys")
if _KEY_FILE.exists():
with _KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
if _LEGACY_KEY_FILE.exists():
with _LEGACY_KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
private_key = ed25519.Ed25519PrivateKey.generate()
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
_KEY_DIR.mkdir(parents=True, exist_ok=True)
with _KEY_FILE.open("wb") as fh:
fh.write(pem)
try:
if hasattr(_KEY_FILE, "chmod"):
_KEY_FILE.chmod(0o600)
except Exception:
pass
return private_key

View File

@@ -0,0 +1,119 @@
"""Service container assembly for the Borealis Engine."""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Optional
from Data.Engine.config import EngineSettings
from Data.Engine.repositories.sqlite import (
SQLiteConnectionFactory,
SQLiteDeviceRepository,
SQLiteEnrollmentRepository,
SQLiteRefreshTokenRepository,
)
from Data.Engine.services.auth import (
DeviceAuthService,
DPoPValidator,
JWTService,
TokenService,
load_jwt_service,
)
from Data.Engine.services.crypto.signing import ScriptSigner, load_signer
from Data.Engine.services.enrollment import EnrollmentService
from Data.Engine.services.enrollment.nonce_cache import NonceCache
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
__all__ = ["EngineServiceContainer", "build_service_container"]
@dataclass(frozen=True, slots=True)
class EngineServiceContainer:
device_auth: DeviceAuthService
token_service: TokenService
enrollment_service: EnrollmentService
jwt_service: JWTService
dpop_validator: DPoPValidator
def build_service_container(
settings: EngineSettings,
*,
db_factory: SQLiteConnectionFactory,
logger: Optional[logging.Logger] = None,
) -> EngineServiceContainer:
log = logger or logging.getLogger("borealis.engine.services")
device_repo = SQLiteDeviceRepository(db_factory, logger=log.getChild("devices"))
token_repo = SQLiteRefreshTokenRepository(db_factory, logger=log.getChild("tokens"))
enrollment_repo = SQLiteEnrollmentRepository(db_factory, logger=log.getChild("enrollment"))
jwt_service = load_jwt_service()
dpop_validator = DPoPValidator()
rate_limiter = SlidingWindowRateLimiter()
token_service = TokenService(
refresh_token_repository=token_repo,
device_repository=device_repo,
jwt_service=jwt_service,
dpop_validator=dpop_validator,
logger=log.getChild("token_service"),
)
enrollment_service = EnrollmentService(
device_repository=device_repo,
enrollment_repository=enrollment_repo,
token_repository=token_repo,
jwt_service=jwt_service,
tls_bundle_loader=_tls_bundle_loader(settings),
ip_rate_limiter=SlidingWindowRateLimiter(),
fingerprint_rate_limiter=SlidingWindowRateLimiter(),
nonce_cache=NonceCache(),
script_signer=_load_script_signer(log),
logger=log.getChild("enrollment"),
)
device_auth = DeviceAuthService(
device_repository=device_repo,
jwt_service=jwt_service,
logger=log.getChild("device_auth"),
rate_limiter=rate_limiter,
dpop_validator=dpop_validator,
)
return EngineServiceContainer(
device_auth=device_auth,
token_service=token_service,
enrollment_service=enrollment_service,
jwt_service=jwt_service,
dpop_validator=dpop_validator,
)
def _tls_bundle_loader(settings: EngineSettings) -> Callable[[], str]:
candidates = [
Path(os.getenv("BOREALIS_TLS_BUNDLE", "")),
settings.project_root / "Certificates" / "Server" / "borealis-server-bundle.pem",
]
def loader() -> str:
for candidate in candidates:
if candidate and candidate.is_file():
try:
return candidate.read_text(encoding="utf-8")
except Exception:
continue
return ""
return loader
def _load_script_signer(logger: logging.Logger) -> Optional[ScriptSigner]:
try:
return load_signer()
except Exception as exc:
logger.warning("script signer unavailable: %s", exc)
return None

View File

@@ -0,0 +1,75 @@
"""Script signing utilities for the Engine."""
from __future__ import annotations
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from Data.Engine.integrations.crypto.keys import base64_from_spki_der
from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path
__all__ = ["ScriptSigner", "load_signer"]
_KEY_DIR = server_certificates_path("Code-Signing")
_SIGNING_KEY_FILE = _KEY_DIR / "engine-script-ed25519.key"
_SIGNING_PUB_FILE = _KEY_DIR / "engine-script-ed25519.pub"
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-script-ed25519.key"
_LEGACY_PUB_FILE = runtime_path("keys") / "borealis-script-ed25519.pub"
class ScriptSigner:
def __init__(self, private_key: ed25519.Ed25519PrivateKey) -> None:
self._private = private_key
self._public = private_key.public_key()
def sign(self, payload: bytes) -> bytes:
return self._private.sign(payload)
def public_spki_der(self) -> bytes:
return self._public.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
def public_base64_spki(self) -> str:
return base64_from_spki_der(self.public_spki_der())
def load_signer() -> ScriptSigner:
private_key = _load_or_create()
return ScriptSigner(private_key)
def _load_or_create() -> ed25519.Ed25519PrivateKey:
ensure_server_certificates_dir("Code-Signing")
if _SIGNING_KEY_FILE.exists():
with _SIGNING_KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
if _LEGACY_KEY_FILE.exists():
with _LEGACY_KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
private_key = ed25519.Ed25519PrivateKey.generate()
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
_KEY_DIR.mkdir(parents=True, exist_ok=True)
_SIGNING_KEY_FILE.write_bytes(pem)
try:
if hasattr(_SIGNING_KEY_FILE, "chmod"):
_SIGNING_KEY_FILE.chmod(0o600)
except Exception:
pass
pub_der = private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
_SIGNING_PUB_FILE.write_bytes(pub_der)
return private_key

View File

@@ -0,0 +1,21 @@
"""Enrollment services for the Borealis Engine."""
from __future__ import annotations
from .enrollment_service import (
EnrollmentRequestResult,
EnrollmentService,
EnrollmentStatus,
EnrollmentTokenBundle,
EnrollmentValidationError,
PollingResult,
)
__all__ = [
"EnrollmentRequestResult",
"EnrollmentService",
"EnrollmentStatus",
"EnrollmentTokenBundle",
"EnrollmentValidationError",
"PollingResult",
]

View File

@@ -0,0 +1,487 @@
"""Enrollment workflow orchestration for the Borealis Engine."""
from __future__ import annotations
import base64
import hashlib
import logging
import secrets
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Callable, Optional, Protocol
from cryptography.hazmat.primitives import serialization
from Data.Engine.builders.device_enrollment import EnrollmentRequestInput
from Data.Engine.domain.device_auth import (
DeviceFingerprint,
DeviceGuid,
sanitize_service_context,
)
from Data.Engine.domain.device_enrollment import (
EnrollmentApproval,
EnrollmentApprovalStatus,
EnrollmentCode,
)
from Data.Engine.services.auth.device_auth_service import DeviceRecord
from Data.Engine.services.auth.token_service import JWTIssuer
from Data.Engine.services.enrollment.errors import EnrollmentValidationError
from Data.Engine.services.enrollment.nonce_cache import NonceCache
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
__all__ = [
"EnrollmentRequestResult",
"EnrollmentService",
"EnrollmentStatus",
"EnrollmentTokenBundle",
"PollingResult",
]
class DeviceRepository(Protocol):
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol
...
def ensure_device_record(
self,
*,
guid: DeviceGuid,
hostname: str,
fingerprint: DeviceFingerprint,
) -> DeviceRecord: # pragma: no cover - protocol
...
def record_device_key(
self,
*,
guid: DeviceGuid,
fingerprint: DeviceFingerprint,
added_at: datetime,
) -> None: # pragma: no cover - protocol
...
class EnrollmentRepository(Protocol):
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
...
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
...
def update_install_code_usage(
self,
record_id: str,
*,
use_count_increment: int,
last_used_at: datetime,
used_by_guid: Optional[DeviceGuid] = None,
mark_first_use: bool = False,
) -> None: # pragma: no cover - protocol
...
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
...
def fetch_pending_approval_by_fingerprint(
self, fingerprint: DeviceFingerprint
) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
...
def update_pending_approval(
self,
record_id: str,
*,
hostname: str,
guid: Optional[DeviceGuid],
enrollment_code_id: Optional[str],
client_nonce_b64: str,
server_nonce_b64: str,
agent_pubkey_der: bytes,
updated_at: datetime,
) -> None: # pragma: no cover - protocol
...
def create_device_approval(
self,
*,
record_id: str,
reference: str,
claimed_hostname: str,
claimed_fingerprint: DeviceFingerprint,
enrollment_code_id: Optional[str],
client_nonce_b64: str,
server_nonce_b64: str,
agent_pubkey_der: bytes,
created_at: datetime,
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
guid: Optional[DeviceGuid] = None,
) -> EnrollmentApproval: # pragma: no cover - protocol
...
def update_device_approval_status(
self,
record_id: str,
*,
status: EnrollmentApprovalStatus,
updated_at: datetime,
approved_by: Optional[str] = None,
guid: Optional[DeviceGuid] = None,
) -> None: # pragma: no cover - protocol
...
class RefreshTokenRepository(Protocol):
def create(
self,
*,
record_id: str,
guid: DeviceGuid,
token_hash: str,
created_at: datetime,
expires_at: Optional[datetime],
) -> None: # pragma: no cover - protocol
...
class ScriptSigner(Protocol):
def public_base64_spki(self) -> str: # pragma: no cover - protocol
...
class EnrollmentStatus(str):
PENDING = "pending"
APPROVED = "approved"
DENIED = "denied"
EXPIRED = "expired"
UNKNOWN = "unknown"
@dataclass(frozen=True, slots=True)
class EnrollmentTokenBundle:
guid: DeviceGuid
access_token: str
refresh_token: str
expires_in: int
token_type: str = "Bearer"
@dataclass(frozen=True, slots=True)
class EnrollmentRequestResult:
status: EnrollmentStatus
approval_reference: Optional[str] = None
server_nonce: Optional[str] = None
poll_after_ms: Optional[int] = None
server_certificate: str
signing_key: str
http_status: int = 200
retry_after: Optional[float] = None
@dataclass(frozen=True, slots=True)
class PollingResult:
status: EnrollmentStatus
http_status: int
poll_after_ms: Optional[int] = None
reason: Optional[str] = None
detail: Optional[str] = None
tokens: Optional[EnrollmentTokenBundle] = None
server_certificate: Optional[str] = None
signing_key: Optional[str] = None
class EnrollmentService:
"""Coordinate the Borealis device enrollment handshake."""
def __init__(
self,
*,
device_repository: DeviceRepository,
enrollment_repository: EnrollmentRepository,
token_repository: RefreshTokenRepository,
jwt_service: JWTIssuer,
tls_bundle_loader: Callable[[], str],
ip_rate_limiter: SlidingWindowRateLimiter,
fingerprint_rate_limiter: SlidingWindowRateLimiter,
nonce_cache: NonceCache,
script_signer: Optional[ScriptSigner] = None,
logger: Optional[logging.Logger] = None,
) -> None:
self._devices = device_repository
self._enrollment = enrollment_repository
self._tokens = token_repository
self._jwt = jwt_service
self._load_tls_bundle = tls_bundle_loader
self._ip_rate_limiter = ip_rate_limiter
self._fp_rate_limiter = fingerprint_rate_limiter
self._nonce_cache = nonce_cache
self._signer = script_signer
self._log = logger or logging.getLogger("borealis.engine.enrollment")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def request_enrollment(
self,
payload: EnrollmentRequestInput,
*,
remote_addr: str,
) -> EnrollmentRequestResult:
context_hint = sanitize_service_context(payload.service_context)
self._log.info(
"enrollment-request ip=%s host=%s code_mask=%s", remote_addr, payload.hostname, self._mask_code(payload.enrollment_code)
)
self._enforce_rate_limit(self._ip_rate_limiter, f"ip:{remote_addr}")
self._enforce_rate_limit(self._fp_rate_limiter, f"fp:{payload.fingerprint.value}")
install_code = self._enrollment.fetch_install_code(payload.enrollment_code)
reuse_guid = self._determine_reuse_guid(install_code, payload.fingerprint)
server_nonce_bytes = secrets.token_bytes(32)
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
now = self._now()
approval = self._enrollment.fetch_pending_approval_by_fingerprint(payload.fingerprint)
if approval:
self._enrollment.update_pending_approval(
approval.record_id,
hostname=payload.hostname,
guid=reuse_guid,
enrollment_code_id=install_code.identifier if install_code else None,
client_nonce_b64=payload.client_nonce_b64,
server_nonce_b64=server_nonce_b64,
agent_pubkey_der=payload.agent_public_key_der,
updated_at=now,
)
approval_reference = approval.reference
else:
record_id = str(uuid.uuid4())
approval_reference = str(uuid.uuid4())
approval = self._enrollment.create_device_approval(
record_id=record_id,
reference=approval_reference,
claimed_hostname=payload.hostname,
claimed_fingerprint=payload.fingerprint,
enrollment_code_id=install_code.identifier if install_code else None,
client_nonce_b64=payload.client_nonce_b64,
server_nonce_b64=server_nonce_b64,
agent_pubkey_der=payload.agent_public_key_der,
created_at=now,
guid=reuse_guid,
)
signing_key = self._signer.public_base64_spki() if self._signer else ""
certificate = self._load_tls_bundle()
return EnrollmentRequestResult(
status=EnrollmentStatus.PENDING,
approval_reference=approval.reference,
server_nonce=server_nonce_b64,
poll_after_ms=3000,
server_certificate=certificate,
signing_key=signing_key,
)
def poll_enrollment(
self,
*,
approval_reference: str,
client_nonce_b64: str,
proof_signature_b64: str,
) -> PollingResult:
if not approval_reference:
raise EnrollmentValidationError("approval_reference_required")
if not client_nonce_b64:
raise EnrollmentValidationError("client_nonce_required")
if not proof_signature_b64:
raise EnrollmentValidationError("proof_sig_required")
approval = self._enrollment.fetch_device_approval_by_reference(approval_reference)
if approval is None:
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=404)
client_nonce = self._decode_base64(client_nonce_b64, "invalid_client_nonce")
server_nonce = self._decode_base64(approval.server_nonce_b64, "server_nonce_invalid")
proof_sig = self._decode_base64(proof_signature_b64, "invalid_proof_sig")
if approval.client_nonce_b64 != client_nonce_b64:
raise EnrollmentValidationError("nonce_mismatch")
self._verify_proof_signature(
approval=approval,
client_nonce=client_nonce,
server_nonce=server_nonce,
signature=proof_sig,
)
status = approval.status
if status is EnrollmentApprovalStatus.PENDING:
return PollingResult(
status=EnrollmentStatus.PENDING,
http_status=200,
poll_after_ms=5000,
)
if status is EnrollmentApprovalStatus.DENIED:
return PollingResult(
status=EnrollmentStatus.DENIED,
http_status=200,
reason="operator_denied",
)
if status is EnrollmentApprovalStatus.EXPIRED:
return PollingResult(status=EnrollmentStatus.EXPIRED, http_status=200)
if status is EnrollmentApprovalStatus.COMPLETED:
return PollingResult(
status=EnrollmentStatus.APPROVED,
http_status=200,
detail="finalized",
)
if status is not EnrollmentApprovalStatus.APPROVED:
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=400)
nonce_key = f"{approval.reference}:{proof_signature_b64}"
if not self._nonce_cache.consume(nonce_key):
raise EnrollmentValidationError("proof_replayed", http_status=409)
token_bundle = self._finalize_approval(approval)
signing_key = self._signer.public_base64_spki() if self._signer else ""
certificate = self._load_tls_bundle()
return PollingResult(
status=EnrollmentStatus.APPROVED,
http_status=200,
tokens=token_bundle,
server_certificate=certificate,
signing_key=signing_key,
)
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
def _enforce_rate_limit(
self,
limiter: SlidingWindowRateLimiter,
key: str,
*,
limit: int = 60,
window_seconds: float = 60.0,
) -> None:
decision = limiter.check(key, limit, window_seconds)
if not decision.allowed:
raise EnrollmentValidationError(
"rate_limited", http_status=429, retry_after=max(decision.retry_after, 1.0)
)
def _determine_reuse_guid(
self,
install_code: Optional[EnrollmentCode],
fingerprint: DeviceFingerprint,
) -> Optional[DeviceGuid]:
if install_code is None:
raise EnrollmentValidationError("invalid_enrollment_code")
if install_code.is_expired:
raise EnrollmentValidationError("invalid_enrollment_code")
if not install_code.is_exhausted:
return None
if not install_code.used_by_guid:
raise EnrollmentValidationError("invalid_enrollment_code")
existing = self._devices.fetch_by_guid(install_code.used_by_guid)
if existing and existing.identity.fingerprint.value == fingerprint.value:
return install_code.used_by_guid
raise EnrollmentValidationError("invalid_enrollment_code")
def _finalize_approval(self, approval: EnrollmentApproval) -> EnrollmentTokenBundle:
now = self._now()
effective_guid = approval.guid or DeviceGuid(str(uuid.uuid4()))
device_record = self._devices.ensure_device_record(
guid=effective_guid,
hostname=approval.claimed_hostname,
fingerprint=approval.claimed_fingerprint,
)
self._devices.record_device_key(
guid=effective_guid,
fingerprint=approval.claimed_fingerprint,
added_at=now,
)
if approval.enrollment_code_id:
code = self._enrollment.fetch_install_code_by_id(approval.enrollment_code_id)
if code is not None:
mark_first = code.used_at is None
self._enrollment.update_install_code_usage(
approval.enrollment_code_id,
use_count_increment=1,
last_used_at=now,
used_by_guid=effective_guid,
mark_first_use=mark_first,
)
refresh_token = secrets.token_urlsafe(48)
refresh_id = str(uuid.uuid4())
expires_at = now + timedelta(days=30)
token_hash = hashlib.sha256(refresh_token.encode("utf-8")).hexdigest()
self._tokens.create(
record_id=refresh_id,
guid=effective_guid,
token_hash=token_hash,
created_at=now,
expires_at=expires_at,
)
access_token = self._jwt.issue_access_token(
effective_guid.value,
device_record.identity.fingerprint.value,
max(device_record.token_version, 1),
)
self._enrollment.update_device_approval_status(
approval.record_id,
status=EnrollmentApprovalStatus.COMPLETED,
updated_at=now,
guid=effective_guid,
)
return EnrollmentTokenBundle(
guid=effective_guid,
access_token=access_token,
refresh_token=refresh_token,
expires_in=900,
)
def _verify_proof_signature(
self,
*,
approval: EnrollmentApproval,
client_nonce: bytes,
server_nonce: bytes,
signature: bytes,
) -> None:
message = server_nonce + approval.reference.encode("utf-8") + client_nonce
try:
public_key = serialization.load_der_public_key(approval.agent_pubkey_der)
except Exception as exc:
raise EnrollmentValidationError("agent_pubkey_invalid") from exc
try:
public_key.verify(signature, message)
except Exception as exc:
raise EnrollmentValidationError("invalid_proof") from exc
@staticmethod
def _decode_base64(value: str, error_code: str) -> bytes:
try:
return base64.b64decode(value, validate=True)
except Exception as exc:
raise EnrollmentValidationError(error_code) from exc
@staticmethod
def _mask_code(code: str) -> str:
trimmed = (code or "").strip()
if len(trimmed) <= 6:
return "***"
return f"{trimmed[:3]}***{trimmed[-3:]}"
@staticmethod
def _now() -> datetime:
return datetime.now(tz=timezone.utc)

View File

@@ -0,0 +1,26 @@
"""Error types shared across enrollment components."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
__all__ = ["EnrollmentValidationError"]
@dataclass(frozen=True, slots=True)
class EnrollmentValidationError(Exception):
"""Raised when enrollment input fails validation."""
code: str
http_status: int = 400
retry_after: Optional[float] = None
def to_response(self) -> dict[str, object]:
payload: dict[str, object] = {"error": self.code}
if self.retry_after is not None:
payload["retry_after"] = self.retry_after
return payload
def __str__(self) -> str: # pragma: no cover - debug helper
return f"{self.code} (status={self.http_status})"

View File

@@ -0,0 +1,32 @@
"""Nonce replay protection for enrollment workflows."""
from __future__ import annotations
import time
from threading import Lock
from typing import Dict
__all__ = ["NonceCache"]
class NonceCache:
"""Track recently observed nonces to prevent replay."""
def __init__(self, ttl_seconds: float = 300.0) -> None:
self._ttl = ttl_seconds
self._entries: Dict[str, float] = {}
self._lock = Lock()
def consume(self, key: str) -> bool:
"""Consume *key* if it has not been seen recently."""
now = time.monotonic()
with self._lock:
expiry = self._entries.get(key)
if expiry and expiry > now:
return False
self._entries[key] = now + self._ttl
stale = [nonce for nonce, ttl in self._entries.items() if ttl <= now]
for nonce in stale:
self._entries.pop(nonce, None)
return True

View File

@@ -0,0 +1,45 @@
"""In-process rate limiting utilities for the Borealis Engine."""
from __future__ import annotations
import time
from collections import deque
from dataclasses import dataclass
from threading import Lock
from typing import Deque, Dict
__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"]
@dataclass(frozen=True, slots=True)
class RateLimitDecision:
"""Result of a rate limit check."""
allowed: bool
retry_after: float
class SlidingWindowRateLimiter:
"""Tiny in-memory sliding window limiter suitable for single-process use."""
def __init__(self) -> None:
self._buckets: Dict[str, Deque[float]] = {}
self._lock = Lock()
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
now = time.monotonic()
with self._lock:
bucket = self._buckets.get(key)
if bucket is None:
bucket = deque()
self._buckets[key] = bucket
while bucket and now - bucket[0] > window_seconds:
bucket.popleft()
if len(bucket) >= limit:
retry_after = max(0.0, window_seconds - (now - bucket[0]))
return RateLimitDecision(False, retry_after)
bucket.append(now)
return RateLimitDecision(True, 0.0)