Implement Engine HTTP interfaces for health, enrollment, and tokens

This commit is contained in:
2025-10-22 13:33:15 -06:00
parent 7b5248dfe5
commit 9292cfb280
28 changed files with 1840 additions and 77 deletions

View File

@@ -37,7 +37,7 @@
- 7.2 Write integration tests exercising CRUD against a temporary SQLite file.
- 7.3 Commit when repositories provide the required ports used by services.
8. Recreate HTTP interfaces
[COMPLETED] 8. Recreate HTTP interfaces
- 8.1 Port health/enrollment/token blueprints into `interfaces/http/<feature>/routes.py`, calling Engine services only.
- 8.2 Ensure request validation occurs via builders; response schemas stay aligned with legacy JSON.
- 8.3 Register blueprints through Engine `server.py`; confirm endpoints respond via manual or automated tests.

View File

@@ -25,14 +25,21 @@ The Engine mirrors the legacy defaults so it can boot without additional configu
## Bootstrapping flow
1. `Data/Engine/bootstrapper.py` loads the environment, configures logging, prepares the SQLite connection factory, optionally applies schema migrations, and builds the Flask application via `Data/Engine/server.py`.
2. Placeholder HTTP and Socket.IO registration hooks run so the Engine can start without any migrated routes yet.
3. The resulting runtime object exposes the Flask app, resolved settings, optional Socket.IO server, and the configured database connection factory. `bootstrapper.main()` runs the appropriate server based on whether Socket.IO is present.
2. A service container is assembled (`Data/Engine/services/container.py`) that wires repositories, JWT/DPoP helpers, and Engine services (device auth, token refresh, enrollment). The container is stored on the Flask app for interface modules to consume.
3. HTTP and Socket.IO interfaces register against the new service container. The resulting runtime object exposes the Flask app, resolved settings, optional Socket.IO server, and the configured database connection factory. `bootstrapper.main()` runs the appropriate server based on whether Socket.IO is present.
As migration continues, services, repositories, interfaces, and integrations will live under their respective subpackages while maintaining isolation from the legacy server.
## Interface scaffolding
## HTTP interfaces
The Engine currently exposes placeholder HTTP blueprints under `Data/Engine/interfaces/http/` (agents, enrollment, tokens, admin, and health) so that future commits can drop in real routes without reshaping the bootstrap wiring. WebSocket namespaces follow the same pattern in `Data/Engine/interfaces/ws/`, with feature-oriented modules (e.g., `agents`, `job_management`) registered by `bootstrapper.bootstrap()` when Socket.IO is available. These stubs intentionally contain no business logic yet—they merely ensure the application factory exercises the full wiring path.
The Engine now exposes working HTTP routes alongside the remaining scaffolding:
- `Data/Engine/interfaces/http/health.py` implements `GET /health` for liveness probes.
- `Data/Engine/interfaces/http/tokens.py` ports the refresh-token endpoint (`POST /api/agent/token/refresh`) using the Engine `TokenService` and request builders.
- `Data/Engine/interfaces/http/enrollment.py` handles the enrollment handshake (`/api/agent/enroll/request` and `/api/agent/enroll/poll`) with rate limiting, nonce protection, and repository-backed approvals.
- The admin and agent blueprints remain placeholders until their services migrate.
WebSocket namespaces continue to follow the same pattern in `Data/Engine/interfaces/ws/`, with feature-oriented modules (e.g., `agents`, `job_management`) registered by `bootstrapper.bootstrap()` when Socket.IO is available.
## Authentication services
@@ -43,7 +50,7 @@ Step6 introduces the first real Engine services:
- `Data/Engine/services/auth/device_auth_service.py` ports the legacy `DeviceAuthManager` into a repository-driven service that emits `DeviceAuthContext` instances from the new domain layer.
- `Data/Engine/services/auth/token_service.py` issues refreshed access tokens while enforcing DPoP bindings and repository lookups.
Interfaces will begin consuming these services once the repository adapters land in the next milestone.
Interfaces now consume these services via the shared container, keeping business logic inside the Engine service layer while HTTP modules remain thin request/response translators.
## SQLite repositories

View File

