diff --git a/Data/Engine/CURRENT_STAGE.md b/Data/Engine/CURRENT_STAGE.md index 2ec116f..80cb072 100644 --- a/Data/Engine/CURRENT_STAGE.md +++ b/Data/Engine/CURRENT_STAGE.md @@ -37,7 +37,7 @@ - 7.2 Write integration tests exercising CRUD against a temporary SQLite file. - 7.3 Commit when repositories provide the required ports used by services. -8. Recreate HTTP interfaces +[COMPLETED] 8. Recreate HTTP interfaces - 8.1 Port health/enrollment/token blueprints into `interfaces/http//routes.py`, calling Engine services only. - 8.2 Ensure request validation occurs via builders; response schemas stay aligned with legacy JSON. - 8.3 Register blueprints through Engine `server.py`; confirm endpoints respond via manual or automated tests. diff --git a/Data/Engine/README.md b/Data/Engine/README.md index dcab08c..fda3c1a 100644 --- a/Data/Engine/README.md +++ b/Data/Engine/README.md @@ -25,14 +25,21 @@ The Engine mirrors the legacy defaults so it can boot without additional configu ## Bootstrapping flow 1. `Data/Engine/bootstrapper.py` loads the environment, configures logging, prepares the SQLite connection factory, optionally applies schema migrations, and builds the Flask application via `Data/Engine/server.py`. -2. Placeholder HTTP and Socket.IO registration hooks run so the Engine can start without any migrated routes yet. -3. The resulting runtime object exposes the Flask app, resolved settings, optional Socket.IO server, and the configured database connection factory. `bootstrapper.main()` runs the appropriate server based on whether Socket.IO is present. +2. A service container is assembled (`Data/Engine/services/container.py`) that wires repositories, JWT/DPoP helpers, and Engine services (device auth, token refresh, enrollment). The container is stored on the Flask app for interface modules to consume. +3. HTTP and Socket.IO interfaces register against the new service container. The resulting runtime object exposes the Flask app, resolved settings, optional Socket.IO server, and the configured database connection factory. `bootstrapper.main()` runs the appropriate server based on whether Socket.IO is present. As migration continues, services, repositories, interfaces, and integrations will live under their respective subpackages while maintaining isolation from the legacy server. -## Interface scaffolding +## HTTP interfaces -The Engine currently exposes placeholder HTTP blueprints under `Data/Engine/interfaces/http/` (agents, enrollment, tokens, admin, and health) so that future commits can drop in real routes without reshaping the bootstrap wiring. WebSocket namespaces follow the same pattern in `Data/Engine/interfaces/ws/`, with feature-oriented modules (e.g., `agents`, `job_management`) registered by `bootstrapper.bootstrap()` when Socket.IO is available. These stubs intentionally contain no business logic yet—they merely ensure the application factory exercises the full wiring path. +The Engine now exposes working HTTP routes alongside the remaining scaffolding: + +- `Data/Engine/interfaces/http/health.py` implements `GET /health` for liveness probes. +- `Data/Engine/interfaces/http/tokens.py` ports the refresh-token endpoint (`POST /api/agent/token/refresh`) using the Engine `TokenService` and request builders. +- `Data/Engine/interfaces/http/enrollment.py` handles the enrollment handshake (`/api/agent/enroll/request` and `/api/agent/enroll/poll`) with rate limiting, nonce protection, and repository-backed approvals. +- The admin and agent blueprints remain placeholders until their services migrate. + +WebSocket namespaces continue to follow the same pattern in `Data/Engine/interfaces/ws/`, with feature-oriented modules (e.g., `agents`, `job_management`) registered by `bootstrapper.bootstrap()` when Socket.IO is available. ## Authentication services @@ -43,7 +50,7 @@ Step 6 introduces the first real Engine services: - `Data/Engine/services/auth/device_auth_service.py` ports the legacy `DeviceAuthManager` into a repository-driven service that emits `DeviceAuthContext` instances from the new domain layer. - `Data/Engine/services/auth/token_service.py` issues refreshed access tokens while enforcing DPoP bindings and repository lookups. -Interfaces will begin consuming these services once the repository adapters land in the next milestone. +Interfaces now consume these services via the shared container, keeping business logic inside the Engine service layer while HTTP modules remain thin request/response translators. ## SQLite repositories diff --git a/Data/Engine/bootstrapper.py b/Data/Engine/bootstrapper.py index 1e4d7b7..7a18f45 100644 --- a/Data/Engine/bootstrapper.py +++ b/Data/Engine/bootstrapper.py @@ -16,6 +16,7 @@ from .interfaces import ( from .repositories.sqlite import connection as sqlite_connection from .repositories.sqlite import migrations as sqlite_migrations from .server import create_app +from .services.container import build_service_container @dataclass(frozen=True, slots=True) @@ -45,7 +46,9 @@ def bootstrap() -> EngineRuntime: logger.info("migrations-skipped") app = create_app(settings, db_factory=db_factory) - register_http_interfaces(app) + services = build_service_container(settings, db_factory=db_factory, logger=logger.getChild("services")) + app.extensions["engine_services"] = services + register_http_interfaces(app, services) socketio = create_socket_server(app, settings.socketio) register_ws_interfaces(socketio) logger.info("bootstrap-complete") diff --git a/Data/Engine/builders/device_enrollment.py b/Data/Engine/builders/device_enrollment.py index 0bb8fa9..26f64da 100644 --- a/Data/Engine/builders/device_enrollment.py +++ b/Data/Engine/builders/device_enrollment.py @@ -2,67 +2,96 @@ from __future__ import annotations +import base64 from dataclasses import dataclass from typing import Optional -from Data.Engine.domain.device_auth import DeviceFingerprint -from Data.Engine.domain.device_enrollment import EnrollmentRequest, ProofChallenge +from Data.Engine.domain.device_auth import DeviceFingerprint, sanitize_service_context +from Data.Engine.domain.device_enrollment import ProofChallenge +from Data.Engine.integrations.crypto import keys as crypto_keys +from Data.Engine.services.enrollment.errors import EnrollmentValidationError __all__ = [ "EnrollmentRequestBuilder", + "EnrollmentRequestInput", "ProofChallengeBuilder", ] @dataclass(frozen=True, slots=True) -class _EnrollmentPayload: +class EnrollmentRequestInput: + """Structured enrollment request payload ready for the service layer.""" + hostname: str enrollment_code: str - fingerprint: str + fingerprint: DeviceFingerprint client_nonce: bytes - server_nonce: bytes + client_nonce_b64: str + agent_public_key_der: bytes + service_context: Optional[str] class EnrollmentRequestBuilder: """Normalize agent enrollment JSON payloads into domain objects.""" def __init__(self) -> None: - self._payload: Optional[_EnrollmentPayload] = None + self._hostname: Optional[str] = None + self._enrollment_code: Optional[str] = None + self._agent_pubkey_b64: Optional[str] = None + self._client_nonce_b64: Optional[str] = None + self._service_context: Optional[str] = None def with_payload(self, payload: Optional[dict[str, object]]) -> "EnrollmentRequestBuilder": payload = payload or {} - hostname = str(payload.get("hostname") or "").strip() - enrollment_code = str(payload.get("enrollment_code") or "").strip() - fingerprint = str(payload.get("fingerprint") or "").strip() - client_nonce = self._coerce_bytes(payload.get("client_nonce")) - server_nonce = self._coerce_bytes(payload.get("server_nonce")) - self._payload = _EnrollmentPayload( - hostname=hostname, - enrollment_code=enrollment_code, - fingerprint=fingerprint, - client_nonce=client_nonce, - server_nonce=server_nonce, - ) + self._hostname = str(payload.get("hostname") or "").strip() + self._enrollment_code = str(payload.get("enrollment_code") or "").strip() + agent_pubkey = payload.get("agent_pubkey") + self._agent_pubkey_b64 = agent_pubkey if isinstance(agent_pubkey, str) else None + client_nonce = payload.get("client_nonce") + self._client_nonce_b64 = client_nonce if isinstance(client_nonce, str) else None return self - def build(self) -> EnrollmentRequest: - if not self._payload: - raise ValueError("payload has not been provided") - return EnrollmentRequest.from_payload( - hostname=self._payload.hostname, - enrollment_code=self._payload.enrollment_code, - fingerprint=self._payload.fingerprint, - client_nonce=self._payload.client_nonce, - server_nonce=self._payload.server_nonce, - ) + def with_service_context(self, value: Optional[str]) -> "EnrollmentRequestBuilder": + self._service_context = value + return self - @staticmethod - def _coerce_bytes(value: object) -> bytes: - if isinstance(value, (bytes, bytearray)): - return bytes(value) - if isinstance(value, str): - return value.encode("utf-8") - raise ValueError("nonce values must be bytes or base strings") + def build(self) -> EnrollmentRequestInput: + if not self._hostname: + raise EnrollmentValidationError("hostname_required") + if not self._enrollment_code: + raise EnrollmentValidationError("enrollment_code_required") + if not self._agent_pubkey_b64: + raise EnrollmentValidationError("agent_pubkey_required") + if not self._client_nonce_b64: + raise EnrollmentValidationError("client_nonce_required") + + try: + agent_pubkey_der = crypto_keys.spki_der_from_base64(self._agent_pubkey_b64) + except Exception as exc: # pragma: no cover - invalid input path + raise EnrollmentValidationError("invalid_agent_pubkey") from exc + + if len(agent_pubkey_der) < 10: + raise EnrollmentValidationError("invalid_agent_pubkey") + + try: + client_nonce_bytes = base64.b64decode(self._client_nonce_b64, validate=True) + except Exception as exc: # pragma: no cover - invalid input path + raise EnrollmentValidationError("invalid_client_nonce") from exc + + if len(client_nonce_bytes) < 16: + raise EnrollmentValidationError("invalid_client_nonce") + + fingerprint_value = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der) + + return EnrollmentRequestInput( + hostname=self._hostname, + enrollment_code=self._enrollment_code, + fingerprint=DeviceFingerprint(fingerprint_value), + client_nonce=client_nonce_bytes, + client_nonce_b64=self._client_nonce_b64, + agent_public_key_der=agent_pubkey_der, + service_context=sanitize_service_context(self._service_context), + ) class ProofChallengeBuilder: diff --git a/Data/Engine/domain/device_enrollment.py b/Data/Engine/domain/device_enrollment.py index 283d16d..713b4b5 100644 --- a/Data/Engine/domain/device_enrollment.py +++ b/Data/Engine/domain/device_enrollment.py @@ -4,6 +4,7 @@ from __future__ import annotations from dataclasses import dataclass from datetime import datetime, timezone +import base64 from enum import Enum from typing import Any, Mapping, Optional @@ -44,6 +45,7 @@ def _require(value: Optional[str], field: str) -> str: class EnrollmentCode: """Installer code metadata loaded from the persistence layer.""" + record_id: Optional[str] = None code: str expires_at: datetime max_uses: int @@ -67,6 +69,7 @@ class EnrollmentCode: used_by = record.get("used_by_guid") used_by_guid = DeviceGuid(used_by) if used_by else None return cls( + record_id=str(record.get("id") or "") or None, code=_require(record.get("code"), "code"), expires_at=_parse_iso8601(record.get("expires_at")) or datetime.now(tz=timezone.utc), max_uses=int(record.get("max_uses") or 1), @@ -88,6 +91,10 @@ class EnrollmentCode: def is_expired(self) -> bool: return self.expires_at <= datetime.now(tz=timezone.utc) + @property + def identifier(self) -> Optional[str]: + return self.record_id + class EnrollmentApprovalStatus(str, Enum): """Possible states for a device approval entry.""" @@ -181,6 +188,9 @@ class EnrollmentApproval: enrollment_code_id: Optional[str] created_at: datetime updated_at: datetime + client_nonce_b64: str + server_nonce_b64: str + agent_pubkey_der: bytes guid: Optional[DeviceGuid] = None approved_by: Optional[str] = None @@ -207,6 +217,9 @@ class EnrollmentApproval: updated_at=_parse_iso8601(record.get("updated_at")) or datetime.now(tz=timezone.utc), guid=DeviceGuid(guid_raw) if guid_raw else None, approved_by=(approved_raw or None), + client_nonce_b64=_require(record.get("client_nonce"), "client_nonce"), + server_nonce_b64=_require(record.get("server_nonce"), "server_nonce"), + agent_pubkey_der=bytes(record.get("agent_pubkey_der") or b""), ) @property @@ -219,3 +232,11 @@ class EnrollmentApproval: EnrollmentApprovalStatus.APPROVED, EnrollmentApprovalStatus.COMPLETED, } + + @property + def client_nonce_bytes(self) -> bytes: + return base64.b64decode(self.client_nonce_b64.encode("ascii"), validate=True) + + @property + def server_nonce_bytes(self) -> bytes: + return base64.b64decode(self.server_nonce_b64.encode("ascii"), validate=True) diff --git a/Data/Engine/integrations/crypto/__init__.py b/Data/Engine/integrations/crypto/__init__.py new file mode 100644 index 0000000..0d790bf --- /dev/null +++ b/Data/Engine/integrations/crypto/__init__.py @@ -0,0 +1,25 @@ +"""Crypto integration helpers for the Engine.""" + +from __future__ import annotations + +from .keys import ( + base64_from_spki_der, + fingerprint_from_base64_spki, + fingerprint_from_spki_der, + generate_ed25519_keypair, + normalize_base64, + private_key_to_pem, + public_key_to_pem, + spki_der_from_base64, +) + +__all__ = [ + "base64_from_spki_der", + "fingerprint_from_base64_spki", + "fingerprint_from_spki_der", + "generate_ed25519_keypair", + "normalize_base64", + "private_key_to_pem", + "public_key_to_pem", + "spki_der_from_base64", +] diff --git a/Data/Engine/integrations/crypto/keys.py b/Data/Engine/integrations/crypto/keys.py new file mode 100644 index 0000000..076a520 --- /dev/null +++ b/Data/Engine/integrations/crypto/keys.py @@ -0,0 +1,70 @@ +"""Key utilities mirrored from the legacy crypto helpers.""" + +from __future__ import annotations + +import base64 +import hashlib +import re +from typing import Tuple + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 +from cryptography.hazmat.primitives.serialization import load_der_public_key + +__all__ = [ + "base64_from_spki_der", + "fingerprint_from_base64_spki", + "fingerprint_from_spki_der", + "generate_ed25519_keypair", + "normalize_base64", + "private_key_to_pem", + "public_key_to_pem", + "spki_der_from_base64", +] + + +def generate_ed25519_keypair() -> Tuple[ed25519.Ed25519PrivateKey, bytes]: + private_key = ed25519.Ed25519PrivateKey.generate() + public_key = private_key.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return private_key, public_key + + +def normalize_base64(data: str) -> str: + cleaned = re.sub(r"\s+", "", data or "") + return cleaned.replace("-", "+").replace("_", "/") + + +def spki_der_from_base64(spki_b64: str) -> bytes: + return base64.b64decode(normalize_base64(spki_b64), validate=True) + + +def base64_from_spki_der(spki_der: bytes) -> str: + return base64.b64encode(spki_der).decode("ascii") + + +def fingerprint_from_spki_der(spki_der: bytes) -> str: + digest = hashlib.sha256(spki_der).hexdigest() + return digest.lower() + + +def fingerprint_from_base64_spki(spki_b64: str) -> str: + return fingerprint_from_spki_der(spki_der_from_base64(spki_b64)) + + +def private_key_to_pem(private_key: ed25519.Ed25519PrivateKey) -> bytes: + return private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + +def public_key_to_pem(public_spki_der: bytes) -> bytes: + public_key = load_der_public_key(public_spki_der) + return public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) diff --git a/Data/Engine/interfaces/http/__init__.py b/Data/Engine/interfaces/http/__init__.py index cc625c7..7626bf1 100644 --- a/Data/Engine/interfaces/http/__init__.py +++ b/Data/Engine/interfaces/http/__init__.py @@ -4,6 +4,8 @@ from __future__ import annotations from flask import Flask +from Data.Engine.services.container import EngineServiceContainer + from . import admin, agents, enrollment, health, tokens _REGISTRARS = ( @@ -15,14 +17,14 @@ _REGISTRARS = ( ) -def register_http_interfaces(app: Flask) -> None: +def register_http_interfaces(app: Flask, services: EngineServiceContainer) -> None: """Attach HTTP blueprints to *app*. The implementation is intentionally minimal for the initial scaffolding. """ for registrar in _REGISTRARS: - registrar(app) + registrar(app, services) __all__ = ["register_http_interfaces"] diff --git a/Data/Engine/interfaces/http/admin.py b/Data/Engine/interfaces/http/admin.py index fb95511..2da2ec2 100644 --- a/Data/Engine/interfaces/http/admin.py +++ b/Data/Engine/interfaces/http/admin.py @@ -4,11 +4,13 @@ from __future__ import annotations from flask import Blueprint, Flask +from Data.Engine.services.container import EngineServiceContainer + blueprint = Blueprint("engine_admin", __name__, url_prefix="/api/admin") -def register(app: Flask) -> None: +def register(app: Flask, _services: EngineServiceContainer) -> None: """Attach administrative routes to *app*. Concrete endpoints will be migrated in subsequent phases. diff --git a/Data/Engine/interfaces/http/agents.py b/Data/Engine/interfaces/http/agents.py index 618ade6..0485bd0 100644 --- a/Data/Engine/interfaces/http/agents.py +++ b/Data/Engine/interfaces/http/agents.py @@ -4,11 +4,13 @@ from __future__ import annotations from flask import Blueprint, Flask +from Data.Engine.services.container import EngineServiceContainer + blueprint = Blueprint("engine_agents", __name__, url_prefix="/api/agents") -def register(app: Flask) -> None: +def register(app: Flask, _services: EngineServiceContainer) -> None: """Attach agent management routes to *app*. Implementation will be populated as services migrate from the legacy server. diff --git a/Data/Engine/interfaces/http/enrollment.py b/Data/Engine/interfaces/http/enrollment.py index a514011..5d65ff5 100644 --- a/Data/Engine/interfaces/http/enrollment.py +++ b/Data/Engine/interfaces/http/enrollment.py @@ -2,20 +2,110 @@ from __future__ import annotations -from flask import Blueprint, Flask +from flask import Blueprint, Flask, current_app, jsonify, request + +from Data.Engine.builders.device_enrollment import EnrollmentRequestBuilder +from Data.Engine.services import EnrollmentValidationError +from Data.Engine.services.container import EngineServiceContainer + +AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context" -blueprint = Blueprint("engine_enrollment", __name__, url_prefix="/api/enrollment") +blueprint = Blueprint("engine_enrollment", __name__) -def register(app: Flask) -> None: - """Attach enrollment routes to *app*. - - Implementation will be ported during later migration phases. - """ +def register(app: Flask, _services: EngineServiceContainer) -> None: + """Attach enrollment routes to *app*.""" if "engine_enrollment" not in app.blueprints: app.register_blueprint(blueprint) -__all__ = ["register", "blueprint"] +@blueprint.route("/api/agent/enroll/request", methods=["POST"]) +def enrollment_request() -> object: + services: EngineServiceContainer = current_app.extensions["engine_services"] + payload = request.get_json(force=True, silent=True) + builder = EnrollmentRequestBuilder().with_payload(payload).with_service_context( + request.headers.get(AGENT_CONTEXT_HEADER) + ) + try: + normalized = builder.build() + result = services.enrollment_service.request_enrollment( + normalized, + remote_addr=_remote_addr(), + ) + except EnrollmentValidationError as exc: + response = jsonify(exc.to_response()) + response.status_code = exc.http_status + if exc.retry_after is not None: + response.headers["Retry-After"] = f"{int(exc.retry_after)}" + return response + + response_payload = { + "status": result.status, + "approval_reference": result.approval_reference, + "server_nonce": result.server_nonce, + "poll_after_ms": result.poll_after_ms, + "server_certificate": result.server_certificate, + "signing_key": result.signing_key, + } + response = jsonify(response_payload) + response.status_code = result.http_status + if result.retry_after is not None: + response.headers["Retry-After"] = f"{int(result.retry_after)}" + return response + + +@blueprint.route("/api/agent/enroll/poll", methods=["POST"]) +def enrollment_poll() -> object: + services: EngineServiceContainer = current_app.extensions["engine_services"] + payload = request.get_json(force=True, silent=True) or {} + approval_reference = str(payload.get("approval_reference") or "").strip() + client_nonce = str(payload.get("client_nonce") or "").strip() + proof_sig = str(payload.get("proof_sig") or "").strip() + + try: + result = services.enrollment_service.poll_enrollment( + approval_reference=approval_reference, + client_nonce_b64=client_nonce, + proof_signature_b64=proof_sig, + ) + except EnrollmentValidationError as exc: + return jsonify(exc.to_response()), exc.http_status + + body = {"status": result.status} + if result.poll_after_ms is not None: + body["poll_after_ms"] = result.poll_after_ms + if result.reason: + body["reason"] = result.reason + if result.detail: + body["detail"] = result.detail + if result.tokens: + body.update( + { + "guid": result.tokens.guid.value, + "access_token": result.tokens.access_token, + "refresh_token": result.tokens.refresh_token, + "token_type": result.tokens.token_type, + "expires_in": result.tokens.expires_in, + "server_certificate": result.server_certificate or "", + "signing_key": result.signing_key or "", + } + ) + else: + if result.server_certificate: + body["server_certificate"] = result.server_certificate + if result.signing_key: + body["signing_key"] = result.signing_key + + return jsonify(body), result.http_status + + +def _remote_addr() -> str: + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + return (request.remote_addr or "unknown").strip() + + +__all__ = ["register", "blueprint", "enrollment_request", "enrollment_poll"] diff --git a/Data/Engine/interfaces/http/health.py b/Data/Engine/interfaces/http/health.py index 4cbfa46..37e74a7 100644 --- a/Data/Engine/interfaces/http/health.py +++ b/Data/Engine/interfaces/http/health.py @@ -1,21 +1,26 @@ -"""Health check HTTP interface placeholders for the Engine.""" +"""Health check HTTP interface for the Engine.""" from __future__ import annotations -from flask import Blueprint, Flask +from flask import Blueprint, Flask, jsonify +from Data.Engine.services.container import EngineServiceContainer blueprint = Blueprint("engine_health", __name__) -def register(app: Flask) -> None: - """Attach health-related routes to *app*. - - Routes will be populated in later migration phases. - """ +def register(app: Flask, _services: EngineServiceContainer) -> None: + """Attach health-related routes to *app*.""" if "engine_health" not in app.blueprints: app.register_blueprint(blueprint) +@blueprint.route("/health", methods=["GET"]) +def health() -> object: + """Return a basic liveness response.""" + + return jsonify({"status": "ok"}) + + __all__ = ["register", "blueprint"] diff --git a/Data/Engine/interfaces/http/tokens.py b/Data/Engine/interfaces/http/tokens.py index 6aa4bbc..89bbc3e 100644 --- a/Data/Engine/interfaces/http/tokens.py +++ b/Data/Engine/interfaces/http/tokens.py @@ -1,21 +1,52 @@ -"""Token management HTTP interface placeholders for the Engine.""" +"""Token management HTTP interface for the Engine.""" from __future__ import annotations -from flask import Blueprint, Flask +from flask import Blueprint, Flask, current_app, jsonify, request + +from Data.Engine.builders.device_auth import RefreshTokenRequestBuilder +from Data.Engine.domain.device_auth import DeviceAuthFailure +from Data.Engine.services.container import EngineServiceContainer +from Data.Engine.services import TokenRefreshError + +blueprint = Blueprint("engine_tokens", __name__) -blueprint = Blueprint("engine_tokens", __name__, url_prefix="/api/tokens") - - -def register(app: Flask) -> None: - """Attach token management routes to *app*. - - Implementation will be introduced as authentication services are migrated. - """ +def register(app: Flask, _services: EngineServiceContainer) -> None: + """Attach token management routes to *app*.""" if "engine_tokens" not in app.blueprints: app.register_blueprint(blueprint) -__all__ = ["register", "blueprint"] +@blueprint.route("/api/agent/token/refresh", methods=["POST"]) +def refresh_token() -> object: + services: EngineServiceContainer = current_app.extensions["engine_services"] + builder = ( + RefreshTokenRequestBuilder() + .with_payload(request.get_json(force=True, silent=True)) + .with_http_method(request.method) + .with_htu(request.url) + .with_dpop_proof(request.headers.get("DPoP")) + ) + try: + refresh_request = builder.build() + except DeviceAuthFailure as exc: + payload = exc.to_dict() + return jsonify(payload), exc.http_status + + try: + response = services.token_service.refresh_access_token(refresh_request) + except TokenRefreshError as exc: + return jsonify(exc.to_dict()), exc.http_status + + return jsonify( + { + "access_token": response.access_token, + "expires_in": response.expires_in, + "token_type": response.token_type, + } + ) + + +__all__ = ["register", "blueprint", "refresh_token"] diff --git a/Data/Engine/repositories/sqlite/device_repository.py b/Data/Engine/repositories/sqlite/device_repository.py index 35fc00c..481243e 100644 --- a/Data/Engine/repositories/sqlite/device_repository.py +++ b/Data/Engine/repositories/sqlite/device_repository.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging import sqlite3 import time +import uuid from contextlib import closing from datetime import datetime, timezone from typing import Optional @@ -152,6 +153,133 @@ class SQLiteDeviceRepository: return self._row_to_record(row) + def ensure_device_record( + self, + *, + guid: DeviceGuid, + hostname: str, + fingerprint: DeviceFingerprint, + ) -> DeviceRecord: + now_iso = datetime.now(tz=timezone.utc).isoformat() + now_ts = int(time.time()) + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at + FROM devices + WHERE UPPER(guid) = ? + """, + (guid.value.upper(),), + ) + row = cur.fetchone() + + if row: + stored_fp = (row[4] or "").strip().lower() + new_fp = fingerprint.value + if not stored_fp: + cur.execute( + "UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?", + (new_fp, now_iso, row[0]), + ) + elif stored_fp != new_fp: + token_version = self._coerce_int(row[2], default=1) + 1 + cur.execute( + """ + UPDATE devices + SET ssl_key_fingerprint = ?, + key_added_at = ?, + token_version = ?, + status = 'active' + WHERE guid = ? + """, + (new_fp, now_iso, token_version, row[0]), + ) + cur.execute( + """ + UPDATE refresh_tokens + SET revoked_at = ? + WHERE guid = ? + AND revoked_at IS NULL + """, + (now_iso, row[0]), + ) + conn.commit() + else: + resolved_hostname = self._resolve_hostname(cur, hostname, guid) + cur.execute( + """ + INSERT INTO devices ( + guid, + hostname, + created_at, + last_seen, + ssl_key_fingerprint, + token_version, + status, + key_added_at + ) + VALUES (?, ?, ?, ?, ?, 1, 'active', ?) + """, + ( + guid.value, + resolved_hostname, + now_ts, + now_ts, + fingerprint.value, + now_iso, + ), + ) + conn.commit() + + cur.execute( + """ + SELECT guid, ssl_key_fingerprint, token_version, status + FROM devices + WHERE UPPER(guid) = ? + """, + (guid.value.upper(),), + ) + latest = cur.fetchone() + + if not latest: + raise RuntimeError("device record could not be ensured") + + record = self._row_to_record(latest) + if record is None: + raise RuntimeError("device record invalid after ensure") + return record + + def record_device_key( + self, + *, + guid: DeviceGuid, + fingerprint: DeviceFingerprint, + added_at: datetime, + ) -> None: + added_iso = added_at.astimezone(timezone.utc).isoformat() + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at) + VALUES (?, ?, ?, ?) + """, + (str(uuid.uuid4()), guid.value, fingerprint.value, added_iso), + ) + cur.execute( + """ + UPDATE device_keys + SET retired_at = ? + WHERE guid = ? + AND ssl_key_fingerprint != ? + AND retired_at IS NULL + """, + (added_iso, guid.value, fingerprint.value), + ) + conn.commit() + def _row_to_record(self, row: tuple) -> Optional[DeviceRecord]: try: guid = DeviceGuid(row[0]) @@ -181,3 +309,31 @@ class SQLiteDeviceRepository: token_version=max(token_version, 1), status=status, ) + + @staticmethod + def _coerce_int(value: object, *, default: int = 0) -> int: + try: + return int(value) + except Exception: + return default + + def _resolve_hostname(self, cur: sqlite3.Cursor, hostname: str, guid: DeviceGuid) -> str: + base = (hostname or "").strip() or guid.value + base = base[:253] + candidate = base + suffix = 1 + while True: + cur.execute( + "SELECT guid FROM devices WHERE hostname = ?", + (candidate,), + ) + row = cur.fetchone() + if not row: + return candidate + existing = (row[0] or "").strip().upper() + if existing == guid.value: + return candidate + candidate = f"{base}-{suffix}" + suffix += 1 + if suffix > 50: + return guid.value diff --git a/Data/Engine/repositories/sqlite/enrollment_repository.py b/Data/Engine/repositories/sqlite/enrollment_repository.py index 207bbce..a6549ec 100644 --- a/Data/Engine/repositories/sqlite/enrollment_repository.py +++ b/Data/Engine/repositories/sqlite/enrollment_repository.py @@ -78,6 +78,50 @@ class SQLiteEnrollmentRepository: self._log.warning("invalid enrollment code record for code=%s: %s", code_value, exc) return None + def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]: + record_value = (record_id or "").strip() + if not record_value: + return None + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT id, + code, + expires_at, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at + FROM enrollment_install_codes + WHERE id = ? + """, + (record_value,), + ) + row = cur.fetchone() + + if not row: + return None + + record = { + "id": row[0], + "code": row[1], + "expires_at": row[2], + "used_at": row[3], + "used_by_guid": row[4], + "max_uses": row[5], + "use_count": row[6], + "last_used_at": row[7], + } + + try: + return EnrollmentCode.from_mapping(record) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc) + return None + def update_install_code_usage( self, record_id: str, @@ -135,6 +179,53 @@ class SQLiteEnrollmentRepository: return None return self._fetch_device_approval("id = ?", (record_value,)) + def fetch_pending_approval_by_fingerprint( + self, fingerprint: DeviceFingerprint + ) -> Optional[EnrollmentApproval]: + return self._fetch_device_approval( + "ssl_key_fingerprint_claimed = ? AND status = 'pending'", + (fingerprint.value,), + ) + + def update_pending_approval( + self, + record_id: str, + *, + hostname: str, + guid: Optional[DeviceGuid], + enrollment_code_id: Optional[str], + client_nonce_b64: str, + server_nonce_b64: str, + agent_pubkey_der: bytes, + updated_at: datetime, + ) -> None: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + UPDATE device_approvals + SET hostname_claimed = ?, + guid = ?, + enrollment_code_id = ?, + client_nonce = ?, + server_nonce = ?, + agent_pubkey_der = ?, + updated_at = ? + WHERE id = ? + """, + ( + hostname, + guid.value if guid else None, + enrollment_code_id, + client_nonce_b64, + server_nonce_b64, + agent_pubkey_der, + self._isoformat(updated_at), + record_id, + ), + ) + conn.commit() + def create_device_approval( self, *, @@ -143,8 +234,8 @@ class SQLiteEnrollmentRepository: claimed_hostname: str, claimed_fingerprint: DeviceFingerprint, enrollment_code_id: Optional[str], - client_nonce: bytes, - server_nonce: bytes, + client_nonce_b64: str, + server_nonce_b64: str, agent_pubkey_der: bytes, created_at: datetime, status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING, @@ -183,8 +274,8 @@ class SQLiteEnrollmentRepository: status.value, created_iso, created_iso, - client_nonce, - server_nonce, + client_nonce_b64, + server_nonce_b64, agent_pubkey_der, ), ) @@ -244,7 +335,10 @@ class SQLiteEnrollmentRepository: created_at, updated_at, status, - approved_by_user_id + approved_by_user_id, + client_nonce, + server_nonce, + agent_pubkey_der FROM device_approvals WHERE {where} """, @@ -266,6 +360,9 @@ class SQLiteEnrollmentRepository: "updated_at": row[7], "status": row[8], "approved_by_user_id": row[9], + "client_nonce": row[10], + "server_nonce": row[11], + "agent_pubkey_der": row[12], } try: diff --git a/Data/Engine/repositories/sqlite/token_repository.py b/Data/Engine/repositories/sqlite/token_repository.py index 5a2850d..fb2f605 100644 --- a/Data/Engine/repositories/sqlite/token_repository.py +++ b/Data/Engine/repositories/sqlite/token_repository.py @@ -78,6 +78,35 @@ class SQLiteRefreshTokenRepository: ) conn.commit() + def create( + self, + *, + record_id: str, + guid: DeviceGuid, + token_hash: str, + created_at: datetime, + expires_at: Optional[datetime], + ) -> None: + created_iso = self._isoformat(created_at) + expires_iso = self._isoformat(expires_at) if expires_at else None + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO refresh_tokens ( + id, + guid, + token_hash, + created_at, + expires_at + ) + VALUES (?, ?, ?, ?, ?) + """, + (record_id, guid.value, token_hash, created_iso, expires_iso), + ) + conn.commit() + def _row_to_record(self, row: tuple) -> Optional[RefreshTokenRecord]: try: guid = DeviceGuid(row[1]) diff --git a/Data/Engine/runtime.py b/Data/Engine/runtime.py new file mode 100644 index 0000000..7e0167e --- /dev/null +++ b/Data/Engine/runtime.py @@ -0,0 +1,139 @@ +"""Runtime filesystem helpers for the Borealis Engine.""" + +from __future__ import annotations + +import os +from functools import lru_cache +from pathlib import Path +from typing import Optional + +__all__ = [ + "agent_certificates_path", + "ensure_agent_certificates_dir", + "ensure_certificates_dir", + "ensure_runtime_dir", + "ensure_server_certificates_dir", + "project_root", + "runtime_path", + "server_certificates_path", +] + + +def _env_path(name: str) -> Optional[Path]: + value = os.environ.get(name) + if not value: + return None + try: + return Path(value).expanduser().resolve() + except Exception: + return None + + +@lru_cache(maxsize=None) +def project_root() -> Path: + env = _env_path("BOREALIS_PROJECT_ROOT") + if env: + return env + + current = Path(__file__).resolve() + for parent in current.parents: + if (parent / "Borealis.ps1").exists() or (parent / ".git").is_dir(): + return parent + + try: + return current.parents[1] + except IndexError: + return current.parent + + +@lru_cache(maxsize=None) +def server_runtime_root() -> Path: + env = _env_path("BOREALIS_ENGINE_ROOT") or _env_path("BOREALIS_SERVER_ROOT") + if env: + env.mkdir(parents=True, exist_ok=True) + return env + + root = project_root() / "Engine" + root.mkdir(parents=True, exist_ok=True) + return root + + +def runtime_path(*parts: str) -> Path: + return server_runtime_root().joinpath(*parts) + + +def ensure_runtime_dir(*parts: str) -> Path: + path = runtime_path(*parts) + path.mkdir(parents=True, exist_ok=True) + return path + + +@lru_cache(maxsize=None) +def certificates_root() -> Path: + env = _env_path("BOREALIS_CERTIFICATES_ROOT") or _env_path("BOREALIS_CERT_ROOT") + if env: + env.mkdir(parents=True, exist_ok=True) + return env + + root = project_root() / "Certificates" + root.mkdir(parents=True, exist_ok=True) + for name in ("Server", "Agent"): + try: + (root / name).mkdir(parents=True, exist_ok=True) + except Exception: + pass + return root + + +@lru_cache(maxsize=None) +def server_certificates_root() -> Path: + env = _env_path("BOREALIS_SERVER_CERT_ROOT") + if env: + env.mkdir(parents=True, exist_ok=True) + return env + + root = certificates_root() / "Server" + root.mkdir(parents=True, exist_ok=True) + return root + + +@lru_cache(maxsize=None) +def agent_certificates_root() -> Path: + env = _env_path("BOREALIS_AGENT_CERT_ROOT") + if env: + env.mkdir(parents=True, exist_ok=True) + return env + + root = certificates_root() / "Agent" + root.mkdir(parents=True, exist_ok=True) + return root + + +def certificates_path(*parts: str) -> Path: + return certificates_root().joinpath(*parts) + + +def ensure_certificates_dir(*parts: str) -> Path: + path = certificates_path(*parts) + path.mkdir(parents=True, exist_ok=True) + return path + + +def server_certificates_path(*parts: str) -> Path: + return server_certificates_root().joinpath(*parts) + + +def ensure_server_certificates_dir(*parts: str) -> Path: + path = server_certificates_path(*parts) + path.mkdir(parents=True, exist_ok=True) + return path + + +def agent_certificates_path(*parts: str) -> Path: + return agent_certificates_root().joinpath(*parts) + + +def ensure_agent_certificates_dir(*parts: str) -> Path: + path = agent_certificates_path(*parts) + path.mkdir(parents=True, exist_ok=True) + return path diff --git a/Data/Engine/services/__init__.py b/Data/Engine/services/__init__.py index 0f932d1..dcb8c24 100644 --- a/Data/Engine/services/__init__.py +++ b/Data/Engine/services/__init__.py @@ -10,6 +10,14 @@ from .auth import ( TokenRefreshErrorCode, TokenService, ) +from .enrollment import ( + EnrollmentRequestResult, + EnrollmentService, + EnrollmentStatus, + EnrollmentTokenBundle, + EnrollmentValidationError, + PollingResult, +) __all__ = [ "DeviceAuthService", @@ -18,4 +26,10 @@ __all__ = [ "TokenRefreshError", "TokenRefreshErrorCode", "TokenService", + "EnrollmentService", + "EnrollmentRequestResult", + "EnrollmentStatus", + "EnrollmentTokenBundle", + "EnrollmentValidationError", + "PollingResult", ] diff --git a/Data/Engine/services/auth/__init__.py b/Data/Engine/services/auth/__init__.py index a4efad1..f24d072 100644 --- a/Data/Engine/services/auth/__init__.py +++ b/Data/Engine/services/auth/__init__.py @@ -3,6 +3,8 @@ from __future__ import annotations from .device_auth_service import DeviceAuthService, DeviceRecord +from .dpop import DPoPReplayError, DPoPVerificationError, DPoPValidator +from .jwt_service import JWTService, load_service as load_jwt_service from .token_service import ( RefreshTokenRecord, TokenRefreshError, @@ -13,6 +15,11 @@ from .token_service import ( __all__ = [ "DeviceAuthService", "DeviceRecord", + "DPoPReplayError", + "DPoPVerificationError", + "DPoPValidator", + "JWTService", + "load_jwt_service", "RefreshTokenRecord", "TokenRefreshError", "TokenRefreshErrorCode", diff --git a/Data/Engine/services/auth/dpop.py b/Data/Engine/services/auth/dpop.py new file mode 100644 index 0000000..2ea7e02 --- /dev/null +++ b/Data/Engine/services/auth/dpop.py @@ -0,0 +1,105 @@ +"""DPoP proof validation for Engine services.""" + +from __future__ import annotations + +import hashlib +import time +from threading import Lock +from typing import Dict, Optional + +import jwt + +__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"] + + +_DP0P_MAX_SKEW = 300.0 + + +class DPoPVerificationError(Exception): + pass + + +class DPoPReplayError(DPoPVerificationError): + pass + + +class DPoPValidator: + def __init__(self) -> None: + self._observed_jti: Dict[str, float] = {} + self._lock = Lock() + + def verify( + self, + method: str, + htu: str, + proof: str, + access_token: Optional[str] = None, + ) -> str: + if not proof: + raise DPoPVerificationError("DPoP proof missing") + + try: + header = jwt.get_unverified_header(proof) + except Exception as exc: + raise DPoPVerificationError("invalid DPoP header") from exc + + jwk = header.get("jwk") + alg = header.get("alg") + if not jwk or not isinstance(jwk, dict): + raise DPoPVerificationError("missing jwk in DPoP header") + if alg not in ("EdDSA", "ES256", "ES384", "ES512"): + raise DPoPVerificationError(f"unsupported DPoP alg {alg}") + + try: + key = jwt.PyJWK(jwk) + public_key = key.key + except Exception as exc: + raise DPoPVerificationError("invalid jwk in DPoP header") from exc + + try: + claims = jwt.decode( + proof, + public_key, + algorithms=[alg], + options={"require": ["htm", "htu", "jti", "iat"]}, + ) + except Exception as exc: + raise DPoPVerificationError("invalid DPoP signature") from exc + + htm = claims.get("htm") + proof_htu = claims.get("htu") + jti = claims.get("jti") + iat = claims.get("iat") + ath = claims.get("ath") + + if not isinstance(htm, str) or htm.lower() != method.lower(): + raise DPoPVerificationError("DPoP htm mismatch") + if not isinstance(proof_htu, str) or proof_htu != htu: + raise DPoPVerificationError("DPoP htu mismatch") + if not isinstance(jti, str): + raise DPoPVerificationError("DPoP jti missing") + if not isinstance(iat, (int, float)): + raise DPoPVerificationError("DPoP iat missing") + + now = time.time() + if abs(now - float(iat)) > _DP0P_MAX_SKEW: + raise DPoPVerificationError("DPoP proof outside allowed skew") + + if ath and access_token: + expected_ath = jwt.utils.base64url_encode( + hashlib.sha256(access_token.encode("utf-8")).digest() + ).decode("ascii") + if expected_ath != ath: + raise DPoPVerificationError("DPoP ath mismatch") + + with self._lock: + expiry = self._observed_jti.get(jti) + if expiry and expiry > now: + raise DPoPReplayError("DPoP proof replay detected") + self._observed_jti[jti] = now + _DP0P_MAX_SKEW + stale = [key for key, exp in self._observed_jti.items() if exp <= now] + for key in stale: + self._observed_jti.pop(key, None) + + thumbprint = jwt.PyJWK(jwk).thumbprint() + return thumbprint.decode("ascii") diff --git a/Data/Engine/services/auth/jwt_service.py b/Data/Engine/services/auth/jwt_service.py new file mode 100644 index 0000000..6a9d2e9 --- /dev/null +++ b/Data/Engine/services/auth/jwt_service.py @@ -0,0 +1,124 @@ +"""JWT issuance utilities for the Engine.""" + +from __future__ import annotations + +import hashlib +import time +from typing import Any, Dict, Optional + +import jwt +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 + +from Data.Engine.runtime import ensure_runtime_dir, runtime_path + +__all__ = ["JWTService", "load_service"] + + +_KEY_DIR = runtime_path("auth_keys") +_KEY_FILE = _KEY_DIR / "engine-jwt-ed25519.key" +_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-jwt-ed25519.key" + + +class JWTService: + def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str) -> None: + self._private_key = private_key + self._public_key = private_key.public_key() + self._key_id = key_id + + @property + def key_id(self) -> str: + return self._key_id + + def issue_access_token( + self, + guid: str, + ssl_key_fingerprint: str, + token_version: int, + expires_in: int = 900, + extra_claims: Optional[Dict[str, Any]] = None, + ) -> str: + now = int(time.time()) + payload: Dict[str, Any] = { + "sub": f"device:{guid}", + "guid": guid, + "ssl_key_fingerprint": ssl_key_fingerprint, + "token_version": int(token_version), + "iat": now, + "nbf": now, + "exp": now + int(expires_in), + } + if extra_claims: + payload.update(extra_claims) + + token = jwt.encode( + payload, + self._private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ), + algorithm="EdDSA", + headers={"kid": self._key_id}, + ) + return token + + def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]: + options = {"require": ["exp", "iat", "sub"]} + public_pem = self._public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return jwt.decode( + token, + public_pem, + algorithms=["EdDSA"], + audience=audience, + options=options, + ) + + def public_jwk(self) -> Dict[str, Any]: + public_bytes = self._public_key.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, + ) + jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii") + return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x} + + +def load_service() -> JWTService: + private_key = _load_or_create_private_key() + public_bytes = private_key.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + key_id = hashlib.sha256(public_bytes).hexdigest()[:16] + return JWTService(private_key, key_id) + + +def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey: + ensure_runtime_dir("auth_keys") + + if _KEY_FILE.exists(): + with _KEY_FILE.open("rb") as fh: + return serialization.load_pem_private_key(fh.read(), password=None) + + if _LEGACY_KEY_FILE.exists(): + with _LEGACY_KEY_FILE.open("rb") as fh: + return serialization.load_pem_private_key(fh.read(), password=None) + + private_key = ed25519.Ed25519PrivateKey.generate() + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + _KEY_DIR.mkdir(parents=True, exist_ok=True) + with _KEY_FILE.open("wb") as fh: + fh.write(pem) + try: + if hasattr(_KEY_FILE, "chmod"): + _KEY_FILE.chmod(0o600) + except Exception: + pass + return private_key diff --git a/Data/Engine/services/container.py b/Data/Engine/services/container.py new file mode 100644 index 0000000..4e883d8 --- /dev/null +++ b/Data/Engine/services/container.py @@ -0,0 +1,119 @@ +"""Service container assembly for the Borealis Engine.""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Optional + +from Data.Engine.config import EngineSettings +from Data.Engine.repositories.sqlite import ( + SQLiteConnectionFactory, + SQLiteDeviceRepository, + SQLiteEnrollmentRepository, + SQLiteRefreshTokenRepository, +) +from Data.Engine.services.auth import ( + DeviceAuthService, + DPoPValidator, + JWTService, + TokenService, + load_jwt_service, +) +from Data.Engine.services.crypto.signing import ScriptSigner, load_signer +from Data.Engine.services.enrollment import EnrollmentService +from Data.Engine.services.enrollment.nonce_cache import NonceCache +from Data.Engine.services.rate_limit import SlidingWindowRateLimiter + +__all__ = ["EngineServiceContainer", "build_service_container"] + + +@dataclass(frozen=True, slots=True) +class EngineServiceContainer: + device_auth: DeviceAuthService + token_service: TokenService + enrollment_service: EnrollmentService + jwt_service: JWTService + dpop_validator: DPoPValidator + + +def build_service_container( + settings: EngineSettings, + *, + db_factory: SQLiteConnectionFactory, + logger: Optional[logging.Logger] = None, +) -> EngineServiceContainer: + log = logger or logging.getLogger("borealis.engine.services") + + device_repo = SQLiteDeviceRepository(db_factory, logger=log.getChild("devices")) + token_repo = SQLiteRefreshTokenRepository(db_factory, logger=log.getChild("tokens")) + enrollment_repo = SQLiteEnrollmentRepository(db_factory, logger=log.getChild("enrollment")) + + jwt_service = load_jwt_service() + dpop_validator = DPoPValidator() + rate_limiter = SlidingWindowRateLimiter() + + token_service = TokenService( + refresh_token_repository=token_repo, + device_repository=device_repo, + jwt_service=jwt_service, + dpop_validator=dpop_validator, + logger=log.getChild("token_service"), + ) + + enrollment_service = EnrollmentService( + device_repository=device_repo, + enrollment_repository=enrollment_repo, + token_repository=token_repo, + jwt_service=jwt_service, + tls_bundle_loader=_tls_bundle_loader(settings), + ip_rate_limiter=SlidingWindowRateLimiter(), + fingerprint_rate_limiter=SlidingWindowRateLimiter(), + nonce_cache=NonceCache(), + script_signer=_load_script_signer(log), + logger=log.getChild("enrollment"), + ) + + device_auth = DeviceAuthService( + device_repository=device_repo, + jwt_service=jwt_service, + logger=log.getChild("device_auth"), + rate_limiter=rate_limiter, + dpop_validator=dpop_validator, + ) + + return EngineServiceContainer( + device_auth=device_auth, + token_service=token_service, + enrollment_service=enrollment_service, + jwt_service=jwt_service, + dpop_validator=dpop_validator, + ) + + +def _tls_bundle_loader(settings: EngineSettings) -> Callable[[], str]: + candidates = [ + Path(os.getenv("BOREALIS_TLS_BUNDLE", "")), + settings.project_root / "Certificates" / "Server" / "borealis-server-bundle.pem", + ] + + def loader() -> str: + for candidate in candidates: + if candidate and candidate.is_file(): + try: + return candidate.read_text(encoding="utf-8") + except Exception: + continue + return "" + + return loader + + +def _load_script_signer(logger: logging.Logger) -> Optional[ScriptSigner]: + try: + return load_signer() + except Exception as exc: + logger.warning("script signer unavailable: %s", exc) + return None diff --git a/Data/Engine/services/crypto/signing.py b/Data/Engine/services/crypto/signing.py new file mode 100644 index 0000000..17d8875 --- /dev/null +++ b/Data/Engine/services/crypto/signing.py @@ -0,0 +1,75 @@ +"""Script signing utilities for the Engine.""" + +from __future__ import annotations + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 + +from Data.Engine.integrations.crypto.keys import base64_from_spki_der +from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path + +__all__ = ["ScriptSigner", "load_signer"] + + +_KEY_DIR = server_certificates_path("Code-Signing") +_SIGNING_KEY_FILE = _KEY_DIR / "engine-script-ed25519.key" +_SIGNING_PUB_FILE = _KEY_DIR / "engine-script-ed25519.pub" +_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-script-ed25519.key" +_LEGACY_PUB_FILE = runtime_path("keys") / "borealis-script-ed25519.pub" + + +class ScriptSigner: + def __init__(self, private_key: ed25519.Ed25519PrivateKey) -> None: + self._private = private_key + self._public = private_key.public_key() + + def sign(self, payload: bytes) -> bytes: + return self._private.sign(payload) + + def public_spki_der(self) -> bytes: + return self._public.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + def public_base64_spki(self) -> str: + return base64_from_spki_der(self.public_spki_der()) + + +def load_signer() -> ScriptSigner: + private_key = _load_or_create() + return ScriptSigner(private_key) + + +def _load_or_create() -> ed25519.Ed25519PrivateKey: + ensure_server_certificates_dir("Code-Signing") + + if _SIGNING_KEY_FILE.exists(): + with _SIGNING_KEY_FILE.open("rb") as fh: + return serialization.load_pem_private_key(fh.read(), password=None) + + if _LEGACY_KEY_FILE.exists(): + with _LEGACY_KEY_FILE.open("rb") as fh: + return serialization.load_pem_private_key(fh.read(), password=None) + + private_key = ed25519.Ed25519PrivateKey.generate() + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + _KEY_DIR.mkdir(parents=True, exist_ok=True) + _SIGNING_KEY_FILE.write_bytes(pem) + try: + if hasattr(_SIGNING_KEY_FILE, "chmod"): + _SIGNING_KEY_FILE.chmod(0o600) + except Exception: + pass + + pub_der = private_key.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + _SIGNING_PUB_FILE.write_bytes(pub_der) + + return private_key diff --git a/Data/Engine/services/enrollment/__init__.py b/Data/Engine/services/enrollment/__init__.py new file mode 100644 index 0000000..129d32c --- /dev/null +++ b/Data/Engine/services/enrollment/__init__.py @@ -0,0 +1,21 @@ +"""Enrollment services for the Borealis Engine.""" + +from __future__ import annotations + +from .enrollment_service import ( + EnrollmentRequestResult, + EnrollmentService, + EnrollmentStatus, + EnrollmentTokenBundle, + EnrollmentValidationError, + PollingResult, +) + +__all__ = [ + "EnrollmentRequestResult", + "EnrollmentService", + "EnrollmentStatus", + "EnrollmentTokenBundle", + "EnrollmentValidationError", + "PollingResult", +] diff --git a/Data/Engine/services/enrollment/enrollment_service.py b/Data/Engine/services/enrollment/enrollment_service.py new file mode 100644 index 0000000..ae960e7 --- /dev/null +++ b/Data/Engine/services/enrollment/enrollment_service.py @@ -0,0 +1,487 @@ +"""Enrollment workflow orchestration for the Borealis Engine.""" + +from __future__ import annotations + +import base64 +import hashlib +import logging +import secrets +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Callable, Optional, Protocol + +from cryptography.hazmat.primitives import serialization + +from Data.Engine.builders.device_enrollment import EnrollmentRequestInput +from Data.Engine.domain.device_auth import ( + DeviceFingerprint, + DeviceGuid, + sanitize_service_context, +) +from Data.Engine.domain.device_enrollment import ( + EnrollmentApproval, + EnrollmentApprovalStatus, + EnrollmentCode, +) +from Data.Engine.services.auth.device_auth_service import DeviceRecord +from Data.Engine.services.auth.token_service import JWTIssuer +from Data.Engine.services.enrollment.errors import EnrollmentValidationError +from Data.Engine.services.enrollment.nonce_cache import NonceCache +from Data.Engine.services.rate_limit import SlidingWindowRateLimiter + +__all__ = [ + "EnrollmentRequestResult", + "EnrollmentService", + "EnrollmentStatus", + "EnrollmentTokenBundle", + "PollingResult", +] + + +class DeviceRepository(Protocol): + def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol + ... + + def ensure_device_record( + self, + *, + guid: DeviceGuid, + hostname: str, + fingerprint: DeviceFingerprint, + ) -> DeviceRecord: # pragma: no cover - protocol + ... + + def record_device_key( + self, + *, + guid: DeviceGuid, + fingerprint: DeviceFingerprint, + added_at: datetime, + ) -> None: # pragma: no cover - protocol + ... + + +class EnrollmentRepository(Protocol): + def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol + ... + + def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol + ... + + def update_install_code_usage( + self, + record_id: str, + *, + use_count_increment: int, + last_used_at: datetime, + used_by_guid: Optional[DeviceGuid] = None, + mark_first_use: bool = False, + ) -> None: # pragma: no cover - protocol + ... + + def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol + ... + + def fetch_pending_approval_by_fingerprint( + self, fingerprint: DeviceFingerprint + ) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol + ... + + def update_pending_approval( + self, + record_id: str, + *, + hostname: str, + guid: Optional[DeviceGuid], + enrollment_code_id: Optional[str], + client_nonce_b64: str, + server_nonce_b64: str, + agent_pubkey_der: bytes, + updated_at: datetime, + ) -> None: # pragma: no cover - protocol + ... + + def create_device_approval( + self, + *, + record_id: str, + reference: str, + claimed_hostname: str, + claimed_fingerprint: DeviceFingerprint, + enrollment_code_id: Optional[str], + client_nonce_b64: str, + server_nonce_b64: str, + agent_pubkey_der: bytes, + created_at: datetime, + status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING, + guid: Optional[DeviceGuid] = None, + ) -> EnrollmentApproval: # pragma: no cover - protocol + ... + + def update_device_approval_status( + self, + record_id: str, + *, + status: EnrollmentApprovalStatus, + updated_at: datetime, + approved_by: Optional[str] = None, + guid: Optional[DeviceGuid] = None, + ) -> None: # pragma: no cover - protocol + ... + + +class RefreshTokenRepository(Protocol): + def create( + self, + *, + record_id: str, + guid: DeviceGuid, + token_hash: str, + created_at: datetime, + expires_at: Optional[datetime], + ) -> None: # pragma: no cover - protocol + ... + + +class ScriptSigner(Protocol): + def public_base64_spki(self) -> str: # pragma: no cover - protocol + ... + + +class EnrollmentStatus(str): + PENDING = "pending" + APPROVED = "approved" + DENIED = "denied" + EXPIRED = "expired" + UNKNOWN = "unknown" + + +@dataclass(frozen=True, slots=True) +class EnrollmentTokenBundle: + guid: DeviceGuid + access_token: str + refresh_token: str + expires_in: int + token_type: str = "Bearer" + + +@dataclass(frozen=True, slots=True) +class EnrollmentRequestResult: + status: EnrollmentStatus + approval_reference: Optional[str] = None + server_nonce: Optional[str] = None + poll_after_ms: Optional[int] = None + server_certificate: str + signing_key: str + http_status: int = 200 + retry_after: Optional[float] = None + + +@dataclass(frozen=True, slots=True) +class PollingResult: + status: EnrollmentStatus + http_status: int + poll_after_ms: Optional[int] = None + reason: Optional[str] = None + detail: Optional[str] = None + tokens: Optional[EnrollmentTokenBundle] = None + server_certificate: Optional[str] = None + signing_key: Optional[str] = None + + +class EnrollmentService: + """Coordinate the Borealis device enrollment handshake.""" + + def __init__( + self, + *, + device_repository: DeviceRepository, + enrollment_repository: EnrollmentRepository, + token_repository: RefreshTokenRepository, + jwt_service: JWTIssuer, + tls_bundle_loader: Callable[[], str], + ip_rate_limiter: SlidingWindowRateLimiter, + fingerprint_rate_limiter: SlidingWindowRateLimiter, + nonce_cache: NonceCache, + script_signer: Optional[ScriptSigner] = None, + logger: Optional[logging.Logger] = None, + ) -> None: + self._devices = device_repository + self._enrollment = enrollment_repository + self._tokens = token_repository + self._jwt = jwt_service + self._load_tls_bundle = tls_bundle_loader + self._ip_rate_limiter = ip_rate_limiter + self._fp_rate_limiter = fingerprint_rate_limiter + self._nonce_cache = nonce_cache + self._signer = script_signer + self._log = logger or logging.getLogger("borealis.engine.enrollment") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def request_enrollment( + self, + payload: EnrollmentRequestInput, + *, + remote_addr: str, + ) -> EnrollmentRequestResult: + context_hint = sanitize_service_context(payload.service_context) + self._log.info( + "enrollment-request ip=%s host=%s code_mask=%s", remote_addr, payload.hostname, self._mask_code(payload.enrollment_code) + ) + + self._enforce_rate_limit(self._ip_rate_limiter, f"ip:{remote_addr}") + self._enforce_rate_limit(self._fp_rate_limiter, f"fp:{payload.fingerprint.value}") + + install_code = self._enrollment.fetch_install_code(payload.enrollment_code) + reuse_guid = self._determine_reuse_guid(install_code, payload.fingerprint) + + server_nonce_bytes = secrets.token_bytes(32) + server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii") + + now = self._now() + approval = self._enrollment.fetch_pending_approval_by_fingerprint(payload.fingerprint) + if approval: + self._enrollment.update_pending_approval( + approval.record_id, + hostname=payload.hostname, + guid=reuse_guid, + enrollment_code_id=install_code.identifier if install_code else None, + client_nonce_b64=payload.client_nonce_b64, + server_nonce_b64=server_nonce_b64, + agent_pubkey_der=payload.agent_public_key_der, + updated_at=now, + ) + approval_reference = approval.reference + else: + record_id = str(uuid.uuid4()) + approval_reference = str(uuid.uuid4()) + approval = self._enrollment.create_device_approval( + record_id=record_id, + reference=approval_reference, + claimed_hostname=payload.hostname, + claimed_fingerprint=payload.fingerprint, + enrollment_code_id=install_code.identifier if install_code else None, + client_nonce_b64=payload.client_nonce_b64, + server_nonce_b64=server_nonce_b64, + agent_pubkey_der=payload.agent_public_key_der, + created_at=now, + guid=reuse_guid, + ) + + signing_key = self._signer.public_base64_spki() if self._signer else "" + certificate = self._load_tls_bundle() + + return EnrollmentRequestResult( + status=EnrollmentStatus.PENDING, + approval_reference=approval.reference, + server_nonce=server_nonce_b64, + poll_after_ms=3000, + server_certificate=certificate, + signing_key=signing_key, + ) + + def poll_enrollment( + self, + *, + approval_reference: str, + client_nonce_b64: str, + proof_signature_b64: str, + ) -> PollingResult: + if not approval_reference: + raise EnrollmentValidationError("approval_reference_required") + if not client_nonce_b64: + raise EnrollmentValidationError("client_nonce_required") + if not proof_signature_b64: + raise EnrollmentValidationError("proof_sig_required") + + approval = self._enrollment.fetch_device_approval_by_reference(approval_reference) + if approval is None: + return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=404) + + client_nonce = self._decode_base64(client_nonce_b64, "invalid_client_nonce") + server_nonce = self._decode_base64(approval.server_nonce_b64, "server_nonce_invalid") + proof_sig = self._decode_base64(proof_signature_b64, "invalid_proof_sig") + + if approval.client_nonce_b64 != client_nonce_b64: + raise EnrollmentValidationError("nonce_mismatch") + + self._verify_proof_signature( + approval=approval, + client_nonce=client_nonce, + server_nonce=server_nonce, + signature=proof_sig, + ) + + status = approval.status + if status is EnrollmentApprovalStatus.PENDING: + return PollingResult( + status=EnrollmentStatus.PENDING, + http_status=200, + poll_after_ms=5000, + ) + if status is EnrollmentApprovalStatus.DENIED: + return PollingResult( + status=EnrollmentStatus.DENIED, + http_status=200, + reason="operator_denied", + ) + if status is EnrollmentApprovalStatus.EXPIRED: + return PollingResult(status=EnrollmentStatus.EXPIRED, http_status=200) + if status is EnrollmentApprovalStatus.COMPLETED: + return PollingResult( + status=EnrollmentStatus.APPROVED, + http_status=200, + detail="finalized", + ) + if status is not EnrollmentApprovalStatus.APPROVED: + return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=400) + + nonce_key = f"{approval.reference}:{proof_signature_b64}" + if not self._nonce_cache.consume(nonce_key): + raise EnrollmentValidationError("proof_replayed", http_status=409) + + token_bundle = self._finalize_approval(approval) + signing_key = self._signer.public_base64_spki() if self._signer else "" + certificate = self._load_tls_bundle() + + return PollingResult( + status=EnrollmentStatus.APPROVED, + http_status=200, + tokens=token_bundle, + server_certificate=certificate, + signing_key=signing_key, + ) + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + def _enforce_rate_limit( + self, + limiter: SlidingWindowRateLimiter, + key: str, + *, + limit: int = 60, + window_seconds: float = 60.0, + ) -> None: + decision = limiter.check(key, limit, window_seconds) + if not decision.allowed: + raise EnrollmentValidationError( + "rate_limited", http_status=429, retry_after=max(decision.retry_after, 1.0) + ) + + def _determine_reuse_guid( + self, + install_code: Optional[EnrollmentCode], + fingerprint: DeviceFingerprint, + ) -> Optional[DeviceGuid]: + if install_code is None: + raise EnrollmentValidationError("invalid_enrollment_code") + if install_code.is_expired: + raise EnrollmentValidationError("invalid_enrollment_code") + if not install_code.is_exhausted: + return None + if not install_code.used_by_guid: + raise EnrollmentValidationError("invalid_enrollment_code") + + existing = self._devices.fetch_by_guid(install_code.used_by_guid) + if existing and existing.identity.fingerprint.value == fingerprint.value: + return install_code.used_by_guid + raise EnrollmentValidationError("invalid_enrollment_code") + + def _finalize_approval(self, approval: EnrollmentApproval) -> EnrollmentTokenBundle: + now = self._now() + effective_guid = approval.guid or DeviceGuid(str(uuid.uuid4())) + device_record = self._devices.ensure_device_record( + guid=effective_guid, + hostname=approval.claimed_hostname, + fingerprint=approval.claimed_fingerprint, + ) + self._devices.record_device_key( + guid=effective_guid, + fingerprint=approval.claimed_fingerprint, + added_at=now, + ) + + if approval.enrollment_code_id: + code = self._enrollment.fetch_install_code_by_id(approval.enrollment_code_id) + if code is not None: + mark_first = code.used_at is None + self._enrollment.update_install_code_usage( + approval.enrollment_code_id, + use_count_increment=1, + last_used_at=now, + used_by_guid=effective_guid, + mark_first_use=mark_first, + ) + + refresh_token = secrets.token_urlsafe(48) + refresh_id = str(uuid.uuid4()) + expires_at = now + timedelta(days=30) + token_hash = hashlib.sha256(refresh_token.encode("utf-8")).hexdigest() + self._tokens.create( + record_id=refresh_id, + guid=effective_guid, + token_hash=token_hash, + created_at=now, + expires_at=expires_at, + ) + + access_token = self._jwt.issue_access_token( + effective_guid.value, + device_record.identity.fingerprint.value, + max(device_record.token_version, 1), + ) + + self._enrollment.update_device_approval_status( + approval.record_id, + status=EnrollmentApprovalStatus.COMPLETED, + updated_at=now, + guid=effective_guid, + ) + + return EnrollmentTokenBundle( + guid=effective_guid, + access_token=access_token, + refresh_token=refresh_token, + expires_in=900, + ) + + def _verify_proof_signature( + self, + *, + approval: EnrollmentApproval, + client_nonce: bytes, + server_nonce: bytes, + signature: bytes, + ) -> None: + message = server_nonce + approval.reference.encode("utf-8") + client_nonce + try: + public_key = serialization.load_der_public_key(approval.agent_pubkey_der) + except Exception as exc: + raise EnrollmentValidationError("agent_pubkey_invalid") from exc + + try: + public_key.verify(signature, message) + except Exception as exc: + raise EnrollmentValidationError("invalid_proof") from exc + + @staticmethod + def _decode_base64(value: str, error_code: str) -> bytes: + try: + return base64.b64decode(value, validate=True) + except Exception as exc: + raise EnrollmentValidationError(error_code) from exc + + @staticmethod + def _mask_code(code: str) -> str: + trimmed = (code or "").strip() + if len(trimmed) <= 6: + return "***" + return f"{trimmed[:3]}***{trimmed[-3:]}" + + @staticmethod + def _now() -> datetime: + return datetime.now(tz=timezone.utc) diff --git a/Data/Engine/services/enrollment/errors.py b/Data/Engine/services/enrollment/errors.py new file mode 100644 index 0000000..a77df2d --- /dev/null +++ b/Data/Engine/services/enrollment/errors.py @@ -0,0 +1,26 @@ +"""Error types shared across enrollment components.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +__all__ = ["EnrollmentValidationError"] + + +@dataclass(frozen=True, slots=True) +class EnrollmentValidationError(Exception): + """Raised when enrollment input fails validation.""" + + code: str + http_status: int = 400 + retry_after: Optional[float] = None + + def to_response(self) -> dict[str, object]: + payload: dict[str, object] = {"error": self.code} + if self.retry_after is not None: + payload["retry_after"] = self.retry_after + return payload + + def __str__(self) -> str: # pragma: no cover - debug helper + return f"{self.code} (status={self.http_status})" diff --git a/Data/Engine/services/enrollment/nonce_cache.py b/Data/Engine/services/enrollment/nonce_cache.py new file mode 100644 index 0000000..6653a7d --- /dev/null +++ b/Data/Engine/services/enrollment/nonce_cache.py @@ -0,0 +1,32 @@ +"""Nonce replay protection for enrollment workflows.""" + +from __future__ import annotations + +import time +from threading import Lock +from typing import Dict + +__all__ = ["NonceCache"] + + +class NonceCache: + """Track recently observed nonces to prevent replay.""" + + def __init__(self, ttl_seconds: float = 300.0) -> None: + self._ttl = ttl_seconds + self._entries: Dict[str, float] = {} + self._lock = Lock() + + def consume(self, key: str) -> bool: + """Consume *key* if it has not been seen recently.""" + + now = time.monotonic() + with self._lock: + expiry = self._entries.get(key) + if expiry and expiry > now: + return False + self._entries[key] = now + self._ttl + stale = [nonce for nonce, ttl in self._entries.items() if ttl <= now] + for nonce in stale: + self._entries.pop(nonce, None) + return True diff --git a/Data/Engine/services/rate_limit.py b/Data/Engine/services/rate_limit.py new file mode 100644 index 0000000..49b8fd8 --- /dev/null +++ b/Data/Engine/services/rate_limit.py @@ -0,0 +1,45 @@ +"""In-process rate limiting utilities for the Borealis Engine.""" + +from __future__ import annotations + +import time +from collections import deque +from dataclasses import dataclass +from threading import Lock +from typing import Deque, Dict + +__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"] + + +@dataclass(frozen=True, slots=True) +class RateLimitDecision: + """Result of a rate limit check.""" + + allowed: bool + retry_after: float + + +class SlidingWindowRateLimiter: + """Tiny in-memory sliding window limiter suitable for single-process use.""" + + def __init__(self) -> None: + self._buckets: Dict[str, Deque[float]] = {} + self._lock = Lock() + + def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision: + now = time.monotonic() + with self._lock: + bucket = self._buckets.get(key) + if bucket is None: + bucket = deque() + self._buckets[key] = bucket + + while bucket and now - bucket[0] > window_seconds: + bucket.popleft() + + if len(bucket) >= limit: + retry_after = max(0.0, window_seconds - (now - bucket[0])) + return RateLimitDecision(False, retry_after) + + bucket.append(now) + return RateLimitDecision(True, 0.0)