mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 17:41:58 -06:00
Implement Engine HTTP interfaces for health, enrollment, and tokens
This commit is contained in:
@@ -37,7 +37,7 @@
|
||||
- 7.2 Write integration tests exercising CRUD against a temporary SQLite file.
|
||||
- 7.3 Commit when repositories provide the required ports used by services.
|
||||
|
||||
8. Recreate HTTP interfaces
|
||||
[COMPLETED] 8. Recreate HTTP interfaces
|
||||
- 8.1 Port health/enrollment/token blueprints into `interfaces/http/<feature>/routes.py`, calling Engine services only.
|
||||
- 8.2 Ensure request validation occurs via builders; response schemas stay aligned with legacy JSON.
|
||||
- 8.3 Register blueprints through Engine `server.py`; confirm endpoints respond via manual or automated tests.
|
||||
|
||||
@@ -25,14 +25,21 @@ The Engine mirrors the legacy defaults so it can boot without additional configu
|
||||
## Bootstrapping flow
|
||||
|
||||
1. `Data/Engine/bootstrapper.py` loads the environment, configures logging, prepares the SQLite connection factory, optionally applies schema migrations, and builds the Flask application via `Data/Engine/server.py`.
|
||||
2. Placeholder HTTP and Socket.IO registration hooks run so the Engine can start without any migrated routes yet.
|
||||
3. The resulting runtime object exposes the Flask app, resolved settings, optional Socket.IO server, and the configured database connection factory. `bootstrapper.main()` runs the appropriate server based on whether Socket.IO is present.
|
||||
2. A service container is assembled (`Data/Engine/services/container.py`) that wires repositories, JWT/DPoP helpers, and Engine services (device auth, token refresh, enrollment). The container is stored on the Flask app for interface modules to consume.
|
||||
3. HTTP and Socket.IO interfaces register against the new service container. The resulting runtime object exposes the Flask app, resolved settings, optional Socket.IO server, and the configured database connection factory. `bootstrapper.main()` runs the appropriate server based on whether Socket.IO is present.
|
||||
|
||||
As migration continues, services, repositories, interfaces, and integrations will live under their respective subpackages while maintaining isolation from the legacy server.
|
||||
|
||||
## Interface scaffolding
|
||||
## HTTP interfaces
|
||||
|
||||
The Engine currently exposes placeholder HTTP blueprints under `Data/Engine/interfaces/http/` (agents, enrollment, tokens, admin, and health) so that future commits can drop in real routes without reshaping the bootstrap wiring. WebSocket namespaces follow the same pattern in `Data/Engine/interfaces/ws/`, with feature-oriented modules (e.g., `agents`, `job_management`) registered by `bootstrapper.bootstrap()` when Socket.IO is available. These stubs intentionally contain no business logic yet—they merely ensure the application factory exercises the full wiring path.
|
||||
The Engine now exposes working HTTP routes alongside the remaining scaffolding:
|
||||
|
||||
- `Data/Engine/interfaces/http/health.py` implements `GET /health` for liveness probes.
|
||||
- `Data/Engine/interfaces/http/tokens.py` ports the refresh-token endpoint (`POST /api/agent/token/refresh`) using the Engine `TokenService` and request builders.
|
||||
- `Data/Engine/interfaces/http/enrollment.py` handles the enrollment handshake (`/api/agent/enroll/request` and `/api/agent/enroll/poll`) with rate limiting, nonce protection, and repository-backed approvals.
|
||||
- The admin and agent blueprints remain placeholders until their services migrate.
|
||||
|
||||
WebSocket namespaces continue to follow the same pattern in `Data/Engine/interfaces/ws/`, with feature-oriented modules (e.g., `agents`, `job_management`) registered by `bootstrapper.bootstrap()` when Socket.IO is available.
|
||||
|
||||
## Authentication services
|
||||
|
||||
@@ -43,7 +50,7 @@ Step 6 introduces the first real Engine services:
|
||||
- `Data/Engine/services/auth/device_auth_service.py` ports the legacy `DeviceAuthManager` into a repository-driven service that emits `DeviceAuthContext` instances from the new domain layer.
|
||||
- `Data/Engine/services/auth/token_service.py` issues refreshed access tokens while enforcing DPoP bindings and repository lookups.
|
||||
|
||||
Interfaces will begin consuming these services once the repository adapters land in the next milestone.
|
||||
Interfaces now consume these services via the shared container, keeping business logic inside the Engine service layer while HTTP modules remain thin request/response translators.
|
||||
|
||||
## SQLite repositories
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from .interfaces import (
|
||||
from .repositories.sqlite import connection as sqlite_connection
|
||||
from .repositories.sqlite import migrations as sqlite_migrations
|
||||
from .server import create_app
|
||||
from .services.container import build_service_container
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@@ -45,7 +46,9 @@ def bootstrap() -> EngineRuntime:
|
||||
logger.info("migrations-skipped")
|
||||
|
||||
app = create_app(settings, db_factory=db_factory)
|
||||
register_http_interfaces(app)
|
||||
services = build_service_container(settings, db_factory=db_factory, logger=logger.getChild("services"))
|
||||
app.extensions["engine_services"] = services
|
||||
register_http_interfaces(app, services)
|
||||
socketio = create_socket_server(app, settings.socketio)
|
||||
register_ws_interfaces(socketio)
|
||||
logger.info("bootstrap-complete")
|
||||
|
||||
@@ -2,67 +2,96 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from Data.Engine.domain.device_auth import DeviceFingerprint
|
||||
from Data.Engine.domain.device_enrollment import EnrollmentRequest, ProofChallenge
|
||||
from Data.Engine.domain.device_auth import DeviceFingerprint, sanitize_service_context
|
||||
from Data.Engine.domain.device_enrollment import ProofChallenge
|
||||
from Data.Engine.integrations.crypto import keys as crypto_keys
|
||||
from Data.Engine.services.enrollment.errors import EnrollmentValidationError
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentRequestBuilder",
|
||||
"EnrollmentRequestInput",
|
||||
"ProofChallengeBuilder",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _EnrollmentPayload:
|
||||
class EnrollmentRequestInput:
|
||||
"""Structured enrollment request payload ready for the service layer."""
|
||||
|
||||
hostname: str
|
||||
enrollment_code: str
|
||||
fingerprint: str
|
||||
fingerprint: DeviceFingerprint
|
||||
client_nonce: bytes
|
||||
server_nonce: bytes
|
||||
client_nonce_b64: str
|
||||
agent_public_key_der: bytes
|
||||
service_context: Optional[str]
|
||||
|
||||
|
||||
class EnrollmentRequestBuilder:
|
||||
"""Normalize agent enrollment JSON payloads into domain objects."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._payload: Optional[_EnrollmentPayload] = None
|
||||
self._hostname: Optional[str] = None
|
||||
self._enrollment_code: Optional[str] = None
|
||||
self._agent_pubkey_b64: Optional[str] = None
|
||||
self._client_nonce_b64: Optional[str] = None
|
||||
self._service_context: Optional[str] = None
|
||||
|
||||
def with_payload(self, payload: Optional[dict[str, object]]) -> "EnrollmentRequestBuilder":
|
||||
payload = payload or {}
|
||||
hostname = str(payload.get("hostname") or "").strip()
|
||||
enrollment_code = str(payload.get("enrollment_code") or "").strip()
|
||||
fingerprint = str(payload.get("fingerprint") or "").strip()
|
||||
client_nonce = self._coerce_bytes(payload.get("client_nonce"))
|
||||
server_nonce = self._coerce_bytes(payload.get("server_nonce"))
|
||||
self._payload = _EnrollmentPayload(
|
||||
hostname=hostname,
|
||||
enrollment_code=enrollment_code,
|
||||
fingerprint=fingerprint,
|
||||
client_nonce=client_nonce,
|
||||
server_nonce=server_nonce,
|
||||
)
|
||||
self._hostname = str(payload.get("hostname") or "").strip()
|
||||
self._enrollment_code = str(payload.get("enrollment_code") or "").strip()
|
||||
agent_pubkey = payload.get("agent_pubkey")
|
||||
self._agent_pubkey_b64 = agent_pubkey if isinstance(agent_pubkey, str) else None
|
||||
client_nonce = payload.get("client_nonce")
|
||||
self._client_nonce_b64 = client_nonce if isinstance(client_nonce, str) else None
|
||||
return self
|
||||
|
||||
def build(self) -> EnrollmentRequest:
|
||||
if not self._payload:
|
||||
raise ValueError("payload has not been provided")
|
||||
return EnrollmentRequest.from_payload(
|
||||
hostname=self._payload.hostname,
|
||||
enrollment_code=self._payload.enrollment_code,
|
||||
fingerprint=self._payload.fingerprint,
|
||||
client_nonce=self._payload.client_nonce,
|
||||
server_nonce=self._payload.server_nonce,
|
||||
)
|
||||
def with_service_context(self, value: Optional[str]) -> "EnrollmentRequestBuilder":
|
||||
self._service_context = value
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _coerce_bytes(value: object) -> bytes:
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
return bytes(value)
|
||||
if isinstance(value, str):
|
||||
return value.encode("utf-8")
|
||||
raise ValueError("nonce values must be bytes or base strings")
|
||||
def build(self) -> EnrollmentRequestInput:
|
||||
if not self._hostname:
|
||||
raise EnrollmentValidationError("hostname_required")
|
||||
if not self._enrollment_code:
|
||||
raise EnrollmentValidationError("enrollment_code_required")
|
||||
if not self._agent_pubkey_b64:
|
||||
raise EnrollmentValidationError("agent_pubkey_required")
|
||||
if not self._client_nonce_b64:
|
||||
raise EnrollmentValidationError("client_nonce_required")
|
||||
|
||||
try:
|
||||
agent_pubkey_der = crypto_keys.spki_der_from_base64(self._agent_pubkey_b64)
|
||||
except Exception as exc: # pragma: no cover - invalid input path
|
||||
raise EnrollmentValidationError("invalid_agent_pubkey") from exc
|
||||
|
||||
if len(agent_pubkey_der) < 10:
|
||||
raise EnrollmentValidationError("invalid_agent_pubkey")
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(self._client_nonce_b64, validate=True)
|
||||
except Exception as exc: # pragma: no cover - invalid input path
|
||||
raise EnrollmentValidationError("invalid_client_nonce") from exc
|
||||
|
||||
if len(client_nonce_bytes) < 16:
|
||||
raise EnrollmentValidationError("invalid_client_nonce")
|
||||
|
||||
fingerprint_value = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
||||
|
||||
return EnrollmentRequestInput(
|
||||
hostname=self._hostname,
|
||||
enrollment_code=self._enrollment_code,
|
||||
fingerprint=DeviceFingerprint(fingerprint_value),
|
||||
client_nonce=client_nonce_bytes,
|
||||
client_nonce_b64=self._client_nonce_b64,
|
||||
agent_public_key_der=agent_pubkey_der,
|
||||
service_context=sanitize_service_context(self._service_context),
|
||||
)
|
||||
|
||||
|
||||
class ProofChallengeBuilder:
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
@@ -44,6 +45,7 @@ def _require(value: Optional[str], field: str) -> str:
|
||||
class EnrollmentCode:
|
||||
"""Installer code metadata loaded from the persistence layer."""
|
||||
|
||||
record_id: Optional[str] = None
|
||||
code: str
|
||||
expires_at: datetime
|
||||
max_uses: int
|
||||
@@ -67,6 +69,7 @@ class EnrollmentCode:
|
||||
used_by = record.get("used_by_guid")
|
||||
used_by_guid = DeviceGuid(used_by) if used_by else None
|
||||
return cls(
|
||||
record_id=str(record.get("id") or "") or None,
|
||||
code=_require(record.get("code"), "code"),
|
||||
expires_at=_parse_iso8601(record.get("expires_at")) or datetime.now(tz=timezone.utc),
|
||||
max_uses=int(record.get("max_uses") or 1),
|
||||
@@ -88,6 +91,10 @@ class EnrollmentCode:
|
||||
def is_expired(self) -> bool:
|
||||
return self.expires_at <= datetime.now(tz=timezone.utc)
|
||||
|
||||
@property
|
||||
def identifier(self) -> Optional[str]:
|
||||
return self.record_id
|
||||
|
||||
|
||||
class EnrollmentApprovalStatus(str, Enum):
|
||||
"""Possible states for a device approval entry."""
|
||||
@@ -181,6 +188,9 @@ class EnrollmentApproval:
|
||||
enrollment_code_id: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
client_nonce_b64: str
|
||||
server_nonce_b64: str
|
||||
agent_pubkey_der: bytes
|
||||
guid: Optional[DeviceGuid] = None
|
||||
approved_by: Optional[str] = None
|
||||
|
||||
@@ -207,6 +217,9 @@ class EnrollmentApproval:
|
||||
updated_at=_parse_iso8601(record.get("updated_at")) or datetime.now(tz=timezone.utc),
|
||||
guid=DeviceGuid(guid_raw) if guid_raw else None,
|
||||
approved_by=(approved_raw or None),
|
||||
client_nonce_b64=_require(record.get("client_nonce"), "client_nonce"),
|
||||
server_nonce_b64=_require(record.get("server_nonce"), "server_nonce"),
|
||||
agent_pubkey_der=bytes(record.get("agent_pubkey_der") or b""),
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -219,3 +232,11 @@ class EnrollmentApproval:
|
||||
EnrollmentApprovalStatus.APPROVED,
|
||||
EnrollmentApprovalStatus.COMPLETED,
|
||||
}
|
||||
|
||||
@property
|
||||
def client_nonce_bytes(self) -> bytes:
|
||||
return base64.b64decode(self.client_nonce_b64.encode("ascii"), validate=True)
|
||||
|
||||
@property
|
||||
def server_nonce_bytes(self) -> bytes:
|
||||
return base64.b64decode(self.server_nonce_b64.encode("ascii"), validate=True)
|
||||
|
||||
25
Data/Engine/integrations/crypto/__init__.py
Normal file
25
Data/Engine/integrations/crypto/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Crypto integration helpers for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .keys import (
|
||||
base64_from_spki_der,
|
||||
fingerprint_from_base64_spki,
|
||||
fingerprint_from_spki_der,
|
||||
generate_ed25519_keypair,
|
||||
normalize_base64,
|
||||
private_key_to_pem,
|
||||
public_key_to_pem,
|
||||
spki_der_from_base64,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"base64_from_spki_der",
|
||||
"fingerprint_from_base64_spki",
|
||||
"fingerprint_from_spki_der",
|
||||
"generate_ed25519_keypair",
|
||||
"normalize_base64",
|
||||
"private_key_to_pem",
|
||||
"public_key_to_pem",
|
||||
"spki_der_from_base64",
|
||||
]
|
||||
70
Data/Engine/integrations/crypto/keys.py
Normal file
70
Data/Engine/integrations/crypto/keys.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Key utilities mirrored from the legacy crypto helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||
|
||||
__all__ = [
|
||||
"base64_from_spki_der",
|
||||
"fingerprint_from_base64_spki",
|
||||
"fingerprint_from_spki_der",
|
||||
"generate_ed25519_keypair",
|
||||
"normalize_base64",
|
||||
"private_key_to_pem",
|
||||
"public_key_to_pem",
|
||||
"spki_der_from_base64",
|
||||
]
|
||||
|
||||
|
||||
def generate_ed25519_keypair() -> Tuple[ed25519.Ed25519PrivateKey, bytes]:
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
public_key = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return private_key, public_key
|
||||
|
||||
|
||||
def normalize_base64(data: str) -> str:
|
||||
cleaned = re.sub(r"\s+", "", data or "")
|
||||
return cleaned.replace("-", "+").replace("_", "/")
|
||||
|
||||
|
||||
def spki_der_from_base64(spki_b64: str) -> bytes:
|
||||
return base64.b64decode(normalize_base64(spki_b64), validate=True)
|
||||
|
||||
|
||||
def base64_from_spki_der(spki_der: bytes) -> str:
|
||||
return base64.b64encode(spki_der).decode("ascii")
|
||||
|
||||
|
||||
def fingerprint_from_spki_der(spki_der: bytes) -> str:
|
||||
digest = hashlib.sha256(spki_der).hexdigest()
|
||||
return digest.lower()
|
||||
|
||||
|
||||
def fingerprint_from_base64_spki(spki_b64: str) -> str:
|
||||
return fingerprint_from_spki_der(spki_der_from_base64(spki_b64))
|
||||
|
||||
|
||||
def private_key_to_pem(private_key: ed25519.Ed25519PrivateKey) -> bytes:
|
||||
return private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
|
||||
def public_key_to_pem(public_spki_der: bytes) -> bytes:
|
||||
public_key = load_der_public_key(public_spki_der)
|
||||
return public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
from . import admin, agents, enrollment, health, tokens
|
||||
|
||||
_REGISTRARS = (
|
||||
@@ -15,14 +17,14 @@ _REGISTRARS = (
|
||||
)
|
||||
|
||||
|
||||
def register_http_interfaces(app: Flask) -> None:
|
||||
def register_http_interfaces(app: Flask, services: EngineServiceContainer) -> None:
|
||||
"""Attach HTTP blueprints to *app*.
|
||||
|
||||
The implementation is intentionally minimal for the initial scaffolding.
|
||||
"""
|
||||
|
||||
for registrar in _REGISTRARS:
|
||||
registrar(app)
|
||||
registrar(app, services)
|
||||
|
||||
|
||||
__all__ = ["register_http_interfaces"]
|
||||
|
||||
@@ -4,11 +4,13 @@ from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
|
||||
blueprint = Blueprint("engine_admin", __name__, url_prefix="/api/admin")
|
||||
|
||||
|
||||
def register(app: Flask) -> None:
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach administrative routes to *app*.
|
||||
|
||||
Concrete endpoints will be migrated in subsequent phases.
|
||||
|
||||
@@ -4,11 +4,13 @@ from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
|
||||
blueprint = Blueprint("engine_agents", __name__, url_prefix="/api/agents")
|
||||
|
||||
|
||||
def register(app: Flask) -> None:
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach agent management routes to *app*.
|
||||
|
||||
Implementation will be populated as services migrate from the legacy server.
|
||||
|
||||
@@ -2,20 +2,110 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request
|
||||
|
||||
from Data.Engine.builders.device_enrollment import EnrollmentRequestBuilder
|
||||
from Data.Engine.services import EnrollmentValidationError
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||
|
||||
|
||||
blueprint = Blueprint("engine_enrollment", __name__, url_prefix="/api/enrollment")
|
||||
blueprint = Blueprint("engine_enrollment", __name__)
|
||||
|
||||
|
||||
def register(app: Flask) -> None:
|
||||
"""Attach enrollment routes to *app*.
|
||||
|
||||
Implementation will be ported during later migration phases.
|
||||
"""
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach enrollment routes to *app*."""
|
||||
|
||||
if "engine_enrollment" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
|
||||
def enrollment_request() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
payload = request.get_json(force=True, silent=True)
|
||||
builder = EnrollmentRequestBuilder().with_payload(payload).with_service_context(
|
||||
request.headers.get(AGENT_CONTEXT_HEADER)
|
||||
)
|
||||
try:
|
||||
normalized = builder.build()
|
||||
result = services.enrollment_service.request_enrollment(
|
||||
normalized,
|
||||
remote_addr=_remote_addr(),
|
||||
)
|
||||
except EnrollmentValidationError as exc:
|
||||
response = jsonify(exc.to_response())
|
||||
response.status_code = exc.http_status
|
||||
if exc.retry_after is not None:
|
||||
response.headers["Retry-After"] = f"{int(exc.retry_after)}"
|
||||
return response
|
||||
|
||||
response_payload = {
|
||||
"status": result.status,
|
||||
"approval_reference": result.approval_reference,
|
||||
"server_nonce": result.server_nonce,
|
||||
"poll_after_ms": result.poll_after_ms,
|
||||
"server_certificate": result.server_certificate,
|
||||
"signing_key": result.signing_key,
|
||||
}
|
||||
response = jsonify(response_payload)
|
||||
response.status_code = result.http_status
|
||||
if result.retry_after is not None:
|
||||
response.headers["Retry-After"] = f"{int(result.retry_after)}"
|
||||
return response
|
||||
|
||||
|
||||
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
|
||||
def enrollment_poll() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
approval_reference = str(payload.get("approval_reference") or "").strip()
|
||||
client_nonce = str(payload.get("client_nonce") or "").strip()
|
||||
proof_sig = str(payload.get("proof_sig") or "").strip()
|
||||
|
||||
try:
|
||||
result = services.enrollment_service.poll_enrollment(
|
||||
approval_reference=approval_reference,
|
||||
client_nonce_b64=client_nonce,
|
||||
proof_signature_b64=proof_sig,
|
||||
)
|
||||
except EnrollmentValidationError as exc:
|
||||
return jsonify(exc.to_response()), exc.http_status
|
||||
|
||||
body = {"status": result.status}
|
||||
if result.poll_after_ms is not None:
|
||||
body["poll_after_ms"] = result.poll_after_ms
|
||||
if result.reason:
|
||||
body["reason"] = result.reason
|
||||
if result.detail:
|
||||
body["detail"] = result.detail
|
||||
if result.tokens:
|
||||
body.update(
|
||||
{
|
||||
"guid": result.tokens.guid.value,
|
||||
"access_token": result.tokens.access_token,
|
||||
"refresh_token": result.tokens.refresh_token,
|
||||
"token_type": result.tokens.token_type,
|
||||
"expires_in": result.tokens.expires_in,
|
||||
"server_certificate": result.server_certificate or "",
|
||||
"signing_key": result.signing_key or "",
|
||||
}
|
||||
)
|
||||
else:
|
||||
if result.server_certificate:
|
||||
body["server_certificate"] = result.server_certificate
|
||||
if result.signing_key:
|
||||
body["signing_key"] = result.signing_key
|
||||
|
||||
return jsonify(body), result.http_status
|
||||
|
||||
|
||||
def _remote_addr() -> str:
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return (request.remote_addr or "unknown").strip()
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint", "enrollment_request", "enrollment_poll"]
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
"""Health check HTTP interface placeholders for the Engine."""
|
||||
"""Health check HTTP interface for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask
|
||||
from flask import Blueprint, Flask, jsonify
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
blueprint = Blueprint("engine_health", __name__)
|
||||
|
||||
|
||||
def register(app: Flask) -> None:
|
||||
"""Attach health-related routes to *app*.
|
||||
|
||||
Routes will be populated in later migration phases.
|
||||
"""
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach health-related routes to *app*."""
|
||||
|
||||
if "engine_health" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@blueprint.route("/health", methods=["GET"])
|
||||
def health() -> object:
|
||||
"""Return a basic liveness response."""
|
||||
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
|
||||
@@ -1,21 +1,52 @@
|
||||
"""Token management HTTP interface placeholders for the Engine."""
|
||||
"""Token management HTTP interface for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request
|
||||
|
||||
from Data.Engine.builders.device_auth import RefreshTokenRequestBuilder
|
||||
from Data.Engine.domain.device_auth import DeviceAuthFailure
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
from Data.Engine.services import TokenRefreshError
|
||||
|
||||
blueprint = Blueprint("engine_tokens", __name__)
|
||||
|
||||
|
||||
blueprint = Blueprint("engine_tokens", __name__, url_prefix="/api/tokens")
|
||||
|
||||
|
||||
def register(app: Flask) -> None:
|
||||
"""Attach token management routes to *app*.
|
||||
|
||||
Implementation will be introduced as authentication services are migrated.
|
||||
"""
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach token management routes to *app*."""
|
||||
|
||||
if "engine_tokens" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
@blueprint.route("/api/agent/token/refresh", methods=["POST"])
|
||||
def refresh_token() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
builder = (
|
||||
RefreshTokenRequestBuilder()
|
||||
.with_payload(request.get_json(force=True, silent=True))
|
||||
.with_http_method(request.method)
|
||||
.with_htu(request.url)
|
||||
.with_dpop_proof(request.headers.get("DPoP"))
|
||||
)
|
||||
try:
|
||||
refresh_request = builder.build()
|
||||
except DeviceAuthFailure as exc:
|
||||
payload = exc.to_dict()
|
||||
return jsonify(payload), exc.http_status
|
||||
|
||||
try:
|
||||
response = services.token_service.refresh_access_token(refresh_request)
|
||||
except TokenRefreshError as exc:
|
||||
return jsonify(exc.to_dict()), exc.http_status
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"access_token": response.access_token,
|
||||
"expires_in": response.expires_in,
|
||||
"token_type": response.token_type,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint", "refresh_token"]
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import closing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
@@ -152,6 +153,133 @@ class SQLiteDeviceRepository:
|
||||
|
||||
return self._row_to_record(row)
|
||||
|
||||
def ensure_device_record(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
hostname: str,
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> DeviceRecord:
|
||||
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||
now_ts = int(time.time())
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid.value.upper(),),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if row:
|
||||
stored_fp = (row[4] or "").strip().lower()
|
||||
new_fp = fingerprint.value
|
||||
if not stored_fp:
|
||||
cur.execute(
|
||||
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
||||
(new_fp, now_iso, row[0]),
|
||||
)
|
||||
elif stored_fp != new_fp:
|
||||
token_version = self._coerce_int(row[2], default=1) + 1
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET ssl_key_fingerprint = ?,
|
||||
key_added_at = ?,
|
||||
token_version = ?,
|
||||
status = 'active'
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(new_fp, now_iso, token_version, row[0]),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = ?
|
||||
WHERE guid = ?
|
||||
AND revoked_at IS NULL
|
||||
""",
|
||||
(now_iso, row[0]),
|
||||
)
|
||||
conn.commit()
|
||||
else:
|
||||
resolved_hostname = self._resolve_hostname(cur, hostname, guid)
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO devices (
|
||||
guid,
|
||||
hostname,
|
||||
created_at,
|
||||
last_seen,
|
||||
ssl_key_fingerprint,
|
||||
token_version,
|
||||
status,
|
||||
key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
|
||||
""",
|
||||
(
|
||||
guid.value,
|
||||
resolved_hostname,
|
||||
now_ts,
|
||||
now_ts,
|
||||
fingerprint.value,
|
||||
now_iso,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid.value.upper(),),
|
||||
)
|
||||
latest = cur.fetchone()
|
||||
|
||||
if not latest:
|
||||
raise RuntimeError("device record could not be ensured")
|
||||
|
||||
record = self._row_to_record(latest)
|
||||
if record is None:
|
||||
raise RuntimeError("device record invalid after ensure")
|
||||
return record
|
||||
|
||||
def record_device_key(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
added_at: datetime,
|
||||
) -> None:
|
||||
added_iso = added_at.astimezone(timezone.utc).isoformat()
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(str(uuid.uuid4()), guid.value, fingerprint.value, added_iso),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_keys
|
||||
SET retired_at = ?
|
||||
WHERE guid = ?
|
||||
AND ssl_key_fingerprint != ?
|
||||
AND retired_at IS NULL
|
||||
""",
|
||||
(added_iso, guid.value, fingerprint.value),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _row_to_record(self, row: tuple) -> Optional[DeviceRecord]:
|
||||
try:
|
||||
guid = DeviceGuid(row[0])
|
||||
@@ -181,3 +309,31 @@ class SQLiteDeviceRepository:
|
||||
token_version=max(token_version, 1),
|
||||
status=status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: object, *, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
def _resolve_hostname(self, cur: sqlite3.Cursor, hostname: str, guid: DeviceGuid) -> str:
|
||||
base = (hostname or "").strip() or guid.value
|
||||
base = base[:253]
|
||||
candidate = base
|
||||
suffix = 1
|
||||
while True:
|
||||
cur.execute(
|
||||
"SELECT guid FROM devices WHERE hostname = ?",
|
||||
(candidate,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return candidate
|
||||
existing = (row[0] or "").strip().upper()
|
||||
if existing == guid.value:
|
||||
return candidate
|
||||
candidate = f"{base}-{suffix}"
|
||||
suffix += 1
|
||||
if suffix > 50:
|
||||
return guid.value
|
||||
|
||||
@@ -78,6 +78,50 @@ class SQLiteEnrollmentRepository:
|
||||
self._log.warning("invalid enrollment code record for code=%s: %s", code_value, exc)
|
||||
return None
|
||||
|
||||
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]:
|
||||
record_value = (record_id or "").strip()
|
||||
if not record_value:
|
||||
return None
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
WHERE id = ?
|
||||
""",
|
||||
(record_value,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
record = {
|
||||
"id": row[0],
|
||||
"code": row[1],
|
||||
"expires_at": row[2],
|
||||
"used_at": row[3],
|
||||
"used_by_guid": row[4],
|
||||
"max_uses": row[5],
|
||||
"use_count": row[6],
|
||||
"last_used_at": row[7],
|
||||
}
|
||||
|
||||
try:
|
||||
return EnrollmentCode.from_mapping(record)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc)
|
||||
return None
|
||||
|
||||
def update_install_code_usage(
|
||||
self,
|
||||
record_id: str,
|
||||
@@ -135,6 +179,53 @@ class SQLiteEnrollmentRepository:
|
||||
return None
|
||||
return self._fetch_device_approval("id = ?", (record_value,))
|
||||
|
||||
def fetch_pending_approval_by_fingerprint(
|
||||
self, fingerprint: DeviceFingerprint
|
||||
) -> Optional[EnrollmentApproval]:
|
||||
return self._fetch_device_approval(
|
||||
"ssl_key_fingerprint_claimed = ? AND status = 'pending'",
|
||||
(fingerprint.value,),
|
||||
)
|
||||
|
||||
def update_pending_approval(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
hostname: str,
|
||||
guid: Optional[DeviceGuid],
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
updated_at: datetime,
|
||||
) -> None:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET hostname_claimed = ?,
|
||||
guid = ?,
|
||||
enrollment_code_id = ?,
|
||||
client_nonce = ?,
|
||||
server_nonce = ?,
|
||||
agent_pubkey_der = ?,
|
||||
updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
hostname,
|
||||
guid.value if guid else None,
|
||||
enrollment_code_id,
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
self._isoformat(updated_at),
|
||||
record_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def create_device_approval(
|
||||
self,
|
||||
*,
|
||||
@@ -143,8 +234,8 @@ class SQLiteEnrollmentRepository:
|
||||
claimed_hostname: str,
|
||||
claimed_fingerprint: DeviceFingerprint,
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce: bytes,
|
||||
server_nonce: bytes,
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
created_at: datetime,
|
||||
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
|
||||
@@ -183,8 +274,8 @@ class SQLiteEnrollmentRepository:
|
||||
status.value,
|
||||
created_iso,
|
||||
created_iso,
|
||||
client_nonce,
|
||||
server_nonce,
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
),
|
||||
)
|
||||
@@ -244,7 +335,10 @@ class SQLiteEnrollmentRepository:
|
||||
created_at,
|
||||
updated_at,
|
||||
status,
|
||||
approved_by_user_id
|
||||
approved_by_user_id,
|
||||
client_nonce,
|
||||
server_nonce,
|
||||
agent_pubkey_der
|
||||
FROM device_approvals
|
||||
WHERE {where}
|
||||
""",
|
||||
@@ -266,6 +360,9 @@ class SQLiteEnrollmentRepository:
|
||||
"updated_at": row[7],
|
||||
"status": row[8],
|
||||
"approved_by_user_id": row[9],
|
||||
"client_nonce": row[10],
|
||||
"server_nonce": row[11],
|
||||
"agent_pubkey_der": row[12],
|
||||
}
|
||||
|
||||
try:
|
||||
|
||||
@@ -78,6 +78,35 @@ class SQLiteRefreshTokenRepository:
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
guid: DeviceGuid,
|
||||
token_hash: str,
|
||||
created_at: datetime,
|
||||
expires_at: Optional[datetime],
|
||||
) -> None:
|
||||
created_iso = self._isoformat(created_at)
|
||||
expires_iso = self._isoformat(expires_at) if expires_at else None
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO refresh_tokens (
|
||||
id,
|
||||
guid,
|
||||
token_hash,
|
||||
created_at,
|
||||
expires_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(record_id, guid.value, token_hash, created_iso, expires_iso),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _row_to_record(self, row: tuple) -> Optional[RefreshTokenRecord]:
|
||||
try:
|
||||
guid = DeviceGuid(row[1])
|
||||
|
||||
139
Data/Engine/runtime.py
Normal file
139
Data/Engine/runtime.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Runtime filesystem helpers for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
__all__ = [
|
||||
"agent_certificates_path",
|
||||
"ensure_agent_certificates_dir",
|
||||
"ensure_certificates_dir",
|
||||
"ensure_runtime_dir",
|
||||
"ensure_server_certificates_dir",
|
||||
"project_root",
|
||||
"runtime_path",
|
||||
"server_certificates_path",
|
||||
]
|
||||
|
||||
|
||||
def _env_path(name: str) -> Optional[Path]:
|
||||
value = os.environ.get(name)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return Path(value).expanduser().resolve()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def project_root() -> Path:
|
||||
env = _env_path("BOREALIS_PROJECT_ROOT")
|
||||
if env:
|
||||
return env
|
||||
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "Borealis.ps1").exists() or (parent / ".git").is_dir():
|
||||
return parent
|
||||
|
||||
try:
|
||||
return current.parents[1]
|
||||
except IndexError:
|
||||
return current.parent
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def server_runtime_root() -> Path:
|
||||
env = _env_path("BOREALIS_ENGINE_ROOT") or _env_path("BOREALIS_SERVER_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = project_root() / "Engine"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def runtime_path(*parts: str) -> Path:
|
||||
return server_runtime_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_runtime_dir(*parts: str) -> Path:
|
||||
path = runtime_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def certificates_root() -> Path:
|
||||
env = _env_path("BOREALIS_CERTIFICATES_ROOT") or _env_path("BOREALIS_CERT_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = project_root() / "Certificates"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
for name in ("Server", "Agent"):
|
||||
try:
|
||||
(root / name).mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
return root
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def server_certificates_root() -> Path:
|
||||
env = _env_path("BOREALIS_SERVER_CERT_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = certificates_root() / "Server"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def agent_certificates_root() -> Path:
|
||||
env = _env_path("BOREALIS_AGENT_CERT_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = certificates_root() / "Agent"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def certificates_path(*parts: str) -> Path:
|
||||
return certificates_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_certificates_dir(*parts: str) -> Path:
|
||||
path = certificates_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def server_certificates_path(*parts: str) -> Path:
|
||||
return server_certificates_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_server_certificates_dir(*parts: str) -> Path:
|
||||
path = server_certificates_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def agent_certificates_path(*parts: str) -> Path:
|
||||
return agent_certificates_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_agent_certificates_dir(*parts: str) -> Path:
|
||||
path = agent_certificates_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
@@ -10,6 +10,14 @@ from .auth import (
|
||||
TokenRefreshErrorCode,
|
||||
TokenService,
|
||||
)
|
||||
from .enrollment import (
|
||||
EnrollmentRequestResult,
|
||||
EnrollmentService,
|
||||
EnrollmentStatus,
|
||||
EnrollmentTokenBundle,
|
||||
EnrollmentValidationError,
|
||||
PollingResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
@@ -18,4 +26,10 @@ __all__ = [
|
||||
"TokenRefreshError",
|
||||
"TokenRefreshErrorCode",
|
||||
"TokenService",
|
||||
"EnrollmentService",
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"EnrollmentValidationError",
|
||||
"PollingResult",
|
||||
]
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .device_auth_service import DeviceAuthService, DeviceRecord
|
||||
from .dpop import DPoPReplayError, DPoPVerificationError, DPoPValidator
|
||||
from .jwt_service import JWTService, load_service as load_jwt_service
|
||||
from .token_service import (
|
||||
RefreshTokenRecord,
|
||||
TokenRefreshError,
|
||||
@@ -13,6 +15,11 @@ from .token_service import (
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
"DeviceRecord",
|
||||
"DPoPReplayError",
|
||||
"DPoPVerificationError",
|
||||
"DPoPValidator",
|
||||
"JWTService",
|
||||
"load_jwt_service",
|
||||
"RefreshTokenRecord",
|
||||
"TokenRefreshError",
|
||||
"TokenRefreshErrorCode",
|
||||
|
||||
105
Data/Engine/services/auth/dpop.py
Normal file
105
Data/Engine/services/auth/dpop.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""DPoP proof validation for Engine services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict, Optional
|
||||
|
||||
import jwt
|
||||
|
||||
__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"]
|
||||
|
||||
|
||||
_DP0P_MAX_SKEW = 300.0
|
||||
|
||||
|
||||
class DPoPVerificationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DPoPReplayError(DPoPVerificationError):
|
||||
pass
|
||||
|
||||
|
||||
class DPoPValidator:
|
||||
def __init__(self) -> None:
|
||||
self._observed_jti: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def verify(
|
||||
self,
|
||||
method: str,
|
||||
htu: str,
|
||||
proof: str,
|
||||
access_token: Optional[str] = None,
|
||||
) -> str:
|
||||
if not proof:
|
||||
raise DPoPVerificationError("DPoP proof missing")
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(proof)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP header") from exc
|
||||
|
||||
jwk = header.get("jwk")
|
||||
alg = header.get("alg")
|
||||
if not jwk or not isinstance(jwk, dict):
|
||||
raise DPoPVerificationError("missing jwk in DPoP header")
|
||||
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
|
||||
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
|
||||
|
||||
try:
|
||||
key = jwt.PyJWK(jwk)
|
||||
public_key = key.key
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
|
||||
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
proof,
|
||||
public_key,
|
||||
algorithms=[alg],
|
||||
options={"require": ["htm", "htu", "jti", "iat"]},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP signature") from exc
|
||||
|
||||
htm = claims.get("htm")
|
||||
proof_htu = claims.get("htu")
|
||||
jti = claims.get("jti")
|
||||
iat = claims.get("iat")
|
||||
ath = claims.get("ath")
|
||||
|
||||
if not isinstance(htm, str) or htm.lower() != method.lower():
|
||||
raise DPoPVerificationError("DPoP htm mismatch")
|
||||
if not isinstance(proof_htu, str) or proof_htu != htu:
|
||||
raise DPoPVerificationError("DPoP htu mismatch")
|
||||
if not isinstance(jti, str):
|
||||
raise DPoPVerificationError("DPoP jti missing")
|
||||
if not isinstance(iat, (int, float)):
|
||||
raise DPoPVerificationError("DPoP iat missing")
|
||||
|
||||
now = time.time()
|
||||
if abs(now - float(iat)) > _DP0P_MAX_SKEW:
|
||||
raise DPoPVerificationError("DPoP proof outside allowed skew")
|
||||
|
||||
if ath and access_token:
|
||||
expected_ath = jwt.utils.base64url_encode(
|
||||
hashlib.sha256(access_token.encode("utf-8")).digest()
|
||||
).decode("ascii")
|
||||
if expected_ath != ath:
|
||||
raise DPoPVerificationError("DPoP ath mismatch")
|
||||
|
||||
with self._lock:
|
||||
expiry = self._observed_jti.get(jti)
|
||||
if expiry and expiry > now:
|
||||
raise DPoPReplayError("DPoP proof replay detected")
|
||||
self._observed_jti[jti] = now + _DP0P_MAX_SKEW
|
||||
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
|
||||
for key in stale:
|
||||
self._observed_jti.pop(key, None)
|
||||
|
||||
thumbprint = jwt.PyJWK(jwk).thumbprint()
|
||||
return thumbprint.decode("ascii")
|
||||
124
Data/Engine/services/auth/jwt_service.py
Normal file
124
Data/Engine/services/auth/jwt_service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""JWT issuance utilities for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from Data.Engine.runtime import ensure_runtime_dir, runtime_path
|
||||
|
||||
__all__ = ["JWTService", "load_service"]
|
||||
|
||||
|
||||
_KEY_DIR = runtime_path("auth_keys")
|
||||
_KEY_FILE = _KEY_DIR / "engine-jwt-ed25519.key"
|
||||
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-jwt-ed25519.key"
|
||||
|
||||
|
||||
class JWTService:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str) -> None:
|
||||
self._private_key = private_key
|
||||
self._public_key = private_key.public_key()
|
||||
self._key_id = key_id
|
||||
|
||||
@property
|
||||
def key_id(self) -> str:
|
||||
return self._key_id
|
||||
|
||||
def issue_access_token(
|
||||
self,
|
||||
guid: str,
|
||||
ssl_key_fingerprint: str,
|
||||
token_version: int,
|
||||
expires_in: int = 900,
|
||||
extra_claims: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
now = int(time.time())
|
||||
payload: Dict[str, Any] = {
|
||||
"sub": f"device:{guid}",
|
||||
"guid": guid,
|
||||
"ssl_key_fingerprint": ssl_key_fingerprint,
|
||||
"token_version": int(token_version),
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": now + int(expires_in),
|
||||
}
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
self._private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
),
|
||||
algorithm="EdDSA",
|
||||
headers={"kid": self._key_id},
|
||||
)
|
||||
return token
|
||||
|
||||
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
|
||||
options = {"require": ["exp", "iat", "sub"]}
|
||||
public_pem = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=["EdDSA"],
|
||||
audience=audience,
|
||||
options=options,
|
||||
)
|
||||
|
||||
def public_jwk(self) -> Dict[str, Any]:
|
||||
public_bytes = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
|
||||
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
|
||||
|
||||
|
||||
def load_service() -> JWTService:
|
||||
private_key = _load_or_create_private_key()
|
||||
public_bytes = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
|
||||
return JWTService(private_key, key_id)
|
||||
|
||||
|
||||
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
|
||||
ensure_runtime_dir("auth_keys")
|
||||
|
||||
if _KEY_FILE.exists():
|
||||
with _KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
if _LEGACY_KEY_FILE.exists():
|
||||
with _LEGACY_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
_KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with _KEY_FILE.open("wb") as fh:
|
||||
fh.write(pem)
|
||||
try:
|
||||
if hasattr(_KEY_FILE, "chmod"):
|
||||
_KEY_FILE.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
return private_key
|
||||
119
Data/Engine/services/container.py
Normal file
119
Data/Engine/services/container.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Service container assembly for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from Data.Engine.config import EngineSettings
|
||||
from Data.Engine.repositories.sqlite import (
|
||||
SQLiteConnectionFactory,
|
||||
SQLiteDeviceRepository,
|
||||
SQLiteEnrollmentRepository,
|
||||
SQLiteRefreshTokenRepository,
|
||||
)
|
||||
from Data.Engine.services.auth import (
|
||||
DeviceAuthService,
|
||||
DPoPValidator,
|
||||
JWTService,
|
||||
TokenService,
|
||||
load_jwt_service,
|
||||
)
|
||||
from Data.Engine.services.crypto.signing import ScriptSigner, load_signer
|
||||
from Data.Engine.services.enrollment import EnrollmentService
|
||||
from Data.Engine.services.enrollment.nonce_cache import NonceCache
|
||||
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
|
||||
|
||||
__all__ = ["EngineServiceContainer", "build_service_container"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EngineServiceContainer:
|
||||
device_auth: DeviceAuthService
|
||||
token_service: TokenService
|
||||
enrollment_service: EnrollmentService
|
||||
jwt_service: JWTService
|
||||
dpop_validator: DPoPValidator
|
||||
|
||||
|
||||
def build_service_container(
|
||||
settings: EngineSettings,
|
||||
*,
|
||||
db_factory: SQLiteConnectionFactory,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> EngineServiceContainer:
|
||||
log = logger or logging.getLogger("borealis.engine.services")
|
||||
|
||||
device_repo = SQLiteDeviceRepository(db_factory, logger=log.getChild("devices"))
|
||||
token_repo = SQLiteRefreshTokenRepository(db_factory, logger=log.getChild("tokens"))
|
||||
enrollment_repo = SQLiteEnrollmentRepository(db_factory, logger=log.getChild("enrollment"))
|
||||
|
||||
jwt_service = load_jwt_service()
|
||||
dpop_validator = DPoPValidator()
|
||||
rate_limiter = SlidingWindowRateLimiter()
|
||||
|
||||
token_service = TokenService(
|
||||
refresh_token_repository=token_repo,
|
||||
device_repository=device_repo,
|
||||
jwt_service=jwt_service,
|
||||
dpop_validator=dpop_validator,
|
||||
logger=log.getChild("token_service"),
|
||||
)
|
||||
|
||||
enrollment_service = EnrollmentService(
|
||||
device_repository=device_repo,
|
||||
enrollment_repository=enrollment_repo,
|
||||
token_repository=token_repo,
|
||||
jwt_service=jwt_service,
|
||||
tls_bundle_loader=_tls_bundle_loader(settings),
|
||||
ip_rate_limiter=SlidingWindowRateLimiter(),
|
||||
fingerprint_rate_limiter=SlidingWindowRateLimiter(),
|
||||
nonce_cache=NonceCache(),
|
||||
script_signer=_load_script_signer(log),
|
||||
logger=log.getChild("enrollment"),
|
||||
)
|
||||
|
||||
device_auth = DeviceAuthService(
|
||||
device_repository=device_repo,
|
||||
jwt_service=jwt_service,
|
||||
logger=log.getChild("device_auth"),
|
||||
rate_limiter=rate_limiter,
|
||||
dpop_validator=dpop_validator,
|
||||
)
|
||||
|
||||
return EngineServiceContainer(
|
||||
device_auth=device_auth,
|
||||
token_service=token_service,
|
||||
enrollment_service=enrollment_service,
|
||||
jwt_service=jwt_service,
|
||||
dpop_validator=dpop_validator,
|
||||
)
|
||||
|
||||
|
||||
def _tls_bundle_loader(settings: EngineSettings) -> Callable[[], str]:
|
||||
candidates = [
|
||||
Path(os.getenv("BOREALIS_TLS_BUNDLE", "")),
|
||||
settings.project_root / "Certificates" / "Server" / "borealis-server-bundle.pem",
|
||||
]
|
||||
|
||||
def loader() -> str:
|
||||
for candidate in candidates:
|
||||
if candidate and candidate.is_file():
|
||||
try:
|
||||
return candidate.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
return ""
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def _load_script_signer(logger: logging.Logger) -> Optional[ScriptSigner]:
|
||||
try:
|
||||
return load_signer()
|
||||
except Exception as exc:
|
||||
logger.warning("script signer unavailable: %s", exc)
|
||||
return None
|
||||
75
Data/Engine/services/crypto/signing.py
Normal file
75
Data/Engine/services/crypto/signing.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Script signing utilities for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from Data.Engine.integrations.crypto.keys import base64_from_spki_der
|
||||
from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path
|
||||
|
||||
__all__ = ["ScriptSigner", "load_signer"]
|
||||
|
||||
|
||||
_KEY_DIR = server_certificates_path("Code-Signing")
|
||||
_SIGNING_KEY_FILE = _KEY_DIR / "engine-script-ed25519.key"
|
||||
_SIGNING_PUB_FILE = _KEY_DIR / "engine-script-ed25519.pub"
|
||||
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-script-ed25519.key"
|
||||
_LEGACY_PUB_FILE = runtime_path("keys") / "borealis-script-ed25519.pub"
|
||||
|
||||
|
||||
class ScriptSigner:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey) -> None:
|
||||
self._private = private_key
|
||||
self._public = private_key.public_key()
|
||||
|
||||
def sign(self, payload: bytes) -> bytes:
|
||||
return self._private.sign(payload)
|
||||
|
||||
def public_spki_der(self) -> bytes:
|
||||
return self._public.public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def public_base64_spki(self) -> str:
|
||||
return base64_from_spki_der(self.public_spki_der())
|
||||
|
||||
|
||||
def load_signer() -> ScriptSigner:
|
||||
private_key = _load_or_create()
|
||||
return ScriptSigner(private_key)
|
||||
|
||||
|
||||
def _load_or_create() -> ed25519.Ed25519PrivateKey:
|
||||
ensure_server_certificates_dir("Code-Signing")
|
||||
|
||||
if _SIGNING_KEY_FILE.exists():
|
||||
with _SIGNING_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
if _LEGACY_KEY_FILE.exists():
|
||||
with _LEGACY_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
_KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_SIGNING_KEY_FILE.write_bytes(pem)
|
||||
try:
|
||||
if hasattr(_SIGNING_KEY_FILE, "chmod"):
|
||||
_SIGNING_KEY_FILE.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
pub_der = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
_SIGNING_PUB_FILE.write_bytes(pub_der)
|
||||
|
||||
return private_key
|
||||
21
Data/Engine/services/enrollment/__init__.py
Normal file
21
Data/Engine/services/enrollment/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Enrollment services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .enrollment_service import (
|
||||
EnrollmentRequestResult,
|
||||
EnrollmentService,
|
||||
EnrollmentStatus,
|
||||
EnrollmentTokenBundle,
|
||||
EnrollmentValidationError,
|
||||
PollingResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentService",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"EnrollmentValidationError",
|
||||
"PollingResult",
|
||||
]
|
||||
487
Data/Engine/services/enrollment/enrollment_service.py
Normal file
487
Data/Engine/services/enrollment/enrollment_service.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""Enrollment workflow orchestration for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable, Optional, Protocol
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
from Data.Engine.builders.device_enrollment import EnrollmentRequestInput
|
||||
from Data.Engine.domain.device_auth import (
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
sanitize_service_context,
|
||||
)
|
||||
from Data.Engine.domain.device_enrollment import (
|
||||
EnrollmentApproval,
|
||||
EnrollmentApprovalStatus,
|
||||
EnrollmentCode,
|
||||
)
|
||||
from Data.Engine.services.auth.device_auth_service import DeviceRecord
|
||||
from Data.Engine.services.auth.token_service import JWTIssuer
|
||||
from Data.Engine.services.enrollment.errors import EnrollmentValidationError
|
||||
from Data.Engine.services.enrollment.nonce_cache import NonceCache
|
||||
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentService",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"PollingResult",
|
||||
]
|
||||
|
||||
|
||||
class DeviceRepository(Protocol):
|
||||
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def ensure_device_record(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
hostname: str,
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> DeviceRecord: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def record_device_key(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
added_at: datetime,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class EnrollmentRepository(Protocol):
|
||||
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_install_code_usage(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
use_count_increment: int,
|
||||
last_used_at: datetime,
|
||||
used_by_guid: Optional[DeviceGuid] = None,
|
||||
mark_first_use: bool = False,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_pending_approval_by_fingerprint(
|
||||
self, fingerprint: DeviceFingerprint
|
||||
) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_pending_approval(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
hostname: str,
|
||||
guid: Optional[DeviceGuid],
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
updated_at: datetime,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def create_device_approval(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
reference: str,
|
||||
claimed_hostname: str,
|
||||
claimed_fingerprint: DeviceFingerprint,
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
created_at: datetime,
|
||||
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> EnrollmentApproval: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_device_approval_status(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
status: EnrollmentApprovalStatus,
|
||||
updated_at: datetime,
|
||||
approved_by: Optional[str] = None,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class RefreshTokenRepository(Protocol):
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
guid: DeviceGuid,
|
||||
token_hash: str,
|
||||
created_at: datetime,
|
||||
expires_at: Optional[datetime],
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class ScriptSigner(Protocol):
|
||||
def public_base64_spki(self) -> str: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class EnrollmentStatus(str):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
EXPIRED = "expired"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentTokenBundle:
|
||||
guid: DeviceGuid
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
token_type: str = "Bearer"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentRequestResult:
|
||||
status: EnrollmentStatus
|
||||
approval_reference: Optional[str] = None
|
||||
server_nonce: Optional[str] = None
|
||||
poll_after_ms: Optional[int] = None
|
||||
server_certificate: str
|
||||
signing_key: str
|
||||
http_status: int = 200
|
||||
retry_after: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PollingResult:
|
||||
status: EnrollmentStatus
|
||||
http_status: int
|
||||
poll_after_ms: Optional[int] = None
|
||||
reason: Optional[str] = None
|
||||
detail: Optional[str] = None
|
||||
tokens: Optional[EnrollmentTokenBundle] = None
|
||||
server_certificate: Optional[str] = None
|
||||
signing_key: Optional[str] = None
|
||||
|
||||
|
||||
class EnrollmentService:
|
||||
"""Coordinate the Borealis device enrollment handshake."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_repository: DeviceRepository,
|
||||
enrollment_repository: EnrollmentRepository,
|
||||
token_repository: RefreshTokenRepository,
|
||||
jwt_service: JWTIssuer,
|
||||
tls_bundle_loader: Callable[[], str],
|
||||
ip_rate_limiter: SlidingWindowRateLimiter,
|
||||
fingerprint_rate_limiter: SlidingWindowRateLimiter,
|
||||
nonce_cache: NonceCache,
|
||||
script_signer: Optional[ScriptSigner] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._devices = device_repository
|
||||
self._enrollment = enrollment_repository
|
||||
self._tokens = token_repository
|
||||
self._jwt = jwt_service
|
||||
self._load_tls_bundle = tls_bundle_loader
|
||||
self._ip_rate_limiter = ip_rate_limiter
|
||||
self._fp_rate_limiter = fingerprint_rate_limiter
|
||||
self._nonce_cache = nonce_cache
|
||||
self._signer = script_signer
|
||||
self._log = logger or logging.getLogger("borealis.engine.enrollment")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
def request_enrollment(
|
||||
self,
|
||||
payload: EnrollmentRequestInput,
|
||||
*,
|
||||
remote_addr: str,
|
||||
) -> EnrollmentRequestResult:
|
||||
context_hint = sanitize_service_context(payload.service_context)
|
||||
self._log.info(
|
||||
"enrollment-request ip=%s host=%s code_mask=%s", remote_addr, payload.hostname, self._mask_code(payload.enrollment_code)
|
||||
)
|
||||
|
||||
self._enforce_rate_limit(self._ip_rate_limiter, f"ip:{remote_addr}")
|
||||
self._enforce_rate_limit(self._fp_rate_limiter, f"fp:{payload.fingerprint.value}")
|
||||
|
||||
install_code = self._enrollment.fetch_install_code(payload.enrollment_code)
|
||||
reuse_guid = self._determine_reuse_guid(install_code, payload.fingerprint)
|
||||
|
||||
server_nonce_bytes = secrets.token_bytes(32)
|
||||
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
|
||||
|
||||
now = self._now()
|
||||
approval = self._enrollment.fetch_pending_approval_by_fingerprint(payload.fingerprint)
|
||||
if approval:
|
||||
self._enrollment.update_pending_approval(
|
||||
approval.record_id,
|
||||
hostname=payload.hostname,
|
||||
guid=reuse_guid,
|
||||
enrollment_code_id=install_code.identifier if install_code else None,
|
||||
client_nonce_b64=payload.client_nonce_b64,
|
||||
server_nonce_b64=server_nonce_b64,
|
||||
agent_pubkey_der=payload.agent_public_key_der,
|
||||
updated_at=now,
|
||||
)
|
||||
approval_reference = approval.reference
|
||||
else:
|
||||
record_id = str(uuid.uuid4())
|
||||
approval_reference = str(uuid.uuid4())
|
||||
approval = self._enrollment.create_device_approval(
|
||||
record_id=record_id,
|
||||
reference=approval_reference,
|
||||
claimed_hostname=payload.hostname,
|
||||
claimed_fingerprint=payload.fingerprint,
|
||||
enrollment_code_id=install_code.identifier if install_code else None,
|
||||
client_nonce_b64=payload.client_nonce_b64,
|
||||
server_nonce_b64=server_nonce_b64,
|
||||
agent_pubkey_der=payload.agent_public_key_der,
|
||||
created_at=now,
|
||||
guid=reuse_guid,
|
||||
)
|
||||
|
||||
signing_key = self._signer.public_base64_spki() if self._signer else ""
|
||||
certificate = self._load_tls_bundle()
|
||||
|
||||
return EnrollmentRequestResult(
|
||||
status=EnrollmentStatus.PENDING,
|
||||
approval_reference=approval.reference,
|
||||
server_nonce=server_nonce_b64,
|
||||
poll_after_ms=3000,
|
||||
server_certificate=certificate,
|
||||
signing_key=signing_key,
|
||||
)
|
||||
|
||||
def poll_enrollment(
|
||||
self,
|
||||
*,
|
||||
approval_reference: str,
|
||||
client_nonce_b64: str,
|
||||
proof_signature_b64: str,
|
||||
) -> PollingResult:
|
||||
if not approval_reference:
|
||||
raise EnrollmentValidationError("approval_reference_required")
|
||||
if not client_nonce_b64:
|
||||
raise EnrollmentValidationError("client_nonce_required")
|
||||
if not proof_signature_b64:
|
||||
raise EnrollmentValidationError("proof_sig_required")
|
||||
|
||||
approval = self._enrollment.fetch_device_approval_by_reference(approval_reference)
|
||||
if approval is None:
|
||||
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=404)
|
||||
|
||||
client_nonce = self._decode_base64(client_nonce_b64, "invalid_client_nonce")
|
||||
server_nonce = self._decode_base64(approval.server_nonce_b64, "server_nonce_invalid")
|
||||
proof_sig = self._decode_base64(proof_signature_b64, "invalid_proof_sig")
|
||||
|
||||
if approval.client_nonce_b64 != client_nonce_b64:
|
||||
raise EnrollmentValidationError("nonce_mismatch")
|
||||
|
||||
self._verify_proof_signature(
|
||||
approval=approval,
|
||||
client_nonce=client_nonce,
|
||||
server_nonce=server_nonce,
|
||||
signature=proof_sig,
|
||||
)
|
||||
|
||||
status = approval.status
|
||||
if status is EnrollmentApprovalStatus.PENDING:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.PENDING,
|
||||
http_status=200,
|
||||
poll_after_ms=5000,
|
||||
)
|
||||
if status is EnrollmentApprovalStatus.DENIED:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.DENIED,
|
||||
http_status=200,
|
||||
reason="operator_denied",
|
||||
)
|
||||
if status is EnrollmentApprovalStatus.EXPIRED:
|
||||
return PollingResult(status=EnrollmentStatus.EXPIRED, http_status=200)
|
||||
if status is EnrollmentApprovalStatus.COMPLETED:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.APPROVED,
|
||||
http_status=200,
|
||||
detail="finalized",
|
||||
)
|
||||
if status is not EnrollmentApprovalStatus.APPROVED:
|
||||
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=400)
|
||||
|
||||
nonce_key = f"{approval.reference}:{proof_signature_b64}"
|
||||
if not self._nonce_cache.consume(nonce_key):
|
||||
raise EnrollmentValidationError("proof_replayed", http_status=409)
|
||||
|
||||
token_bundle = self._finalize_approval(approval)
|
||||
signing_key = self._signer.public_base64_spki() if self._signer else ""
|
||||
certificate = self._load_tls_bundle()
|
||||
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.APPROVED,
|
||||
http_status=200,
|
||||
tokens=token_bundle,
|
||||
server_certificate=certificate,
|
||||
signing_key=signing_key,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
def _enforce_rate_limit(
|
||||
self,
|
||||
limiter: SlidingWindowRateLimiter,
|
||||
key: str,
|
||||
*,
|
||||
limit: int = 60,
|
||||
window_seconds: float = 60.0,
|
||||
) -> None:
|
||||
decision = limiter.check(key, limit, window_seconds)
|
||||
if not decision.allowed:
|
||||
raise EnrollmentValidationError(
|
||||
"rate_limited", http_status=429, retry_after=max(decision.retry_after, 1.0)
|
||||
)
|
||||
|
||||
def _determine_reuse_guid(
|
||||
self,
|
||||
install_code: Optional[EnrollmentCode],
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> Optional[DeviceGuid]:
|
||||
if install_code is None:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
if install_code.is_expired:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
if not install_code.is_exhausted:
|
||||
return None
|
||||
if not install_code.used_by_guid:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
|
||||
existing = self._devices.fetch_by_guid(install_code.used_by_guid)
|
||||
if existing and existing.identity.fingerprint.value == fingerprint.value:
|
||||
return install_code.used_by_guid
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
|
||||
def _finalize_approval(self, approval: EnrollmentApproval) -> EnrollmentTokenBundle:
|
||||
now = self._now()
|
||||
effective_guid = approval.guid or DeviceGuid(str(uuid.uuid4()))
|
||||
device_record = self._devices.ensure_device_record(
|
||||
guid=effective_guid,
|
||||
hostname=approval.claimed_hostname,
|
||||
fingerprint=approval.claimed_fingerprint,
|
||||
)
|
||||
self._devices.record_device_key(
|
||||
guid=effective_guid,
|
||||
fingerprint=approval.claimed_fingerprint,
|
||||
added_at=now,
|
||||
)
|
||||
|
||||
if approval.enrollment_code_id:
|
||||
code = self._enrollment.fetch_install_code_by_id(approval.enrollment_code_id)
|
||||
if code is not None:
|
||||
mark_first = code.used_at is None
|
||||
self._enrollment.update_install_code_usage(
|
||||
approval.enrollment_code_id,
|
||||
use_count_increment=1,
|
||||
last_used_at=now,
|
||||
used_by_guid=effective_guid,
|
||||
mark_first_use=mark_first,
|
||||
)
|
||||
|
||||
refresh_token = secrets.token_urlsafe(48)
|
||||
refresh_id = str(uuid.uuid4())
|
||||
expires_at = now + timedelta(days=30)
|
||||
token_hash = hashlib.sha256(refresh_token.encode("utf-8")).hexdigest()
|
||||
self._tokens.create(
|
||||
record_id=refresh_id,
|
||||
guid=effective_guid,
|
||||
token_hash=token_hash,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
access_token = self._jwt.issue_access_token(
|
||||
effective_guid.value,
|
||||
device_record.identity.fingerprint.value,
|
||||
max(device_record.token_version, 1),
|
||||
)
|
||||
|
||||
self._enrollment.update_device_approval_status(
|
||||
approval.record_id,
|
||||
status=EnrollmentApprovalStatus.COMPLETED,
|
||||
updated_at=now,
|
||||
guid=effective_guid,
|
||||
)
|
||||
|
||||
return EnrollmentTokenBundle(
|
||||
guid=effective_guid,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=900,
|
||||
)
|
||||
|
||||
def _verify_proof_signature(
|
||||
self,
|
||||
*,
|
||||
approval: EnrollmentApproval,
|
||||
client_nonce: bytes,
|
||||
server_nonce: bytes,
|
||||
signature: bytes,
|
||||
) -> None:
|
||||
message = server_nonce + approval.reference.encode("utf-8") + client_nonce
|
||||
try:
|
||||
public_key = serialization.load_der_public_key(approval.agent_pubkey_der)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError("agent_pubkey_invalid") from exc
|
||||
|
||||
try:
|
||||
public_key.verify(signature, message)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError("invalid_proof") from exc
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64(value: str, error_code: str) -> bytes:
|
||||
try:
|
||||
return base64.b64decode(value, validate=True)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError(error_code) from exc
|
||||
|
||||
@staticmethod
|
||||
def _mask_code(code: str) -> str:
|
||||
trimmed = (code or "").strip()
|
||||
if len(trimmed) <= 6:
|
||||
return "***"
|
||||
return f"{trimmed[:3]}***{trimmed[-3:]}"
|
||||
|
||||
@staticmethod
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
26
Data/Engine/services/enrollment/errors.py
Normal file
26
Data/Engine/services/enrollment/errors.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Error types shared across enrollment components."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
__all__ = ["EnrollmentValidationError"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentValidationError(Exception):
|
||||
"""Raised when enrollment input fails validation."""
|
||||
|
||||
code: str
|
||||
http_status: int = 400
|
||||
retry_after: Optional[float] = None
|
||||
|
||||
def to_response(self) -> dict[str, object]:
|
||||
payload: dict[str, object] = {"error": self.code}
|
||||
if self.retry_after is not None:
|
||||
payload["retry_after"] = self.retry_after
|
||||
return payload
|
||||
|
||||
def __str__(self) -> str: # pragma: no cover - debug helper
|
||||
return f"{self.code} (status={self.http_status})"
|
||||
32
Data/Engine/services/enrollment/nonce_cache.py
Normal file
32
Data/Engine/services/enrollment/nonce_cache.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Nonce replay protection for enrollment workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict
|
||||
|
||||
__all__ = ["NonceCache"]
|
||||
|
||||
|
||||
class NonceCache:
|
||||
"""Track recently observed nonces to prevent replay."""
|
||||
|
||||
def __init__(self, ttl_seconds: float = 300.0) -> None:
|
||||
self._ttl = ttl_seconds
|
||||
self._entries: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def consume(self, key: str) -> bool:
|
||||
"""Consume *key* if it has not been seen recently."""
|
||||
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
expiry = self._entries.get(key)
|
||||
if expiry and expiry > now:
|
||||
return False
|
||||
self._entries[key] = now + self._ttl
|
||||
stale = [nonce for nonce, ttl in self._entries.items() if ttl <= now]
|
||||
for nonce in stale:
|
||||
self._entries.pop(nonce, None)
|
||||
return True
|
||||
45
Data/Engine/services/rate_limit.py
Normal file
45
Data/Engine/services/rate_limit.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""In-process rate limiting utilities for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Deque, Dict
|
||||
|
||||
__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RateLimitDecision:
|
||||
"""Result of a rate limit check."""
|
||||
|
||||
allowed: bool
|
||||
retry_after: float
|
||||
|
||||
|
||||
class SlidingWindowRateLimiter:
|
||||
"""Tiny in-memory sliding window limiter suitable for single-process use."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buckets: Dict[str, Deque[float]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
bucket = self._buckets.get(key)
|
||||
if bucket is None:
|
||||
bucket = deque()
|
||||
self._buckets[key] = bucket
|
||||
|
||||
while bucket and now - bucket[0] > window_seconds:
|
||||
bucket.popleft()
|
||||
|
||||
if len(bucket) >= limit:
|
||||
retry_after = max(0.0, window_seconds - (now - bucket[0]))
|
||||
return RateLimitDecision(False, retry_after)
|
||||
|
||||
bucket.append(now)
|
||||
return RateLimitDecision(True, 0.0)
|
||||
Reference in New Issue
Block a user