@@ -16,6 +16,7 @@ from .interfaces import (
from .repositories.sqlite import connection as sqlite_connection
from .repositories.sqlite import migrations as sqlite_migrations
from .server import create_app
from .services.container import build_service_container
@dataclass(frozen=True, slots=True)
@@ -45,7 +46,9 @@ def bootstrap() -> EngineRuntime:
logger.info("migrations-skipped")
app = create_app(settings, db_factory=db_factory)
register_http_interfaces(app)
services = build_service_container(settings, db_factory=db_factory, logger=logger.getChild("services"))
app.extensions["engine_services"] = services
register_http_interfaces(app, services)
socketio = create_socket_server(app, settings.socketio)
register_ws_interfaces(socketio)
logger.info("bootstrap-complete")

View File

@@ -2,67 +2,96 @@
from __future__ import annotations
import base64
from dataclasses import dataclass
from typing import Optional
from Data.Engine.domain.device_auth import DeviceFingerprint
from Data.Engine.domain.device_enrollment import EnrollmentRequest, ProofChallenge
from Data.Engine.domain.device_auth import DeviceFingerprint, sanitize_service_context
from Data.Engine.domain.device_enrollment import ProofChallenge
from Data.Engine.integrations.crypto import keys as crypto_keys
from Data.Engine.services.enrollment.errors import EnrollmentValidationError
__all__ = [
"EnrollmentRequestBuilder",
"EnrollmentRequestInput",
"ProofChallengeBuilder",
]
@dataclass(frozen=True, slots=True)
class _EnrollmentPayload:
class EnrollmentRequestInput:
"""Structured enrollment request payload ready for the service layer."""
hostname: str
enrollment_code: str
fingerprint: str
fingerprint: DeviceFingerprint
client_nonce: bytes
server_nonce: bytes
client_nonce_b64: str
agent_public_key_der: bytes
service_context: Optional[str]
class EnrollmentRequestBuilder:
"""Normalize agent enrollment JSON payloads into domain objects."""
def __init__(self) -> None:
self._payload: Optional[_EnrollmentPayload] = None
self._hostname: Optional[str] = None
self._enrollment_code: Optional[str] = None
self._agent_pubkey_b64: Optional[str] = None
self._client_nonce_b64: Optional[str] = None
self._service_context: Optional[str] = None
def with_payload(self, payload: Optional[dict[str, object]]) -> "EnrollmentRequestBuilder":
payload = payload or {}
hostname = str(payload.get("hostname") or "").strip()
enrollment_code = str(payload.get("enrollment_code") or "").strip()
fingerprint = str(payload.get("fingerprint") or "").strip()
client_nonce = self._coerce_bytes(payload.get("client_nonce"))
server_nonce = self._coerce_bytes(payload.get("server_nonce"))
self._payload = _EnrollmentPayload(
hostname=hostname,
enrollment_code=enrollment_code,
fingerprint=fingerprint,
client_nonce=client_nonce,
server_nonce=server_nonce,
)
self._hostname = str(payload.get("hostname") or "").strip()
self._enrollment_code = str(payload.get("enrollment_code") or "").strip()
agent_pubkey = payload.get("agent_pubkey")
self._agent_pubkey_b64 = agent_pubkey if isinstance(agent_pubkey, str) else None
client_nonce = payload.get("client_nonce")
self._client_nonce_b64 = client_nonce if isinstance(client_nonce, str) else None
return self
def build(self) -> EnrollmentRequest:
if not self._payload:
raise ValueError("payload has not been provided")
return EnrollmentRequest.from_payload(
hostname=self._payload.hostname,
enrollment_code=self._payload.enrollment_code,
fingerprint=self._payload.fingerprint,
client_nonce=self._payload.client_nonce,
server_nonce=self._payload.server_nonce,
)
def with_service_context(self, value: Optional[str]) -> "EnrollmentRequestBuilder":
self._service_context = value
return self
@staticmethod
def _coerce_bytes(value: object) -> bytes:
if isinstance(value, (bytes, bytearray)):
return bytes(value)
if isinstance(value, str):
return value.encode("utf-8")
raise ValueError("nonce values must be bytes or base strings")
def build(self) -> EnrollmentRequestInput:
if not self._hostname:
raise EnrollmentValidationError("hostname_required")
if not self._enrollment_code:
raise EnrollmentValidationError("enrollment_code_required")
if not self._agent_pubkey_b64:
raise EnrollmentValidationError("agent_pubkey_required")
if not self._client_nonce_b64:
raise EnrollmentValidationError("client_nonce_required")
try:
agent_pubkey_der = crypto_keys.spki_der_from_base64(self._agent_pubkey_b64)
except Exception as exc: # pragma: no cover - invalid input path
raise EnrollmentValidationError("invalid_agent_pubkey") from exc
if len(agent_pubkey_der) < 10:
raise EnrollmentValidationError("invalid_agent_pubkey")
try:
client_nonce_bytes = base64.b64decode(self._client_nonce_b64, validate=True)
except Exception as exc: # pragma: no cover - invalid input path
raise EnrollmentValidationError("invalid_client_nonce") from exc
if len(client_nonce_bytes) < 16:
raise EnrollmentValidationError("invalid_client_nonce")
fingerprint_value = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
return EnrollmentRequestInput(
hostname=self._hostname,
enrollment_code=self._enrollment_code,
fingerprint=DeviceFingerprint(fingerprint_value),
client_nonce=client_nonce_bytes,
client_nonce_b64=self._client_nonce_b64,
agent_public_key_der=agent_pubkey_der,
service_context=sanitize_service_context(self._service_context),
)
class ProofChallengeBuilder:

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
import base64
from enum import Enum
from typing import Any, Mapping, Optional
@@ -44,6 +45,7 @@ def _require(value: Optional[str], field: str) -> str:
class EnrollmentCode:
"""Installer code metadata loaded from the persistence layer."""
record_id: Optional[str] = None
code: str
expires_at: datetime
max_uses: int
@@ -67,6 +69,7 @@ class EnrollmentCode:
used_by = record.get("used_by_guid")
used_by_guid = DeviceGuid(used_by) if used_by else None
return cls(
record_id=str(record.get("id") or "") or None,
code=_require(record.get("code"), "code"),
expires_at=_parse_iso8601(record.get("expires_at")) or datetime.now(tz=timezone.utc),
max_uses=int(record.get("max_uses") or 1),
@@ -88,6 +91,10 @@ class EnrollmentCode:
def is_expired(self) -> bool:
return self.expires_at <= datetime.now(tz=timezone.utc)
@property
def identifier(self) -> Optional[str]:
return self.record_id
class EnrollmentApprovalStatus(str, Enum):
"""Possible states for a device approval entry."""
@@ -181,6 +188,9 @@ class EnrollmentApproval:
enrollment_code_id: Optional[str]
created_at: datetime
updated_at: datetime
client_nonce_b64: str
server_nonce_b64: str
agent_pubkey_der: bytes
guid: Optional[DeviceGuid] = None
approved_by: Optional[str] = None
@@ -207,6 +217,9 @@ class EnrollmentApproval:
updated_at=_parse_iso8601(record.get("updated_at")) or datetime.now(tz=timezone.utc),
guid=DeviceGuid(guid_raw) if guid_raw else None,
approved_by=(approved_raw or None),
client_nonce_b64=_require(record.get("client_nonce"), "client_nonce"),
server_nonce_b64=_require(record.get("server_nonce"), "server_nonce"),
agent_pubkey_der=bytes(record.get("agent_pubkey_der") or b""),
)
@property
@@ -219,3 +232,11 @@ class EnrollmentApproval:
EnrollmentApprovalStatus.APPROVED,
EnrollmentApprovalStatus.COMPLETED,
}
@property
def client_nonce_bytes(self) -> bytes:
return base64.b64decode(self.client_nonce_b64.encode("ascii"), validate=True)
@property
def server_nonce_bytes(self) -> bytes:
return base64.b64decode(self.server_nonce_b64.encode("ascii"), validate=True)

View File

@@ -0,0 +1,25 @@
"""Crypto integration helpers for the Engine."""
from __future__ import annotations
from .keys import (
base64_from_spki_der,
fingerprint_from_base64_spki,
fingerprint_from_spki_der,
generate_ed25519_keypair,
normalize_base64,
private_key_to_pem,
public_key_to_pem,
spki_der_from_base64,
)
__all__ = [
"base64_from_spki_der",
"fingerprint_from_base64_spki",
"fingerprint_from_spki_der",
"generate_ed25519_keypair",
"normalize_base64",
"private_key_to_pem",
"public_key_to_pem",
"spki_der_from_base64",
]

View File

