mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 22:01:59 -06:00
Implement Engine HTTP interfaces for health, enrollment, and tokens
This commit is contained in:
@@ -10,6 +10,14 @@ from .auth import (
|
||||
TokenRefreshErrorCode,
|
||||
TokenService,
|
||||
)
|
||||
from .enrollment import (
|
||||
EnrollmentRequestResult,
|
||||
EnrollmentService,
|
||||
EnrollmentStatus,
|
||||
EnrollmentTokenBundle,
|
||||
EnrollmentValidationError,
|
||||
PollingResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
@@ -18,4 +26,10 @@ __all__ = [
|
||||
"TokenRefreshError",
|
||||
"TokenRefreshErrorCode",
|
||||
"TokenService",
|
||||
"EnrollmentService",
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"EnrollmentValidationError",
|
||||
"PollingResult",
|
||||
]
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .device_auth_service import DeviceAuthService, DeviceRecord
|
||||
from .dpop import DPoPReplayError, DPoPVerificationError, DPoPValidator
|
||||
from .jwt_service import JWTService, load_service as load_jwt_service
|
||||
from .token_service import (
|
||||
RefreshTokenRecord,
|
||||
TokenRefreshError,
|
||||
@@ -13,6 +15,11 @@ from .token_service import (
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
"DeviceRecord",
|
||||
"DPoPReplayError",
|
||||
"DPoPVerificationError",
|
||||
"DPoPValidator",
|
||||
"JWTService",
|
||||
"load_jwt_service",
|
||||
"RefreshTokenRecord",
|
||||
"TokenRefreshError",
|
||||
"TokenRefreshErrorCode",
|
||||
|
||||
105
Data/Engine/services/auth/dpop.py
Normal file
105
Data/Engine/services/auth/dpop.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""DPoP proof validation for Engine services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict, Optional
|
||||
|
||||
import jwt
|
||||
|
||||
__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"]
|
||||
|
||||
|
||||
_DP0P_MAX_SKEW = 300.0
|
||||
|
||||
|
||||
class DPoPVerificationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DPoPReplayError(DPoPVerificationError):
|
||||
pass
|
||||
|
||||
|
||||
class DPoPValidator:
|
||||
def __init__(self) -> None:
|
||||
self._observed_jti: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def verify(
|
||||
self,
|
||||
method: str,
|
||||
htu: str,
|
||||
proof: str,
|
||||
access_token: Optional[str] = None,
|
||||
) -> str:
|
||||
if not proof:
|
||||
raise DPoPVerificationError("DPoP proof missing")
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(proof)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP header") from exc
|
||||
|
||||
jwk = header.get("jwk")
|
||||
alg = header.get("alg")
|
||||
if not jwk or not isinstance(jwk, dict):
|
||||
raise DPoPVerificationError("missing jwk in DPoP header")
|
||||
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
|
||||
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
|
||||
|
||||
try:
|
||||
key = jwt.PyJWK(jwk)
|
||||
public_key = key.key
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
|
||||
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
proof,
|
||||
public_key,
|
||||
algorithms=[alg],
|
||||
options={"require": ["htm", "htu", "jti", "iat"]},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP signature") from exc
|
||||
|
||||
htm = claims.get("htm")
|
||||
proof_htu = claims.get("htu")
|
||||
jti = claims.get("jti")
|
||||
iat = claims.get("iat")
|
||||
ath = claims.get("ath")
|
||||
|
||||
if not isinstance(htm, str) or htm.lower() != method.lower():
|
||||
raise DPoPVerificationError("DPoP htm mismatch")
|
||||
if not isinstance(proof_htu, str) or proof_htu != htu:
|
||||
raise DPoPVerificationError("DPoP htu mismatch")
|
||||
if not isinstance(jti, str):
|
||||
raise DPoPVerificationError("DPoP jti missing")
|
||||
if not isinstance(iat, (int, float)):
|
||||
raise DPoPVerificationError("DPoP iat missing")
|
||||
|
||||
now = time.time()
|
||||
if abs(now - float(iat)) > _DP0P_MAX_SKEW:
|
||||
raise DPoPVerificationError("DPoP proof outside allowed skew")
|
||||
|
||||
if ath and access_token:
|
||||
expected_ath = jwt.utils.base64url_encode(
|
||||
hashlib.sha256(access_token.encode("utf-8")).digest()
|
||||
).decode("ascii")
|
||||
if expected_ath != ath:
|
||||
raise DPoPVerificationError("DPoP ath mismatch")
|
||||
|
||||
with self._lock:
|
||||
expiry = self._observed_jti.get(jti)
|
||||
if expiry and expiry > now:
|
||||
raise DPoPReplayError("DPoP proof replay detected")
|
||||
self._observed_jti[jti] = now + _DP0P_MAX_SKEW
|
||||
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
|
||||
for key in stale:
|
||||
self._observed_jti.pop(key, None)
|
||||
|
||||
thumbprint = jwt.PyJWK(jwk).thumbprint()
|
||||
return thumbprint.decode("ascii")
|
||||
124
Data/Engine/services/auth/jwt_service.py
Normal file
124
Data/Engine/services/auth/jwt_service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""JWT issuance utilities for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from Data.Engine.runtime import ensure_runtime_dir, runtime_path
|
||||
|
||||
__all__ = ["JWTService", "load_service"]
|
||||
|
||||
|
||||
_KEY_DIR = runtime_path("auth_keys")
|
||||
_KEY_FILE = _KEY_DIR / "engine-jwt-ed25519.key"
|
||||
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-jwt-ed25519.key"
|
||||
|
||||
|
||||
class JWTService:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str) -> None:
|
||||
self._private_key = private_key
|
||||
self._public_key = private_key.public_key()
|
||||
self._key_id = key_id
|
||||
|
||||
@property
|
||||
def key_id(self) -> str:
|
||||
return self._key_id
|
||||
|
||||
def issue_access_token(
|
||||
self,
|
||||
guid: str,
|
||||
ssl_key_fingerprint: str,
|
||||
token_version: int,
|
||||
expires_in: int = 900,
|
||||
extra_claims: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
now = int(time.time())
|
||||
payload: Dict[str, Any] = {
|
||||
"sub": f"device:{guid}",
|
||||
"guid": guid,
|
||||
"ssl_key_fingerprint": ssl_key_fingerprint,
|
||||
"token_version": int(token_version),
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": now + int(expires_in),
|
||||
}
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
self._private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
),
|
||||
algorithm="EdDSA",
|
||||
headers={"kid": self._key_id},
|
||||
)
|
||||
return token
|
||||
|
||||
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
|
||||
options = {"require": ["exp", "iat", "sub"]}
|
||||
public_pem = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=["EdDSA"],
|
||||
audience=audience,
|
||||
options=options,
|
||||
)
|
||||
|
||||
def public_jwk(self) -> Dict[str, Any]:
|
||||
public_bytes = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
|
||||
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
|
||||
|
||||
|
||||
def load_service() -> JWTService:
|
||||
private_key = _load_or_create_private_key()
|
||||
public_bytes = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
|
||||
return JWTService(private_key, key_id)
|
||||
|
||||
|
||||
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
|
||||
ensure_runtime_dir("auth_keys")
|
||||
|
||||
if _KEY_FILE.exists():
|
||||
with _KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
if _LEGACY_KEY_FILE.exists():
|
||||
with _LEGACY_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
_KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with _KEY_FILE.open("wb") as fh:
|
||||
fh.write(pem)
|
||||
try:
|
||||
if hasattr(_KEY_FILE, "chmod"):
|
||||
_KEY_FILE.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
return private_key
|
||||
119
Data/Engine/services/container.py
Normal file
119
Data/Engine/services/container.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Service container assembly for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from Data.Engine.config import EngineSettings
|
||||
from Data.Engine.repositories.sqlite import (
|
||||
SQLiteConnectionFactory,
|
||||
SQLiteDeviceRepository,
|
||||
SQLiteEnrollmentRepository,
|
||||
SQLiteRefreshTokenRepository,
|
||||
)
|
||||
from Data.Engine.services.auth import (
|
||||
DeviceAuthService,
|
||||
DPoPValidator,
|
||||
JWTService,
|
||||
TokenService,
|
||||
load_jwt_service,
|
||||
)
|
||||
from Data.Engine.services.crypto.signing import ScriptSigner, load_signer
|
||||
from Data.Engine.services.enrollment import EnrollmentService
|
||||
from Data.Engine.services.enrollment.nonce_cache import NonceCache
|
||||
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
|
||||
|
||||
__all__ = ["EngineServiceContainer", "build_service_container"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EngineServiceContainer:
|
||||
device_auth: DeviceAuthService
|
||||
token_service: TokenService
|
||||
enrollment_service: EnrollmentService
|
||||
jwt_service: JWTService
|
||||
dpop_validator: DPoPValidator
|
||||
|
||||
|
||||
def build_service_container(
|
||||
settings: EngineSettings,
|
||||
*,
|
||||
db_factory: SQLiteConnectionFactory,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> EngineServiceContainer:
|
||||
log = logger or logging.getLogger("borealis.engine.services")
|
||||
|
||||
device_repo = SQLiteDeviceRepository(db_factory, logger=log.getChild("devices"))
|
||||
token_repo = SQLiteRefreshTokenRepository(db_factory, logger=log.getChild("tokens"))
|
||||
enrollment_repo = SQLiteEnrollmentRepository(db_factory, logger=log.getChild("enrollment"))
|
||||
|
||||
jwt_service = load_jwt_service()
|
||||
dpop_validator = DPoPValidator()
|
||||
rate_limiter = SlidingWindowRateLimiter()
|
||||
|
||||
token_service = TokenService(
|
||||
refresh_token_repository=token_repo,
|
||||
device_repository=device_repo,
|
||||
jwt_service=jwt_service,
|
||||
dpop_validator=dpop_validator,
|
||||
logger=log.getChild("token_service"),
|
||||
)
|
||||
|
||||
enrollment_service = EnrollmentService(
|
||||
device_repository=device_repo,
|
||||
enrollment_repository=enrollment_repo,
|
||||
token_repository=token_repo,
|
||||
jwt_service=jwt_service,
|
||||
tls_bundle_loader=_tls_bundle_loader(settings),
|
||||
ip_rate_limiter=SlidingWindowRateLimiter(),
|
||||
fingerprint_rate_limiter=SlidingWindowRateLimiter(),
|
||||
nonce_cache=NonceCache(),
|
||||
script_signer=_load_script_signer(log),
|
||||
logger=log.getChild("enrollment"),
|
||||
)
|
||||
|
||||
device_auth = DeviceAuthService(
|
||||
device_repository=device_repo,
|
||||
jwt_service=jwt_service,
|
||||
logger=log.getChild("device_auth"),
|
||||
rate_limiter=rate_limiter,
|
||||
dpop_validator=dpop_validator,
|
||||
)
|
||||
|
||||
return EngineServiceContainer(
|
||||
device_auth=device_auth,
|
||||
token_service=token_service,
|
||||
enrollment_service=enrollment_service,
|
||||
jwt_service=jwt_service,
|
||||
dpop_validator=dpop_validator,
|
||||
)
|
||||
|
||||
|
||||
def _tls_bundle_loader(settings: EngineSettings) -> Callable[[], str]:
|
||||
candidates = [
|
||||
Path(os.getenv("BOREALIS_TLS_BUNDLE", "")),
|
||||
settings.project_root / "Certificates" / "Server" / "borealis-server-bundle.pem",
|
||||
]
|
||||
|
||||
def loader() -> str:
|
||||
for candidate in candidates:
|
||||
if candidate and candidate.is_file():
|
||||
try:
|
||||
return candidate.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
return ""
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def _load_script_signer(logger: logging.Logger) -> Optional[ScriptSigner]:
|
||||
try:
|
||||
return load_signer()
|
||||
except Exception as exc:
|
||||
logger.warning("script signer unavailable: %s", exc)
|
||||
return None
|
||||
75
Data/Engine/services/crypto/signing.py
Normal file
75
Data/Engine/services/crypto/signing.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Script signing utilities for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from Data.Engine.integrations.crypto.keys import base64_from_spki_der
|
||||
from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path
|
||||
|
||||
__all__ = ["ScriptSigner", "load_signer"]
|
||||
|
||||
|
||||
_KEY_DIR = server_certificates_path("Code-Signing")
|
||||
_SIGNING_KEY_FILE = _KEY_DIR / "engine-script-ed25519.key"
|
||||
_SIGNING_PUB_FILE = _KEY_DIR / "engine-script-ed25519.pub"
|
||||
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-script-ed25519.key"
|
||||
_LEGACY_PUB_FILE = runtime_path("keys") / "borealis-script-ed25519.pub"
|
||||
|
||||
|
||||
class ScriptSigner:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey) -> None:
|
||||
self._private = private_key
|
||||
self._public = private_key.public_key()
|
||||
|
||||
def sign(self, payload: bytes) -> bytes:
|
||||
return self._private.sign(payload)
|
||||
|
||||
def public_spki_der(self) -> bytes:
|
||||
return self._public.public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def public_base64_spki(self) -> str:
|
||||
return base64_from_spki_der(self.public_spki_der())
|
||||
|
||||
|
||||
def load_signer() -> ScriptSigner:
|
||||
private_key = _load_or_create()
|
||||
return ScriptSigner(private_key)
|
||||
|
||||
|
||||
def _load_or_create() -> ed25519.Ed25519PrivateKey:
|
||||
ensure_server_certificates_dir("Code-Signing")
|
||||
|
||||
if _SIGNING_KEY_FILE.exists():
|
||||
with _SIGNING_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
if _LEGACY_KEY_FILE.exists():
|
||||
with _LEGACY_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
_KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_SIGNING_KEY_FILE.write_bytes(pem)
|
||||
try:
|
||||
if hasattr(_SIGNING_KEY_FILE, "chmod"):
|
||||
_SIGNING_KEY_FILE.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
pub_der = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
_SIGNING_PUB_FILE.write_bytes(pub_der)
|
||||
|
||||
return private_key
|
||||
21
Data/Engine/services/enrollment/__init__.py
Normal file
21
Data/Engine/services/enrollment/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Enrollment services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .enrollment_service import (
|
||||
EnrollmentRequestResult,
|
||||
EnrollmentService,
|
||||
EnrollmentStatus,
|
||||
EnrollmentTokenBundle,
|
||||
EnrollmentValidationError,
|
||||
PollingResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentService",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"EnrollmentValidationError",
|
||||
"PollingResult",
|
||||
]
|
||||
487
Data/Engine/services/enrollment/enrollment_service.py
Normal file
487
Data/Engine/services/enrollment/enrollment_service.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""Enrollment workflow orchestration for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable, Optional, Protocol
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
from Data.Engine.builders.device_enrollment import EnrollmentRequestInput
|
||||
from Data.Engine.domain.device_auth import (
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
sanitize_service_context,
|
||||
)
|
||||
from Data.Engine.domain.device_enrollment import (
|
||||
EnrollmentApproval,
|
||||
EnrollmentApprovalStatus,
|
||||
EnrollmentCode,
|
||||
)
|
||||
from Data.Engine.services.auth.device_auth_service import DeviceRecord
|
||||
from Data.Engine.services.auth.token_service import JWTIssuer
|
||||
from Data.Engine.services.enrollment.errors import EnrollmentValidationError
|
||||
from Data.Engine.services.enrollment.nonce_cache import NonceCache
|
||||
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentService",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"PollingResult",
|
||||
]
|
||||
|
||||
|
||||
class DeviceRepository(Protocol):
|
||||
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def ensure_device_record(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
hostname: str,
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> DeviceRecord: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def record_device_key(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
added_at: datetime,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class EnrollmentRepository(Protocol):
|
||||
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_install_code_usage(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
use_count_increment: int,
|
||||
last_used_at: datetime,
|
||||
used_by_guid: Optional[DeviceGuid] = None,
|
||||
mark_first_use: bool = False,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_pending_approval_by_fingerprint(
|
||||
self, fingerprint: DeviceFingerprint
|
||||
) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_pending_approval(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
hostname: str,
|
||||
guid: Optional[DeviceGuid],
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
updated_at: datetime,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def create_device_approval(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
reference: str,
|
||||
claimed_hostname: str,
|
||||
claimed_fingerprint: DeviceFingerprint,
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
created_at: datetime,
|
||||
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> EnrollmentApproval: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_device_approval_status(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
status: EnrollmentApprovalStatus,
|
||||
updated_at: datetime,
|
||||
approved_by: Optional[str] = None,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class RefreshTokenRepository(Protocol):
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
guid: DeviceGuid,
|
||||
token_hash: str,
|
||||
created_at: datetime,
|
||||
expires_at: Optional[datetime],
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class ScriptSigner(Protocol):
|
||||
def public_base64_spki(self) -> str: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class EnrollmentStatus(str):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
EXPIRED = "expired"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentTokenBundle:
|
||||
guid: DeviceGuid
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
token_type: str = "Bearer"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentRequestResult:
|
||||
status: EnrollmentStatus
|
||||
approval_reference: Optional[str] = None
|
||||
server_nonce: Optional[str] = None
|
||||
poll_after_ms: Optional[int] = None
|
||||
server_certificate: str
|
||||
signing_key: str
|
||||
http_status: int = 200
|
||||
retry_after: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PollingResult:
|
||||
status: EnrollmentStatus
|
||||
http_status: int
|
||||
poll_after_ms: Optional[int] = None
|
||||
reason: Optional[str] = None
|
||||
detail: Optional[str] = None
|
||||
tokens: Optional[EnrollmentTokenBundle] = None
|
||||
server_certificate: Optional[str] = None
|
||||
signing_key: Optional[str] = None
|
||||
|
||||
|
||||
class EnrollmentService:
|
||||
"""Coordinate the Borealis device enrollment handshake."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_repository: DeviceRepository,
|
||||
enrollment_repository: EnrollmentRepository,
|
||||
token_repository: RefreshTokenRepository,
|
||||
jwt_service: JWTIssuer,
|
||||
tls_bundle_loader: Callable[[], str],
|
||||
ip_rate_limiter: SlidingWindowRateLimiter,
|
||||
fingerprint_rate_limiter: SlidingWindowRateLimiter,
|
||||
nonce_cache: NonceCache,
|
||||
script_signer: Optional[ScriptSigner] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._devices = device_repository
|
||||
self._enrollment = enrollment_repository
|
||||
self._tokens = token_repository
|
||||
self._jwt = jwt_service
|
||||
self._load_tls_bundle = tls_bundle_loader
|
||||
self._ip_rate_limiter = ip_rate_limiter
|
||||
self._fp_rate_limiter = fingerprint_rate_limiter
|
||||
self._nonce_cache = nonce_cache
|
||||
self._signer = script_signer
|
||||
self._log = logger or logging.getLogger("borealis.engine.enrollment")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
def request_enrollment(
|
||||
self,
|
||||
payload: EnrollmentRequestInput,
|
||||
*,
|
||||
remote_addr: str,
|
||||
) -> EnrollmentRequestResult:
|
||||
context_hint = sanitize_service_context(payload.service_context)
|
||||
self._log.info(
|
||||
"enrollment-request ip=%s host=%s code_mask=%s", remote_addr, payload.hostname, self._mask_code(payload.enrollment_code)
|
||||
)
|
||||
|
||||
self._enforce_rate_limit(self._ip_rate_limiter, f"ip:{remote_addr}")
|
||||
self._enforce_rate_limit(self._fp_rate_limiter, f"fp:{payload.fingerprint.value}")
|
||||
|
||||
install_code = self._enrollment.fetch_install_code(payload.enrollment_code)
|
||||
reuse_guid = self._determine_reuse_guid(install_code, payload.fingerprint)
|
||||
|
||||
server_nonce_bytes = secrets.token_bytes(32)
|
||||
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
|
||||
|
||||
now = self._now()
|
||||
approval = self._enrollment.fetch_pending_approval_by_fingerprint(payload.fingerprint)
|
||||
if approval:
|
||||
self._enrollment.update_pending_approval(
|
||||
approval.record_id,
|
||||
hostname=payload.hostname,
|
||||
guid=reuse_guid,
|
||||
enrollment_code_id=install_code.identifier if install_code else None,
|
||||
client_nonce_b64=payload.client_nonce_b64,
|
||||
server_nonce_b64=server_nonce_b64,
|
||||
agent_pubkey_der=payload.agent_public_key_der,
|
||||
updated_at=now,
|
||||
)
|
||||
approval_reference = approval.reference
|
||||
else:
|
||||
record_id = str(uuid.uuid4())
|
||||
approval_reference = str(uuid.uuid4())
|
||||
approval = self._enrollment.create_device_approval(
|
||||
record_id=record_id,
|
||||
reference=approval_reference,
|
||||
claimed_hostname=payload.hostname,
|
||||
claimed_fingerprint=payload.fingerprint,
|
||||
enrollment_code_id=install_code.identifier if install_code else None,
|
||||
client_nonce_b64=payload.client_nonce_b64,
|
||||
server_nonce_b64=server_nonce_b64,
|
||||
agent_pubkey_der=payload.agent_public_key_der,
|
||||
created_at=now,
|
||||
guid=reuse_guid,
|
||||
)
|
||||
|
||||
signing_key = self._signer.public_base64_spki() if self._signer else ""
|
||||
certificate = self._load_tls_bundle()
|
||||
|
||||
return EnrollmentRequestResult(
|
||||
status=EnrollmentStatus.PENDING,
|
||||
approval_reference=approval.reference,
|
||||
server_nonce=server_nonce_b64,
|
||||
poll_after_ms=3000,
|
||||
server_certificate=certificate,
|
||||
signing_key=signing_key,
|
||||
)
|
||||
|
||||
def poll_enrollment(
|
||||
self,
|
||||
*,
|
||||
approval_reference: str,
|
||||
client_nonce_b64: str,
|
||||
proof_signature_b64: str,
|
||||
) -> PollingResult:
|
||||
if not approval_reference:
|
||||
raise EnrollmentValidationError("approval_reference_required")
|
||||
if not client_nonce_b64:
|
||||
raise EnrollmentValidationError("client_nonce_required")
|
||||
if not proof_signature_b64:
|
||||
raise EnrollmentValidationError("proof_sig_required")
|
||||
|
||||
approval = self._enrollment.fetch_device_approval_by_reference(approval_reference)
|
||||
if approval is None:
|
||||
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=404)
|
||||
|
||||
client_nonce = self._decode_base64(client_nonce_b64, "invalid_client_nonce")
|
||||
server_nonce = self._decode_base64(approval.server_nonce_b64, "server_nonce_invalid")
|
||||
proof_sig = self._decode_base64(proof_signature_b64, "invalid_proof_sig")
|
||||
|
||||
if approval.client_nonce_b64 != client_nonce_b64:
|
||||
raise EnrollmentValidationError("nonce_mismatch")
|
||||
|
||||
self._verify_proof_signature(
|
||||
approval=approval,
|
||||
client_nonce=client_nonce,
|
||||
server_nonce=server_nonce,
|
||||
signature=proof_sig,
|
||||
)
|
||||
|
||||
status = approval.status
|
||||
if status is EnrollmentApprovalStatus.PENDING:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.PENDING,
|
||||
http_status=200,
|
||||
poll_after_ms=5000,
|
||||
)
|
||||
if status is EnrollmentApprovalStatus.DENIED:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.DENIED,
|
||||
http_status=200,
|
||||
reason="operator_denied",
|
||||
)
|
||||
if status is EnrollmentApprovalStatus.EXPIRED:
|
||||
return PollingResult(status=EnrollmentStatus.EXPIRED, http_status=200)
|
||||
if status is EnrollmentApprovalStatus.COMPLETED:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.APPROVED,
|
||||
http_status=200,
|
||||
detail="finalized",
|
||||
)
|
||||
if status is not EnrollmentApprovalStatus.APPROVED:
|
||||
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=400)
|
||||
|
||||
nonce_key = f"{approval.reference}:{proof_signature_b64}"
|
||||
if not self._nonce_cache.consume(nonce_key):
|
||||
raise EnrollmentValidationError("proof_replayed", http_status=409)
|
||||
|
||||
token_bundle = self._finalize_approval(approval)
|
||||
signing_key = self._signer.public_base64_spki() if self._signer else ""
|
||||
certificate = self._load_tls_bundle()
|
||||
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.APPROVED,
|
||||
http_status=200,
|
||||
tokens=token_bundle,
|
||||
server_certificate=certificate,
|
||||
signing_key=signing_key,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
def _enforce_rate_limit(
|
||||
self,
|
||||
limiter: SlidingWindowRateLimiter,
|
||||
key: str,
|
||||
*,
|
||||
limit: int = 60,
|
||||
window_seconds: float = 60.0,
|
||||
) -> None:
|
||||
decision = limiter.check(key, limit, window_seconds)
|
||||
if not decision.allowed:
|
||||
raise EnrollmentValidationError(
|
||||
"rate_limited", http_status=429, retry_after=max(decision.retry_after, 1.0)
|
||||
)
|
||||
|
||||
def _determine_reuse_guid(
|
||||
self,
|
||||
install_code: Optional[EnrollmentCode],
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> Optional[DeviceGuid]:
|
||||
if install_code is None:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
if install_code.is_expired:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
if not install_code.is_exhausted:
|
||||
return None
|
||||
if not install_code.used_by_guid:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
|
||||
existing = self._devices.fetch_by_guid(install_code.used_by_guid)
|
||||
if existing and existing.identity.fingerprint.value == fingerprint.value:
|
||||
return install_code.used_by_guid
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
|
||||
def _finalize_approval(self, approval: EnrollmentApproval) -> EnrollmentTokenBundle:
|
||||
now = self._now()
|
||||
effective_guid = approval.guid or DeviceGuid(str(uuid.uuid4()))
|
||||
device_record = self._devices.ensure_device_record(
|
||||
guid=effective_guid,
|
||||
hostname=approval.claimed_hostname,
|
||||
fingerprint=approval.claimed_fingerprint,
|
||||
)
|
||||
self._devices.record_device_key(
|
||||
guid=effective_guid,
|
||||
fingerprint=approval.claimed_fingerprint,
|
||||
added_at=now,
|
||||
)
|
||||
|
||||
if approval.enrollment_code_id:
|
||||
code = self._enrollment.fetch_install_code_by_id(approval.enrollment_code_id)
|
||||
if code is not None:
|
||||
mark_first = code.used_at is None
|
||||
self._enrollment.update_install_code_usage(
|
||||
approval.enrollment_code_id,
|
||||
use_count_increment=1,
|
||||
last_used_at=now,
|
||||
used_by_guid=effective_guid,
|
||||
mark_first_use=mark_first,
|
||||
)
|
||||
|
||||
refresh_token = secrets.token_urlsafe(48)
|
||||
refresh_id = str(uuid.uuid4())
|
||||
expires_at = now + timedelta(days=30)
|
||||
token_hash = hashlib.sha256(refresh_token.encode("utf-8")).hexdigest()
|
||||
self._tokens.create(
|
||||
record_id=refresh_id,
|
||||
guid=effective_guid,
|
||||
token_hash=token_hash,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
access_token = self._jwt.issue_access_token(
|
||||
effective_guid.value,
|
||||
device_record.identity.fingerprint.value,
|
||||
max(device_record.token_version, 1),
|
||||
)
|
||||
|
||||
self._enrollment.update_device_approval_status(
|
||||
approval.record_id,
|
||||
status=EnrollmentApprovalStatus.COMPLETED,
|
||||
updated_at=now,
|
||||
guid=effective_guid,
|
||||
)
|
||||
|
||||
return EnrollmentTokenBundle(
|
||||
guid=effective_guid,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=900,
|
||||
)
|
||||
|
||||
def _verify_proof_signature(
|
||||
self,
|
||||
*,
|
||||
approval: EnrollmentApproval,
|
||||
client_nonce: bytes,
|
||||
server_nonce: bytes,
|
||||
signature: bytes,
|
||||
) -> None:
|
||||
message = server_nonce + approval.reference.encode("utf-8") + client_nonce
|
||||
try:
|
||||
public_key = serialization.load_der_public_key(approval.agent_pubkey_der)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError("agent_pubkey_invalid") from exc
|
||||
|
||||
try:
|
||||
public_key.verify(signature, message)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError("invalid_proof") from exc
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64(value: str, error_code: str) -> bytes:
|
||||
try:
|
||||
return base64.b64decode(value, validate=True)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError(error_code) from exc
|
||||
|
||||
@staticmethod
|
||||
def _mask_code(code: str) -> str:
|
||||
trimmed = (code or "").strip()
|
||||
if len(trimmed) <= 6:
|
||||
return "***"
|
||||
return f"{trimmed[:3]}***{trimmed[-3:]}"
|
||||
|
||||
@staticmethod
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
26
Data/Engine/services/enrollment/errors.py
Normal file
26
Data/Engine/services/enrollment/errors.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Error types shared across enrollment components."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
__all__ = ["EnrollmentValidationError"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentValidationError(Exception):
|
||||
"""Raised when enrollment input fails validation."""
|
||||
|
||||
code: str
|
||||
http_status: int = 400
|
||||
retry_after: Optional[float] = None
|
||||
|
||||
def to_response(self) -> dict[str, object]:
|
||||
payload: dict[str, object] = {"error": self.code}
|
||||
if self.retry_after is not None:
|
||||
payload["retry_after"] = self.retry_after
|
||||
return payload
|
||||
|
||||
def __str__(self) -> str: # pragma: no cover - debug helper
|
||||
return f"{self.code} (status={self.http_status})"
|
||||
32
Data/Engine/services/enrollment/nonce_cache.py
Normal file
32
Data/Engine/services/enrollment/nonce_cache.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Nonce replay protection for enrollment workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict
|
||||
|
||||
__all__ = ["NonceCache"]
|
||||
|
||||
|
||||
class NonceCache:
|
||||
"""Track recently observed nonces to prevent replay."""
|
||||
|
||||
def __init__(self, ttl_seconds: float = 300.0) -> None:
|
||||
self._ttl = ttl_seconds
|
||||
self._entries: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def consume(self, key: str) -> bool:
|
||||
"""Consume *key* if it has not been seen recently."""
|
||||
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
expiry = self._entries.get(key)
|
||||
if expiry and expiry > now:
|
||||
return False
|
||||
self._entries[key] = now + self._ttl
|
||||
stale = [nonce for nonce, ttl in self._entries.items() if ttl <= now]
|
||||
for nonce in stale:
|
||||
self._entries.pop(nonce, None)
|
||||
return True
|
||||
45
Data/Engine/services/rate_limit.py
Normal file
45
Data/Engine/services/rate_limit.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""In-process rate limiting utilities for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Deque, Dict
|
||||
|
||||
__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RateLimitDecision:
|
||||
"""Result of a rate limit check."""
|
||||
|
||||
allowed: bool
|
||||
retry_after: float
|
||||
|
||||
|
||||
class SlidingWindowRateLimiter:
|
||||
"""Tiny in-memory sliding window limiter suitable for single-process use."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buckets: Dict[str, Deque[float]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
bucket = self._buckets.get(key)
|
||||
if bucket is None:
|
||||
bucket = deque()
|
||||
self._buckets[key] = bucket
|
||||
|
||||
while bucket and now - bucket[0] > window_seconds:
|
||||
bucket.popleft()
|
||||
|
||||
if len(bucket) >= limit:
|
||||
retry_after = max(0.0, window_seconds - (now - bucket[0]))
|
||||
return RateLimitDecision(False, retry_after)
|
||||
|
||||
bucket.append(now)
|
||||
return RateLimitDecision(True, 0.0)
|
||||
Reference in New Issue
Block a user