@@ -0,0 +1,70 @@
"""Key utilities mirrored from the legacy crypto helpers."""
from __future__ import annotations
import base64
import hashlib
import re
from typing import Tuple
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from cryptography.hazmat.primitives.serialization import load_der_public_key
__all__ = [
"base64_from_spki_der",
"fingerprint_from_base64_spki",
"fingerprint_from_spki_der",
"generate_ed25519_keypair",
"normalize_base64",
"private_key_to_pem",
"public_key_to_pem",
"spki_der_from_base64",
]
def generate_ed25519_keypair() -> Tuple[ed25519.Ed25519PrivateKey, bytes]:
private_key = ed25519.Ed25519PrivateKey.generate()
public_key = private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return private_key, public_key
def normalize_base64(data: str) -> str:
cleaned = re.sub(r"\s+", "", data or "")
return cleaned.replace("-", "+").replace("_", "/")
def spki_der_from_base64(spki_b64: str) -> bytes:
return base64.b64decode(normalize_base64(spki_b64), validate=True)
def base64_from_spki_der(spki_der: bytes) -> str:
return base64.b64encode(spki_der).decode("ascii")
def fingerprint_from_spki_der(spki_der: bytes) -> str:
digest = hashlib.sha256(spki_der).hexdigest()
return digest.lower()
def fingerprint_from_base64_spki(spki_b64: str) -> str:
return fingerprint_from_spki_der(spki_der_from_base64(spki_b64))
def private_key_to_pem(private_key: ed25519.Ed25519PrivateKey) -> bytes:
return private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
def public_key_to_pem(public_spki_der: bytes) -> bytes:
public_key = load_der_public_key(public_spki_der)
return public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
from flask import Flask
from Data.Engine.services.container import EngineServiceContainer
from . import admin, agents, enrollment, health, tokens
_REGISTRARS = (
@@ -15,14 +17,14 @@ _REGISTRARS = (
)
def register_http_interfaces(app: Flask) -> None:
def register_http_interfaces(app: Flask, services: EngineServiceContainer) -> None:
"""Attach HTTP blueprints to *app*.
The implementation is intentionally minimal for the initial scaffolding.
"""
for registrar in _REGISTRARS:
registrar(app)
registrar(app, services)
__all__ = ["register_http_interfaces"]

View File

@@ -4,11 +4,13 @@ from __future__ import annotations
from flask import Blueprint, Flask
from Data.Engine.services.container import EngineServiceContainer
blueprint = Blueprint("engine_admin", __name__, url_prefix="/api/admin")
def register(app: Flask) -> None:
def register(app: Flask, _services: EngineServiceContainer) -> None:
"""Attach administrative routes to *app*.
Concrete endpoints will be migrated in subsequent phases.

View File

@@ -4,11 +4,13 @@ from __future__ import annotations
from flask import Blueprint, Flask
from Data.Engine.services.container import EngineServiceContainer
blueprint = Blueprint("engine_agents", __name__, url_prefix="/api/agents")
def register(app: Flask) -> None:
def register(app: Flask, _services: EngineServiceContainer) -> None:
"""Attach agent management routes to *app*.
Implementation will be populated as services migrate from the legacy server.

View File

@@ -2,20 +2,110 @@
from __future__ import annotations
from flask import Blueprint, Flask
from flask import Blueprint, Flask, current_app, jsonify, request
from Data.Engine.builders.device_enrollment import EnrollmentRequestBuilder
from Data.Engine.services import EnrollmentValidationError
from Data.Engine.services.container import EngineServiceContainer
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
blueprint = Blueprint("engine_enrollment", __name__, url_prefix="/api/enrollment")
blueprint = Blueprint("engine_enrollment", __name__)
def register(app: Flask) -> None:
"""Attach enrollment routes to *app*.
Implementation will be ported during later migration phases.
"""
def register(app: Flask, _services: EngineServiceContainer) -> None:
"""Attach enrollment routes to *app*."""
if "engine_enrollment" not in app.blueprints:
app.register_blueprint(blueprint)
__all__ = ["register", "blueprint"]
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
def enrollment_request() -> object:
services: EngineServiceContainer = current_app.extensions["engine_services"]
payload = request.get_json(force=True, silent=True)
builder = EnrollmentRequestBuilder().with_payload(payload).with_service_context(
request.headers.get(AGENT_CONTEXT_HEADER)
)
try:
normalized = builder.build()
result = services.enrollment_service.request_enrollment(
normalized,
remote_addr=_remote_addr(),
)
except EnrollmentValidationError as exc:
response = jsonify(exc.to_response())
response.status_code = exc.http_status
if exc.retry_after is not None:
response.headers["Retry-After"] = f"{int(exc.retry_after)}"
return response
response_payload = {
"status": result.status,
"approval_reference": result.approval_reference,
"server_nonce": result.server_nonce,
"poll_after_ms": result.poll_after_ms,
"server_certificate": result.server_certificate,
"signing_key": result.signing_key,
}
response = jsonify(response_payload)
response.status_code = result.http_status
if result.retry_after is not None:
response.headers["Retry-After"] = f"{int(result.retry_after)}"
return response
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
def enrollment_poll() -> object:
services: EngineServiceContainer = current_app.extensions["engine_services"]
payload = request.get_json(force=True, silent=True) or {}
approval_reference = str(payload.get("approval_reference") or "").strip()
client_nonce = str(payload.get("client_nonce") or "").strip()
proof_sig = str(payload.get("proof_sig") or "").strip()
try:
result = services.enrollment_service.poll_enrollment(
approval_reference=approval_reference,
client_nonce_b64=client_nonce,
proof_signature_b64=proof_sig,
)
except EnrollmentValidationError as exc:
return jsonify(exc.to_response()), exc.http_status
body = {"status": result.status}
if result.poll_after_ms is not None:
body["poll_after_ms"] = result.poll_after_ms
if result.reason:
body["reason"] = result.reason
if result.detail:
body["detail"] = result.detail
if result.tokens:
body.update(
{
"guid": result.tokens.guid.value,
"access_token": result.tokens.access_token,
"refresh_token": result.tokens.refresh_token,
"token_type": result.tokens.token_type,
"expires_in": result.tokens.expires_in,
"server_certificate": result.server_certificate or "",
"signing_key": result.signing_key or "",
}
)
else:
if result.server_certificate:
body["server_certificate"] = result.server_certificate
if result.signing_key:
body["signing_key"] = result.signing_key
return jsonify(body), result.http_status
def _remote_addr() -> str:
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
return (request.remote_addr or "unknown").strip()
__all__ = ["register", "blueprint", "enrollment_request", "enrollment_poll"]

View File

@@ -1,21 +1,26 @@
"""Health check HTTP interface placeholders for the Engine."""
"""Health check HTTP interface for the Engine."""
from __future__ import annotations
from flask import Blueprint, Flask
from flask import Blueprint, Flask, jsonify
from Data.Engine.services.container import EngineServiceContainer
blueprint = Blueprint("engine_health", __name__)
def register(app: Flask) -> None:
"""Attach health-related routes to *app*.
Routes will be populated in later migration phases.
"""
def register(app: Flask, _services: EngineServiceContainer) -> None:
"""Attach health-related routes to *app*."""
if "engine_health" not in app.blueprints:
app.register_blueprint(blueprint)
@blueprint.route("/health", methods=["GET"])
def health() -> object:
"""Return a basic liveness response."""
return jsonify({"status": "ok"})
__all__ = ["register", "blueprint"]

View File

@@ -1,21 +1,52 @@
"""Token management HTTP interface placeholders for the Engine."""
"""Token management HTTP interface for the Engine."""
from __future__ import annotations
from flask import Blueprint, Flask
from flask import Blueprint, Flask, current_app, jsonify, request
from Data.Engine.builders.device_auth import RefreshTokenRequestBuilder
from Data.Engine.domain.device_auth import DeviceAuthFailure
from Data.Engine.services.container import EngineServiceContainer
from Data.Engine.services import TokenRefreshError
blueprint = Blueprint("engine_tokens", __name__)
blueprint = Blueprint("engine_tokens", __name__, url_prefix="/api/tokens")
def register(app: Flask) -> None:
"""Attach token management routes to *app*.
Implementation will be introduced as authentication services are migrated.
"""
def register(app: Flask, _services: EngineServiceContainer) -> None:
"""Attach token management routes to *app*."""
if "engine_tokens" not in app.blueprints:
app.register_blueprint(blueprint)
__all__ = ["register", "blueprint"]
@blueprint.route("/api/agent/token/refresh", methods=["POST"])
def refresh_token() -> object:
services: EngineServiceContainer = current_app.extensions["engine_services"]
builder = (
RefreshTokenRequestBuilder()
.with_payload(request.get_json(force=True, silent=True))
.with_http_method(request.method)
.with_htu(request.url)
.with_dpop_proof(request.headers.get("DPoP"))
)
try:
refresh_request = builder.build()
except DeviceAuthFailure as exc:
payload = exc.to_dict()
return jsonify(payload), exc.http_status
try:
response = services.token_service.refresh_access_token(refresh_request)
except TokenRefreshError as exc:
return jsonify(exc.to_dict()), exc.http_status
return jsonify(
{
"access_token": response.access_token,
"expires_in": response.expires_in,
"token_type": response.token_type,
}
)
__all__ = ["register", "blueprint", "refresh_token"]

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import logging
import sqlite3
import time
import uuid
from contextlib import closing
from datetime import datetime, timezone
from typing import Optional
@@ -152,6 +153,133 @@ class SQLiteDeviceRepository:
return self._row_to_record(row)
def ensure_device_record(
self,
*,
guid: DeviceGuid,
hostname: str,
fingerprint: DeviceFingerprint,
) -> DeviceRecord:
now_iso = datetime.now(tz=timezone.utc).isoformat()
now_ts = int(time.time())
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
FROM devices
WHERE UPPER(guid) = ?
""",
(guid.value.upper(),),
)
row = cur.fetchone()
if row:
stored_fp = (row[4] or "").strip().lower()
new_fp = fingerprint.value
if not stored_fp:
cur.execute(
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
(new_fp, now_iso, row[0]),
)
elif stored_fp != new_fp:
token_version = self._coerce_int(row[2], default=1) + 1
cur.execute(
"""
UPDATE devices
SET ssl_key_fingerprint = ?,
key_added_at = ?,
token_version = ?,
status = 'active'
WHERE guid = ?
""",
(new_fp, now_iso, token_version, row[0]),
)
cur.execute(
"""
UPDATE refresh_tokens
SET revoked_at = ?
WHERE guid = ?
AND revoked_at IS NULL
""",
(now_iso, row[0]),
)
conn.commit()
else:
resolved_hostname = self._resolve_hostname(cur, hostname, guid)
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
token_version,
status,
key_added_at
)
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
""",
(
guid.value,
resolved_hostname,
now_ts,
now_ts,
fingerprint.value,
now_iso,
),
)
conn.commit()
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE UPPER(guid) = ?
""",
(guid.value.upper(),),
)
latest = cur.fetchone()
if not latest:
raise RuntimeError("device record could not be ensured")
record = self._row_to_record(latest)
if record is None:
raise RuntimeError("device record invalid after ensure")
return record
def record_device_key(
self,
*,
guid: DeviceGuid,
fingerprint: DeviceFingerprint,
added_at: datetime,
) -> None:
added_iso = added_at.astimezone(timezone.utc).isoformat()
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
VALUES (?, ?, ?, ?)
""",
(str(uuid.uuid4()), guid.value, fingerprint.value, added_iso),
)
cur.execute(
"""
UPDATE device_keys
SET retired_at = ?
WHERE guid = ?
AND ssl_key_fingerprint != ?
AND retired_at IS NULL
""",
(added_iso, guid.value, fingerprint.value),
)
conn.commit()
def _row_to_record(self, row: tuple) -> Optional[DeviceRecord]:
try:
guid = DeviceGuid(row[0])
@@ -181,3 +309,31 @@ class SQLiteDeviceRepository:
token_version=max(token_version, 1),
status=status,
)
@staticmethod
def _coerce_int(value: object, *, default: int = 0) -> int:
try:
return int(value)
except Exception:
return default
def _resolve_hostname(self, cur: sqlite3.Cursor, hostname: str, guid: DeviceGuid) -> str:
base = (hostname or "").strip() or guid.value
base = base[:253]
candidate = base
suffix = 1
while True:
cur.execute(
"SELECT guid FROM devices WHERE hostname = ?",
(candidate,),
)
row = cur.fetchone()
if not row:
return candidate
existing = (row[0] or "").strip().upper()
if existing == guid.value:
return candidate
candidate = f"{base}-{suffix}"
suffix += 1
if suffix > 50:
return guid.value

View File

@@ -78,6 +78,50 @@ class SQLiteEnrollmentRepository:
self._log.warning("invalid enrollment code record for code=%s: %s", code_value, exc)
return None
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]:
record_value = (record_id or "").strip()
if not record_value:
return None
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT id,
code,
expires_at,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
WHERE id = ?
""",
(record_value,),
)
row = cur.fetchone()
if not row:
return None
record = {
"id": row[0],
"code": row[1],
"expires_at": row[2],
"used_at": row[3],
"used_by_guid": row[4],
"max_uses": row[5],
"use_count": row[6],
"last_used_at": row[7],
}
try:
return EnrollmentCode.from_mapping(record)
except Exception as exc: # pragma: no cover - defensive logging
self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc)
return None
def update_install_code_usage(
self,
record_id: str,
@@ -135,6 +179,53 @@ class SQLiteEnrollmentRepository:
return None
return self._fetch_device_approval("id = ?", (record_value,))
def fetch_pending_approval_by_fingerprint(
self, fingerprint: DeviceFingerprint
) -> Optional[EnrollmentApproval]:
return self._fetch_device_approval(
"ssl_key_fingerprint_claimed = ? AND status = 'pending'",
(fingerprint.value,),
)
def update_pending_approval(
self,
record_id: str,
*,
hostname: str,
guid: Optional[DeviceGuid],
enrollment_code_id: Optional[str],
client_nonce_b64: str,
server_nonce_b64: str,
agent_pubkey_der: bytes,
updated_at: datetime,
) -> None:
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
UPDATE device_approvals
SET hostname_claimed = ?,
guid = ?,
enrollment_code_id = ?,
client_nonce = ?,
server_nonce = ?,
agent_pubkey_der = ?,
updated_at = ?
WHERE id = ?
""",
(
hostname,
guid.value if guid else None,
enrollment_code_id,
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
self._isoformat(updated_at),
record_id,
),
)
conn.commit()
def create_device_approval(
self,
*,
@@ -143,8 +234,8 @@ class SQLiteEnrollmentRepository:
claimed_hostname: str,
claimed_fingerprint: DeviceFingerprint,
enrollment_code_id: Optional[str],
client_nonce: bytes,
server_nonce: bytes,
client_nonce_b64: str,
server_nonce_b64: str,
agent_pubkey_der: bytes,
created_at: datetime,
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
@@ -183,8 +274,8 @@ class SQLiteEnrollmentRepository:
status.value,
created_iso,
created_iso,
client_nonce,
server_nonce,
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
),
)
@@ -244,7 +335,10 @@ class SQLiteEnrollmentRepository:
created_at,
updated_at,
status,
approved_by_user_id
approved_by_user_id,
client_nonce,
server_nonce,
agent_pubkey_der
FROM device_approvals
WHERE {where}
""",
@@ -266,6 +360,9 @@ class SQLiteEnrollmentRepository:
"updated_at": row[7],
"status": row[8],
"approved_by_user_id": row[9],
"client_nonce": row[10],
"server_nonce": row[11],
"agent_pubkey_der": row[12],
}
try:

View File

@@ -78,6 +78,35 @@ class SQLiteRefreshTokenRepository:
)
conn.commit()
def create(
self,
*,
record_id: str,
guid: DeviceGuid,
token_hash: str,
created_at: datetime,
expires_at: Optional[datetime],
) -> None:
created_iso = self._isoformat(created_at)
expires_iso = self._isoformat(expires_at) if expires_at else None
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
INSERT INTO refresh_tokens (
id,
guid,
token_hash,
created_at,
expires_at
)
VALUES (?, ?, ?, ?, ?)
""",
(record_id, guid.value, token_hash, created_iso, expires_iso),
)
conn.commit()
def _row_to_record(self, row: tuple) -> Optional[RefreshTokenRecord]:
try:
guid = DeviceGuid(row[1])

139
Data/Engine/runtime.py Normal file
View File

@@ -0,0 +1,139 @@
"""Runtime filesystem helpers for the Borealis Engine."""
from __future__ import annotations
import os
from functools import lru_cache
from pathlib import Path
from typing import Optional
__all__ = [
"agent_certificates_path",
"ensure_agent_certificates_dir",
"ensure_certificates_dir",
"ensure_runtime_dir",
"ensure_server_certificates_dir",
"project_root",
"runtime_path",
"server_certificates_path",
]
def _env_path(name: str) -> Optional[Path]:
value = os.environ.get(name)
if not value:
return None
try:
return Path(value).expanduser().resolve()
except Exception:
return None
@lru_cache(maxsize=None)
def project_root() -> Path:
env = _env_path("BOREALIS_PROJECT_ROOT")
if env:
return env
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "Borealis.ps1").exists() or (parent / ".git").is_dir():
return parent
try:
return current.parents[1]
except IndexError:
return current.parent
@lru_cache(maxsize=None)
def server_runtime_root() -> Path:
env = _env_path("BOREALIS_ENGINE_ROOT") or _env_path("BOREALIS_SERVER_ROOT")
if env:
env.mkdir(parents=True, exist_ok=True)
return env
root = project_root() / "Engine"
root.mkdir(parents=True, exist_ok=True)
return root
def runtime_path(*parts: str) -> Path:
return server_runtime_root().joinpath(*parts)
def ensure_runtime_dir(*parts: str) -> Path:
path = runtime_path(*parts)
path.mkdir(parents=True, exist_ok=True)
return path
@lru_cache(maxsize=None)
def certificates_root() -> Path:
env = _env_path("BOREALIS_CERTIFICATES_ROOT") or _env_path("BOREALIS_CERT_ROOT")
if env:
env.mkdir(parents=True, exist_ok=True)
return env
root = project_root() / "Certificates"
root.mkdir(parents=True, exist_ok=True)
for name in ("Server", "Agent"):
try:
(root / name).mkdir(parents=True, exist_ok=True)
except Exception:
pass
return root
@lru_cache(maxsize=None)
def server_certificates_root() -> Path:
env = _env_path("BOREALIS_SERVER_CERT_ROOT")
if env:
env.mkdir(parents=True, exist_ok=True)
return env
root = certificates_root() / "Server"
root.mkdir(parents=True, exist_ok=True)
return root
@lru_cache(maxsize=None)
def agent_certificates_root() -> Path:
env = _env_path("BOREALIS_AGENT_CERT_ROOT")
if env:
env.mkdir(parents=True, exist_ok=True)
return env
root = certificates_root() / "Agent"
root.mkdir(parents=True, exist_ok=True)
return root
def certificates_path(*parts: str) -> Path:
return certificates_root().joinpath(*parts)
def ensure_certificates_dir(*parts: str) -> Path:
path = certificates_path(*parts)
path.mkdir(parents=True, exist_ok=True)
return path
def server_certificates_path(*parts: str) -> Path:
return server_certificates_root().joinpath(*parts)
def ensure_server_certificates_dir(*parts: str) -> Path:
path = server_certificates_path(*parts)
path.mkdir(parents=True, exist_ok=True)
return path
def agent_certificates_path(*parts: str) -> Path:
return agent_certificates_root().joinpath(*parts)
def ensure_agent_certificates_dir(*parts: str) -> Path:
path = agent_certificates_path(*parts)
path.mkdir(parents=True, exist_ok=True)
return path

View File

@@ -10,6 +10,14 @@ from .auth import (
TokenRefreshErrorCode,
TokenService,
)
from .enrollment import (
EnrollmentRequestResult,
EnrollmentService,
EnrollmentStatus,
EnrollmentTokenBundle,
EnrollmentValidationError,
PollingResult,
)
__all__ = [
"DeviceAuthService",
@@ -18,4 +26,10 @@ __all__ = [
"TokenRefreshError",
"TokenRefreshErrorCode",
"TokenService",
"EnrollmentService",
"EnrollmentRequestResult",
"EnrollmentStatus",
"EnrollmentTokenBundle",
"EnrollmentValidationError",
"PollingResult",
]

View File

@@ -3,6 +3,8 @@
from __future__ import annotations
from .device_auth_service import DeviceAuthService, DeviceRecord
from .dpop import DPoPReplayError, DPoPVerificationError, DPoPValidator
from .jwt_service import JWTService, load_service as load_jwt_service
from .token_service import (
RefreshTokenRecord,
TokenRefreshError,
@@ -13,6 +15,11 @@ from .token_service import (
__all__ = [
"DeviceAuthService",
"DeviceRecord",
"DPoPReplayError",
"DPoPVerificationError",
"DPoPValidator",
"JWTService",
"load_jwt_service",
"RefreshTokenRecord",
"TokenRefreshError",
"TokenRefreshErrorCode",

View File

@@ -0,0 +1,105 @@
"""DPoP proof validation for Engine services."""
from __future__ import annotations
import hashlib
import time
from threading import Lock
from typing import Dict, Optional
import jwt
__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"]
_DP0P_MAX_SKEW = 300.0
class DPoPVerificationError(Exception):
pass
class DPoPReplayError(DPoPVerificationError):
pass
class DPoPValidator:
def __init__(self) -> None:
self._observed_jti: Dict[str, float] = {}
self._lock = Lock()
def verify(
self,
method: str,
htu: str,
proof: str,
access_token: Optional[str] = None,
) -> str:
if not proof:
raise DPoPVerificationError("DPoP proof missing")
try:
header = jwt.get_unverified_header(proof)
except Exception as exc:
raise DPoPVerificationError("invalid DPoP header") from exc
jwk = header.get("jwk")
alg = header.get("alg")
if not jwk or not isinstance(jwk, dict):
raise DPoPVerificationError("missing jwk in DPoP header")
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
try:
key = jwt.PyJWK(jwk)
public_key = key.key
except Exception as exc:
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
try:
claims = jwt.decode(
proof,
public_key,
algorithms=[alg],
options={"require": ["htm", "htu", "jti", "iat"]},
)
except Exception as exc:
raise DPoPVerificationError("invalid DPoP signature") from exc
htm = claims.get("htm")
proof_htu = claims.get("htu")
jti = claims.get("jti")
iat = claims.get("iat")
ath = claims.get("ath")
if not isinstance(htm, str) or htm.lower() != method.lower():
raise DPoPVerificationError("DPoP htm mismatch")
if not isinstance(proof_htu, str) or proof_htu != htu:
raise DPoPVerificationError("DPoP htu mismatch")
if not isinstance(jti, str):
raise DPoPVerificationError("DPoP jti missing")
if not isinstance(iat, (int, float)):
raise DPoPVerificationError("DPoP iat missing")
now = time.time()
if abs(now - float(iat)) > _DP0P_MAX_SKEW:
raise DPoPVerificationError("DPoP proof outside allowed skew")
if ath and access_token:
expected_ath = jwt.utils.base64url_encode(
hashlib.sha256(access_token.encode("utf-8")).digest()
).decode("ascii")
if expected_ath != ath:
raise DPoPVerificationError("DPoP ath mismatch")
with self._lock:
expiry = self._observed_jti.get(jti)
if expiry and expiry > now:
raise DPoPReplayError("DPoP proof replay detected")
self._observed_jti[jti] = now + _DP0P_MAX_SKEW
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
for key in stale:
self._observed_jti.pop(key, None)
thumbprint = jwt.PyJWK(jwk).thumbprint()
return thumbprint.decode("ascii")

View File

@@ -0,0 +1,124 @@
"""JWT issuance utilities for the Engine."""
from __future__ import annotations
import hashlib
import time
from typing import Any, Dict, Optional
import jwt
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from Data.Engine.runtime import ensure_runtime_dir, runtime_path
__all__ = ["JWTService", "load_service"]
_KEY_DIR = runtime_path("auth_keys")
_KEY_FILE = _KEY_DIR / "engine-jwt-ed25519.key"
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-jwt-ed25519.key"
class JWTService:
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str) -> None:
self._private_key = private_key
self._public_key = private_key.public_key()
self._key_id = key_id
@property
def key_id(self) -> str:
return self._key_id
def issue_access_token(
self,
guid: str,
ssl_key_fingerprint: str,
token_version: int,
expires_in: int = 900,
extra_claims: Optional[Dict[str, Any]] = None,
) -> str:
now = int(time.time())
payload: Dict[str, Any] = {
"sub": f"device:{guid}",
"guid": guid,
"ssl_key_fingerprint": ssl_key_fingerprint,
"token_version": int(token_version),
"iat": now,
"nbf": now,
"exp": now + int(expires_in),
}
if extra_claims:
payload.update(extra_claims)
token = jwt.encode(
payload,
self._private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
),
algorithm="EdDSA",
headers={"kid": self._key_id},
)
return token
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
options = {"require": ["exp", "iat", "sub"]}
public_pem = self._public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return jwt.decode(
token,
public_pem,
algorithms=["EdDSA"],
audience=audience,
options=options,
)
def public_jwk(self) -> Dict[str, Any]:
public_bytes = self._public_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
def load_service() -> JWTService:
private_key = _load_or_create_private_key()
public_bytes = private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
return JWTService(private_key, key_id)
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
ensure_runtime_dir("auth_keys")
if _KEY_FILE.exists():
with _KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
if _LEGACY_KEY_FILE.exists():
with _LEGACY_KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
private_key = ed25519.Ed25519PrivateKey.generate()
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
_KEY_DIR.mkdir(parents=True, exist_ok=True)
with _KEY_FILE.open("wb") as fh:
fh.write(pem)
try:
if hasattr(_KEY_FILE, "chmod"):
_KEY_FILE.chmod(0o600)
except Exception:
pass
return private_key

View File

@@ -0,0 +1,119 @@
"""Service container assembly for the Borealis Engine."""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Optional
from Data.Engine.config import EngineSettings
from Data.Engine.repositories.sqlite import (
SQLiteConnectionFactory,
SQLiteDeviceRepository,
SQLiteEnrollmentRepository,
SQLiteRefreshTokenRepository,
)
from Data.Engine.services.auth import (
DeviceAuthService,
DPoPValidator,
JWTService,
TokenService,
load_jwt_service,
)
from Data.Engine.services.crypto.signing import ScriptSigner, load_signer
from Data.Engine.services.enrollment import EnrollmentService
from Data.Engine.services.enrollment.nonce_cache import NonceCache
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
__all__ = ["EngineServiceContainer", "build_service_container"]
@dataclass(frozen=True, slots=True)
class EngineServiceContainer:
device_auth: DeviceAuthService
token_service: TokenService
enrollment_service: EnrollmentService
jwt_service: JWTService
dpop_validator: DPoPValidator
def build_service_container(
settings: EngineSettings,
*,
db_factory: SQLiteConnectionFactory,
logger: Optional[logging.Logger] = None,
) -> EngineServiceContainer:
log = logger or logging.getLogger("borealis.engine.services")
device_repo = SQLiteDeviceRepository(db_factory, logger=log.getChild("devices"))
token_repo = SQLiteRefreshTokenRepository(db_factory, logger=log.getChild("tokens"))
enrollment_repo = SQLiteEnrollmentRepository(db_factory, logger=log.getChild("enrollment"))
jwt_service = load_jwt_service()
dpop_validator = DPoPValidator()
rate_limiter = SlidingWindowRateLimiter()
token_service = TokenService(
refresh_token_repository=token_repo,
device_repository=device_repo,
jwt_service=jwt_service,
dpop_validator=dpop_validator,
logger=log.getChild("token_service"),
)
enrollment_service = EnrollmentService(
device_repository=device_repo,
enrollment_repository=enrollment_repo,
token_repository=token_repo,
jwt_service=jwt_service,
tls_bundle_loader=_tls_bundle_loader(settings),
ip_rate_limiter=SlidingWindowRateLimiter(),
fingerprint_rate_limiter=SlidingWindowRateLimiter(),
nonce_cache=NonceCache(),
script_signer=_load_script_signer(log),
logger=log.getChild("enrollment"),
)
device_auth = DeviceAuthService(
device_repository=device_repo,
jwt_service=jwt_service,
logger=log.getChild("device_auth"),
rate_limiter=rate_limiter,
dpop_validator=dpop_validator,
)
return EngineServiceContainer(
device_auth=device_auth,
token_service=token_service,
enrollment_service=enrollment_service,
jwt_service=jwt_service,
dpop_validator=dpop_validator,
)
def _tls_bundle_loader(settings: EngineSettings) -> Callable[[], str]:
candidates = [
Path(os.getenv("BOREALIS_TLS_BUNDLE", "")),
settings.project_root / "Certificates" / "Server" / "borealis-server-bundle.pem",
]
def loader() -> str:
for candidate in candidates:
if candidate and candidate.is_file():
try:
return candidate.read_text(encoding="utf-8")
except Exception:
continue
return ""
return loader
def _load_script_signer(logger: logging.Logger) -> Optional[ScriptSigner]:
try:
return load_signer()
except Exception as exc:
logger.warning("script signer unavailable: %s", exc)
return None

View File

@@ -0,0 +1,75 @@
"""Script signing utilities for the Engine."""
from __future__ import annotations
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from Data.Engine.integrations.crypto.keys import base64_from_spki_der
from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path
__all__ = ["ScriptSigner", "load_signer"]
_KEY_DIR = server_certificates_path("Code-Signing")
_SIGNING_KEY_FILE = _KEY_DIR / "engine-script-ed25519.key"
_SIGNING_PUB_FILE = _KEY_DIR / "engine-script-ed25519.pub"
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-script-ed25519.key"
_LEGACY_PUB_FILE = runtime_path("keys") / "borealis-script-ed25519.pub"
class ScriptSigner:
def __init__(self, private_key: ed25519.Ed25519PrivateKey) -> None:
self._private = private_key
self._public = private_key.public_key()
def sign(self, payload: bytes) -> bytes:
return self._private.sign(payload)
def public_spki_der(self) -> bytes:
return self._public.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
def public_base64_spki(self) -> str:
return base64_from_spki_der(self.public_spki_der())
def load_signer() -> ScriptSigner:
private_key = _load_or_create()
return ScriptSigner(private_key)
def _load_or_create() -> ed25519.Ed25519PrivateKey:
ensure_server_certificates_dir("Code-Signing")
if _SIGNING_KEY_FILE.exists():
with _SIGNING_KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
if _LEGACY_KEY_FILE.exists():
with _LEGACY_KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
private_key = ed25519.Ed25519PrivateKey.generate()
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
_KEY_DIR.mkdir(parents=True, exist_ok=True)
_SIGNING_KEY_FILE.write_bytes(pem)
try:
if hasattr(_SIGNING_KEY_FILE, "chmod"):
_SIGNING_KEY_FILE.chmod(0o600)
except Exception:
pass
pub_der = private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
_SIGNING_PUB_FILE.write_bytes(pub_der)
return private_key

View File

@@ -0,0 +1,21 @@
"""Enrollment services for the Borealis Engine."""
from __future__ import annotations
from .enrollment_service import (
EnrollmentRequestResult,
EnrollmentService,
EnrollmentStatus,
EnrollmentTokenBundle,
EnrollmentValidationError,
PollingResult,
)
__all__ = [
"EnrollmentRequestResult",
"EnrollmentService",
"EnrollmentStatus",
"EnrollmentTokenBundle",
"EnrollmentValidationError",
"PollingResult",
]

View File

@@ -0,0 +1,487 @@
"""Enrollment workflow orchestration for the Borealis Engine."""
from __future__ import annotations
import base64
import hashlib
import logging
import secrets
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Callable, Optional, Protocol
from cryptography.hazmat.primitives import serialization
from Data.Engine.builders.device_enrollment import EnrollmentRequestInput
from Data.Engine.domain.device_auth import (
DeviceFingerprint,
DeviceGuid,
sanitize_service_context,
)
from Data.Engine.domain.device_enrollment import (
EnrollmentApproval,
EnrollmentApprovalStatus,
EnrollmentCode,
)
from Data.Engine.services.auth.device_auth_service import DeviceRecord
from Data.Engine.services.auth.token_service import JWTIssuer
from Data.Engine.services.enrollment.errors import EnrollmentValidationError
from Data.Engine.services.enrollment.nonce_cache import NonceCache
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
__all__ = [
"EnrollmentRequestResult",
"EnrollmentService",
"EnrollmentStatus",
"EnrollmentTokenBundle",
"PollingResult",
]
class DeviceRepository(Protocol):
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol
...
def ensure_device_record(
self,
*,
guid: DeviceGuid,
hostname: str,
fingerprint: DeviceFingerprint,
) -> DeviceRecord: # pragma: no cover - protocol
...
def record_device_key(
self,
*,
guid: DeviceGuid,
fingerprint: DeviceFingerprint,
added_at: datetime,
) -> None: # pragma: no cover - protocol
...
class EnrollmentRepository(Protocol):
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
...
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
...
def update_install_code_usage(
self,
record_id: str,
*,
use_count_increment: int,
last_used_at: datetime,
used_by_guid: Optional[DeviceGuid] = None,
mark_first_use: bool = False,
) -> None: # pragma: no cover - protocol
...
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
...
def fetch_pending_approval_by_fingerprint(
self, fingerprint: DeviceFingerprint
) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
...
def update_pending_approval(
self,
record_id: str,
*,
hostname: str,
guid: Optional[DeviceGuid],
enrollment_code_id: Optional[str],
client_nonce_b64: str,
server_nonce_b64: str,
agent_pubkey_der: bytes,
updated_at: datetime,
) -> None: # pragma: no cover - protocol
...
def create_device_approval(
self,
*,
record_id: str,
reference: str,
claimed_hostname: str,
claimed_fingerprint: DeviceFingerprint,
enrollment_code_id: Optional[str],
client_nonce_b64: str,
server_nonce_b64: str,
agent_pubkey_der: bytes,
created_at: datetime,
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
guid: Optional[DeviceGuid] = None,
) -> EnrollmentApproval: # pragma: no cover - protocol
...
def update_device_approval_status(
self,
record_id: str,
*,
status: EnrollmentApprovalStatus,
updated_at: datetime,
approved_by: Optional[str] = None,
guid: Optional[DeviceGuid] = None,
) -> None: # pragma: no cover - protocol
...
class RefreshTokenRepository(Protocol):
def create(
self,
*,
record_id: str,
guid: DeviceGuid,
token_hash: str,
created_at: datetime,
expires_at: Optional[datetime],
) -> None: # pragma: no cover - protocol
...
class ScriptSigner(Protocol):
def public_base64_spki(self) -> str: # pragma: no cover - protocol
...
class EnrollmentStatus(str):
PENDING = "pending"
APPROVED = "approved"
DENIED = "denied"
EXPIRED = "expired"
UNKNOWN = "unknown"
@dataclass(frozen=True, slots=True)
class EnrollmentTokenBundle:
guid: DeviceGuid
access_token: str
refresh_token: str
expires_in: int
token_type: str = "Bearer"
@dataclass(frozen=True, slots=True)
class EnrollmentRequestResult:
status: EnrollmentStatus
approval_reference: Optional[str] = None
server_nonce: Optional[str] = None
poll_after_ms: Optional[int] = None
server_certificate: str
signing_key: str
http_status: int = 200
retry_after: Optional[float] = None
@dataclass(frozen=True, slots=True)
class PollingResult:
status: EnrollmentStatus
http_status: int
poll_after_ms: Optional[int] = None
reason: Optional[str] = None
detail: Optional[str] = None
tokens: Optional[EnrollmentTokenBundle] = None
server_certificate: Optional[str] = None
signing_key: Optional[str] = None
class EnrollmentService:
"""Coordinate the Borealis device enrollment handshake."""
def __init__(
self,
*,
device_repository: DeviceRepository,
enrollment_repository: EnrollmentRepository,
token_repository: RefreshTokenRepository,
jwt_service: JWTIssuer,
tls_bundle_loader: Callable[[], str],
ip_rate_limiter: SlidingWindowRateLimiter,
fingerprint_rate_limiter: SlidingWindowRateLimiter,
nonce_cache: NonceCache,
script_signer: Optional[ScriptSigner] = None,
logger: Optional[logging.Logger] = None,
) -> None:
self._devices = device_repository
self._enrollment = enrollment_repository
self._tokens = token_repository
self._jwt = jwt_service
self._load_tls_bundle = tls_bundle_loader
self._ip_rate_limiter = ip_rate_limiter
self._fp_rate_limiter = fingerprint_rate_limiter
self._nonce_cache = nonce_cache
self._signer = script_signer
self._log = logger or logging.getLogger("borealis.engine.enrollment")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def request_enrollment(
self,
payload: EnrollmentRequestInput,
*,
remote_addr: str,
) -> EnrollmentRequestResult:
context_hint = sanitize_service_context(payload.service_context)
self._log.info(
"enrollment-request ip=%s host=%s code_mask=%s", remote_addr, payload.hostname, self._mask_code(payload.enrollment_code)
)
self._enforce_rate_limit(self._ip_rate_limiter, f"ip:{remote_addr}")
self._enforce_rate_limit(self._fp_rate_limiter, f"fp:{payload.fingerprint.value}")
install_code = self._enrollment.fetch_install_code(payload.enrollment_code)
reuse_guid = self._determine_reuse_guid(install_code, payload.fingerprint)
server_nonce_bytes = secrets.token_bytes(32)
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
now = self._now()
approval = self._enrollment.fetch_pending_approval_by_fingerprint(payload.fingerprint)
if approval:
self._enrollment.update_pending_approval(
approval.record_id,
hostname=payload.hostname,
guid=reuse_guid,
enrollment_code_id=install_code.identifier if install_code else None,
client_nonce_b64=payload.client_nonce_b64,
server_nonce_b64=server_nonce_b64,
agent_pubkey_der=payload.agent_public_key_der,
updated_at=now,
)
approval_reference = approval.reference
else:
record_id = str(uuid.uuid4())
approval_reference = str(uuid.uuid4())
approval = self._enrollment.create_device_approval(
record_id=record_id,
reference=approval_reference,
claimed_hostname=payload.hostname,
claimed_fingerprint=payload.fingerprint,
enrollment_code_id=install_code.identifier if install_code else None,
client_nonce_b64=payload.client_nonce_b64,
server_nonce_b64=server_nonce_b64,
agent_pubkey_der=payload.agent_public_key_der,
created_at=now,
guid=reuse_guid,
)
signing_key = self._signer.public_base64_spki() if self._signer else ""
certificate = self._load_tls_bundle()
return EnrollmentRequestResult(
status=EnrollmentStatus.PENDING,
approval_reference=approval.reference,
server_nonce=server_nonce_b64,
poll_after_ms=3000,
server_certificate=certificate,
signing_key=signing_key,
)
def poll_enrollment(
self,
*,
approval_reference: str,
client_nonce_b64: str,
proof_signature_b64: str,
) -> PollingResult:
if not approval_reference:
raise EnrollmentValidationError("approval_reference_required")
if not client_nonce_b64:
raise EnrollmentValidationError("client_nonce_required")
if not proof_signature_b64:
raise EnrollmentValidationError("proof_sig_required")
approval = self._enrollment.fetch_device_approval_by_reference(approval_reference)
if approval is None:
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=404)
client_nonce = self._decode_base64(client_nonce_b64, "invalid_client_nonce")
server_nonce = self._decode_base64(approval.server_nonce_b64, "server_nonce_invalid")
proof_sig = self._decode_base64(proof_signature_b64, "invalid_proof_sig")
if approval.client_nonce_b64 != client_nonce_b64:
raise EnrollmentValidationError("nonce_mismatch")
self._verify_proof_signature(
approval=approval,
client_nonce=client_nonce,
server_nonce=server_nonce,
signature=proof_sig,
)
status = approval.status
if status is EnrollmentApprovalStatus.PENDING:
return PollingResult(
status=EnrollmentStatus.PENDING,
http_status=200,
poll_after_ms=5000,
)
if status is EnrollmentApprovalStatus.DENIED:
return PollingResult(
status=EnrollmentStatus.DENIED,
http_status=200,
reason="operator_denied",
)
if status is EnrollmentApprovalStatus.EXPIRED:
return PollingResult(status=EnrollmentStatus.EXPIRED, http_status=200)
if status is EnrollmentApprovalStatus.COMPLETED:
return PollingResult(
status=EnrollmentStatus.APPROVED,
http_status=200,
detail="finalized",
)
if status is not EnrollmentApprovalStatus.APPROVED:
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=400)
nonce_key = f"{approval.reference}:{proof_signature_b64}"
if not self._nonce_cache.consume(nonce_key):
raise EnrollmentValidationError("proof_replayed", http_status=409)
token_bundle = self._finalize_approval(approval)
signing_key = self._signer.public_base64_spki() if self._signer else ""
certificate = self._load_tls_bundle()
return PollingResult(
status=EnrollmentStatus.APPROVED,
http_status=200,
tokens=token_bundle,
server_certificate=certificate,
signing_key=signing_key,
)
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
def _enforce_rate_limit(
self,
limiter: SlidingWindowRateLimiter,
key: str,
*,
limit: int = 60,
window_seconds: float = 60.0,
) -> None:
decision = limiter.check(key, limit, window_seconds)
if not decision.allowed:
raise EnrollmentValidationError(
"rate_limited", http_status=429, retry_after=max(decision.retry_after, 1.0)
)
def _determine_reuse_guid(
self,
install_code: Optional[EnrollmentCode],
fingerprint: DeviceFingerprint,
) -> Optional[DeviceGuid]:
if install_code is None:
raise EnrollmentValidationError("invalid_enrollment_code")
if install_code.is_expired:
raise EnrollmentValidationError("invalid_enrollment_code")
if not install_code.is_exhausted:
return None
if not install_code.used_by_guid:
raise EnrollmentValidationError("invalid_enrollment_code")
existing = self._devices.fetch_by_guid(install_code.used_by_guid)
if existing and existing.identity.fingerprint.value == fingerprint.value:
return install_code.used_by_guid
raise EnrollmentValidationError("invalid_enrollment_code")
def _finalize_approval(self, approval: EnrollmentApproval) -> EnrollmentTokenBundle:
now = self._now()
effective_guid = approval.guid or DeviceGuid(str(uuid.uuid4()))
device_record = self._devices.ensure_device_record(
guid=effective_guid,
hostname=approval.claimed_hostname,
fingerprint=approval.claimed_fingerprint,
)
self._devices.record_device_key(
guid=effective_guid,
fingerprint=approval.claimed_fingerprint,
added_at=now,
)
if approval.enrollment_code_id:
code = self._enrollment.fetch_install_code_by_id(approval.enrollment_code_id)
if code is not None:
mark_first = code.used_at is None
self._enrollment.update_install_code_usage(
approval.enrollment_code_id,
use_count_increment=1,
last_used_at=now,
used_by_guid=effective_guid,
mark_first_use=mark_first,
)
refresh_token = secrets.token_urlsafe(48)
refresh_id = str(uuid.uuid4())
expires_at = now + timedelta(days=30)
token_hash = hashlib.sha256(refresh_token.encode("utf-8")).hexdigest()
self._tokens.create(
record_id=refresh_id,
guid=effective_guid,
token_hash=token_hash,
created_at=now,
expires_at=expires_at,
)
access_token = self._jwt.issue_access_token(
effective_guid.value,
device_record.identity.fingerprint.value,
max(device_record.token_version, 1),
)
self._enrollment.update_device_approval_status(
approval.record_id,
status=EnrollmentApprovalStatus.COMPLETED,
updated_at=now,
guid=effective_guid,
)
return EnrollmentTokenBundle(
guid=effective_guid,
access_token=access_token,
refresh_token=refresh_token,
expires_in=900,
)
def _verify_proof_signature(
self,
*,
approval: EnrollmentApproval,
client_nonce: bytes,
server_nonce: bytes,
signature: bytes,
) -> None:
message = server_nonce + approval.reference.encode("utf-8") + client_nonce
try:
public_key = serialization.load_der_public_key(approval.agent_pubkey_der)
except Exception as exc:
raise EnrollmentValidationError("agent_pubkey_invalid") from exc
try:
public_key.verify(signature, message)
except Exception as exc:
raise EnrollmentValidationError("invalid_proof") from exc
@staticmethod
def _decode_base64(value: str, error_code: str) -> bytes:
try:
return base64.b64decode(value, validate=True)
except Exception as exc:
raise EnrollmentValidationError(error_code) from exc
@staticmethod
def _mask_code(code: str) -> str:
trimmed = (code or "").strip()
if len(trimmed) <= 6:
return "***"
return f"{trimmed[:3]}***{trimmed[-3:]}"
@staticmethod
def _now() -> datetime:
return datetime.now(tz=timezone.utc)

View File

@@ -0,0 +1,26 @@
"""Error types shared across enrollment components."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
__all__ = ["EnrollmentValidationError"]
@dataclass(frozen=True, slots=True)
class EnrollmentValidationError(Exception):
"""Raised when enrollment input fails validation."""
code: str
http_status: int = 400
retry_after: Optional[float] = None
def to_response(self) -> dict[str, object]:
payload: dict[str, object] = {"error": self.code}
if self.retry_after is not None:
payload["retry_after"] = self.retry_after
return payload
def __str__(self) -> str: # pragma: no cover - debug helper
return f"{self.code} (status={self.http_status})"

View File

@@ -0,0 +1,32 @@
"""Nonce replay protection for enrollment workflows."""
from __future__ import annotations
import time
from threading import Lock
from typing import Dict
__all__ = ["NonceCache"]
class NonceCache:
"""Track recently observed nonces to prevent replay."""
def __init__(self, ttl_seconds: float = 300.0) -> None:
self._ttl = ttl_seconds
self._entries: Dict[str, float] = {}
self._lock = Lock()
def consume(self, key: str) -> bool:
"""Consume *key* if it has not been seen recently."""
now = time.monotonic()
with self._lock:
expiry = self._entries.get(key)
if expiry and expiry > now:
return False
self._entries[key] = now + self._ttl
stale = [nonce for nonce, ttl in self._entries.items() if ttl <= now]
for nonce in stale:
self._entries.pop(nonce, None)
return True

View File

@@ -0,0 +1,45 @@
"""In-process rate limiting utilities for the Borealis Engine."""
from __future__ import annotations
import time
from collections import deque
from dataclasses import dataclass
from threading import Lock
from typing import Deque, Dict
__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"]
@dataclass(frozen=True, slots=True)
class RateLimitDecision:
"""Result of a rate limit check."""
allowed: bool
retry_after: float
class SlidingWindowRateLimiter:
"""Tiny in-memory sliding window limiter suitable for single-process use."""
def __init__(self) -> None:
self._buckets: Dict[str, Deque[float]] = {}
self._lock = Lock()
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
now = time.monotonic()
with self._lock:
bucket = self._buckets.get(key)
if bucket is None:
bucket = deque()
self._buckets[key] = bucket
while bucket and now - bucket[0] > window_seconds:
bucket.popleft()
if len(bucket) >= limit:
retry_after = max(0.0, window_seconds - (now - bucket[0]))
return RateLimitDecision(False, retry_after)
bucket.append(now)
return RateLimitDecision(True, 0.0)