mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 13:21:57 -06:00
Merge pull request #134 from bunny-lab-io:codex/create-diagnostic-report-and-migration-plan
Add initial Engine scaffolding
This commit is contained in:
158
Borealis.ps1
158
Borealis.ps1
@@ -977,6 +977,8 @@ if (-not $choice) {
|
||||
Write-Host "[Experimental]" -ForegroundColor Red
|
||||
Write-Host " 4) Package Self-Contained EXE of Server or Agent " -NoNewline -ForegroundColor DarkGray
|
||||
Write-Host "[Experimental]" -ForegroundColor Red
|
||||
Write-Host " 5) Borealis Engine " -NoNewline -ForegroundColor DarkGray
|
||||
Write-Host "[In Development]" -ForegroundColor Cyan
|
||||
# (Removed) AutoHotKey experimental testing
|
||||
Write-Host "Type a number and press " -NoNewLine
|
||||
Write-Host "<ENTER>" -ForegroundColor DarkCyan
|
||||
@@ -1214,6 +1216,162 @@ switch ($choice) {
|
||||
}
|
||||
}
|
||||
|
||||
"5" {
|
||||
$host.UI.RawUI.WindowTitle = "Borealis Engine"
|
||||
Write-Host "Ensuring Engine Dependencies Exist..." -ForegroundColor DarkCyan
|
||||
|
||||
Install_Shared_Dependencies
|
||||
Install_Server_Dependencies
|
||||
|
||||
foreach ($tool in @($pythonExe, $nodeExe, $npmCmd, $npxCmd)) {
|
||||
if (-not (Test-Path $tool)) {
|
||||
Write-Host "`r$($symbols.Fail) Bundled executable not found at '$tool'." -ForegroundColor Red
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
$env:PATH = '{0};{1};{2}' -f (Split-Path $pythonExe), (Split-Path $nodeExe), $env:PATH
|
||||
|
||||
Write-Host " "
|
||||
Write-Host "Configure Borealis Engine Mode:" -ForegroundColor DarkYellow
|
||||
Write-Host " 1) Build & Launch > Production Flask Server @ http://localhost:5001" -ForegroundColor DarkCyan
|
||||
Write-Host " 2) [Skip Build] & Immediately Launch > Production Flask Server @ http://localhost:5001" -ForegroundColor DarkCyan
|
||||
Write-Host " 3) Launch > [Hotload-Ready] Vite Dev Server @ http://localhost:5173" -ForegroundColor DarkCyan
|
||||
$engineModeChoice = Read-Host "Enter choice [1/2/3]"
|
||||
|
||||
$engineOperationMode = "production"
|
||||
$engineImmediateLaunch = $false
|
||||
switch ($engineModeChoice) {
|
||||
"1" { $engineOperationMode = "production" }
|
||||
"2" { $engineImmediateLaunch = $true }
|
||||
"3" { $engineOperationMode = "developer" }
|
||||
default {
|
||||
Write-Host "Invalid mode choice: $engineModeChoice" -ForegroundColor Red
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ($engineModeChoice -notin @('1','2','3')) {
|
||||
break
|
||||
}
|
||||
|
||||
if ($engineImmediateLaunch) {
|
||||
Run-Step "Borealis Engine: Launch Flask Server" {
|
||||
Push-Location (Join-Path $scriptDir "Engine")
|
||||
$py = Join-Path $scriptDir "Engine\Scripts\python.exe"
|
||||
Write-Host "`nLaunching Borealis Engine..." -ForegroundColor Green
|
||||
Write-Host "===================================================================================="
|
||||
Write-Host "$($symbols.Running) Engine Socket Server Started..."
|
||||
& $py -m Data.Engine.bootstrapper
|
||||
Pop-Location
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
Write-Host "Deploying Borealis Engine in '$engineOperationMode' mode" -ForegroundColor Blue
|
||||
|
||||
$venvFolder = "Engine"
|
||||
$dataSource = "Data"
|
||||
$engineSource = "$dataSource\Engine"
|
||||
$engineDataDestination = "$venvFolder\Data\Engine"
|
||||
$webUIFallbackSource = "$dataSource\Server\WebUI"
|
||||
$webUIDestination = "$venvFolder\web-interface"
|
||||
$venvPython = Join-Path $venvFolder 'Scripts\python.exe'
|
||||
$engineSourceAbsolute = Join-Path $scriptDir $engineSource
|
||||
$webUIFallbackAbsolute = Join-Path $scriptDir $webUIFallbackSource
|
||||
|
||||
Run-Step "Create Borealis Engine Virtual Python Environment" {
|
||||
$venvActivate = Join-Path $venvFolder 'Scripts\Activate'
|
||||
if (-not (Test-Path $venvActivate)) {
|
||||
& $pythonExe -m venv $venvFolder | Out-Null
|
||||
}
|
||||
|
||||
$engineDataRoot = Join-Path $venvFolder 'Data'
|
||||
if (-not (Test-Path $engineDataRoot)) {
|
||||
New-Item -Path $engineDataRoot -ItemType Directory -Force | Out-Null
|
||||
}
|
||||
|
||||
if (Test-Path (Join-Path $scriptDir $engineDataDestination)) {
|
||||
Remove-Item (Join-Path $scriptDir $engineDataDestination) -Recurse -Force -ErrorAction SilentlyContinue
|
||||
}
|
||||
New-Item -Path (Join-Path $scriptDir $engineDataDestination) -ItemType Directory -Force | Out-Null
|
||||
|
||||
if (-not (Test-Path $engineSourceAbsolute)) {
|
||||
throw "Engine source directory '$engineSourceAbsolute' not found."
|
||||
}
|
||||
Copy-Item (Join-Path $engineSourceAbsolute '*') (Join-Path $scriptDir $engineDataDestination) -Recurse -Force
|
||||
|
||||
. (Join-Path $venvFolder 'Scripts\Activate')
|
||||
}
|
||||
|
||||
Run-Step "Install Engine Python Dependencies into Virtual Python Environment" {
|
||||
$engineRequirements = @(
|
||||
(Join-Path $engineSourceAbsolute 'engine-requirements.txt'),
|
||||
(Join-Path $engineSourceAbsolute 'requirements.txt')
|
||||
)
|
||||
$requirementsPath = $engineRequirements | Where-Object { Test-Path $_ } | Select-Object -First 1
|
||||
if ($requirementsPath) {
|
||||
& $venvPython -m pip install --disable-pip-version-check -q -r $requirementsPath | Out-Null
|
||||
}
|
||||
}
|
||||
|
||||
Run-Step "Copy Borealis Engine WebUI Files into: $webUIDestination" {
|
||||
$engineWebUISource = Join-Path $engineSourceAbsolute 'WebUI'
|
||||
if (Test-Path $engineWebUISource) {
|
||||
$webUIDestinationAbsolute = Join-Path $scriptDir $webUIDestination
|
||||
if (Test-Path $webUIDestinationAbsolute) {
|
||||
Remove-Item (Join-Path $webUIDestinationAbsolute '*') -Recurse -Force -ErrorAction SilentlyContinue
|
||||
} else {
|
||||
New-Item -Path $webUIDestinationAbsolute -ItemType Directory -Force | Out-Null
|
||||
}
|
||||
Copy-Item (Join-Path $engineWebUISource '*') $webUIDestinationAbsolute -Recurse -Force
|
||||
} elseif (-not (Test-Path (Join-Path $scriptDir $webUIDestination)) -or -not (Get-ChildItem -Path (Join-Path $scriptDir $webUIDestination) -ErrorAction SilentlyContinue | Select-Object -First 1)) {
|
||||
if (Test-Path $webUIFallbackAbsolute) {
|
||||
$webUIDestinationAbsolute = Join-Path $scriptDir $webUIDestination
|
||||
if (-not (Test-Path $webUIDestinationAbsolute)) {
|
||||
New-Item -Path $webUIDestinationAbsolute -ItemType Directory -Force | Out-Null
|
||||
}
|
||||
Copy-Item (Join-Path $webUIFallbackAbsolute '*') $webUIDestinationAbsolute -Recurse -Force
|
||||
} else {
|
||||
Write-Host "Fallback WebUI source not found at '$webUIFallbackAbsolute'." -ForegroundColor Yellow
|
||||
}
|
||||
} else {
|
||||
Write-Host "Existing Engine web interface detected; skipping fallback copy." -ForegroundColor DarkYellow
|
||||
}
|
||||
}
|
||||
|
||||
Run-Step "Vite Web Frontend: Install NPM Packages" {
|
||||
$webUIDestinationAbsolute = Join-Path $scriptDir $webUIDestination
|
||||
if (Test-Path $webUIDestinationAbsolute) {
|
||||
Push-Location $webUIDestinationAbsolute
|
||||
$env:npm_config_loglevel = "silent"
|
||||
& $npmCmd install --silent --no-fund --audit=false | Out-Null
|
||||
Pop-Location
|
||||
} else {
|
||||
Write-Host "Web interface destination '$webUIDestinationAbsolute' not found." -ForegroundColor Yellow
|
||||
}
|
||||
}
|
||||
|
||||
Run-Step "Vite Web Frontend: Start ($engineOperationMode)" {
|
||||
$webUIDestinationAbsolute = Join-Path $scriptDir $webUIDestination
|
||||
if (Test-Path $webUIDestinationAbsolute) {
|
||||
Push-Location $webUIDestinationAbsolute
|
||||
if ($engineOperationMode -eq "developer") { $viteSubCommand = "dev" } else { $viteSubCommand = "build" }
|
||||
Start-Process -NoNewWindow -FilePath $npmCmd -ArgumentList @("run", $viteSubCommand)
|
||||
Pop-Location
|
||||
}
|
||||
}
|
||||
|
||||
Run-Step "Borealis Engine: Launch Flask Server" {
|
||||
Push-Location (Join-Path $scriptDir "Engine")
|
||||
$py = Join-Path $scriptDir "Engine\Scripts\python.exe"
|
||||
Write-Host "`nLaunching Borealis Engine..." -ForegroundColor Green
|
||||
Write-Host "===================================================================================="
|
||||
Write-Host "$($symbols.Running) Engine Socket Server Started..."
|
||||
& $py -m Data.Engine.bootstrapper
|
||||
Pop-Location
|
||||
}
|
||||
}
|
||||
|
||||
# (Removed) case "6" experimental AHK test
|
||||
|
||||
default { Write-Host "Invalid selection. Exiting..." -ForegroundColor Red; exit 1 }
|
||||
|
||||
67
Data/Engine/CURRENT_STAGE.md
Normal file
67
Data/Engine/CURRENT_STAGE.md
Normal file
@@ -0,0 +1,67 @@
|
||||
# Borealis Engine Migration Progress
|
||||
|
||||
[COMPLETED] 1. Stabilize Engine foundation (baseline already in repo)
|
||||
- 1.1 Confirm `Data/Engine/bootstrapper.py` launches the placeholder Flask app without side effects.
|
||||
- 1.2 Document environment variables/settings expected by the Engine to keep parity with legacy defaults.
|
||||
- 1.3 Verify Engine logging produces `Logs/Server/engine.log` entries alongside the legacy server.
|
||||
|
||||
[COMPLETED] 2. Introduce configuration & dependency wiring
|
||||
- 2.1 Create `config/environment.py` loaders mirroring legacy defaults (TLS paths, feature flags).
|
||||
- 2.2 Add settings dataclasses for Flask, Socket.IO, and DB paths; inject them via `server.py`.
|
||||
- 2.3 Commit once the Engine can start with equivalent config but no real routes.
|
||||
|
||||
[COMPLETED] 3. Copy Flask application scaffolding
|
||||
- 3.1 Port proxy/CORS/static setup from `Data/Server/server.py` into Engine `server.py` using dependency injection.
|
||||
- 3.2 Stub out blueprint/Socket.IO registration hooks that mirror names from legacy code (no logic yet).
|
||||
- 3.3 Smoke-test app startup via `python Data/Engine/bootstrapper.py` (or Flask CLI) to ensure no regressions.
|
||||
|
||||
[COMPLETED] 4. Establish SQLite infrastructure
|
||||
- 4.1 Copy `_db_conn` logic into `repositories/sqlite/connection.py`, parameterized by database path (`<root>/database.db`).
|
||||
- 4.2 Port migration helpers into `repositories/sqlite/migrations.py`; expose an `apply_all()` callable.
|
||||
- 4.3 Wire migrations to run during Engine bootstrap (behind a flag) and confirm tables initialize in a sandbox DB.
|
||||
- 4.4 Commit once DB connection + migrations succeed independently of legacy server.
|
||||
|
||||
[COMPLETED] 5. Extract authentication/enrollment domain surface
|
||||
- 5.1 Define immutable dataclasses in `domain/device_auth.py`, `domain/device_enrollment.py` for tokens, GUIDs, approvals.
|
||||
- 5.2 Map legacy error codes/enums into domain exceptions or enums in the same modules.
|
||||
- 5.3 Commit after unit tests (or doctests) validate dataclass invariants.
|
||||
|
||||
[COMPLETED] 6. Port authentication services
|
||||
- 6.1 Copy `DeviceAuthManager` logic into `services/auth/device_auth_service.py`, refactoring to use new repositories and domain types.
|
||||
- 6.2 Create `builders/device_auth.py` to assemble `DeviceAuthContext` from headers/DPoP proof.
|
||||
- 6.3 Mirror refresh token issuance into `services/auth/token_service.py`; use `builders/device_enrollment.py` for payload assembly.
|
||||
- 6.4 Commit once services pass targeted unit tests and integrate with placeholder repositories.
|
||||
|
||||
[COMPLETED] 7. Implement SQLite repositories
|
||||
- 7.1 Introduce `repositories/sqlite/device_repository.py`, `token_repository.py`, `enrollment_repository.py` using copied SQL.
|
||||
- 7.2 Write integration tests exercising CRUD against a temporary SQLite file.
|
||||
- 7.3 Commit when repositories provide the required ports used by services.
|
||||
|
||||
[COMPLETED] 8. Recreate HTTP interfaces
|
||||
- 8.1 Port health/enrollment/token blueprints into `interfaces/http/<feature>/routes.py`, calling Engine services only.
|
||||
- 8.2 Ensure request validation occurs via builders; response schemas stay aligned with legacy JSON.
|
||||
- 8.3 Register blueprints through Engine `server.py`; confirm endpoints respond via manual or automated tests.
|
||||
- 8.4 Commit after each major blueprint migration for clear milestones.
|
||||
|
||||
[COMPLETED] 9. Rebuild WebSocket interfaces
|
||||
- 9.1 Establish feature-scoped modules (e.g., `interfaces/ws/agents/events.py`) and copy event handlers.
|
||||
- 9.2 Replace global state with repository/service calls where feasible; otherwise encapsulate in Engine-managed caches.
|
||||
- 9.3 Validate namespace registration with Socket.IO test clients before committing.
|
||||
|
||||
[COMPLETED] 10. Scheduler & job management
|
||||
- 10.1 Port scheduler core into `services/jobs/scheduler_service.py`; wrap job state persistence via new repositories.
|
||||
- 10.2 Implement `builders/job_fabricator.py` for manifest assembly; ensure immutability and validation.
|
||||
- 10.3 Expose HTTP orchestration via `interfaces/http/job_management.py` and WS notifications via dedicated modules.
|
||||
- 10.4 Commit after scheduler can run a no-op job loop independently.
|
||||
|
||||
[COMPLETED] 11. GitHub integration
|
||||
- 11.1 Copy GitHub helper logic into `integrations/github/artifact_provider.py` with proper configuration injection.
|
||||
- 11.2 Provide repository/service hooks for fetching artifacts or repo heads; add resilience logging.
|
||||
- 11.3 Commit after integration tests (or mocked unit tests) confirm API workflows.
|
||||
|
||||
[IN PROGRESS] 12. Final parity verification
|
||||
- 12.1 Follow the staging playbook in `Data/Engine/STAGING_GUIDE.md` to stand up the Engine end-to-end and exercise enrollment, token refresh, agent connections, GitHub integration, and scheduler flows.
|
||||
- 12.2 Record any divergences in the staging guide’s table and address them with follow-up commits before cut-over.
|
||||
- 12.3 Once parity is confirmed, coordinate entrypoint switching (point deployment at `Data/Engine/bootstrapper.py`) and plan the legacy server deprecation.
|
||||
- Supporting documentation and unit tests live in `Data/Engine/README.md`, `Data/Engine/STAGING_GUIDE.md`, and `Data/Engine/tests/` to guide the remaining staging work.
|
||||
- Engine deployment now installs dependencies via `Data/Engine/requirements.txt` so parity runs include Flask, Socket.IO, and security packages.
|
||||
204
Data/Engine/README.md
Normal file
204
Data/Engine/README.md
Normal file
@@ -0,0 +1,204 @@
|
||||
# Borealis Engine Overview
|
||||
|
||||
The Engine is an additive server stack that will ultimately replace the legacy Flask app under `Data/Server`. It is safe to run the Engine entrypoint (`Data/Engine/bootstrapper.py`) side-by-side with the legacy server while we migrate functionality feature-by-feature.
|
||||
|
||||
## Architectural roles
|
||||
|
||||
The Engine is organized around explicit dependency layers so each concern stays
|
||||
testable and replaceable:
|
||||
|
||||
- **Configuration (`Data/Engine/config/`)** parses environment variables into
|
||||
immutable settings objects that the bootstrapper hands to factories and
|
||||
integrations.
|
||||
- **Builders (`Data/Engine/builders/`)** transform external inputs (HTTP
|
||||
headers, JSON payloads, scheduled job definitions) into validated immutable
|
||||
records that services can trust.
|
||||
- **Domain models (`Data/Engine/domain/`)** house pure value objects, enums, and
|
||||
error types with no I/O so services can express intent without depending on
|
||||
Flask or SQLite.
|
||||
- **Repositories (`Data/Engine/repositories/`)** encapsulate all SQLite access
|
||||
and expose protocol methods that return domain models. They are injected into
|
||||
services through the container so persistence can be swapped or mocked.
|
||||
- **Services (`Data/Engine/services/`)** host business logic such as device
|
||||
authentication, enrollment, job scheduling, GitHub artifact lookups, and
|
||||
real-time agent coordination. Services depend only on repositories,
|
||||
integrations, and builders.
|
||||
- **Integrations (`Data/Engine/integrations/`)** wrap external systems (GitHub
|
||||
today) and keep HTTP/token handling outside the services that consume them.
|
||||
- **Interfaces (`Data/Engine/interfaces/`)** provide thin HTTP/Socket.IO
|
||||
adapters that translate requests to builder/service calls and serialize
|
||||
responses. They contain no business rules of their own.
|
||||
|
||||
The runtime factory (`Data/Engine/runtime.py`) wires these layers together and
|
||||
attaches the resulting container to the Flask app created in
|
||||
`Data/Engine/server.py`.
|
||||
|
||||
## Environment configuration
|
||||
|
||||
The Engine mirrors the legacy defaults so it can boot without additional configuration. These environment variables are read by `Data/Engine/config/environment.py`:
|
||||
|
||||
| Variable | Purpose | Default |
|
||||
| --- | --- | --- |
|
||||
| `BOREALIS_ROOT` | Overrides automatic project root detection. Useful when running from a packaged location. | Directory two levels above `Data/Engine/` |
|
||||
| `BOREALIS_DATABASE_PATH` | Path to the SQLite database. | `<project_root>/database.db` |
|
||||
| `BOREALIS_ENGINE_AUTO_MIGRATE` | Run Engine-managed schema migrations during bootstrap (`true`/`false`). | `true` |
|
||||
| `BOREALIS_STATIC_ROOT` | Directory that serves static assets for the SPA. | First existing path among `Engine/web-interface/build`, `Engine/web-interface/dist`, `Data/Engine/WebUI/build`, `Data/Server/web-interface/build`, `Data/Server/WebUI/build`, `Data/WebUI/build` |
|
||||
| `BOREALIS_CORS_ALLOWED_ORIGINS` | Comma-delimited list of origins granted CORS access. Use `*` for all origins. | `*` |
|
||||
| `BOREALIS_FLASK_SECRET_KEY` | Secret key for Flask session signing. | `change-me` |
|
||||
| `BOREALIS_DEBUG` | Enables debug logging, disables secure-cookie requirements, and allows Werkzeug debug mode. | `false` |
|
||||
| `BOREALIS_HOST` | Bind address for the HTTP/Socket.IO server. | `127.0.0.1` |
|
||||
| `BOREALIS_PORT` | Bind port for the HTTP/Socket.IO server. | `5000` |
|
||||
| `BOREALIS_REPO` | Default GitHub repository (`owner/name`) for artifact lookups. | `bunny-lab-io/Borealis` |
|
||||
| `BOREALIS_REPO_BRANCH` | Default branch tracked by the Engine GitHub integration. | `main` |
|
||||
| `BOREALIS_REPO_HASH_REFRESH` | Seconds between default repository head refresh attempts (clamped 30-3600). | `60` |
|
||||
| `BOREALIS_CACHE_DIR` | Directory used to persist Engine cache files (GitHub repo head cache). | `<project_root>/Data/Engine/cache` |
|
||||
| `BOREALIS_CERTIFICATES_ROOT` | Overrides where TLS certificates (root CA + leaf) are stored. | `<project_root>/Certificates` |
|
||||
| `BOREALIS_SERVER_CERT_ROOT` | Directly points to the Engine server certificate directory if certificates are staged elsewhere. | `<project_root>/Certificates/Server` |
|
||||
|
||||
## TLS and transport stack
|
||||
|
||||
`Data/Engine/services/crypto/certificates.py` mirrors the legacy certificate
|
||||
generator so the Engine always serves HTTPS with a self-managed root CA and
|
||||
leaf certificate. During bootstrap the Engine:
|
||||
|
||||
1. Runs the certificate helper to ensure the root CA, server key, and bundle
|
||||
exist under `Certificates/Server/` (or the configured override path).
|
||||
2. Exposes the resulting bundle via `BOREALIS_TLS_BUNDLE` so enrollment flows
|
||||
can deliver the pinned certificate to agents.
|
||||
3. Launches Socket.IO/Eventlet with the generated cert/key pair. A fallback to
|
||||
Werkzeug’s TLS support keeps HTTPS available even if Socket.IO is disabled.
|
||||
|
||||
`Data/Engine/interfaces/eventlet_compat.py` applies the same Eventlet monkey
|
||||
patch as the legacy server so TLS handshakes presented to the HTTP listener are
|
||||
handled quietly instead of surfacing `400 Bad Request` noise when non-TLS
|
||||
clients connect.
|
||||
|
||||
## Logging expectations
|
||||
|
||||
`Data/Engine/config/logging.py` configures a timed rotating file handler that writes to `Logs/Server/engine.log`. Each entry follows the `<timestamp>-engine-<message>` format required by the project logging policy. The handler is attached to both the Engine logger (`borealis.engine`) and the root logger so that third-party frameworks share the same log destination.
|
||||
|
||||
## Bootstrapping flow
|
||||
|
||||
1. `Data/Engine/bootstrapper.py` loads the environment, configures logging, prepares the SQLite connection factory, optionally applies schema migrations, and builds the Flask application via `Data/Engine/server.py`.
|
||||
2. A service container is assembled (`Data/Engine/services/container.py`) that wires repositories, JWT/DPoP helpers, and Engine services (device auth, token refresh, enrollment). The container is stored on the Flask app for interface modules to consume.
|
||||
3. HTTP and Socket.IO interfaces register against the new service container. The resulting runtime object exposes the Flask app, resolved settings, optional Socket.IO server, and the configured database connection factory. `bootstrapper.main()` runs the appropriate server based on whether Socket.IO is present.
|
||||
|
||||
As migration continues, services, repositories, interfaces, and integrations will live under their respective subpackages while maintaining isolation from the legacy server.
|
||||
|
||||
## Python dependencies
|
||||
|
||||
`Data/Engine/requirements.txt` mirrors the minimal runtime stack (Flask, Flask-SocketIO, CORS, requests, PyJWT, and cryptography) needed by the Engine entrypoint. The PowerShell launcher consumes this file when preparing the `Engine/` virtual environment so parity tests always run against an environment with the expected web and security packages preinstalled.
|
||||
|
||||
## HTTP interfaces
|
||||
|
||||
The Engine now exposes working HTTP routes alongside the remaining scaffolding:
|
||||
|
||||
- `Data/Engine/interfaces/http/health.py` implements `GET /health` for liveness probes.
|
||||
- `Data/Engine/interfaces/http/tokens.py` ports the refresh-token endpoint (`POST /api/agent/token/refresh`) using the Engine `TokenService` and request builders.
|
||||
- `Data/Engine/interfaces/http/enrollment.py` handles the enrollment handshake (`/api/agent/enroll/request` and `/api/agent/enroll/poll`) with rate limiting, nonce protection, and repository-backed approvals.
|
||||
- The admin and agent blueprints remain placeholders until their services migrate.
|
||||
|
||||
## WebSocket interfaces
|
||||
|
||||
Step 9 introduces real-time handlers backed by the new service container:
|
||||
|
||||
- `Data/Engine/services/realtime/agent_registry.py` manages connected-agent state, last-seen persistence, collector updates, and screenshot caches without sharing globals with the legacy server.
|
||||
- `Data/Engine/interfaces/ws/agents/events.py` ports the agent namespace, handling connect/disconnect logging, heartbeat reconciliation, screenshot relays, macro status broadcasts, and provisioning lookups through the realtime service.
|
||||
- `Data/Engine/interfaces/ws/job_management/events.py` now forwards scheduler updates and responds to job status requests, keeping WebSocket clients informed as new runs are simulated.
|
||||
|
||||
The WebSocket factory (`Data/Engine/interfaces/ws/__init__.py`) now accepts the Engine service container so namespaces can resolve dependencies just like their HTTP counterparts.
|
||||
|
||||
## Authentication services
|
||||
|
||||
Step 6 introduces the first real Engine services:
|
||||
|
||||
- `Data/Engine/builders/device_auth.py` normalizes headers for access-token authentication and token refresh payloads.
|
||||
- `Data/Engine/builders/device_enrollment.py` prepares enrollment payloads and nonce proof challenges for future migration steps.
|
||||
- `Data/Engine/services/auth/device_auth_service.py` ports the legacy `DeviceAuthManager` into a repository-driven service that emits `DeviceAuthContext` instances from the new domain layer.
|
||||
- `Data/Engine/services/auth/token_service.py` issues refreshed access tokens while enforcing DPoP bindings and repository lookups.
|
||||
|
||||
Interfaces now consume these services via the shared container, keeping business logic inside the Engine service layer while HTTP modules remain thin request/response translators.
|
||||
|
||||
## SQLite repositories
|
||||
|
||||
Step 7 ports the first persistence adapters into the Engine:
|
||||
|
||||
- `Data/Engine/repositories/sqlite/device_repository.py` exposes `SQLiteDeviceRepository`, mirroring the legacy device lookups and automatic record recovery used during authentication.
|
||||
- `Data/Engine/repositories/sqlite/token_repository.py` provides `SQLiteRefreshTokenRepository` for refresh-token validation, DPoP binding management, and usage timestamps.
|
||||
- `Data/Engine/repositories/sqlite/enrollment_repository.py` surfaces enrollment install-code counters and device approval records so future services can operate without touching raw SQL.
|
||||
|
||||
Each repository accepts the shared `SQLiteConnectionFactory`, keeping all SQL execution confined to the Engine layer while services depend only on protocol interfaces.
|
||||
|
||||
## Job scheduling services
|
||||
|
||||
Step 10 migrates the foundational job scheduler into the Engine:
|
||||
|
||||
- `Data/Engine/builders/job_fabricator.py` transforms stored job definitions into immutable manifests, decoding scripts, resolving environment variables, and preparing execution metadata.
|
||||
- `Data/Engine/repositories/sqlite/job_repository.py` encapsulates scheduled job persistence, run history, and status tracking in SQLite.
|
||||
- `Data/Engine/services/jobs/scheduler_service.py` runs the background evaluation loop, emits Socket.IO lifecycle events, and exposes CRUD helpers for the HTTP and WebSocket interfaces.
|
||||
- `Data/Engine/interfaces/http/job_management.py` mirrors the legacy REST surface for creating, updating, toggling, and inspecting scheduled jobs and their run history.
|
||||
|
||||
The scheduler service starts automatically from `Data/Engine/bootstrapper.py` once the Engine runtime builds the service container, ensuring a no-op scheduling loop executes independently of the legacy server.
|
||||
|
||||
## GitHub integration
|
||||
|
||||
Step 11 migrates the GitHub artifact provider into the Engine:
|
||||
|
||||
- `Data/Engine/integrations/github/artifact_provider.py` caches branch head lookups, verifies API tokens, and optionally refreshes the default repository in the background.
|
||||
- `Data/Engine/repositories/sqlite/github_repository.py` persists the GitHub API token so HTTP handlers do not speak to SQLite directly.
|
||||
- `Data/Engine/services/github/github_service.py` coordinates token caching, verification, and repo head lookups for both HTTP and background refresh flows.
|
||||
- `Data/Engine/interfaces/http/github.py` exposes `/api/repo/current_hash` and `/api/github/token` through the Engine stack while keeping business logic in the service layer.
|
||||
|
||||
The service container now wires `github_service`, giving other interfaces and background jobs a clean entry point for GitHub functionality.
|
||||
|
||||
## Final parity checklist
|
||||
|
||||
Step 12 tracks the final integration work required before switching over to the
|
||||
Engine entrypoint. Use the detailed playbook in
|
||||
[`Data/Engine/STAGING_GUIDE.md`](./STAGING_GUIDE.md) to coordinate each
|
||||
staging run:
|
||||
|
||||
1. Stand up the Engine in a staging environment and exercise enrollment, token
|
||||
refresh, scheduler operations, and the agent real-time channel side-by-side
|
||||
with the legacy server.
|
||||
2. Capture any behavioural differences uncovered during staging using the
|
||||
divergence table in the staging guide and file them for follow-up fixes
|
||||
before the cut-over.
|
||||
3. When satisfied with parity, coordinate the entrypoint swap (point production
|
||||
tooling at `Data/Engine/bootstrapper.py`) and plan the deprecation of
|
||||
`Data/Server`.
|
||||
|
||||
## Performing unit tests
|
||||
|
||||
Targeted unit tests cover the most important domain, builder, repository, and
|
||||
migration behaviours without requiring Flask or external services. Run them
|
||||
with the standard library test runner:
|
||||
|
||||
```bash
|
||||
python -m unittest discover Data/Engine/tests
|
||||
```
|
||||
|
||||
The suite currently validates:
|
||||
|
||||
- Domain normalization helpers for GUIDs, fingerprints, and authentication
|
||||
failures.
|
||||
- Device authentication and refresh-token builders, including error handling for
|
||||
malformed requests.
|
||||
- SQLite schema migrations to ensure the Engine can provision required tables in
|
||||
a fresh database.
|
||||
- TLS certificate provisioning helpers to guarantee HTTPS material exists before
|
||||
the Engine starts serving requests.
|
||||
|
||||
Successful execution prints a summary similar to:
|
||||
|
||||
```
|
||||
.............
|
||||
----------------------------------------------------------------------
|
||||
Ran 13 tests in <N>.<M>s
|
||||
|
||||
OK
|
||||
```
|
||||
|
||||
Additional tests should follow the same pattern and live under
|
||||
`Data/Engine/tests/` so this command remains the single entry point for Engine
|
||||
unit verification.
|
||||
116
Data/Engine/STAGING_GUIDE.md
Normal file
116
Data/Engine/STAGING_GUIDE.md
Normal file
@@ -0,0 +1,116 @@
|
||||
# Engine Staging & Parity Guide
|
||||
|
||||
This guide supports Step 12 of the migration plan by walking operators through
|
||||
standing up the Engine alongside the legacy server, validating core workflows,
|
||||
and documenting any behavioural gaps before switching the production entrypoint
|
||||
to `Data/Engine/bootstrapper.py`.
|
||||
|
||||
## 1. Prerequisites
|
||||
|
||||
- Python 3.11 or later available on the host.
|
||||
- A clone of the Borealis repository with the Engine tree checked out.
|
||||
- Access to the legacy runtime assets (certificates, TLS bundle, etc.).
|
||||
- Optional: a staging agent install for end-to-end WebSocket validation.
|
||||
|
||||
Ensure the SQLite database lives at `<project_root>/database.db` and that the
|
||||
Engine migrations have already run (they execute automatically when the
|
||||
`BOREALIS_ENGINE_AUTO_MIGRATE` environment variable is left at its default
|
||||
`true`).
|
||||
|
||||
## 2. Launching the Engine in staging mode
|
||||
|
||||
1. Open a terminal at the project root.
|
||||
2. Set any environment overrides required for the test scenario (for example,
|
||||
`BOREALIS_DEBUG=true` to surface verbose logging, or
|
||||
`BOREALIS_CORS_ALLOWED_ORIGINS=https://localhost:3000` when pairing with the
|
||||
React UI).
|
||||
3. Run the Engine entrypoint:
|
||||
|
||||
```bash
|
||||
python Data/Engine/bootstrapper.py
|
||||
```
|
||||
|
||||
4. Verify `Logs/Server/engine.log` is created and that the startup entries are
|
||||
timestamped `<timestamp>-engine-<message>`.
|
||||
|
||||
Keep the legacy server running in a separate process if comparative testing is
|
||||
required; they do not share global state.
|
||||
|
||||
## 3. Feature validation checklist
|
||||
|
||||
Work through the following areas and tick each box once verified. Capture any
|
||||
issues in the log table in §4.
|
||||
|
||||
### Authentication and tokens
|
||||
|
||||
- [ ] `POST /api/agent/token/refresh` returns a new access token when supplied a
|
||||
valid refresh token + DPoP proof.
|
||||
- [ ] Invalid DPoP proofs or revoked refresh tokens yield the expected HTTP 401
|
||||
responses and structured error payloads.
|
||||
- [ ] Device last-seen metadata updates inside the database after a successful
|
||||
refresh.
|
||||
|
||||
### Enrollment
|
||||
|
||||
- [ ] `POST /api/agent/enroll/request` produces an enrollment ticket with the
|
||||
correct expiration and retry counters.
|
||||
- [ ] `POST /api/agent/enroll/poll` transitions an approved device into an
|
||||
authenticated state and returns the TLS bundle.
|
||||
- [ ] Audit logging for approvals lands in `Logs/Server/engine.log`.
|
||||
|
||||
### Job management
|
||||
|
||||
- [ ] `POST /api/jobs` (or UI equivalent) creates a scheduled job and returns a
|
||||
manifest identifier.
|
||||
- [ ] `GET /api/jobs/<id>` surfaces the stored manifest with normalized
|
||||
schedules and environment variables.
|
||||
- [ ] Job lifecycle events arrive over the `job_management` Socket.IO namespace
|
||||
when a job transitions between `pending`, `running`, and `completed`.
|
||||
|
||||
### Real-time agents
|
||||
|
||||
- [ ] Agents connecting to the `agents` namespace appear in the realtime roster
|
||||
with accurate hostname, username, and fingerprint details.
|
||||
- [ ] Screenshot broadcasts relay from agents to the UI without residual cache
|
||||
bleed-through after disconnects.
|
||||
- [ ] Macro execution responses round-trip through Socket.IO and reach the
|
||||
initiating client.
|
||||
|
||||
### GitHub integration
|
||||
|
||||
- [ ] `GET /api/repo/current_hash` reflects the latest branch head and caches
|
||||
repeated calls.
|
||||
- [ ] `POST /api/github/token` persists a new token and survives Engine restarts
|
||||
(confirm via database inspection).
|
||||
- [ ] The background refresher logs rate-limit warnings instead of raising
|
||||
uncaught exceptions when the GitHub API throttles requests.
|
||||
|
||||
## 4. Recording divergences
|
||||
|
||||
Use the table below to document behavioural differences or bugs uncovered during
|
||||
staging. This artifact should accompany the staging run summary so follow-up
|
||||
fixes can be triaged quickly.
|
||||
|
||||
| Area | Legacy Behaviour | Engine Behaviour | Notes / Links |
|
||||
| --- | --- | --- | --- |
|
||||
| Authentication | | | |
|
||||
| Enrollment | | | |
|
||||
| Scheduler | | | |
|
||||
| Realtime | | | |
|
||||
| GitHub | | | |
|
||||
| Other | | | |
|
||||
|
||||
## 5. Cut-over readiness
|
||||
|
||||
Once every checklist item passes and no critical divergences remain:
|
||||
|
||||
1. Update `Data/Engine/CURRENT_STAGE.md` with the completion date for Step 12.
|
||||
2. Coordinate with the operator to switch deployment scripts to
|
||||
`Data/Engine/bootstrapper.py`.
|
||||
3. Plan a rollback strategy (typically re-launching the legacy server) should
|
||||
issues appear immediately after the cut-over.
|
||||
4. Archive the filled divergence table alongside Engine logs for historical
|
||||
traceability.
|
||||
|
||||
Document the results in project tracking tools before moving on to deprecating
|
||||
`Data/Server`.
|
||||
11
Data/Engine/__init__.py
Normal file
11
Data/Engine/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Borealis Engine package.
|
||||
|
||||
This namespace contains the next-generation server implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"bootstrapper",
|
||||
"server",
|
||||
]
|
||||
113
Data/Engine/bootstrapper.py
Normal file
113
Data/Engine/bootstrapper.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Entrypoint for the Borealis Engine server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from .config import EngineSettings, configure_logging, load_environment
|
||||
from .interfaces import (
|
||||
create_socket_server,
|
||||
register_http_interfaces,
|
||||
register_ws_interfaces,
|
||||
)
|
||||
from .interfaces.eventlet_compat import apply_eventlet_patches
|
||||
from .repositories.sqlite import connection as sqlite_connection
|
||||
from .repositories.sqlite import migrations as sqlite_migrations
|
||||
from .server import create_app
|
||||
from .services.container import build_service_container
|
||||
from .services.crypto.certificates import ensure_certificate
|
||||
|
||||
|
||||
apply_eventlet_patches()
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EngineRuntime:
|
||||
"""Aggregated runtime context produced by :func:`bootstrap`."""
|
||||
|
||||
app: Flask
|
||||
settings: EngineSettings
|
||||
socketio: Optional[object]
|
||||
db_factory: sqlite_connection.SQLiteConnectionFactory
|
||||
tls_certificate: Path
|
||||
tls_key: Path
|
||||
tls_bundle: Path
|
||||
|
||||
|
||||
def bootstrap() -> EngineRuntime:
|
||||
"""Construct the Flask application and supporting infrastructure."""
|
||||
|
||||
settings = load_environment()
|
||||
logger = configure_logging(settings)
|
||||
logger.info("bootstrap-started")
|
||||
|
||||
cert_path, key_path, bundle_path = ensure_certificate()
|
||||
os.environ.setdefault("BOREALIS_TLS_BUNDLE", str(bundle_path))
|
||||
logger.info(
|
||||
"tls-material-ready",
|
||||
extra={
|
||||
"cert_path": str(cert_path),
|
||||
"key_path": str(key_path),
|
||||
"bundle_path": str(bundle_path),
|
||||
},
|
||||
)
|
||||
|
||||
db_factory = sqlite_connection.connection_factory(settings.database_path)
|
||||
if settings.apply_migrations:
|
||||
logger.info("migrations-start")
|
||||
with sqlite_connection.connection_scope(settings.database_path) as conn:
|
||||
sqlite_migrations.apply_all(conn)
|
||||
logger.info("migrations-complete")
|
||||
else:
|
||||
logger.info("migrations-skipped")
|
||||
|
||||
app = create_app(settings, db_factory=db_factory)
|
||||
services = build_service_container(settings, db_factory=db_factory, logger=logger.getChild("services"))
|
||||
app.extensions["engine_services"] = services
|
||||
register_http_interfaces(app, services)
|
||||
socketio = create_socket_server(app, settings.socketio)
|
||||
register_ws_interfaces(socketio, services)
|
||||
services.scheduler_service.start(socketio)
|
||||
logger.info("bootstrap-complete")
|
||||
return EngineRuntime(
|
||||
app=app,
|
||||
settings=settings,
|
||||
socketio=socketio,
|
||||
db_factory=db_factory,
|
||||
tls_certificate=cert_path,
|
||||
tls_key=key_path,
|
||||
tls_bundle=bundle_path,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
runtime = bootstrap()
|
||||
socketio = runtime.socketio
|
||||
certfile = str(runtime.tls_bundle)
|
||||
keyfile = str(runtime.tls_key)
|
||||
|
||||
if socketio is not None:
|
||||
socketio.run( # type: ignore[call-arg]
|
||||
runtime.app,
|
||||
host=runtime.settings.server.host,
|
||||
port=runtime.settings.server.port,
|
||||
debug=runtime.settings.debug,
|
||||
certfile=certfile,
|
||||
keyfile=keyfile,
|
||||
)
|
||||
else:
|
||||
runtime.app.run(
|
||||
host=runtime.settings.server.host,
|
||||
port=runtime.settings.server.port,
|
||||
debug=runtime.settings.debug,
|
||||
ssl_context=(certfile, keyfile),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover - manual execution
|
||||
main()
|
||||
35
Data/Engine/builders/__init__.py
Normal file
35
Data/Engine/builders/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Builder utilities for constructing immutable Engine aggregates."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .device_auth import (
|
||||
DeviceAuthRequest,
|
||||
DeviceAuthRequestBuilder,
|
||||
RefreshTokenRequest,
|
||||
RefreshTokenRequestBuilder,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthRequest",
|
||||
"DeviceAuthRequestBuilder",
|
||||
"RefreshTokenRequest",
|
||||
"RefreshTokenRequestBuilder",
|
||||
]
|
||||
|
||||
try: # pragma: no cover - optional dependency shim
|
||||
from .device_enrollment import (
|
||||
EnrollmentRequestBuilder,
|
||||
ProofChallengeBuilder,
|
||||
)
|
||||
except ModuleNotFoundError as exc: # pragma: no cover - executed when crypto deps missing
|
||||
_missing_reason = str(exc)
|
||||
|
||||
def _missing_builder(*_args: object, **_kwargs: object) -> None:
|
||||
raise ModuleNotFoundError(
|
||||
"device enrollment builders require optional cryptography dependencies"
|
||||
) from exc
|
||||
|
||||
EnrollmentRequestBuilder = _missing_builder # type: ignore[assignment]
|
||||
ProofChallengeBuilder = _missing_builder # type: ignore[assignment]
|
||||
else:
|
||||
__all__ += ["EnrollmentRequestBuilder", "ProofChallengeBuilder"]
|
||||
165
Data/Engine/builders/device_auth.py
Normal file
165
Data/Engine/builders/device_auth.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Builders for device authentication and token refresh inputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from Data.Engine.domain.device_auth import (
|
||||
DeviceAuthErrorCode,
|
||||
DeviceAuthFailure,
|
||||
DeviceGuid,
|
||||
sanitize_service_context,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthRequest",
|
||||
"DeviceAuthRequestBuilder",
|
||||
"RefreshTokenRequest",
|
||||
"RefreshTokenRequestBuilder",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceAuthRequest:
|
||||
"""Normalized authentication inputs derived from an HTTP request."""
|
||||
|
||||
access_token: str
|
||||
http_method: str
|
||||
htu: str
|
||||
service_context: Optional[str]
|
||||
dpop_proof: Optional[str]
|
||||
|
||||
|
||||
class DeviceAuthRequestBuilder:
|
||||
"""Validate and normalize HTTP headers for device authentication."""
|
||||
|
||||
_authorization: Optional[str]
|
||||
_http_method: Optional[str]
|
||||
_htu: Optional[str]
|
||||
_service_context: Optional[str]
|
||||
_dpop_proof: Optional[str]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._authorization = None
|
||||
self._http_method = None
|
||||
self._htu = None
|
||||
self._service_context = None
|
||||
self._dpop_proof = None
|
||||
|
||||
def with_authorization(self, header_value: Optional[str]) -> "DeviceAuthRequestBuilder":
|
||||
if header_value is None:
|
||||
self._authorization = None
|
||||
else:
|
||||
self._authorization = header_value.strip()
|
||||
return self
|
||||
|
||||
def with_http_method(self, method: Optional[str]) -> "DeviceAuthRequestBuilder":
|
||||
self._http_method = (method or "").strip().upper()
|
||||
return self
|
||||
|
||||
def with_htu(self, url: Optional[str]) -> "DeviceAuthRequestBuilder":
|
||||
self._htu = (url or "").strip()
|
||||
return self
|
||||
|
||||
def with_service_context(self, header_value: Optional[str]) -> "DeviceAuthRequestBuilder":
|
||||
self._service_context = sanitize_service_context(header_value)
|
||||
return self
|
||||
|
||||
def with_dpop_proof(self, proof: Optional[str]) -> "DeviceAuthRequestBuilder":
|
||||
self._dpop_proof = (proof or "").strip() or None
|
||||
return self
|
||||
|
||||
def build(self) -> DeviceAuthRequest:
|
||||
token = self._parse_authorization(self._authorization)
|
||||
method = (self._http_method or "").strip().upper()
|
||||
if not method:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN, detail="missing HTTP method")
|
||||
url = (self._htu or "").strip()
|
||||
if not url:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN, detail="missing request URL")
|
||||
return DeviceAuthRequest(
|
||||
access_token=token,
|
||||
http_method=method,
|
||||
htu=url,
|
||||
service_context=self._service_context,
|
||||
dpop_proof=self._dpop_proof,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_authorization(header_value: Optional[str]) -> str:
|
||||
header = (header_value or "").strip()
|
||||
if not header:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.MISSING_AUTHORIZATION)
|
||||
prefix = "Bearer "
|
||||
if not header.startswith(prefix):
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.MISSING_AUTHORIZATION)
|
||||
token = header[len(prefix) :].strip()
|
||||
if not token:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.MISSING_AUTHORIZATION)
|
||||
return token
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RefreshTokenRequest:
|
||||
"""Validated refresh token payload supplied by an agent."""
|
||||
|
||||
guid: DeviceGuid
|
||||
refresh_token: str
|
||||
http_method: str
|
||||
htu: str
|
||||
dpop_proof: Optional[str]
|
||||
|
||||
|
||||
class RefreshTokenRequestBuilder:
|
||||
"""Helper to normalize refresh token JSON payloads."""
|
||||
|
||||
_guid: Optional[str]
|
||||
_refresh_token: Optional[str]
|
||||
_http_method: Optional[str]
|
||||
_htu: Optional[str]
|
||||
_dpop_proof: Optional[str]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._guid = None
|
||||
self._refresh_token = None
|
||||
self._http_method = None
|
||||
self._htu = None
|
||||
self._dpop_proof = None
|
||||
|
||||
def with_payload(self, payload: Optional[dict[str, object]]) -> "RefreshTokenRequestBuilder":
|
||||
payload = payload or {}
|
||||
self._guid = str(payload.get("guid") or "").strip()
|
||||
self._refresh_token = str(payload.get("refresh_token") or "").strip()
|
||||
return self
|
||||
|
||||
def with_http_method(self, method: Optional[str]) -> "RefreshTokenRequestBuilder":
|
||||
self._http_method = (method or "").strip().upper()
|
||||
return self
|
||||
|
||||
def with_htu(self, url: Optional[str]) -> "RefreshTokenRequestBuilder":
|
||||
self._htu = (url or "").strip()
|
||||
return self
|
||||
|
||||
def with_dpop_proof(self, proof: Optional[str]) -> "RefreshTokenRequestBuilder":
|
||||
self._dpop_proof = (proof or "").strip() or None
|
||||
return self
|
||||
|
||||
def build(self) -> RefreshTokenRequest:
|
||||
if not self._guid:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_CLAIMS, detail="missing guid")
|
||||
if not self._refresh_token:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_CLAIMS, detail="missing refresh token")
|
||||
method = (self._http_method or "").strip().upper()
|
||||
if not method:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN, detail="missing HTTP method")
|
||||
url = (self._htu or "").strip()
|
||||
if not url:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN, detail="missing request URL")
|
||||
return RefreshTokenRequest(
|
||||
guid=DeviceGuid(self._guid),
|
||||
refresh_token=self._refresh_token,
|
||||
http_method=method,
|
||||
htu=url,
|
||||
dpop_proof=self._dpop_proof,
|
||||
)
|
||||
131
Data/Engine/builders/device_enrollment.py
Normal file
131
Data/Engine/builders/device_enrollment.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Builder utilities for device enrollment payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from Data.Engine.domain.device_auth import DeviceFingerprint, sanitize_service_context
|
||||
from Data.Engine.domain.device_enrollment import (
|
||||
EnrollmentValidationError,
|
||||
ProofChallenge,
|
||||
)
|
||||
from Data.Engine.integrations.crypto import keys as crypto_keys
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentRequestBuilder",
|
||||
"EnrollmentRequestInput",
|
||||
"ProofChallengeBuilder",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentRequestInput:
|
||||
"""Structured enrollment request payload ready for the service layer."""
|
||||
|
||||
hostname: str
|
||||
enrollment_code: str
|
||||
fingerprint: DeviceFingerprint
|
||||
client_nonce: bytes
|
||||
client_nonce_b64: str
|
||||
agent_public_key_der: bytes
|
||||
service_context: Optional[str]
|
||||
|
||||
|
||||
class EnrollmentRequestBuilder:
|
||||
"""Normalize agent enrollment JSON payloads into domain objects."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._hostname: Optional[str] = None
|
||||
self._enrollment_code: Optional[str] = None
|
||||
self._agent_pubkey_b64: Optional[str] = None
|
||||
self._client_nonce_b64: Optional[str] = None
|
||||
self._service_context: Optional[str] = None
|
||||
|
||||
def with_payload(self, payload: Optional[dict[str, object]]) -> "EnrollmentRequestBuilder":
|
||||
payload = payload or {}
|
||||
self._hostname = str(payload.get("hostname") or "").strip()
|
||||
self._enrollment_code = str(payload.get("enrollment_code") or "").strip()
|
||||
agent_pubkey = payload.get("agent_pubkey")
|
||||
self._agent_pubkey_b64 = agent_pubkey if isinstance(agent_pubkey, str) else None
|
||||
client_nonce = payload.get("client_nonce")
|
||||
self._client_nonce_b64 = client_nonce if isinstance(client_nonce, str) else None
|
||||
return self
|
||||
|
||||
def with_service_context(self, value: Optional[str]) -> "EnrollmentRequestBuilder":
|
||||
self._service_context = value
|
||||
return self
|
||||
|
||||
def build(self) -> EnrollmentRequestInput:
|
||||
if not self._hostname:
|
||||
raise EnrollmentValidationError("hostname_required")
|
||||
if not self._enrollment_code:
|
||||
raise EnrollmentValidationError("enrollment_code_required")
|
||||
if not self._agent_pubkey_b64:
|
||||
raise EnrollmentValidationError("agent_pubkey_required")
|
||||
if not self._client_nonce_b64:
|
||||
raise EnrollmentValidationError("client_nonce_required")
|
||||
|
||||
try:
|
||||
agent_pubkey_der = crypto_keys.spki_der_from_base64(self._agent_pubkey_b64)
|
||||
except Exception as exc: # pragma: no cover - invalid input path
|
||||
raise EnrollmentValidationError("invalid_agent_pubkey") from exc
|
||||
|
||||
if len(agent_pubkey_der) < 10:
|
||||
raise EnrollmentValidationError("invalid_agent_pubkey")
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(self._client_nonce_b64, validate=True)
|
||||
except Exception as exc: # pragma: no cover - invalid input path
|
||||
raise EnrollmentValidationError("invalid_client_nonce") from exc
|
||||
|
||||
if len(client_nonce_bytes) < 16:
|
||||
raise EnrollmentValidationError("invalid_client_nonce")
|
||||
|
||||
fingerprint_value = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
||||
|
||||
return EnrollmentRequestInput(
|
||||
hostname=self._hostname,
|
||||
enrollment_code=self._enrollment_code,
|
||||
fingerprint=DeviceFingerprint(fingerprint_value),
|
||||
client_nonce=client_nonce_bytes,
|
||||
client_nonce_b64=self._client_nonce_b64,
|
||||
agent_public_key_der=agent_pubkey_der,
|
||||
service_context=sanitize_service_context(self._service_context),
|
||||
)
|
||||
|
||||
|
||||
class ProofChallengeBuilder:
|
||||
"""Construct proof challenges during enrollment approval."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._server_nonce: Optional[bytes] = None
|
||||
self._client_nonce: Optional[bytes] = None
|
||||
self._fingerprint: Optional[DeviceFingerprint] = None
|
||||
|
||||
def with_server_nonce(self, nonce: Optional[bytes]) -> "ProofChallengeBuilder":
|
||||
self._server_nonce = bytes(nonce or b"")
|
||||
return self
|
||||
|
||||
def with_client_nonce(self, nonce: Optional[bytes]) -> "ProofChallengeBuilder":
|
||||
self._client_nonce = bytes(nonce or b"")
|
||||
return self
|
||||
|
||||
def with_fingerprint(self, fingerprint: Optional[str]) -> "ProofChallengeBuilder":
|
||||
if fingerprint:
|
||||
self._fingerprint = DeviceFingerprint(fingerprint)
|
||||
else:
|
||||
self._fingerprint = None
|
||||
return self
|
||||
|
||||
def build(self) -> ProofChallenge:
|
||||
if self._server_nonce is None or self._client_nonce is None:
|
||||
raise ValueError("both server and client nonces are required")
|
||||
if not self._fingerprint:
|
||||
raise ValueError("fingerprint is required")
|
||||
return ProofChallenge(
|
||||
client_nonce=self._client_nonce,
|
||||
server_nonce=self._server_nonce,
|
||||
fingerprint=self._fingerprint,
|
||||
)
|
||||
382
Data/Engine/builders/job_fabricator.py
Normal file
382
Data/Engine/builders/job_fabricator.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""Builders for Engine job manifests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from Data.Engine.repositories.sqlite.job_repository import ScheduledJobRecord
|
||||
|
||||
__all__ = [
|
||||
"JobComponentManifest",
|
||||
"JobManifest",
|
||||
"JobFabricator",
|
||||
]
|
||||
|
||||
|
||||
_ENV_VAR_PATTERN = re.compile(r"(?i)\$env:(\{)?([A-Za-z0-9_\-]+)(?(1)\})")
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class JobComponentManifest:
|
||||
"""Materialized job component ready for execution."""
|
||||
|
||||
name: str
|
||||
path: str
|
||||
script_type: str
|
||||
script_content: str
|
||||
encoded_content: str
|
||||
environment: Dict[str, str]
|
||||
literal_environment: Dict[str, str]
|
||||
timeout_seconds: int
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class JobManifest:
|
||||
job_id: int
|
||||
name: str
|
||||
occurrence_ts: int
|
||||
execution_context: str
|
||||
targets: Tuple[str, ...]
|
||||
components: Tuple[JobComponentManifest, ...]
|
||||
|
||||
|
||||
class JobFabricator:
|
||||
"""Convert stored job records into immutable manifests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
assemblies_root: Path,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._assemblies_root = assemblies_root
|
||||
self._log = logger or logging.getLogger("borealis.engine.builders.jobs")
|
||||
|
||||
def build(
|
||||
self,
|
||||
job: ScheduledJobRecord,
|
||||
*,
|
||||
occurrence_ts: int,
|
||||
) -> JobManifest:
|
||||
components = tuple(self._materialize_component(job, component) for component in job.components)
|
||||
targets = tuple(str(t) for t in job.targets)
|
||||
return JobManifest(
|
||||
job_id=job.id,
|
||||
name=job.name,
|
||||
occurrence_ts=occurrence_ts,
|
||||
execution_context=job.execution_context,
|
||||
targets=targets,
|
||||
components=tuple(c for c in components if c is not None),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Component handling
|
||||
# ------------------------------------------------------------------
|
||||
def _materialize_component(
|
||||
self,
|
||||
job: ScheduledJobRecord,
|
||||
component: Mapping[str, Any],
|
||||
) -> Optional[JobComponentManifest]:
|
||||
if not isinstance(component, Mapping):
|
||||
return None
|
||||
|
||||
component_type = str(component.get("type") or "").strip().lower()
|
||||
if component_type not in {"script", "ansible"}:
|
||||
return None
|
||||
|
||||
path = str(component.get("path") or component.get("script_path") or "").strip()
|
||||
if not path:
|
||||
return None
|
||||
|
||||
try:
|
||||
abs_path = self._resolve_script_path(path)
|
||||
except FileNotFoundError:
|
||||
self._log.warning(
|
||||
"job component path invalid", extra={"job_id": job.id, "path": path}
|
||||
)
|
||||
return None
|
||||
script_type = self._detect_script_type(abs_path, component_type)
|
||||
script_content = self._load_script_content(abs_path, component)
|
||||
|
||||
doc_variables: List[Dict[str, Any]] = []
|
||||
if isinstance(component.get("variables"), list):
|
||||
doc_variables = [v for v in component["variables"] if isinstance(v, dict)]
|
||||
overrides = self._collect_overrides(component)
|
||||
env_map, _, literal_lookup = _prepare_variable_context(doc_variables, overrides)
|
||||
|
||||
rewritten = _rewrite_powershell_script(script_content, literal_lookup)
|
||||
encoded = _encode_script_content(rewritten)
|
||||
|
||||
timeout_seconds = _coerce_int(component.get("timeout_seconds"))
|
||||
if not timeout_seconds:
|
||||
timeout_seconds = _coerce_int(component.get("timeout"))
|
||||
|
||||
return JobComponentManifest(
|
||||
name=self._component_name(abs_path, component),
|
||||
path=path,
|
||||
script_type=script_type,
|
||||
script_content=rewritten,
|
||||
encoded_content=encoded,
|
||||
environment=env_map,
|
||||
literal_environment=literal_lookup,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
def _component_name(self, abs_path: Path, component: Mapping[str, Any]) -> str:
|
||||
if isinstance(component.get("name"), str) and component["name"].strip():
|
||||
return component["name"].strip()
|
||||
return abs_path.stem
|
||||
|
||||
def _resolve_script_path(self, rel_path: str) -> Path:
|
||||
candidate = Path(rel_path.replace("\\", "/").lstrip("/"))
|
||||
if candidate.parts and candidate.parts[0] != "Scripts":
|
||||
candidate = Path("Scripts") / candidate
|
||||
abs_path = (self._assemblies_root / candidate).resolve()
|
||||
try:
|
||||
abs_path.relative_to(self._assemblies_root)
|
||||
except ValueError:
|
||||
raise FileNotFoundError(rel_path)
|
||||
if not abs_path.is_file():
|
||||
raise FileNotFoundError(rel_path)
|
||||
return abs_path
|
||||
|
||||
def _load_script_content(self, abs_path: Path, component: Mapping[str, Any]) -> str:
|
||||
if isinstance(component.get("script"), str) and component["script"].strip():
|
||||
return _decode_script_content(component["script"], component.get("encoding") or "")
|
||||
try:
|
||||
return abs_path.read_text(encoding="utf-8")
|
||||
except Exception as exc:
|
||||
self._log.warning("unable to read script for job component: path=%s error=%s", abs_path, exc)
|
||||
return ""
|
||||
|
||||
def _detect_script_type(self, abs_path: Path, declared: str) -> str:
|
||||
lower = declared.lower()
|
||||
if lower in {"script", "powershell"}:
|
||||
return "powershell"
|
||||
suffix = abs_path.suffix.lower()
|
||||
if suffix == ".ps1":
|
||||
return "powershell"
|
||||
if suffix == ".yml":
|
||||
return "ansible"
|
||||
if suffix == ".json":
|
||||
try:
|
||||
data = json.loads(abs_path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, dict):
|
||||
t = str(data.get("type") or data.get("script_type") or "").strip().lower()
|
||||
if t:
|
||||
return t
|
||||
except Exception:
|
||||
pass
|
||||
return lower or "powershell"
|
||||
|
||||
def _collect_overrides(self, component: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
overrides: Dict[str, Any] = {}
|
||||
values = component.get("variable_values")
|
||||
if isinstance(values, Mapping):
|
||||
for key, value in values.items():
|
||||
name = str(key or "").strip()
|
||||
if name:
|
||||
overrides[name] = value
|
||||
vars_inline = component.get("variables")
|
||||
if isinstance(vars_inline, Iterable):
|
||||
for var in vars_inline:
|
||||
if not isinstance(var, Mapping):
|
||||
continue
|
||||
name = str(var.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
if "value" in var:
|
||||
overrides[name] = var.get("value")
|
||||
return overrides
|
||||
|
||||
|
||||
def _coerce_int(value: Any) -> int:
|
||||
try:
|
||||
return int(value or 0)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _env_string(value: Any) -> str:
|
||||
if isinstance(value, bool):
|
||||
return "True" if value else "False"
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value)
|
||||
|
||||
|
||||
def _decode_base64_text(value: Any) -> Optional[str]:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return ""
|
||||
try:
|
||||
cleaned = re.sub(r"\s+", "", stripped)
|
||||
except Exception:
|
||||
cleaned = stripped
|
||||
try:
|
||||
decoded = base64.b64decode(cleaned, validate=True)
|
||||
except Exception:
|
||||
return None
|
||||
try:
|
||||
return decoded.decode("utf-8")
|
||||
except Exception:
|
||||
return decoded.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _decode_script_content(value: Any, encoding_hint: Any = "") -> str:
|
||||
encoding = str(encoding_hint or "").strip().lower()
|
||||
if isinstance(value, str):
|
||||
if encoding in {"base64", "b64", "base-64"}:
|
||||
decoded = _decode_base64_text(value)
|
||||
if decoded is not None:
|
||||
return decoded.replace("\r\n", "\n")
|
||||
decoded = _decode_base64_text(value)
|
||||
if decoded is not None:
|
||||
return decoded.replace("\r\n", "\n")
|
||||
return value.replace("\r\n", "\n")
|
||||
return ""
|
||||
|
||||
|
||||
def _encode_script_content(script_text: Any) -> str:
|
||||
if not isinstance(script_text, str):
|
||||
if script_text is None:
|
||||
script_text = ""
|
||||
else:
|
||||
script_text = str(script_text)
|
||||
normalized = script_text.replace("\r\n", "\n")
|
||||
if not normalized:
|
||||
return ""
|
||||
encoded = base64.b64encode(normalized.encode("utf-8"))
|
||||
return encoded.decode("ascii")
|
||||
|
||||
|
||||
def _canonical_env_key(name: Any) -> str:
|
||||
try:
|
||||
return re.sub(r"[^A-Za-z0-9_]", "_", str(name or "").strip()).upper()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _expand_env_aliases(env_map: Dict[str, str], variables: Iterable[Mapping[str, Any]]) -> Dict[str, str]:
|
||||
expanded = dict(env_map or {})
|
||||
for var in variables:
|
||||
if not isinstance(var, Mapping):
|
||||
continue
|
||||
name = str(var.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
canonical = _canonical_env_key(name)
|
||||
if not canonical or canonical not in expanded:
|
||||
continue
|
||||
value = expanded[canonical]
|
||||
alias = re.sub(r"[^A-Za-z0-9_]", "_", name)
|
||||
if alias and alias not in expanded:
|
||||
expanded[alias] = value
|
||||
if alias != name and re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", name) and name not in expanded:
|
||||
expanded[name] = value
|
||||
return expanded
|
||||
|
||||
|
||||
def _powershell_literal(value: Any, var_type: str) -> str:
|
||||
typ = str(var_type or "string").lower()
|
||||
if typ == "boolean":
|
||||
if isinstance(value, bool):
|
||||
truthy = value
|
||||
elif value is None:
|
||||
truthy = False
|
||||
elif isinstance(value, (int, float)):
|
||||
truthy = value != 0
|
||||
else:
|
||||
s = str(value).strip().lower()
|
||||
if s in {"true", "1", "yes", "y", "on"}:
|
||||
truthy = True
|
||||
elif s in {"false", "0", "no", "n", "off", ""}:
|
||||
truthy = False
|
||||
else:
|
||||
truthy = bool(s)
|
||||
return "$true" if truthy else "$false"
|
||||
if typ == "number":
|
||||
if value is None or value == "":
|
||||
return "0"
|
||||
return str(value)
|
||||
s = "" if value is None else str(value)
|
||||
return "'" + s.replace("'", "''") + "'"
|
||||
|
||||
|
||||
def _extract_variable_default(var: Mapping[str, Any]) -> Any:
|
||||
for key in ("value", "default", "defaultValue", "default_value"):
|
||||
if key in var:
|
||||
val = var.get(key)
|
||||
return "" if val is None else val
|
||||
return ""
|
||||
|
||||
|
||||
def _prepare_variable_context(
|
||||
doc_variables: Iterable[Mapping[str, Any]],
|
||||
overrides: Mapping[str, Any],
|
||||
) -> Tuple[Dict[str, str], List[Dict[str, Any]], Dict[str, str]]:
|
||||
env_map: Dict[str, str] = {}
|
||||
variables: List[Dict[str, Any]] = []
|
||||
literal_lookup: Dict[str, str] = {}
|
||||
doc_names: Dict[str, bool] = {}
|
||||
|
||||
overrides = dict(overrides or {})
|
||||
|
||||
for var in doc_variables:
|
||||
if not isinstance(var, Mapping):
|
||||
continue
|
||||
name = str(var.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
doc_names[name] = True
|
||||
canonical = _canonical_env_key(name)
|
||||
var_type = str(var.get("type") or "string").lower()
|
||||
default_val = _extract_variable_default(var)
|
||||
final_val = overrides[name] if name in overrides else default_val
|
||||
if canonical:
|
||||
env_map[canonical] = _env_string(final_val)
|
||||
literal_lookup[canonical] = _powershell_literal(final_val, var_type)
|
||||
if name in overrides:
|
||||
new_var = dict(var)
|
||||
new_var["value"] = overrides[name]
|
||||
variables.append(new_var)
|
||||
else:
|
||||
variables.append(dict(var))
|
||||
|
||||
for name, val in overrides.items():
|
||||
if name in doc_names:
|
||||
continue
|
||||
canonical = _canonical_env_key(name)
|
||||
if canonical:
|
||||
env_map[canonical] = _env_string(val)
|
||||
literal_lookup[canonical] = _powershell_literal(val, "string")
|
||||
variables.append({"name": name, "value": val, "type": "string"})
|
||||
|
||||
env_map = _expand_env_aliases(env_map, variables)
|
||||
return env_map, variables, literal_lookup
|
||||
|
||||
|
||||
def _rewrite_powershell_script(content: str, literal_lookup: Mapping[str, str]) -> str:
|
||||
if not content or not literal_lookup:
|
||||
return content
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
name = match.group(2)
|
||||
canonical = _canonical_env_key(name)
|
||||
if not canonical:
|
||||
return match.group(0)
|
||||
literal = literal_lookup.get(canonical)
|
||||
if literal is None:
|
||||
return match.group(0)
|
||||
return literal
|
||||
|
||||
return _ENV_VAR_PATTERN.sub(_replace, content)
|
||||
25
Data/Engine/config/__init__.py
Normal file
25
Data/Engine/config/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Configuration primitives for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .environment import (
|
||||
DatabaseSettings,
|
||||
EngineSettings,
|
||||
FlaskSettings,
|
||||
GitHubSettings,
|
||||
ServerSettings,
|
||||
SocketIOSettings,
|
||||
load_environment,
|
||||
)
|
||||
from .logging import configure_logging
|
||||
|
||||
__all__ = [
|
||||
"DatabaseSettings",
|
||||
"EngineSettings",
|
||||
"FlaskSettings",
|
||||
"GitHubSettings",
|
||||
"load_environment",
|
||||
"ServerSettings",
|
||||
"SocketIOSettings",
|
||||
"configure_logging",
|
||||
]
|
||||
206
Data/Engine/config/environment.py
Normal file
206
Data/Engine/config/environment.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""Environment detection for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DatabaseSettings:
|
||||
"""SQLite database configuration for the Engine."""
|
||||
|
||||
path: Path
|
||||
apply_migrations: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FlaskSettings:
|
||||
"""Parameters that influence Flask application behavior."""
|
||||
|
||||
secret_key: str
|
||||
static_root: Path
|
||||
cors_allowed_origins: Tuple[str, ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SocketIOSettings:
|
||||
"""Configuration for the optional Socket.IO server."""
|
||||
|
||||
cors_allowed_origins: Tuple[str, ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ServerSettings:
|
||||
"""HTTP server binding configuration."""
|
||||
|
||||
host: str
|
||||
port: int
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GitHubSettings:
|
||||
"""Configuration surface for GitHub repository interactions."""
|
||||
|
||||
default_repo: str
|
||||
default_branch: str
|
||||
refresh_interval_seconds: int
|
||||
cache_root: Path
|
||||
|
||||
@property
|
||||
def cache_file(self) -> Path:
|
||||
"""Location of the persisted repository-head cache."""
|
||||
|
||||
return self.cache_root / "repo_hash_cache.json"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EngineSettings:
|
||||
"""Immutable container describing the Engine runtime configuration."""
|
||||
|
||||
project_root: Path
|
||||
debug: bool
|
||||
database: DatabaseSettings
|
||||
flask: FlaskSettings
|
||||
socketio: SocketIOSettings
|
||||
server: ServerSettings
|
||||
github: GitHubSettings
|
||||
|
||||
@property
|
||||
def logs_root(self) -> Path:
|
||||
"""Return the directory where Engine-specific logs should live."""
|
||||
|
||||
return self.project_root / "Logs" / "Server"
|
||||
|
||||
@property
|
||||
def database_path(self) -> Path:
|
||||
"""Convenience accessor for the database file path."""
|
||||
|
||||
return self.database.path
|
||||
|
||||
@property
|
||||
def apply_migrations(self) -> bool:
|
||||
"""Return whether schema migrations should run at bootstrap."""
|
||||
|
||||
return self.database.apply_migrations
|
||||
|
||||
|
||||
def _resolve_project_root() -> Path:
|
||||
candidate = os.getenv("BOREALIS_ROOT")
|
||||
if candidate:
|
||||
return Path(candidate).expanduser().resolve()
|
||||
return Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
def _resolve_database_path(project_root: Path) -> Path:
|
||||
candidate = os.getenv("BOREALIS_DATABASE_PATH")
|
||||
if candidate:
|
||||
return Path(candidate).expanduser().resolve()
|
||||
return (project_root / "database.db").resolve()
|
||||
|
||||
|
||||
def _should_apply_migrations() -> bool:
|
||||
raw = os.getenv("BOREALIS_ENGINE_AUTO_MIGRATE", "true")
|
||||
return raw.lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _resolve_static_root(project_root: Path) -> Path:
|
||||
candidate = os.getenv("BOREALIS_STATIC_ROOT")
|
||||
if candidate:
|
||||
return Path(candidate).expanduser().resolve()
|
||||
|
||||
candidates = (
|
||||
project_root / "Engine" / "web-interface" / "build",
|
||||
project_root / "Engine" / "web-interface" / "dist",
|
||||
project_root / "Data" / "Engine" / "WebUI" / "build",
|
||||
project_root / "Data" / "Server" / "web-interface" / "build",
|
||||
project_root / "Data" / "Server" / "WebUI" / "build",
|
||||
project_root / "Data" / "WebUI" / "build",
|
||||
)
|
||||
for path in candidates:
|
||||
resolved = path.resolve()
|
||||
if resolved.is_dir():
|
||||
return resolved
|
||||
|
||||
# Fall back to the first candidate even if it does not yet exist so the
|
||||
# Flask factory still initialises; individual requests will surface 404s
|
||||
# until an asset build is available, matching the legacy behaviour.
|
||||
return candidates[0].resolve()
|
||||
|
||||
|
||||
def _resolve_github_cache_root(project_root: Path) -> Path:
|
||||
candidate = os.getenv("BOREALIS_CACHE_DIR")
|
||||
if candidate:
|
||||
return Path(candidate).expanduser().resolve()
|
||||
return (project_root / "Data" / "Engine" / "cache").resolve()
|
||||
|
||||
|
||||
def _parse_refresh_interval(raw: str | None) -> int:
|
||||
if not raw:
|
||||
return 60
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError:
|
||||
value = 60
|
||||
return max(30, min(value, 3600))
|
||||
|
||||
|
||||
def _parse_origins(raw: str | None) -> Tuple[str, ...]:
|
||||
if not raw:
|
||||
return ("*",)
|
||||
parts: Iterable[str] = (segment.strip() for segment in raw.split(","))
|
||||
filtered = tuple(part for part in parts if part)
|
||||
return filtered or ("*",)
|
||||
|
||||
|
||||
def load_environment() -> EngineSettings:
|
||||
"""Load Engine settings from environment variables and filesystem hints."""
|
||||
|
||||
project_root = _resolve_project_root()
|
||||
database = DatabaseSettings(
|
||||
path=_resolve_database_path(project_root),
|
||||
apply_migrations=_should_apply_migrations(),
|
||||
)
|
||||
cors_allowed_origins = _parse_origins(os.getenv("BOREALIS_CORS_ALLOWED_ORIGINS"))
|
||||
flask_settings = FlaskSettings(
|
||||
secret_key=os.getenv("BOREALIS_FLASK_SECRET_KEY", "change-me"),
|
||||
static_root=_resolve_static_root(project_root),
|
||||
cors_allowed_origins=cors_allowed_origins,
|
||||
)
|
||||
socket_settings = SocketIOSettings(cors_allowed_origins=cors_allowed_origins)
|
||||
debug = os.getenv("BOREALIS_DEBUG", "false").lower() in {"1", "true", "yes", "on"}
|
||||
host = os.getenv("BOREALIS_HOST", "127.0.0.1")
|
||||
try:
|
||||
port = int(os.getenv("BOREALIS_PORT", "5000"))
|
||||
except ValueError:
|
||||
port = 5000
|
||||
server_settings = ServerSettings(host=host, port=port)
|
||||
github_settings = GitHubSettings(
|
||||
default_repo=os.getenv("BOREALIS_REPO", "bunny-lab-io/Borealis"),
|
||||
default_branch=os.getenv("BOREALIS_REPO_BRANCH", "main"),
|
||||
refresh_interval_seconds=_parse_refresh_interval(os.getenv("BOREALIS_REPO_HASH_REFRESH")),
|
||||
cache_root=_resolve_github_cache_root(project_root),
|
||||
)
|
||||
|
||||
return EngineSettings(
|
||||
project_root=project_root,
|
||||
debug=debug,
|
||||
database=database,
|
||||
flask=flask_settings,
|
||||
socketio=socket_settings,
|
||||
server=server_settings,
|
||||
github=github_settings,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DatabaseSettings",
|
||||
"EngineSettings",
|
||||
"FlaskSettings",
|
||||
"GitHubSettings",
|
||||
"SocketIOSettings",
|
||||
"ServerSettings",
|
||||
"load_environment",
|
||||
]
|
||||
71
Data/Engine/config/logging.py
Normal file
71
Data/Engine/config/logging.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Logging bootstrap helpers for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
from .environment import EngineSettings
|
||||
|
||||
|
||||
_ENGINE_LOGGER_NAME = "borealis.engine"
|
||||
_SERVICE_NAME = "engine"
|
||||
_DEFAULT_FORMAT = "%(asctime)s-" + _SERVICE_NAME + "-%(message)s"
|
||||
|
||||
|
||||
def _handler_already_attached(logger: logging.Logger, log_path: Path) -> bool:
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, TimedRotatingFileHandler):
|
||||
handler_path = Path(getattr(handler, "baseFilename", ""))
|
||||
if handler_path == log_path:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_handler(log_path: Path) -> TimedRotatingFileHandler:
|
||||
handler = TimedRotatingFileHandler(
|
||||
log_path,
|
||||
when="midnight",
|
||||
backupCount=30,
|
||||
encoding="utf-8",
|
||||
)
|
||||
handler.setLevel(logging.INFO)
|
||||
handler.setFormatter(logging.Formatter(_DEFAULT_FORMAT))
|
||||
return handler
|
||||
|
||||
|
||||
def configure_logging(settings: EngineSettings) -> logging.Logger:
|
||||
"""Configure a rotating log handler for the Engine."""
|
||||
|
||||
logs_root = settings.logs_root
|
||||
logs_root.mkdir(parents=True, exist_ok=True)
|
||||
log_path = logs_root / "engine.log"
|
||||
|
||||
logger = logging.getLogger(_ENGINE_LOGGER_NAME)
|
||||
logger.setLevel(logging.INFO if not settings.debug else logging.DEBUG)
|
||||
|
||||
if not _handler_already_attached(logger, log_path):
|
||||
handler = _build_handler(log_path)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
|
||||
# Also ensure the root logger follows suit so third-party modules inherit the handler.
|
||||
root_logger = logging.getLogger()
|
||||
if not _handler_already_attached(root_logger, log_path):
|
||||
handler = _build_handler(log_path)
|
||||
root_logger.addHandler(handler)
|
||||
if root_logger.level == logging.WARNING:
|
||||
# Default level is WARNING; lower it to INFO so our handler captures application messages.
|
||||
root_logger.setLevel(logging.INFO if not settings.debug else logging.DEBUG)
|
||||
|
||||
# Quieten overly chatty frameworks unless debugging is explicitly requested.
|
||||
if not settings.debug:
|
||||
logging.getLogger("werkzeug").setLevel(logging.WARNING)
|
||||
logging.getLogger("engineio").setLevel(logging.WARNING)
|
||||
logging.getLogger("socketio").setLevel(logging.WARNING)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
__all__ = ["configure_logging"]
|
||||
49
Data/Engine/domain/__init__.py
Normal file
49
Data/Engine/domain/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Pure value objects and enums for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .device_auth import ( # noqa: F401
|
||||
AccessTokenClaims,
|
||||
DeviceAuthContext,
|
||||
DeviceAuthErrorCode,
|
||||
DeviceAuthFailure,
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
DeviceIdentity,
|
||||
DeviceStatus,
|
||||
sanitize_service_context,
|
||||
)
|
||||
from .device_enrollment import ( # noqa: F401
|
||||
EnrollmentApproval,
|
||||
EnrollmentApprovalStatus,
|
||||
EnrollmentCode,
|
||||
EnrollmentRequest,
|
||||
ProofChallenge,
|
||||
)
|
||||
from .github import ( # noqa: F401
|
||||
GitHubRateLimit,
|
||||
GitHubRepoRef,
|
||||
GitHubTokenStatus,
|
||||
RepoHeadSnapshot,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AccessTokenClaims",
|
||||
"DeviceAuthContext",
|
||||
"DeviceAuthErrorCode",
|
||||
"DeviceAuthFailure",
|
||||
"DeviceFingerprint",
|
||||
"DeviceGuid",
|
||||
"DeviceIdentity",
|
||||
"DeviceStatus",
|
||||
"EnrollmentApproval",
|
||||
"EnrollmentApprovalStatus",
|
||||
"EnrollmentCode",
|
||||
"EnrollmentRequest",
|
||||
"ProofChallenge",
|
||||
"GitHubRateLimit",
|
||||
"GitHubRepoRef",
|
||||
"GitHubTokenStatus",
|
||||
"RepoHeadSnapshot",
|
||||
"sanitize_service_context",
|
||||
]
|
||||
242
Data/Engine/domain/device_auth.py
Normal file
242
Data/Engine/domain/device_auth.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Domain primitives for device authentication and token validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Mapping, Optional
|
||||
import string
|
||||
import uuid
|
||||
|
||||
__all__ = [
|
||||
"DeviceGuid",
|
||||
"DeviceFingerprint",
|
||||
"DeviceIdentity",
|
||||
"DeviceStatus",
|
||||
"DeviceAuthErrorCode",
|
||||
"DeviceAuthFailure",
|
||||
"AccessTokenClaims",
|
||||
"DeviceAuthContext",
|
||||
"sanitize_service_context",
|
||||
]
|
||||
|
||||
|
||||
def _normalize_guid(value: Optional[str]) -> str:
|
||||
"""Return a canonical GUID string or an empty string."""
|
||||
candidate = (value or "").strip()
|
||||
if not candidate:
|
||||
return ""
|
||||
candidate = candidate.strip("{}")
|
||||
try:
|
||||
return str(uuid.UUID(candidate)).upper()
|
||||
except Exception:
|
||||
cleaned = "".join(
|
||||
ch for ch in candidate if ch in string.hexdigits or ch == "-"
|
||||
).strip("-")
|
||||
if cleaned:
|
||||
try:
|
||||
return str(uuid.UUID(cleaned)).upper()
|
||||
except Exception:
|
||||
pass
|
||||
return candidate.upper()
|
||||
|
||||
|
||||
def _normalize_fingerprint(value: Optional[str]) -> str:
|
||||
return (value or "").strip().lower()
|
||||
|
||||
|
||||
def sanitize_service_context(value: Optional[str]) -> Optional[str]:
|
||||
"""Normalize the optional agent service context header value."""
|
||||
if not value:
|
||||
return None
|
||||
cleaned = "".join(
|
||||
ch for ch in str(value) if ch.isalnum() or ch in ("_", "-")
|
||||
)
|
||||
if not cleaned:
|
||||
return None
|
||||
return cleaned.upper()
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceGuid:
|
||||
"""Canonical GUID wrapper that enforces Borealis normalization."""
|
||||
|
||||
value: str
|
||||
|
||||
def __post_init__(self) -> None: # pragma: no cover - simple data normalization
|
||||
normalized = _normalize_guid(self.value)
|
||||
if not normalized:
|
||||
raise ValueError("device GUID is required")
|
||||
object.__setattr__(self, "value", normalized)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceFingerprint:
|
||||
"""Normalized TLS key fingerprint associated with a device."""
|
||||
|
||||
value: str
|
||||
|
||||
def __post_init__(self) -> None: # pragma: no cover - simple data normalization
|
||||
normalized = _normalize_fingerprint(self.value)
|
||||
if not normalized:
|
||||
raise ValueError("device fingerprint is required")
|
||||
object.__setattr__(self, "value", normalized)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceIdentity:
|
||||
"""Immutable pairing of device GUID and TLS key fingerprint."""
|
||||
|
||||
guid: DeviceGuid
|
||||
fingerprint: DeviceFingerprint
|
||||
|
||||
|
||||
class DeviceStatus(str, Enum):
|
||||
"""Lifecycle markers mirrored from the legacy devices table."""
|
||||
|
||||
ACTIVE = "active"
|
||||
QUARANTINED = "quarantined"
|
||||
REVOKED = "revoked"
|
||||
DECOMMISSIONED = "decommissioned"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: Optional[str]) -> "DeviceStatus":
|
||||
normalized = (value or "active").strip().lower()
|
||||
try:
|
||||
return cls(normalized)
|
||||
except ValueError:
|
||||
return cls.ACTIVE
|
||||
|
||||
@property
|
||||
def allows_access(self) -> bool:
|
||||
return self in {self.ACTIVE, self.QUARANTINED}
|
||||
|
||||
|
||||
class DeviceAuthErrorCode(str, Enum):
|
||||
"""Well-known authentication failure categories."""
|
||||
|
||||
MISSING_AUTHORIZATION = "missing_authorization"
|
||||
TOKEN_EXPIRED = "token_expired"
|
||||
INVALID_TOKEN = "invalid_token"
|
||||
INVALID_CLAIMS = "invalid_claims"
|
||||
RATE_LIMITED = "rate_limited"
|
||||
DEVICE_NOT_FOUND = "device_not_found"
|
||||
DEVICE_GUID_MISMATCH = "device_guid_mismatch"
|
||||
FINGERPRINT_MISMATCH = "fingerprint_mismatch"
|
||||
TOKEN_VERSION_REVOKED = "token_version_revoked"
|
||||
DEVICE_REVOKED = "device_revoked"
|
||||
DPOP_NOT_SUPPORTED = "dpop_not_supported"
|
||||
DPOP_REPLAYED = "dpop_replayed"
|
||||
DPOP_INVALID = "dpop_invalid"
|
||||
|
||||
|
||||
class DeviceAuthFailure(Exception):
|
||||
"""Domain-level authentication error with HTTP metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code: DeviceAuthErrorCode,
|
||||
*,
|
||||
http_status: int = 401,
|
||||
retry_after: Optional[float] = None,
|
||||
detail: Optional[str] = None,
|
||||
) -> None:
|
||||
self.code = code
|
||||
self.http_status = int(http_status)
|
||||
self.retry_after = retry_after
|
||||
self.detail = detail or code.value
|
||||
super().__init__(self.detail)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"error": self.code.value}
|
||||
if self.retry_after is not None:
|
||||
payload["retry_after"] = float(self.retry_after)
|
||||
if self.detail and self.detail != self.code.value:
|
||||
payload["detail"] = self.detail
|
||||
return payload
|
||||
|
||||
|
||||
def _coerce_int(value: Any, *, minimum: Optional[int] = None) -> int:
|
||||
try:
|
||||
result = int(value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError("expected integer value") from None
|
||||
if minimum is not None and result < minimum:
|
||||
raise ValueError("integer below minimum")
|
||||
return result
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AccessTokenClaims:
|
||||
"""Validated subset of JWT claims issued to a device."""
|
||||
|
||||
subject: str
|
||||
guid: DeviceGuid
|
||||
fingerprint: DeviceFingerprint
|
||||
token_version: int
|
||||
issued_at: int
|
||||
not_before: int
|
||||
expires_at: int
|
||||
raw: Mapping[str, Any]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.token_version <= 0:
|
||||
raise ValueError("token_version must be positive")
|
||||
if self.issued_at <= 0 or self.not_before <= 0 or self.expires_at <= 0:
|
||||
raise ValueError("temporal claims must be positive integers")
|
||||
if self.expires_at <= self.not_before:
|
||||
raise ValueError("token expiration must be after not-before")
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, claims: Mapping[str, Any]) -> "AccessTokenClaims":
|
||||
subject = str(claims.get("sub") or "").strip()
|
||||
if not subject:
|
||||
raise ValueError("missing token subject")
|
||||
guid = DeviceGuid(str(claims.get("guid") or ""))
|
||||
fingerprint = DeviceFingerprint(claims.get("ssl_key_fingerprint"))
|
||||
token_version = _coerce_int(claims.get("token_version"), minimum=1)
|
||||
issued_at = _coerce_int(claims.get("iat"), minimum=1)
|
||||
not_before = _coerce_int(claims.get("nbf"), minimum=1)
|
||||
expires_at = _coerce_int(claims.get("exp"), minimum=1)
|
||||
return cls(
|
||||
subject=subject,
|
||||
guid=guid,
|
||||
fingerprint=fingerprint,
|
||||
token_version=token_version,
|
||||
issued_at=issued_at,
|
||||
not_before=not_before,
|
||||
expires_at=expires_at,
|
||||
raw=dict(claims),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceAuthContext:
|
||||
"""Domain result emitted after successful authentication."""
|
||||
|
||||
identity: DeviceIdentity
|
||||
access_token: str
|
||||
claims: AccessTokenClaims
|
||||
status: DeviceStatus
|
||||
service_context: Optional[str]
|
||||
dpop_jkt: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.access_token:
|
||||
raise ValueError("access token is required")
|
||||
service = sanitize_service_context(self.service_context)
|
||||
object.__setattr__(self, "service_context", service)
|
||||
|
||||
@property
|
||||
def is_quarantined(self) -> bool:
|
||||
return self.status is DeviceStatus.QUARANTINED
|
||||
|
||||
@property
|
||||
def allows_access(self) -> bool:
|
||||
return self.status.allows_access
|
||||
261
Data/Engine/domain/device_enrollment.py
Normal file
261
Data/Engine/domain/device_enrollment.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""Domain types describing device enrollment flows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from .device_auth import DeviceFingerprint, DeviceGuid, sanitize_service_context
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentCode",
|
||||
"EnrollmentApprovalStatus",
|
||||
"EnrollmentApproval",
|
||||
"EnrollmentRequest",
|
||||
"EnrollmentValidationError",
|
||||
"ProofChallenge",
|
||||
]
|
||||
|
||||
|
||||
def _parse_iso8601(value: Optional[str]) -> Optional[datetime]:
|
||||
if not value:
|
||||
return None
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
dt = datetime.fromisoformat(raw)
|
||||
except Exception as exc: # pragma: no cover - error path
|
||||
raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
def _require(value: Optional[str], field: str) -> str:
|
||||
text = (value or "").strip()
|
||||
if not text:
|
||||
raise ValueError(f"{field} is required")
|
||||
return text
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentValidationError(Exception):
|
||||
"""Raised when enrollment input fails validation."""
|
||||
|
||||
code: str
|
||||
http_status: int = 400
|
||||
retry_after: Optional[float] = None
|
||||
|
||||
def to_response(self) -> dict[str, object]:
|
||||
payload: dict[str, object] = {"error": self.code}
|
||||
if self.retry_after is not None:
|
||||
payload["retry_after"] = self.retry_after
|
||||
return payload
|
||||
|
||||
def __str__(self) -> str: # pragma: no cover - debug helper
|
||||
return f"{self.code} (status={self.http_status})"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentCode:
|
||||
"""Installer code metadata loaded from the persistence layer."""
|
||||
|
||||
code: str
|
||||
expires_at: datetime
|
||||
max_uses: int
|
||||
use_count: int
|
||||
used_by_guid: Optional[DeviceGuid]
|
||||
last_used_at: Optional[datetime]
|
||||
used_at: Optional[datetime]
|
||||
record_id: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.code:
|
||||
raise ValueError("code is required")
|
||||
if self.max_uses < 1:
|
||||
raise ValueError("max_uses must be >= 1")
|
||||
if self.use_count < 0:
|
||||
raise ValueError("use_count cannot be negative")
|
||||
if self.use_count > self.max_uses:
|
||||
raise ValueError("use_count cannot exceed max_uses")
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, record: Mapping[str, Any]) -> "EnrollmentCode":
|
||||
used_by = record.get("used_by_guid")
|
||||
used_by_guid = DeviceGuid(used_by) if used_by else None
|
||||
return cls(
|
||||
code=_require(record.get("code"), "code"),
|
||||
expires_at=_parse_iso8601(record.get("expires_at")) or datetime.now(tz=timezone.utc),
|
||||
max_uses=int(record.get("max_uses") or 1),
|
||||
use_count=int(record.get("use_count") or 0),
|
||||
used_by_guid=used_by_guid,
|
||||
last_used_at=_parse_iso8601(record.get("last_used_at")),
|
||||
used_at=_parse_iso8601(record.get("used_at")),
|
||||
record_id=str(record.get("id") or "") or None,
|
||||
)
|
||||
|
||||
@property
|
||||
def remaining_uses(self) -> int:
|
||||
return max(self.max_uses - self.use_count, 0)
|
||||
|
||||
@property
|
||||
def is_exhausted(self) -> bool:
|
||||
return self.remaining_uses == 0
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return self.expires_at <= datetime.now(tz=timezone.utc)
|
||||
|
||||
@property
|
||||
def identifier(self) -> Optional[str]:
|
||||
return self.record_id
|
||||
|
||||
|
||||
class EnrollmentApprovalStatus(str, Enum):
|
||||
"""Possible states for a device approval entry."""
|
||||
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
COMPLETED = "completed"
|
||||
EXPIRED = "expired"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: Optional[str]) -> "EnrollmentApprovalStatus":
|
||||
normalized = (value or "pending").strip().lower()
|
||||
try:
|
||||
return cls(normalized)
|
||||
except ValueError:
|
||||
return cls.PENDING
|
||||
|
||||
@property
|
||||
def is_terminal(self) -> bool:
|
||||
return self in {self.APPROVED, self.DENIED, self.COMPLETED, self.EXPIRED}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProofChallenge:
|
||||
"""Client/server nonce pair distributed during enrollment."""
|
||||
|
||||
client_nonce: bytes
|
||||
server_nonce: bytes
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.client_nonce or not self.server_nonce:
|
||||
raise ValueError("nonce payloads must be non-empty")
|
||||
|
||||
@classmethod
|
||||
def from_base64(cls, *, client: bytes, server: bytes) -> "ProofChallenge":
|
||||
return cls(client_nonce=client, server_nonce=server)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentRequest:
|
||||
"""Validated payload submitted by an agent during enrollment."""
|
||||
|
||||
hostname: str
|
||||
enrollment_code: str
|
||||
fingerprint: DeviceFingerprint
|
||||
proof: ProofChallenge
|
||||
service_context: Optional[str]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.hostname:
|
||||
raise ValueError("hostname is required")
|
||||
if not self.enrollment_code:
|
||||
raise ValueError("enrollment code is required")
|
||||
object.__setattr__(
|
||||
self,
|
||||
"service_context",
|
||||
sanitize_service_context(self.service_context),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_payload(
|
||||
cls,
|
||||
*,
|
||||
hostname: str,
|
||||
enrollment_code: str,
|
||||
fingerprint: str,
|
||||
client_nonce: bytes,
|
||||
server_nonce: bytes,
|
||||
service_context: Optional[str] = None,
|
||||
) -> "EnrollmentRequest":
|
||||
proof = ProofChallenge(client_nonce=client_nonce, server_nonce=server_nonce)
|
||||
return cls(
|
||||
hostname=_require(hostname, "hostname"),
|
||||
enrollment_code=_require(enrollment_code, "enrollment_code"),
|
||||
fingerprint=DeviceFingerprint(fingerprint),
|
||||
proof=proof,
|
||||
service_context=service_context,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentApproval:
|
||||
"""Pending or resolved approval tracked by operators."""
|
||||
|
||||
record_id: str
|
||||
reference: str
|
||||
status: EnrollmentApprovalStatus
|
||||
claimed_hostname: str
|
||||
claimed_fingerprint: DeviceFingerprint
|
||||
enrollment_code_id: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
client_nonce_b64: str
|
||||
server_nonce_b64: str
|
||||
agent_pubkey_der: bytes
|
||||
guid: Optional[DeviceGuid] = None
|
||||
approved_by: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.record_id:
|
||||
raise ValueError("record identifier is required")
|
||||
if not self.reference:
|
||||
raise ValueError("approval reference is required")
|
||||
if not self.claimed_hostname:
|
||||
raise ValueError("claimed hostname is required")
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, record: Mapping[str, Any]) -> "EnrollmentApproval":
|
||||
guid_raw = record.get("guid")
|
||||
approved_raw = record.get("approved_by_user_id")
|
||||
return cls(
|
||||
record_id=_require(record.get("id"), "id"),
|
||||
reference=_require(record.get("approval_reference"), "approval_reference"),
|
||||
status=EnrollmentApprovalStatus.from_string(record.get("status")),
|
||||
claimed_hostname=_require(record.get("hostname_claimed"), "hostname_claimed"),
|
||||
claimed_fingerprint=DeviceFingerprint(record.get("ssl_key_fingerprint_claimed")),
|
||||
enrollment_code_id=record.get("enrollment_code_id"),
|
||||
created_at=_parse_iso8601(record.get("created_at")) or datetime.now(tz=timezone.utc),
|
||||
updated_at=_parse_iso8601(record.get("updated_at")) or datetime.now(tz=timezone.utc),
|
||||
guid=DeviceGuid(guid_raw) if guid_raw else None,
|
||||
approved_by=(approved_raw or None),
|
||||
client_nonce_b64=_require(record.get("client_nonce"), "client_nonce"),
|
||||
server_nonce_b64=_require(record.get("server_nonce"), "server_nonce"),
|
||||
agent_pubkey_der=bytes(record.get("agent_pubkey_der") or b""),
|
||||
)
|
||||
|
||||
@property
|
||||
def is_pending(self) -> bool:
|
||||
return self.status is EnrollmentApprovalStatus.PENDING
|
||||
|
||||
@property
|
||||
def is_completed(self) -> bool:
|
||||
return self.status in {
|
||||
EnrollmentApprovalStatus.APPROVED,
|
||||
EnrollmentApprovalStatus.COMPLETED,
|
||||
}
|
||||
|
||||
@property
|
||||
def client_nonce_bytes(self) -> bytes:
|
||||
return base64.b64decode(self.client_nonce_b64.encode("ascii"), validate=True)
|
||||
|
||||
@property
|
||||
def server_nonce_bytes(self) -> bytes:
|
||||
return base64.b64decode(self.server_nonce_b64.encode("ascii"), validate=True)
|
||||
103
Data/Engine/domain/github.py
Normal file
103
Data/Engine/domain/github.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Domain types for GitHub integrations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GitHubRepoRef:
|
||||
"""Identify a GitHub repository and branch."""
|
||||
|
||||
owner: str
|
||||
name: str
|
||||
branch: str
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
return f"{self.owner}/{self.name}".strip("/")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, owner_repo: str, branch: str) -> "GitHubRepoRef":
|
||||
owner_repo = (owner_repo or "").strip()
|
||||
if "/" not in owner_repo:
|
||||
raise ValueError("repo must be in the form owner/name")
|
||||
owner, repo = owner_repo.split("/", 1)
|
||||
return cls(owner=owner.strip(), name=repo.strip(), branch=(branch or "main").strip())
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RepoHeadSnapshot:
|
||||
"""Snapshot describing the current head of a repository branch."""
|
||||
|
||||
repository: GitHubRepoRef
|
||||
sha: Optional[str]
|
||||
cached: bool
|
||||
age_seconds: Optional[float]
|
||||
source: str
|
||||
error: Optional[str]
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
"repo": self.repository.full_name,
|
||||
"branch": self.repository.branch,
|
||||
"sha": self.sha,
|
||||
"cached": self.cached,
|
||||
"age_seconds": self.age_seconds,
|
||||
"source": self.source,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GitHubRateLimit:
|
||||
"""Subset of rate limit details returned by the GitHub API."""
|
||||
|
||||
limit: Optional[int]
|
||||
remaining: Optional[int]
|
||||
reset_epoch: Optional[int]
|
||||
used: Optional[int]
|
||||
|
||||
def to_dict(self) -> Dict[str, Optional[int]]:
|
||||
return {
|
||||
"limit": self.limit,
|
||||
"remaining": self.remaining,
|
||||
"reset": self.reset_epoch,
|
||||
"used": self.used,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GitHubTokenStatus:
|
||||
"""Describe the verification result for a GitHub access token."""
|
||||
|
||||
has_token: bool
|
||||
valid: bool
|
||||
status: str
|
||||
message: str
|
||||
rate_limit: Optional[GitHubRateLimit]
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
payload: Dict[str, object] = {
|
||||
"has_token": self.has_token,
|
||||
"valid": self.valid,
|
||||
"status": self.status,
|
||||
"message": self.message,
|
||||
"error": self.error,
|
||||
}
|
||||
if self.rate_limit is not None:
|
||||
payload["rate_limit"] = self.rate_limit.to_dict()
|
||||
else:
|
||||
payload["rate_limit"] = None
|
||||
return payload
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GitHubRateLimit",
|
||||
"GitHubRepoRef",
|
||||
"GitHubTokenStatus",
|
||||
"RepoHeadSnapshot",
|
||||
]
|
||||
|
||||
7
Data/Engine/integrations/__init__.py
Normal file
7
Data/Engine/integrations/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""External system adapters for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .github.artifact_provider import GitHubArtifactProvider
|
||||
|
||||
__all__ = ["GitHubArtifactProvider"]
|
||||
25
Data/Engine/integrations/crypto/__init__.py
Normal file
25
Data/Engine/integrations/crypto/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Crypto integration helpers for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .keys import (
|
||||
base64_from_spki_der,
|
||||
fingerprint_from_base64_spki,
|
||||
fingerprint_from_spki_der,
|
||||
generate_ed25519_keypair,
|
||||
normalize_base64,
|
||||
private_key_to_pem,
|
||||
public_key_to_pem,
|
||||
spki_der_from_base64,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"base64_from_spki_der",
|
||||
"fingerprint_from_base64_spki",
|
||||
"fingerprint_from_spki_der",
|
||||
"generate_ed25519_keypair",
|
||||
"normalize_base64",
|
||||
"private_key_to_pem",
|
||||
"public_key_to_pem",
|
||||
"spki_der_from_base64",
|
||||
]
|
||||
70
Data/Engine/integrations/crypto/keys.py
Normal file
70
Data/Engine/integrations/crypto/keys.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Key utilities mirrored from the legacy crypto helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||
|
||||
__all__ = [
|
||||
"base64_from_spki_der",
|
||||
"fingerprint_from_base64_spki",
|
||||
"fingerprint_from_spki_der",
|
||||
"generate_ed25519_keypair",
|
||||
"normalize_base64",
|
||||
"private_key_to_pem",
|
||||
"public_key_to_pem",
|
||||
"spki_der_from_base64",
|
||||
]
|
||||
|
||||
|
||||
def generate_ed25519_keypair() -> Tuple[ed25519.Ed25519PrivateKey, bytes]:
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
public_key = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return private_key, public_key
|
||||
|
||||
|
||||
def normalize_base64(data: str) -> str:
|
||||
cleaned = re.sub(r"\s+", "", data or "")
|
||||
return cleaned.replace("-", "+").replace("_", "/")
|
||||
|
||||
|
||||
def spki_der_from_base64(spki_b64: str) -> bytes:
|
||||
return base64.b64decode(normalize_base64(spki_b64), validate=True)
|
||||
|
||||
|
||||
def base64_from_spki_der(spki_der: bytes) -> str:
|
||||
return base64.b64encode(spki_der).decode("ascii")
|
||||
|
||||
|
||||
def fingerprint_from_spki_der(spki_der: bytes) -> str:
|
||||
digest = hashlib.sha256(spki_der).hexdigest()
|
||||
return digest.lower()
|
||||
|
||||
|
||||
def fingerprint_from_base64_spki(spki_b64: str) -> str:
|
||||
return fingerprint_from_spki_der(spki_der_from_base64(spki_b64))
|
||||
|
||||
|
||||
def private_key_to_pem(private_key: ed25519.Ed25519PrivateKey) -> bytes:
|
||||
return private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
|
||||
def public_key_to_pem(public_spki_der: bytes) -> bytes:
|
||||
public_key = load_der_public_key(public_spki_der)
|
||||
return public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
8
Data/Engine/integrations/github/__init__.py
Normal file
8
Data/Engine/integrations/github/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""GitHub integration surface for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .artifact_provider import GitHubArtifactProvider
|
||||
|
||||
__all__ = ["GitHubArtifactProvider"]
|
||||
|
||||
275
Data/Engine/integrations/github/artifact_provider.py
Normal file
275
Data/Engine/integrations/github/artifact_provider.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""GitHub REST API integration with caching support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from Data.Engine.domain.github import GitHubRepoRef, GitHubTokenStatus, RepoHeadSnapshot, GitHubRateLimit
|
||||
|
||||
try: # pragma: no cover - optional dependency guard
|
||||
import requests
|
||||
from requests import Response
|
||||
except Exception: # pragma: no cover - fallback when requests is unavailable
|
||||
requests = None # type: ignore[assignment]
|
||||
Response = object # type: ignore[misc,assignment]
|
||||
|
||||
__all__ = ["GitHubArtifactProvider"]
|
||||
|
||||
|
||||
class GitHubArtifactProvider:
|
||||
"""Resolve repository heads and token metadata from the GitHub API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cache_file: Path,
|
||||
default_repo: str,
|
||||
default_branch: str,
|
||||
refresh_interval: int,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._cache_file = cache_file
|
||||
self._default_repo = default_repo
|
||||
self._default_branch = default_branch
|
||||
self._refresh_interval = max(30, min(refresh_interval, 3600))
|
||||
self._log = logger or logging.getLogger("borealis.engine.integrations.github")
|
||||
self._token: Optional[str] = None
|
||||
self._cache_lock = threading.Lock()
|
||||
self._cache: Dict[str, Dict[str, float | str]] = {}
|
||||
self._worker: Optional[threading.Thread] = None
|
||||
self._hydrate_cache_from_disk()
|
||||
|
||||
def set_token(self, token: Optional[str]) -> None:
|
||||
self._token = (token or "").strip() or None
|
||||
|
||||
@property
|
||||
def default_repo(self) -> str:
|
||||
return self._default_repo
|
||||
|
||||
@property
|
||||
def default_branch(self) -> str:
|
||||
return self._default_branch
|
||||
|
||||
@property
|
||||
def refresh_interval(self) -> int:
|
||||
return self._refresh_interval
|
||||
|
||||
def fetch_repo_head(
|
||||
self,
|
||||
repo: GitHubRepoRef,
|
||||
*,
|
||||
ttl_seconds: int,
|
||||
force_refresh: bool = False,
|
||||
) -> RepoHeadSnapshot:
|
||||
key = f"{repo.full_name}:{repo.branch}"
|
||||
now = time.time()
|
||||
|
||||
cached_entry = None
|
||||
with self._cache_lock:
|
||||
cached_entry = self._cache.get(key, {}).copy()
|
||||
|
||||
cached_sha = (cached_entry.get("sha") if cached_entry else None) # type: ignore[assignment]
|
||||
cached_ts = cached_entry.get("timestamp") if cached_entry else None # type: ignore[assignment]
|
||||
cached_age = None
|
||||
if isinstance(cached_ts, (int, float)):
|
||||
cached_age = max(0.0, now - float(cached_ts))
|
||||
|
||||
ttl = max(30, min(ttl_seconds, 3600))
|
||||
if cached_sha and not force_refresh and cached_age is not None and cached_age < ttl:
|
||||
return RepoHeadSnapshot(
|
||||
repository=repo,
|
||||
sha=str(cached_sha),
|
||||
cached=True,
|
||||
age_seconds=cached_age,
|
||||
source="cache",
|
||||
error=None,
|
||||
)
|
||||
|
||||
if requests is None:
|
||||
return RepoHeadSnapshot(
|
||||
repository=repo,
|
||||
sha=str(cached_sha) if cached_sha else None,
|
||||
cached=bool(cached_sha),
|
||||
age_seconds=cached_age,
|
||||
source="unavailable",
|
||||
error="requests library not available",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"User-Agent": "Borealis-Engine",
|
||||
}
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
|
||||
url = f"https://api.github.com/repos/{repo.full_name}/branches/{repo.branch}"
|
||||
error: Optional[str] = None
|
||||
sha: Optional[str] = None
|
||||
|
||||
try:
|
||||
response: Response = requests.get(url, headers=headers, timeout=20)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
sha = (data.get("commit", {}).get("sha") or "").strip() # type: ignore[assignment]
|
||||
else:
|
||||
error = f"GitHub REST API repo head lookup failed: HTTP {response.status_code} {response.text[:200]}"
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
error = f"GitHub REST API repo head lookup raised: {exc}"
|
||||
|
||||
if sha:
|
||||
payload = {"sha": sha, "timestamp": now}
|
||||
with self._cache_lock:
|
||||
self._cache[key] = payload
|
||||
self._persist_cache()
|
||||
return RepoHeadSnapshot(
|
||||
repository=repo,
|
||||
sha=sha,
|
||||
cached=False,
|
||||
age_seconds=0.0,
|
||||
source="github",
|
||||
error=None,
|
||||
)
|
||||
|
||||
if error:
|
||||
self._log.warning("repo-head-lookup failure repo=%s branch=%s error=%s", repo.full_name, repo.branch, error)
|
||||
|
||||
return RepoHeadSnapshot(
|
||||
repository=repo,
|
||||
sha=str(cached_sha) if cached_sha else None,
|
||||
cached=bool(cached_sha),
|
||||
age_seconds=cached_age,
|
||||
source="cache-stale" if cached_sha else "github",
|
||||
error=error or ("using cached value" if cached_sha else "unable to resolve repository head"),
|
||||
)
|
||||
|
||||
def refresh_default_repo_head(self, *, force: bool = False) -> RepoHeadSnapshot:
|
||||
repo = GitHubRepoRef.parse(self._default_repo, self._default_branch)
|
||||
return self.fetch_repo_head(repo, ttl_seconds=self._refresh_interval, force_refresh=force)
|
||||
|
||||
def verify_token(self, token: Optional[str]) -> GitHubTokenStatus:
|
||||
token = (token or "").strip()
|
||||
if not token:
|
||||
return GitHubTokenStatus(
|
||||
has_token=False,
|
||||
valid=False,
|
||||
status="missing",
|
||||
message="API Token Not Configured",
|
||||
rate_limit=None,
|
||||
error=None,
|
||||
)
|
||||
|
||||
if requests is None:
|
||||
return GitHubTokenStatus(
|
||||
has_token=True,
|
||||
valid=False,
|
||||
status="unknown",
|
||||
message="requests library not available",
|
||||
rate_limit=None,
|
||||
error="requests library not available",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": "Borealis-Engine",
|
||||
}
|
||||
try:
|
||||
response: Response = requests.get("https://api.github.com/rate_limit", headers=headers, timeout=10)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
message = f"GitHub token verification raised: {exc}"
|
||||
self._log.warning("github-token-verify error=%s", message)
|
||||
return GitHubTokenStatus(
|
||||
has_token=True,
|
||||
valid=False,
|
||||
status="error",
|
||||
message="API Token Invalid",
|
||||
rate_limit=None,
|
||||
error=message,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
message = f"GitHub API error (HTTP {response.status_code})"
|
||||
self._log.warning("github-token-verify http_status=%s", response.status_code)
|
||||
return GitHubTokenStatus(
|
||||
has_token=True,
|
||||
valid=False,
|
||||
status="error",
|
||||
message="API Token Invalid",
|
||||
rate_limit=None,
|
||||
error=message,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
core = (data.get("resources", {}).get("core", {}) if isinstance(data, dict) else {})
|
||||
rate_limit = GitHubRateLimit(
|
||||
limit=_safe_int(core.get("limit")),
|
||||
remaining=_safe_int(core.get("remaining")),
|
||||
reset_epoch=_safe_int(core.get("reset")),
|
||||
used=_safe_int(core.get("used")),
|
||||
)
|
||||
|
||||
message = "API Token Valid" if rate_limit.remaining is not None else "API Token Verified"
|
||||
return GitHubTokenStatus(
|
||||
has_token=True,
|
||||
valid=True,
|
||||
status="valid",
|
||||
message=message,
|
||||
rate_limit=rate_limit,
|
||||
error=None,
|
||||
)
|
||||
|
||||
def start_background_refresh(self) -> None:
|
||||
if self._worker and self._worker.is_alive(): # pragma: no cover - guard
|
||||
return
|
||||
|
||||
def _loop() -> None:
|
||||
interval = max(30, self._refresh_interval)
|
||||
while True:
|
||||
try:
|
||||
self.refresh_default_repo_head(force=True)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("default-repo-refresh failure: %s", exc)
|
||||
time.sleep(interval)
|
||||
|
||||
self._worker = threading.Thread(target=_loop, name="github-repo-refresh", daemon=True)
|
||||
self._worker.start()
|
||||
|
||||
def _hydrate_cache_from_disk(self) -> None:
|
||||
path = self._cache_file
|
||||
try:
|
||||
if not path.exists():
|
||||
return
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, dict):
|
||||
with self._cache_lock:
|
||||
self._cache = {
|
||||
key: value
|
||||
for key, value in data.items()
|
||||
if isinstance(value, dict) and "sha" in value and "timestamp" in value
|
||||
}
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("failed to load repo cache: %s", exc)
|
||||
|
||||
def _persist_cache(self) -> None:
|
||||
path = self._cache_file
|
||||
try:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = json.dumps(self._cache, ensure_ascii=False)
|
||||
tmp = path.with_suffix(".tmp")
|
||||
tmp.write_text(payload, encoding="utf-8")
|
||||
tmp.replace(path)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("failed to persist repo cache: %s", exc)
|
||||
|
||||
|
||||
def _safe_int(value: object) -> Optional[int]:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
12
Data/Engine/interfaces/__init__.py
Normal file
12
Data/Engine/interfaces/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Interface adapters (HTTP, WebSocket, etc.) for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .http import register_http_interfaces
|
||||
from .ws import create_socket_server, register_ws_interfaces
|
||||
|
||||
__all__ = [
|
||||
"register_http_interfaces",
|
||||
"create_socket_server",
|
||||
"register_ws_interfaces",
|
||||
]
|
||||
75
Data/Engine/interfaces/eventlet_compat.py
Normal file
75
Data/Engine/interfaces/eventlet_compat.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Compatibility helpers for running Socket.IO under eventlet."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ssl
|
||||
from typing import Any
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import eventlet # type: ignore
|
||||
except Exception: # pragma: no cover - optional dependency
|
||||
eventlet = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def _quiet_close(connection: Any) -> None:
|
||||
try:
|
||||
if hasattr(connection, "close"):
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _close_connection_quietly(protocol: Any) -> None:
|
||||
try:
|
||||
setattr(protocol, "close_connection", True)
|
||||
except Exception:
|
||||
pass
|
||||
conn = getattr(protocol, "socket", None) or getattr(protocol, "connection", None)
|
||||
if conn is not None:
|
||||
_quiet_close(conn)
|
||||
|
||||
|
||||
def apply_eventlet_patches() -> None:
|
||||
"""Apply Borealis-specific eventlet tweaks when the dependency is available."""
|
||||
|
||||
if eventlet is None: # pragma: no cover - guard for environments without eventlet
|
||||
return
|
||||
|
||||
eventlet.monkey_patch(thread=False)
|
||||
|
||||
try:
|
||||
from eventlet.wsgi import HttpProtocol # type: ignore
|
||||
except Exception: # pragma: no cover - import guard
|
||||
return
|
||||
|
||||
original = HttpProtocol.handle_one_request # type: ignore[attr-defined]
|
||||
|
||||
def _handle_one_request(self: Any, *args: Any, **kwargs: Any) -> Any: # type: ignore[override]
|
||||
try:
|
||||
return original(self, *args, **kwargs)
|
||||
except ssl.SSLError as exc: # type: ignore[arg-type]
|
||||
reason = getattr(exc, "reason", "") or ""
|
||||
message = " ".join(str(arg) for arg in exc.args if arg)
|
||||
lower_reason = str(reason).lower()
|
||||
lower_message = message.lower()
|
||||
if (
|
||||
"http_request" in lower_message
|
||||
or lower_reason == "http request"
|
||||
or "unknown ca" in lower_message
|
||||
or lower_reason == "unknown ca"
|
||||
or "unknown_ca" in lower_message
|
||||
):
|
||||
_close_connection_quietly(self)
|
||||
return None
|
||||
raise
|
||||
except ssl.SSLEOFError:
|
||||
_close_connection_quietly(self)
|
||||
return None
|
||||
except ConnectionAbortedError:
|
||||
_close_connection_quietly(self)
|
||||
return None
|
||||
|
||||
HttpProtocol.handle_one_request = _handle_one_request # type: ignore[assignment]
|
||||
|
||||
|
||||
__all__ = ["apply_eventlet_patches"]
|
||||
32
Data/Engine/interfaces/http/__init__.py
Normal file
32
Data/Engine/interfaces/http/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""HTTP interface registration for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
from . import admin, agents, enrollment, github, health, job_management, tokens
|
||||
|
||||
_REGISTRARS = (
|
||||
health.register,
|
||||
agents.register,
|
||||
enrollment.register,
|
||||
tokens.register,
|
||||
job_management.register,
|
||||
github.register,
|
||||
admin.register,
|
||||
)
|
||||
|
||||
|
||||
def register_http_interfaces(app: Flask, services: EngineServiceContainer) -> None:
|
||||
"""Attach HTTP blueprints to *app*.
|
||||
|
||||
The implementation is intentionally minimal for the initial scaffolding.
|
||||
"""
|
||||
|
||||
for registrar in _REGISTRARS:
|
||||
registrar(app, services)
|
||||
|
||||
|
||||
__all__ = ["register_http_interfaces"]
|
||||
23
Data/Engine/interfaces/http/admin.py
Normal file
23
Data/Engine/interfaces/http/admin.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Administrative HTTP interface placeholders for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
|
||||
blueprint = Blueprint("engine_admin", __name__, url_prefix="/api/admin")
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach administrative routes to *app*.
|
||||
|
||||
Concrete endpoints will be migrated in subsequent phases.
|
||||
"""
|
||||
|
||||
if "engine_admin" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
23
Data/Engine/interfaces/http/agents.py
Normal file
23
Data/Engine/interfaces/http/agents.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Agent HTTP interface placeholders for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
|
||||
blueprint = Blueprint("engine_agents", __name__, url_prefix="/api/agents")
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach agent management routes to *app*.
|
||||
|
||||
Implementation will be populated as services migrate from the legacy server.
|
||||
"""
|
||||
|
||||
if "engine_agents" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
111
Data/Engine/interfaces/http/enrollment.py
Normal file
111
Data/Engine/interfaces/http/enrollment.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Enrollment HTTP interface placeholders for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request
|
||||
|
||||
from Data.Engine.builders.device_enrollment import EnrollmentRequestBuilder
|
||||
from Data.Engine.services import EnrollmentValidationError
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||
|
||||
|
||||
blueprint = Blueprint("engine_enrollment", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach enrollment routes to *app*."""
|
||||
|
||||
if "engine_enrollment" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
|
||||
def enrollment_request() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
payload = request.get_json(force=True, silent=True)
|
||||
builder = EnrollmentRequestBuilder().with_payload(payload).with_service_context(
|
||||
request.headers.get(AGENT_CONTEXT_HEADER)
|
||||
)
|
||||
try:
|
||||
normalized = builder.build()
|
||||
result = services.enrollment_service.request_enrollment(
|
||||
normalized,
|
||||
remote_addr=_remote_addr(),
|
||||
)
|
||||
except EnrollmentValidationError as exc:
|
||||
response = jsonify(exc.to_response())
|
||||
response.status_code = exc.http_status
|
||||
if exc.retry_after is not None:
|
||||
response.headers["Retry-After"] = f"{int(exc.retry_after)}"
|
||||
return response
|
||||
|
||||
response_payload = {
|
||||
"status": result.status,
|
||||
"approval_reference": result.approval_reference,
|
||||
"server_nonce": result.server_nonce,
|
||||
"poll_after_ms": result.poll_after_ms,
|
||||
"server_certificate": result.server_certificate,
|
||||
"signing_key": result.signing_key,
|
||||
}
|
||||
response = jsonify(response_payload)
|
||||
response.status_code = result.http_status
|
||||
if result.retry_after is not None:
|
||||
response.headers["Retry-After"] = f"{int(result.retry_after)}"
|
||||
return response
|
||||
|
||||
|
||||
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
|
||||
def enrollment_poll() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
approval_reference = str(payload.get("approval_reference") or "").strip()
|
||||
client_nonce = str(payload.get("client_nonce") or "").strip()
|
||||
proof_sig = str(payload.get("proof_sig") or "").strip()
|
||||
|
||||
try:
|
||||
result = services.enrollment_service.poll_enrollment(
|
||||
approval_reference=approval_reference,
|
||||
client_nonce_b64=client_nonce,
|
||||
proof_signature_b64=proof_sig,
|
||||
)
|
||||
except EnrollmentValidationError as exc:
|
||||
return jsonify(exc.to_response()), exc.http_status
|
||||
|
||||
body = {"status": result.status}
|
||||
if result.poll_after_ms is not None:
|
||||
body["poll_after_ms"] = result.poll_after_ms
|
||||
if result.reason:
|
||||
body["reason"] = result.reason
|
||||
if result.detail:
|
||||
body["detail"] = result.detail
|
||||
if result.tokens:
|
||||
body.update(
|
||||
{
|
||||
"guid": result.tokens.guid.value,
|
||||
"access_token": result.tokens.access_token,
|
||||
"refresh_token": result.tokens.refresh_token,
|
||||
"token_type": result.tokens.token_type,
|
||||
"expires_in": result.tokens.expires_in,
|
||||
"server_certificate": result.server_certificate or "",
|
||||
"signing_key": result.signing_key or "",
|
||||
}
|
||||
)
|
||||
else:
|
||||
if result.server_certificate:
|
||||
body["server_certificate"] = result.server_certificate
|
||||
if result.signing_key:
|
||||
body["signing_key"] = result.signing_key
|
||||
|
||||
return jsonify(body), result.http_status
|
||||
|
||||
|
||||
def _remote_addr() -> str:
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return (request.remote_addr or "unknown").strip()
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint", "enrollment_request", "enrollment_poll"]
|
||||
60
Data/Engine/interfaces/http/github.py
Normal file
60
Data/Engine/interfaces/http/github.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""GitHub-related HTTP endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
blueprint = Blueprint("engine_github", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
if "engine_github" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@blueprint.route("/api/repo/current_hash", methods=["GET"])
|
||||
def repo_current_hash() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
github = services.github_service
|
||||
|
||||
repo = (request.args.get("repo") or "").strip() or None
|
||||
branch = (request.args.get("branch") or "").strip() or None
|
||||
refresh_flag = (request.args.get("refresh") or "").strip().lower()
|
||||
ttl_raw = request.args.get("ttl")
|
||||
try:
|
||||
ttl = int(ttl_raw) if ttl_raw else github.default_refresh_interval
|
||||
except ValueError:
|
||||
ttl = github.default_refresh_interval
|
||||
force_refresh = refresh_flag in {"1", "true", "yes", "force", "refresh"}
|
||||
|
||||
snapshot = github.get_repo_head(repo, branch, ttl_seconds=ttl, force_refresh=force_refresh)
|
||||
payload = snapshot.to_dict()
|
||||
if not snapshot.sha:
|
||||
return jsonify(payload), 503
|
||||
return jsonify(payload)
|
||||
|
||||
|
||||
@blueprint.route("/api/github/token", methods=["GET", "POST"])
|
||||
def github_token() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
github = services.github_service
|
||||
|
||||
if request.method == "GET":
|
||||
payload = github.get_token_status(force_refresh=True).to_dict()
|
||||
return jsonify(payload)
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
token = data.get("token")
|
||||
normalized = str(token).strip() if token is not None else ""
|
||||
try:
|
||||
payload = github.update_token(normalized).to_dict()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
current_app.logger.exception("failed to store GitHub token: %s", exc)
|
||||
return jsonify({"error": f"Failed to store token: {exc}"}), 500
|
||||
return jsonify(payload)
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint", "repo_current_hash", "github_token"]
|
||||
|
||||
26
Data/Engine/interfaces/http/health.py
Normal file
26
Data/Engine/interfaces/http/health.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Health check HTTP interface for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, jsonify
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
blueprint = Blueprint("engine_health", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach health-related routes to *app*."""
|
||||
|
||||
if "engine_health" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@blueprint.route("/health", methods=["GET"])
|
||||
def health() -> object:
|
||||
"""Return a basic liveness response."""
|
||||
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
108
Data/Engine/interfaces/http/job_management.py
Normal file
108
Data/Engine/interfaces/http/job_management.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""HTTP routes for Engine job management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Blueprint, Flask, jsonify, request
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
bp = Blueprint("engine_job_management", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, services: EngineServiceContainer) -> None:
|
||||
bp.services = services # type: ignore[attr-defined]
|
||||
app.register_blueprint(bp)
|
||||
|
||||
|
||||
def _services() -> EngineServiceContainer:
|
||||
svc = getattr(bp, "services", None)
|
||||
if svc is None: # pragma: no cover - guard
|
||||
raise RuntimeError("job management blueprint not initialized")
|
||||
return svc
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs", methods=["GET"])
|
||||
def list_jobs() -> Any:
|
||||
jobs = _services().scheduler_service.list_jobs()
|
||||
return jsonify({"jobs": jobs})
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs", methods=["POST"])
|
||||
def create_job() -> Any:
|
||||
payload = _json_body()
|
||||
try:
|
||||
job = _services().scheduler_service.create_job(payload)
|
||||
except ValueError as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
return jsonify({"job": job})
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs/<int:job_id>", methods=["GET"])
|
||||
def get_job(job_id: int) -> Any:
|
||||
job = _services().scheduler_service.get_job(job_id)
|
||||
if not job:
|
||||
return jsonify({"error": "job not found"}), 404
|
||||
return jsonify({"job": job})
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs/<int:job_id>", methods=["PUT"])
|
||||
def update_job(job_id: int) -> Any:
|
||||
payload = _json_body()
|
||||
try:
|
||||
job = _services().scheduler_service.update_job(job_id, payload)
|
||||
except ValueError as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
if not job:
|
||||
return jsonify({"error": "job not found"}), 404
|
||||
return jsonify({"job": job})
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs/<int:job_id>", methods=["DELETE"])
|
||||
def delete_job(job_id: int) -> Any:
|
||||
_services().scheduler_service.delete_job(job_id)
|
||||
return ("", 204)
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs/<int:job_id>/toggle", methods=["POST"])
|
||||
def toggle_job(job_id: int) -> Any:
|
||||
payload = _json_body()
|
||||
enabled = bool(payload.get("enabled", True))
|
||||
_services().scheduler_service.toggle_job(job_id, enabled)
|
||||
job = _services().scheduler_service.get_job(job_id)
|
||||
if not job:
|
||||
return jsonify({"error": "job not found"}), 404
|
||||
return jsonify({"job": job})
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs/<int:job_id>/runs", methods=["GET"])
|
||||
def list_runs(job_id: int) -> Any:
|
||||
days = request.args.get("days")
|
||||
days_int: Optional[int] = None
|
||||
if days is not None:
|
||||
try:
|
||||
days_int = max(0, int(days))
|
||||
except Exception:
|
||||
return jsonify({"error": "invalid days parameter"}), 400
|
||||
runs = _services().scheduler_service.list_runs(job_id, days=days_int)
|
||||
return jsonify({"runs": runs})
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs/<int:job_id>/runs", methods=["DELETE"])
|
||||
def purge_runs(job_id: int) -> Any:
|
||||
_services().scheduler_service.purge_runs(job_id)
|
||||
return ("", 204)
|
||||
|
||||
|
||||
def _json_body() -> dict[str, Any]:
|
||||
if not request.data:
|
||||
return {}
|
||||
try:
|
||||
data = request.get_json(force=True, silent=False) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
__all__ = ["register"]
|
||||
52
Data/Engine/interfaces/http/tokens.py
Normal file
52
Data/Engine/interfaces/http/tokens.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Token management HTTP interface for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request
|
||||
|
||||
from Data.Engine.builders.device_auth import RefreshTokenRequestBuilder
|
||||
from Data.Engine.domain.device_auth import DeviceAuthFailure
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
from Data.Engine.services import TokenRefreshError
|
||||
|
||||
blueprint = Blueprint("engine_tokens", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach token management routes to *app*."""
|
||||
|
||||
if "engine_tokens" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@blueprint.route("/api/agent/token/refresh", methods=["POST"])
|
||||
def refresh_token() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
builder = (
|
||||
RefreshTokenRequestBuilder()
|
||||
.with_payload(request.get_json(force=True, silent=True))
|
||||
.with_http_method(request.method)
|
||||
.with_htu(request.url)
|
||||
.with_dpop_proof(request.headers.get("DPoP"))
|
||||
)
|
||||
try:
|
||||
refresh_request = builder.build()
|
||||
except DeviceAuthFailure as exc:
|
||||
payload = exc.to_dict()
|
||||
return jsonify(payload), exc.http_status
|
||||
|
||||
try:
|
||||
response = services.token_service.refresh_access_token(refresh_request)
|
||||
except TokenRefreshError as exc:
|
||||
return jsonify(exc.to_dict()), exc.http_status
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"access_token": response.access_token,
|
||||
"expires_in": response.expires_in,
|
||||
"token_type": response.token_type,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint", "refresh_token"]
|
||||
47
Data/Engine/interfaces/ws/__init__.py
Normal file
47
Data/Engine/interfaces/ws/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""WebSocket interface factory for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from ...config import SocketIOSettings
|
||||
from ...services.container import EngineServiceContainer
|
||||
from .agents import register as register_agent_events
|
||||
from .job_management import register as register_job_events
|
||||
|
||||
try: # pragma: no cover - import guard
|
||||
from flask_socketio import SocketIO
|
||||
except Exception: # pragma: no cover - optional dependency
|
||||
SocketIO = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def create_socket_server(app: Flask, settings: SocketIOSettings) -> Optional[SocketIO]:
|
||||
"""Create a Socket.IO server bound to *app* if dependencies are available."""
|
||||
|
||||
if SocketIO is None:
|
||||
return None
|
||||
|
||||
cors_allowed = settings.cors_allowed_origins or ("*",)
|
||||
socketio = SocketIO(
|
||||
app,
|
||||
cors_allowed_origins=cors_allowed,
|
||||
async_mode=None,
|
||||
logger=False,
|
||||
engineio_logger=False,
|
||||
)
|
||||
return socketio
|
||||
|
||||
|
||||
def register_ws_interfaces(socketio: Any, services: EngineServiceContainer) -> None:
|
||||
"""Attach namespaces for the Engine Socket.IO server."""
|
||||
|
||||
if socketio is None: # pragma: no cover - guard
|
||||
return
|
||||
|
||||
for registrar in (register_agent_events, register_job_events):
|
||||
registrar(socketio, services)
|
||||
|
||||
|
||||
__all__ = ["create_socket_server", "register_ws_interfaces"]
|
||||
16
Data/Engine/interfaces/ws/agents/__init__.py
Normal file
16
Data/Engine/interfaces/ws/agents/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Agent WebSocket namespace wiring for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from . import events
|
||||
|
||||
|
||||
def register(socketio: Any, services) -> None:
|
||||
"""Register agent namespaces on the given Socket.IO *socketio* instance."""
|
||||
|
||||
events.register(socketio, services)
|
||||
|
||||
|
||||
__all__ = ["register"]
|
||||
261
Data/Engine/interfaces/ws/agents/events.py
Normal file
261
Data/Engine/interfaces/ws/agents/events.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""Agent WebSocket event handlers for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
from flask import request
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
try: # pragma: no cover - optional dependency guard
|
||||
from flask_socketio import emit, join_room
|
||||
except Exception: # pragma: no cover - optional dependency guard
|
||||
emit = None # type: ignore[assignment]
|
||||
join_room = None # type: ignore[assignment]
|
||||
|
||||
_AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||
|
||||
|
||||
def register(socketio: Any, services: EngineServiceContainer) -> None:
|
||||
if socketio is None: # pragma: no cover - guard
|
||||
return
|
||||
|
||||
handlers = _AgentEventHandlers(socketio, services)
|
||||
socketio.on_event("connect", handlers.on_connect)
|
||||
socketio.on_event("disconnect", handlers.on_disconnect)
|
||||
socketio.on_event("agent_screenshot_task", handlers.on_agent_screenshot_task)
|
||||
socketio.on_event("connect_agent", handlers.on_connect_agent)
|
||||
socketio.on_event("agent_heartbeat", handlers.on_agent_heartbeat)
|
||||
socketio.on_event("collector_status", handlers.on_collector_status)
|
||||
socketio.on_event("request_config", handlers.on_request_config)
|
||||
socketio.on_event("screenshot", handlers.on_screenshot)
|
||||
socketio.on_event("macro_status", handlers.on_macro_status)
|
||||
socketio.on_event("list_agent_windows", handlers.on_list_agent_windows)
|
||||
socketio.on_event("agent_window_list", handlers.on_agent_window_list)
|
||||
socketio.on_event("ansible_playbook_cancel", handlers.on_ansible_playbook_cancel)
|
||||
socketio.on_event("ansible_playbook_run", handlers.on_ansible_playbook_run)
|
||||
|
||||
|
||||
class _AgentEventHandlers:
|
||||
def __init__(self, socketio: Any, services: EngineServiceContainer) -> None:
|
||||
self._socketio = socketio
|
||||
self._services = services
|
||||
self._realtime = services.agent_realtime
|
||||
self._log = logging.getLogger("borealis.engine.ws.agents")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
def on_connect(self) -> None:
|
||||
sid = getattr(request, "sid", "<unknown>")
|
||||
remote_addr = getattr(request, "remote_addr", None)
|
||||
transport = None
|
||||
try:
|
||||
transport = request.args.get("transport") # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
transport = None
|
||||
query = self._render_query()
|
||||
headers = _summarize_socket_headers(getattr(request, "headers", {}))
|
||||
scope = _canonical_scope(getattr(request.headers, "get", lambda *_: None)(_AGENT_CONTEXT_HEADER))
|
||||
self._log.info(
|
||||
"socket-connect sid=%s ip=%s transport=%r query=%s headers=%s scope=%s",
|
||||
sid,
|
||||
remote_addr,
|
||||
transport,
|
||||
query,
|
||||
headers,
|
||||
scope or "<none>",
|
||||
)
|
||||
|
||||
def on_disconnect(self) -> None:
|
||||
sid = getattr(request, "sid", "<unknown>")
|
||||
remote_addr = getattr(request, "remote_addr", None)
|
||||
self._log.info("socket-disconnect sid=%s ip=%s", sid, remote_addr)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent coordination
|
||||
# ------------------------------------------------------------------
|
||||
def on_agent_screenshot_task(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
agent_id = payload.get("agent_id")
|
||||
node_id = payload.get("node_id")
|
||||
image = payload.get("image_base64", "")
|
||||
|
||||
if not agent_id or not node_id:
|
||||
self._log.warning("screenshot-task missing identifiers: %s", payload)
|
||||
return
|
||||
|
||||
if image:
|
||||
self._realtime.store_task_screenshot(agent_id, node_id, image)
|
||||
|
||||
try:
|
||||
self._socketio.emit("agent_screenshot_task", payload)
|
||||
except Exception as exc: # pragma: no cover - network guard
|
||||
self._log.warning("socket emit failed for agent_screenshot_task: %s", exc)
|
||||
|
||||
def on_connect_agent(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
agent_id = payload.get("agent_id")
|
||||
if not agent_id:
|
||||
return
|
||||
|
||||
service_mode = payload.get("service_mode")
|
||||
record = self._realtime.register_connection(agent_id, service_mode)
|
||||
|
||||
if join_room is not None: # pragma: no branch - optional dependency guard
|
||||
try:
|
||||
join_room(agent_id)
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.debug("join_room failed for %s: %s", agent_id, exc)
|
||||
|
||||
self._log.info(
|
||||
"agent-connected agent_id=%s mode=%s status=%s",
|
||||
agent_id,
|
||||
record.service_mode,
|
||||
record.status,
|
||||
)
|
||||
|
||||
def on_agent_heartbeat(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
record = self._realtime.heartbeat(payload)
|
||||
if record:
|
||||
self._log.debug(
|
||||
"agent-heartbeat agent_id=%s host=%s mode=%s", record.agent_id, record.hostname, record.service_mode
|
||||
)
|
||||
|
||||
def on_collector_status(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
self._realtime.collector_status(payload)
|
||||
|
||||
def on_request_config(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
agent_id = payload.get("agent_id")
|
||||
if not agent_id:
|
||||
return
|
||||
config = self._realtime.get_agent_config(agent_id)
|
||||
if config and emit is not None:
|
||||
try:
|
||||
emit("agent_config", {**config, "agent_id": agent_id})
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.debug("emit(agent_config) failed for %s: %s", agent_id, exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Media + relay events
|
||||
# ------------------------------------------------------------------
|
||||
def on_screenshot(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
agent_id = payload.get("agent_id")
|
||||
image = payload.get("image_base64")
|
||||
if agent_id and image:
|
||||
self._realtime.store_agent_screenshot(agent_id, image)
|
||||
try:
|
||||
self._socketio.emit("new_screenshot", {"agent_id": agent_id, "image_base64": image})
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.warning("socket emit failed for new_screenshot: %s", exc)
|
||||
|
||||
def on_macro_status(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
agent_id = payload.get("agent_id")
|
||||
node_id = payload.get("node_id")
|
||||
success = payload.get("success")
|
||||
message = payload.get("message")
|
||||
self._log.info(
|
||||
"macro-status agent=%s node=%s success=%s message=%s",
|
||||
agent_id,
|
||||
node_id,
|
||||
success,
|
||||
message,
|
||||
)
|
||||
try:
|
||||
self._socketio.emit("macro_status", payload)
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.warning("socket emit failed for macro_status: %s", exc)
|
||||
|
||||
def on_list_agent_windows(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
try:
|
||||
self._socketio.emit("list_agent_windows", payload)
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.warning("socket emit failed for list_agent_windows: %s", exc)
|
||||
|
||||
def on_agent_window_list(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
try:
|
||||
self._socketio.emit("agent_window_list", payload)
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.warning("socket emit failed for agent_window_list: %s", exc)
|
||||
|
||||
def on_ansible_playbook_cancel(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
try:
|
||||
self._socketio.emit("ansible_playbook_cancel", data or {})
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.warning("socket emit failed for ansible_playbook_cancel: %s", exc)
|
||||
|
||||
def on_ansible_playbook_run(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
try:
|
||||
self._socketio.emit("ansible_playbook_run", data or {})
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.warning("socket emit failed for ansible_playbook_run: %s", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _render_query(self) -> str:
|
||||
try:
|
||||
pairs = [f"{k}={v}" for k, v in request.args.items()] # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return "<unavailable>"
|
||||
return "&".join(pairs) if pairs else "<none>"
|
||||
|
||||
|
||||
def _canonical_scope(raw: Optional[str]) -> Optional[str]:
|
||||
if not raw:
|
||||
return None
|
||||
value = "".join(ch for ch in str(raw) if ch.isalnum() or ch in ("_", "-"))
|
||||
if not value:
|
||||
return None
|
||||
return value.upper()
|
||||
|
||||
|
||||
def _mask_value(value: str, *, prefix: int = 4, suffix: int = 4) -> str:
|
||||
try:
|
||||
if not value:
|
||||
return ""
|
||||
stripped = value.strip()
|
||||
if len(stripped) <= prefix + suffix:
|
||||
return "*" * len(stripped)
|
||||
return f"{stripped[:prefix]}***{stripped[-suffix:]}"
|
||||
except Exception:
|
||||
return "***"
|
||||
|
||||
|
||||
def _summarize_socket_headers(headers: Any) -> str:
|
||||
try:
|
||||
items: Iterable[tuple[str, Any]]
|
||||
if isinstance(headers, dict):
|
||||
items = headers.items()
|
||||
else:
|
||||
items = getattr(headers, "items", lambda: [])()
|
||||
except Exception:
|
||||
items = []
|
||||
|
||||
rendered = []
|
||||
for key, value in items:
|
||||
lowered = str(key).lower()
|
||||
display = value
|
||||
if lowered == "authorization":
|
||||
token = str(value or "")
|
||||
if token.lower().startswith("bearer "):
|
||||
display = f"Bearer {_mask_value(token.split(' ', 1)[1])}"
|
||||
else:
|
||||
display = _mask_value(token)
|
||||
elif lowered == "cookie":
|
||||
display = "<redacted>"
|
||||
rendered.append(f"{key}={display}")
|
||||
return ", ".join(rendered) if rendered else "<no-headers>"
|
||||
|
||||
|
||||
__all__ = ["register"]
|
||||
16
Data/Engine/interfaces/ws/job_management/__init__.py
Normal file
16
Data/Engine/interfaces/ws/job_management/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Job management WebSocket namespace wiring for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from . import events
|
||||
|
||||
|
||||
def register(socketio: Any, services) -> None:
|
||||
"""Register job management namespaces on the given Socket.IO *socketio*."""
|
||||
|
||||
events.register(socketio, services)
|
||||
|
||||
|
||||
__all__ = ["register"]
|
||||
38
Data/Engine/interfaces/ws/job_management/events.py
Normal file
38
Data/Engine/interfaces/ws/job_management/events.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Job management WebSocket event handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
|
||||
def register(socketio: Any, services: EngineServiceContainer) -> None:
|
||||
if socketio is None: # pragma: no cover - guard
|
||||
return
|
||||
|
||||
handlers = _JobEventHandlers(socketio, services)
|
||||
socketio.on_event("quick_job_result", handlers.on_quick_job_result)
|
||||
socketio.on_event("job_status_request", handlers.on_job_status_request)
|
||||
|
||||
|
||||
class _JobEventHandlers:
|
||||
def __init__(self, socketio: Any, services: EngineServiceContainer) -> None:
|
||||
self._socketio = socketio
|
||||
self._services = services
|
||||
self._log = logging.getLogger("borealis.engine.ws.jobs")
|
||||
|
||||
def on_quick_job_result(self, data: Optional[dict]) -> None:
|
||||
self._log.info("quick-job-result received; scheduler migration pending")
|
||||
# Step 10 will introduce full persistence + broadcast logic.
|
||||
|
||||
def on_job_status_request(self, _: Optional[dict]) -> None:
|
||||
jobs = self._services.scheduler_service.list_jobs()
|
||||
try:
|
||||
self._socketio.emit("job_status", {"jobs": jobs})
|
||||
except Exception:
|
||||
self._log.debug("job-status emit failed")
|
||||
|
||||
|
||||
__all__ = ["register"]
|
||||
7
Data/Engine/repositories/__init__.py
Normal file
7
Data/Engine/repositories/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Persistence adapters for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import sqlite
|
||||
|
||||
__all__ = ["sqlite"]
|
||||
47
Data/Engine/repositories/sqlite/__init__.py
Normal file
47
Data/Engine/repositories/sqlite/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""SQLite persistence helpers for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .connection import (
|
||||
SQLiteConnectionFactory,
|
||||
configure_connection,
|
||||
connect,
|
||||
connection_factory,
|
||||
connection_scope,
|
||||
)
|
||||
from .migrations import apply_all
|
||||
|
||||
__all__ = [
|
||||
"SQLiteConnectionFactory",
|
||||
"configure_connection",
|
||||
"connect",
|
||||
"connection_factory",
|
||||
"connection_scope",
|
||||
"apply_all",
|
||||
]
|
||||
|
||||
try: # pragma: no cover - optional dependency shim
|
||||
from .device_repository import SQLiteDeviceRepository
|
||||
from .enrollment_repository import SQLiteEnrollmentRepository
|
||||
from .github_repository import SQLiteGitHubRepository
|
||||
from .job_repository import SQLiteJobRepository
|
||||
from .token_repository import SQLiteRefreshTokenRepository
|
||||
except ModuleNotFoundError as exc: # pragma: no cover - triggered when auth deps missing
|
||||
def _missing_repo(*_args: object, **_kwargs: object) -> None:
|
||||
raise ModuleNotFoundError(
|
||||
"Engine SQLite repositories require optional authentication dependencies"
|
||||
) from exc
|
||||
|
||||
SQLiteDeviceRepository = _missing_repo # type: ignore[assignment]
|
||||
SQLiteEnrollmentRepository = _missing_repo # type: ignore[assignment]
|
||||
SQLiteGitHubRepository = _missing_repo # type: ignore[assignment]
|
||||
SQLiteJobRepository = _missing_repo # type: ignore[assignment]
|
||||
SQLiteRefreshTokenRepository = _missing_repo # type: ignore[assignment]
|
||||
else:
|
||||
__all__ += [
|
||||
"SQLiteDeviceRepository",
|
||||
"SQLiteRefreshTokenRepository",
|
||||
"SQLiteJobRepository",
|
||||
"SQLiteEnrollmentRepository",
|
||||
"SQLiteGitHubRepository",
|
||||
]
|
||||
67
Data/Engine/repositories/sqlite/connection.py
Normal file
67
Data/Engine/repositories/sqlite/connection.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""SQLite connection utilities for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Protocol
|
||||
|
||||
__all__ = [
|
||||
"SQLiteConnectionFactory",
|
||||
"configure_connection",
|
||||
"connect",
|
||||
"connection_factory",
|
||||
"connection_scope",
|
||||
]
|
||||
|
||||
|
||||
class SQLiteConnectionFactory(Protocol):
|
||||
"""Callable protocol for obtaining configured SQLite connections."""
|
||||
|
||||
def __call__(self) -> sqlite3.Connection:
|
||||
"""Return a new :class:`sqlite3.Connection`."""
|
||||
|
||||
|
||||
def configure_connection(conn: sqlite3.Connection) -> None:
|
||||
"""Apply the Borealis-standard pragmas to *conn*."""
|
||||
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute("PRAGMA journal_mode=WAL")
|
||||
cur.execute("PRAGMA busy_timeout=5000")
|
||||
cur.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.commit()
|
||||
except Exception:
|
||||
# Pragmas are best-effort; failing to apply them should not block startup.
|
||||
conn.rollback()
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
|
||||
def connect(path: Path, *, timeout: float = 15.0) -> sqlite3.Connection:
|
||||
"""Create a new SQLite connection to *path* with Engine pragmas applied."""
|
||||
|
||||
conn = sqlite3.connect(str(path), timeout=timeout)
|
||||
configure_connection(conn)
|
||||
return conn
|
||||
|
||||
|
||||
def connection_factory(path: Path, *, timeout: float = 15.0) -> SQLiteConnectionFactory:
|
||||
"""Return a factory that opens connections to *path* when invoked."""
|
||||
|
||||
def factory() -> sqlite3.Connection:
|
||||
return connect(path, timeout=timeout)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@contextmanager
|
||||
def connection_scope(path: Path, *, timeout: float = 15.0) -> Iterator[sqlite3.Connection]:
|
||||
"""Context manager yielding a configured connection to *path*."""
|
||||
|
||||
conn = connect(path, timeout=timeout)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
410
Data/Engine/repositories/sqlite/device_repository.py
Normal file
410
Data/Engine/repositories/sqlite/device_repository.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""SQLite-backed device repository for the Engine authentication services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import closing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from Data.Engine.domain.device_auth import (
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
DeviceIdentity,
|
||||
DeviceStatus,
|
||||
)
|
||||
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||
from Data.Engine.services.auth.device_auth_service import DeviceRecord
|
||||
|
||||
__all__ = ["SQLiteDeviceRepository"]
|
||||
|
||||
|
||||
class SQLiteDeviceRepository:
|
||||
"""Persistence adapter that reads and recovers device rows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connections = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.devices")
|
||||
|
||||
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]:
|
||||
"""Fetch a device row by GUID, normalizing legacy case variance."""
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid.value.upper(),),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
for row in rows:
|
||||
record = self._row_to_record(row)
|
||||
if record and record.identity.guid.value == guid.value:
|
||||
return record
|
||||
|
||||
# Fall back to the first row if normalization failed to match exactly.
|
||||
return self._row_to_record(rows[0])
|
||||
|
||||
def recover_missing(
|
||||
self,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
token_version: int,
|
||||
service_context: Optional[str],
|
||||
) -> Optional[DeviceRecord]:
|
||||
"""Attempt to recreate a missing device row for a valid token."""
|
||||
|
||||
now_ts = int(time.time())
|
||||
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||
base_hostname = f"RECOVERED-{guid.value[:12]}"
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
for attempt in range(6):
|
||||
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
|
||||
try:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO devices (
|
||||
guid,
|
||||
hostname,
|
||||
created_at,
|
||||
last_seen,
|
||||
ssl_key_fingerprint,
|
||||
token_version,
|
||||
status,
|
||||
key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
|
||||
""",
|
||||
(
|
||||
guid.value,
|
||||
hostname,
|
||||
now_ts,
|
||||
now_ts,
|
||||
fingerprint.value,
|
||||
max(token_version or 1, 1),
|
||||
now_iso,
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
message = str(exc).lower()
|
||||
if "hostname" in message and "unique" in message:
|
||||
continue
|
||||
self._log.warning(
|
||||
"device auth failed to recover guid=%s (context=%s): %s",
|
||||
guid.value,
|
||||
service_context or "none",
|
||||
exc,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.exception(
|
||||
"device auth unexpected error recovering guid=%s (context=%s)",
|
||||
guid.value,
|
||||
service_context or "none",
|
||||
exc_info=exc,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
else:
|
||||
conn.commit()
|
||||
break
|
||||
else:
|
||||
self._log.warning(
|
||||
"device auth could not recover guid=%s; hostname collisions persisted",
|
||||
guid.value,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(guid.value,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
self._log.warning(
|
||||
"device auth recovery committed but row missing for guid=%s",
|
||||
guid.value,
|
||||
)
|
||||
return None
|
||||
|
||||
return self._row_to_record(row)
|
||||
|
||||
def ensure_device_record(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
hostname: str,
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> DeviceRecord:
|
||||
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||
now_ts = int(time.time())
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid.value.upper(),),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if row:
|
||||
stored_fp = (row[4] or "").strip().lower()
|
||||
new_fp = fingerprint.value
|
||||
if not stored_fp:
|
||||
cur.execute(
|
||||
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
||||
(new_fp, now_iso, row[0]),
|
||||
)
|
||||
elif stored_fp != new_fp:
|
||||
token_version = self._coerce_int(row[2], default=1) + 1
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET ssl_key_fingerprint = ?,
|
||||
key_added_at = ?,
|
||||
token_version = ?,
|
||||
status = 'active'
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(new_fp, now_iso, token_version, row[0]),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = ?
|
||||
WHERE guid = ?
|
||||
AND revoked_at IS NULL
|
||||
""",
|
||||
(now_iso, row[0]),
|
||||
)
|
||||
conn.commit()
|
||||
else:
|
||||
resolved_hostname = self._resolve_hostname(cur, hostname, guid)
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO devices (
|
||||
guid,
|
||||
hostname,
|
||||
created_at,
|
||||
last_seen,
|
||||
ssl_key_fingerprint,
|
||||
token_version,
|
||||
status,
|
||||
key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
|
||||
""",
|
||||
(
|
||||
guid.value,
|
||||
resolved_hostname,
|
||||
now_ts,
|
||||
now_ts,
|
||||
fingerprint.value,
|
||||
now_iso,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid.value.upper(),),
|
||||
)
|
||||
latest = cur.fetchone()
|
||||
|
||||
if not latest:
|
||||
raise RuntimeError("device record could not be ensured")
|
||||
|
||||
record = self._row_to_record(latest)
|
||||
if record is None:
|
||||
raise RuntimeError("device record invalid after ensure")
|
||||
return record
|
||||
|
||||
def record_device_key(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
added_at: datetime,
|
||||
) -> None:
|
||||
added_iso = added_at.astimezone(timezone.utc).isoformat()
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(str(uuid.uuid4()), guid.value, fingerprint.value, added_iso),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_keys
|
||||
SET retired_at = ?
|
||||
WHERE guid = ?
|
||||
AND ssl_key_fingerprint != ?
|
||||
AND retired_at IS NULL
|
||||
""",
|
||||
(added_iso, guid.value, fingerprint.value),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def update_device_summary(
|
||||
self,
|
||||
*,
|
||||
hostname: Optional[str],
|
||||
last_seen: Optional[int] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
operating_system: Optional[str] = None,
|
||||
last_user: Optional[str] = None,
|
||||
) -> None:
|
||||
if not hostname:
|
||||
return
|
||||
|
||||
normalized_hostname = (hostname or "").strip()
|
||||
if not normalized_hostname:
|
||||
return
|
||||
|
||||
fields = []
|
||||
params = []
|
||||
|
||||
if last_seen is not None:
|
||||
try:
|
||||
fields.append("last_seen = ?")
|
||||
params.append(int(last_seen))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if agent_id:
|
||||
try:
|
||||
candidate = agent_id.strip()
|
||||
except Exception:
|
||||
candidate = agent_id
|
||||
if candidate:
|
||||
fields.append("agent_id = ?")
|
||||
params.append(candidate)
|
||||
|
||||
if operating_system:
|
||||
try:
|
||||
os_value = operating_system.strip()
|
||||
except Exception:
|
||||
os_value = operating_system
|
||||
if os_value:
|
||||
fields.append("operating_system = ?")
|
||||
params.append(os_value)
|
||||
|
||||
if last_user:
|
||||
try:
|
||||
user_value = last_user.strip()
|
||||
except Exception:
|
||||
user_value = last_user
|
||||
if user_value:
|
||||
fields.append("last_user = ?")
|
||||
params.append(user_value)
|
||||
|
||||
if not fields:
|
||||
return
|
||||
|
||||
params.append(normalized_hostname)
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
f"UPDATE devices SET {', '.join(fields)} WHERE LOWER(hostname) = LOWER(?)",
|
||||
params,
|
||||
)
|
||||
if cur.rowcount == 0 and agent_id:
|
||||
cur.execute(
|
||||
f"UPDATE devices SET {', '.join(fields)} WHERE agent_id = ?",
|
||||
params[:-1] + [agent_id],
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _row_to_record(self, row: tuple) -> Optional[DeviceRecord]:
|
||||
try:
|
||||
guid = DeviceGuid(row[0])
|
||||
fingerprint_value = (row[1] or "").strip()
|
||||
if not fingerprint_value:
|
||||
self._log.warning(
|
||||
"device row %s missing TLS fingerprint; skipping",
|
||||
row[0],
|
||||
)
|
||||
return None
|
||||
fingerprint = DeviceFingerprint(fingerprint_value)
|
||||
except Exception as exc:
|
||||
self._log.warning("invalid device row for guid=%s: %s", row[0], exc)
|
||||
return None
|
||||
|
||||
token_version_raw = row[2]
|
||||
try:
|
||||
token_version = int(token_version_raw or 0)
|
||||
except Exception:
|
||||
token_version = 0
|
||||
|
||||
status = DeviceStatus.from_string(row[3])
|
||||
identity = DeviceIdentity(guid=guid, fingerprint=fingerprint)
|
||||
|
||||
return DeviceRecord(
|
||||
identity=identity,
|
||||
token_version=max(token_version, 1),
|
||||
status=status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: object, *, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
def _resolve_hostname(self, cur: sqlite3.Cursor, hostname: str, guid: DeviceGuid) -> str:
|
||||
base = (hostname or "").strip() or guid.value
|
||||
base = base[:253]
|
||||
candidate = base
|
||||
suffix = 1
|
||||
while True:
|
||||
cur.execute(
|
||||
"SELECT guid FROM devices WHERE hostname = ?",
|
||||
(candidate,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return candidate
|
||||
existing = (row[0] or "").strip().upper()
|
||||
if existing == guid.value:
|
||||
return candidate
|
||||
candidate = f"{base}-{suffix}"
|
||||
suffix += 1
|
||||
if suffix > 50:
|
||||
return guid.value
|
||||
383
Data/Engine/repositories/sqlite/enrollment_repository.py
Normal file
383
Data/Engine/repositories/sqlite/enrollment_repository.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""SQLite-backed enrollment repository for Engine services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import closing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid
|
||||
from Data.Engine.domain.device_enrollment import (
|
||||
EnrollmentApproval,
|
||||
EnrollmentApprovalStatus,
|
||||
EnrollmentCode,
|
||||
)
|
||||
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||
|
||||
__all__ = ["SQLiteEnrollmentRepository"]
|
||||
|
||||
|
||||
class SQLiteEnrollmentRepository:
|
||||
"""Persistence adapter that manages enrollment codes and approvals."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connections = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.enrollment")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Enrollment install codes
|
||||
# ------------------------------------------------------------------
|
||||
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]:
|
||||
"""Load an enrollment install code by its public value."""
|
||||
|
||||
code_value = (code or "").strip()
|
||||
if not code_value:
|
||||
return None
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
WHERE code = ?
|
||||
""",
|
||||
(code_value,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
record = {
|
||||
"id": row[0],
|
||||
"code": row[1],
|
||||
"expires_at": row[2],
|
||||
"used_at": row[3],
|
||||
"used_by_guid": row[4],
|
||||
"max_uses": row[5],
|
||||
"use_count": row[6],
|
||||
"last_used_at": row[7],
|
||||
}
|
||||
try:
|
||||
return EnrollmentCode.from_mapping(record)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("invalid enrollment code record for code=%s: %s", code_value, exc)
|
||||
return None
|
||||
|
||||
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]:
|
||||
record_value = (record_id or "").strip()
|
||||
if not record_value:
|
||||
return None
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
WHERE id = ?
|
||||
""",
|
||||
(record_value,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
record = {
|
||||
"id": row[0],
|
||||
"code": row[1],
|
||||
"expires_at": row[2],
|
||||
"used_at": row[3],
|
||||
"used_by_guid": row[4],
|
||||
"max_uses": row[5],
|
||||
"use_count": row[6],
|
||||
"last_used_at": row[7],
|
||||
}
|
||||
|
||||
try:
|
||||
return EnrollmentCode.from_mapping(record)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc)
|
||||
return None
|
||||
|
||||
def update_install_code_usage(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
use_count_increment: int,
|
||||
last_used_at: datetime,
|
||||
used_by_guid: Optional[DeviceGuid] = None,
|
||||
mark_first_use: bool = False,
|
||||
) -> None:
|
||||
"""Increment usage counters and usage metadata for an install code."""
|
||||
|
||||
if use_count_increment <= 0:
|
||||
raise ValueError("use_count_increment must be positive")
|
||||
|
||||
last_used_iso = self._isoformat(last_used_at)
|
||||
guid_value = used_by_guid.value if used_by_guid else ""
|
||||
mark_flag = 1 if mark_first_use else 0
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE enrollment_install_codes
|
||||
SET use_count = use_count + ?,
|
||||
last_used_at = ?,
|
||||
used_by_guid = COALESCE(NULLIF(?, ''), used_by_guid),
|
||||
used_at = CASE WHEN ? = 1 AND used_at IS NULL THEN ? ELSE used_at END
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
use_count_increment,
|
||||
last_used_iso,
|
||||
guid_value,
|
||||
mark_flag,
|
||||
last_used_iso,
|
||||
record_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Device approvals
|
||||
# ------------------------------------------------------------------
|
||||
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]:
|
||||
"""Load a device approval using its operator-visible reference."""
|
||||
|
||||
ref_value = (reference or "").strip()
|
||||
if not ref_value:
|
||||
return None
|
||||
return self._fetch_device_approval("approval_reference = ?", (ref_value,))
|
||||
|
||||
def fetch_device_approval(self, record_id: str) -> Optional[EnrollmentApproval]:
|
||||
record_value = (record_id or "").strip()
|
||||
if not record_value:
|
||||
return None
|
||||
return self._fetch_device_approval("id = ?", (record_value,))
|
||||
|
||||
def fetch_pending_approval_by_fingerprint(
|
||||
self, fingerprint: DeviceFingerprint
|
||||
) -> Optional[EnrollmentApproval]:
|
||||
return self._fetch_device_approval(
|
||||
"ssl_key_fingerprint_claimed = ? AND status = 'pending'",
|
||||
(fingerprint.value,),
|
||||
)
|
||||
|
||||
def update_pending_approval(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
hostname: str,
|
||||
guid: Optional[DeviceGuid],
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
updated_at: datetime,
|
||||
) -> None:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET hostname_claimed = ?,
|
||||
guid = ?,
|
||||
enrollment_code_id = ?,
|
||||
client_nonce = ?,
|
||||
server_nonce = ?,
|
||||
agent_pubkey_der = ?,
|
||||
updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
hostname,
|
||||
guid.value if guid else None,
|
||||
enrollment_code_id,
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
self._isoformat(updated_at),
|
||||
record_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def create_device_approval(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
reference: str,
|
||||
claimed_hostname: str,
|
||||
claimed_fingerprint: DeviceFingerprint,
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
created_at: datetime,
|
||||
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> EnrollmentApproval:
|
||||
created_iso = self._isoformat(created_at)
|
||||
guid_value = guid.value if guid else None
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO device_approvals (
|
||||
id,
|
||||
approval_reference,
|
||||
guid,
|
||||
hostname_claimed,
|
||||
ssl_key_fingerprint_claimed,
|
||||
enrollment_code_id,
|
||||
status,
|
||||
created_at,
|
||||
updated_at,
|
||||
client_nonce,
|
||||
server_nonce,
|
||||
agent_pubkey_der
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
record_id,
|
||||
reference,
|
||||
guid_value,
|
||||
claimed_hostname,
|
||||
claimed_fingerprint.value,
|
||||
enrollment_code_id,
|
||||
status.value,
|
||||
created_iso,
|
||||
created_iso,
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
approval = self.fetch_device_approval(record_id)
|
||||
if approval is None:
|
||||
raise RuntimeError("failed to load device approval after insert")
|
||||
return approval
|
||||
|
||||
def update_device_approval_status(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
status: EnrollmentApprovalStatus,
|
||||
updated_at: datetime,
|
||||
approved_by: Optional[str] = None,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> None:
|
||||
"""Transition an approval to a new status."""
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET status = ?,
|
||||
updated_at = ?,
|
||||
guid = COALESCE(?, guid),
|
||||
approved_by_user_id = COALESCE(?, approved_by_user_id)
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
status.value,
|
||||
self._isoformat(updated_at),
|
||||
guid.value if guid else None,
|
||||
approved_by,
|
||||
record_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _fetch_device_approval(self, where: str, params: tuple) -> Optional[EnrollmentApproval]:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT id,
|
||||
approval_reference,
|
||||
guid,
|
||||
hostname_claimed,
|
||||
ssl_key_fingerprint_claimed,
|
||||
enrollment_code_id,
|
||||
created_at,
|
||||
updated_at,
|
||||
status,
|
||||
approved_by_user_id,
|
||||
client_nonce,
|
||||
server_nonce,
|
||||
agent_pubkey_der
|
||||
FROM device_approvals
|
||||
WHERE {where}
|
||||
""",
|
||||
params,
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
record = {
|
||||
"id": row[0],
|
||||
"approval_reference": row[1],
|
||||
"guid": row[2],
|
||||
"hostname_claimed": row[3],
|
||||
"ssl_key_fingerprint_claimed": row[4],
|
||||
"enrollment_code_id": row[5],
|
||||
"created_at": row[6],
|
||||
"updated_at": row[7],
|
||||
"status": row[8],
|
||||
"approved_by_user_id": row[9],
|
||||
"client_nonce": row[10],
|
||||
"server_nonce": row[11],
|
||||
"agent_pubkey_der": row[12],
|
||||
}
|
||||
|
||||
try:
|
||||
return EnrollmentApproval.from_mapping(record)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning(
|
||||
"invalid device approval record id=%s reference=%s: %s",
|
||||
row[0],
|
||||
row[1],
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _isoformat(value: datetime) -> str:
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc).isoformat()
|
||||
53
Data/Engine/repositories/sqlite/github_repository.py
Normal file
53
Data/Engine/repositories/sqlite/github_repository.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""SQLite-backed GitHub token persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import closing
|
||||
from typing import Optional
|
||||
|
||||
from .connection import SQLiteConnectionFactory
|
||||
|
||||
__all__ = ["SQLiteGitHubRepository"]
|
||||
|
||||
|
||||
class SQLiteGitHubRepository:
|
||||
"""Store and retrieve GitHub API tokens for the Engine."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connections = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.github")
|
||||
|
||||
def load_token(self) -> Optional[str]:
|
||||
"""Return the stored GitHub token if one exists."""
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT token FROM github_token LIMIT 1")
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
token = (row[0] or "").strip()
|
||||
return token or None
|
||||
|
||||
def store_token(self, token: Optional[str]) -> None:
|
||||
"""Persist *token*, replacing any prior value."""
|
||||
|
||||
normalized = (token or "").strip()
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("DELETE FROM github_token")
|
||||
if normalized:
|
||||
cur.execute("INSERT INTO github_token (token) VALUES (?)", (normalized,))
|
||||
conn.commit()
|
||||
|
||||
self._log.info("stored-token has_token=%s", bool(normalized))
|
||||
|
||||
355
Data/Engine/repositories/sqlite/job_repository.py
Normal file
355
Data/Engine/repositories/sqlite/job_repository.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""SQLite-backed persistence for Engine job scheduling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterable, Optional, Sequence
|
||||
|
||||
import sqlite3
|
||||
|
||||
from .connection import SQLiteConnectionFactory
|
||||
|
||||
__all__ = [
|
||||
"ScheduledJobRecord",
|
||||
"ScheduledJobRunRecord",
|
||||
"SQLiteJobRepository",
|
||||
]
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def _json_dumps(value: Any) -> str:
|
||||
try:
|
||||
return json.dumps(value or [])
|
||||
except Exception:
|
||||
return "[]"
|
||||
|
||||
|
||||
def _json_loads(value: Optional[str]) -> list[Any]:
|
||||
if not value:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(value)
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
return []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ScheduledJobRecord:
|
||||
id: int
|
||||
name: str
|
||||
components: list[dict[str, Any]]
|
||||
targets: list[str]
|
||||
schedule_type: str
|
||||
start_ts: Optional[int]
|
||||
duration_stop_enabled: bool
|
||||
expiration: Optional[str]
|
||||
execution_context: str
|
||||
credential_id: Optional[int]
|
||||
use_service_account: bool
|
||||
enabled: bool
|
||||
created_at: Optional[int]
|
||||
updated_at: Optional[int]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ScheduledJobRunRecord:
|
||||
id: int
|
||||
job_id: int
|
||||
scheduled_ts: Optional[int]
|
||||
started_ts: Optional[int]
|
||||
finished_ts: Optional[int]
|
||||
status: Optional[str]
|
||||
error: Optional[str]
|
||||
target_hostname: Optional[str]
|
||||
created_at: Optional[int]
|
||||
updated_at: Optional[int]
|
||||
|
||||
|
||||
class SQLiteJobRepository:
|
||||
"""Persistence adapter for Engine job scheduling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._factory = factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.jobs")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Job CRUD
|
||||
# ------------------------------------------------------------------
|
||||
def list_jobs(self) -> list[ScheduledJobRecord]:
|
||||
query = (
|
||||
"SELECT id, name, components_json, targets_json, schedule_type, start_ts, "
|
||||
"duration_stop_enabled, expiration, execution_context, credential_id, "
|
||||
"use_service_account, enabled, created_at, updated_at FROM scheduled_jobs "
|
||||
"ORDER BY id ASC"
|
||||
)
|
||||
return [self._row_to_job(row) for row in self._fetchall(query)]
|
||||
|
||||
def list_enabled_jobs(self) -> list[ScheduledJobRecord]:
|
||||
query = (
|
||||
"SELECT id, name, components_json, targets_json, schedule_type, start_ts, "
|
||||
"duration_stop_enabled, expiration, execution_context, credential_id, "
|
||||
"use_service_account, enabled, created_at, updated_at FROM scheduled_jobs "
|
||||
"WHERE enabled=1 ORDER BY id ASC"
|
||||
)
|
||||
return [self._row_to_job(row) for row in self._fetchall(query)]
|
||||
|
||||
def fetch_job(self, job_id: int) -> Optional[ScheduledJobRecord]:
|
||||
query = (
|
||||
"SELECT id, name, components_json, targets_json, schedule_type, start_ts, "
|
||||
"duration_stop_enabled, expiration, execution_context, credential_id, "
|
||||
"use_service_account, enabled, created_at, updated_at FROM scheduled_jobs "
|
||||
"WHERE id=?"
|
||||
)
|
||||
rows = self._fetchall(query, (job_id,))
|
||||
return self._row_to_job(rows[0]) if rows else None
|
||||
|
||||
def create_job(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
components: Sequence[dict[str, Any]],
|
||||
targets: Sequence[Any],
|
||||
schedule_type: str,
|
||||
start_ts: Optional[int],
|
||||
duration_stop_enabled: bool,
|
||||
expiration: Optional[str],
|
||||
execution_context: str,
|
||||
credential_id: Optional[int],
|
||||
use_service_account: bool,
|
||||
enabled: bool = True,
|
||||
) -> ScheduledJobRecord:
|
||||
now = _now_ts()
|
||||
payload = (
|
||||
name,
|
||||
_json_dumps(list(components)),
|
||||
_json_dumps(list(targets)),
|
||||
schedule_type,
|
||||
start_ts,
|
||||
1 if duration_stop_enabled else 0,
|
||||
expiration,
|
||||
execution_context,
|
||||
credential_id,
|
||||
1 if use_service_account else 0,
|
||||
1 if enabled else 0,
|
||||
now,
|
||||
now,
|
||||
)
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO scheduled_jobs
|
||||
(name, components_json, targets_json, schedule_type, start_ts,
|
||||
duration_stop_enabled, expiration, execution_context, credential_id,
|
||||
use_service_account, enabled, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
payload,
|
||||
)
|
||||
job_id = cur.lastrowid
|
||||
conn.commit()
|
||||
record = self.fetch_job(int(job_id))
|
||||
if record is None:
|
||||
raise RuntimeError("failed to create scheduled job")
|
||||
return record
|
||||
|
||||
def update_job(
|
||||
self,
|
||||
job_id: int,
|
||||
*,
|
||||
name: str,
|
||||
components: Sequence[dict[str, Any]],
|
||||
targets: Sequence[Any],
|
||||
schedule_type: str,
|
||||
start_ts: Optional[int],
|
||||
duration_stop_enabled: bool,
|
||||
expiration: Optional[str],
|
||||
execution_context: str,
|
||||
credential_id: Optional[int],
|
||||
use_service_account: bool,
|
||||
) -> Optional[ScheduledJobRecord]:
|
||||
now = _now_ts()
|
||||
payload = (
|
||||
name,
|
||||
_json_dumps(list(components)),
|
||||
_json_dumps(list(targets)),
|
||||
schedule_type,
|
||||
start_ts,
|
||||
1 if duration_stop_enabled else 0,
|
||||
expiration,
|
||||
execution_context,
|
||||
credential_id,
|
||||
1 if use_service_account else 0,
|
||||
now,
|
||||
job_id,
|
||||
)
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE scheduled_jobs
|
||||
SET name=?, components_json=?, targets_json=?, schedule_type=?,
|
||||
start_ts=?, duration_stop_enabled=?, expiration=?, execution_context=?,
|
||||
credential_id=?, use_service_account=?, updated_at=?
|
||||
WHERE id=?
|
||||
""",
|
||||
payload,
|
||||
)
|
||||
conn.commit()
|
||||
return self.fetch_job(job_id)
|
||||
|
||||
def set_enabled(self, job_id: int, enabled: bool) -> None:
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"UPDATE scheduled_jobs SET enabled=?, updated_at=? WHERE id=?",
|
||||
(1 if enabled else 0, _now_ts(), job_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def delete_job(self, job_id: int) -> None:
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("DELETE FROM scheduled_jobs WHERE id=?", (job_id,))
|
||||
conn.commit()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Run history
|
||||
# ------------------------------------------------------------------
|
||||
def list_runs(self, job_id: int, *, days: Optional[int] = None) -> list[ScheduledJobRunRecord]:
|
||||
params: list[Any] = [job_id]
|
||||
where = "WHERE job_id=?"
|
||||
if days is not None and days > 0:
|
||||
cutoff = _now_ts() - (days * 86400)
|
||||
where += " AND COALESCE(finished_ts, scheduled_ts, started_ts, 0) >= ?"
|
||||
params.append(cutoff)
|
||||
|
||||
query = (
|
||||
"SELECT id, job_id, scheduled_ts, started_ts, finished_ts, status, error, "
|
||||
"target_hostname, created_at, updated_at FROM scheduled_job_runs "
|
||||
f"{where} ORDER BY COALESCE(scheduled_ts, created_at, id) DESC"
|
||||
)
|
||||
return [self._row_to_run(row) for row in self._fetchall(query, tuple(params))]
|
||||
|
||||
def fetch_last_run(self, job_id: int) -> Optional[ScheduledJobRunRecord]:
|
||||
query = (
|
||||
"SELECT id, job_id, scheduled_ts, started_ts, finished_ts, status, error, "
|
||||
"target_hostname, created_at, updated_at FROM scheduled_job_runs "
|
||||
"WHERE job_id=? ORDER BY COALESCE(started_ts, scheduled_ts, created_at, id) DESC LIMIT 1"
|
||||
)
|
||||
rows = self._fetchall(query, (job_id,))
|
||||
return self._row_to_run(rows[0]) if rows else None
|
||||
|
||||
def purge_runs(self, job_id: int) -> None:
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("DELETE FROM scheduled_job_run_activity WHERE run_id IN (SELECT id FROM scheduled_job_runs WHERE job_id=?)", (job_id,))
|
||||
cur.execute("DELETE FROM scheduled_job_runs WHERE job_id=?", (job_id,))
|
||||
conn.commit()
|
||||
|
||||
def create_run(self, job_id: int, scheduled_ts: int, *, target_hostname: Optional[str] = None) -> int:
|
||||
now = _now_ts()
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO scheduled_job_runs
|
||||
(job_id, scheduled_ts, created_at, updated_at, target_hostname, status)
|
||||
VALUES (?, ?, ?, ?, ?, 'Pending')
|
||||
""",
|
||||
(job_id, scheduled_ts, now, now, target_hostname),
|
||||
)
|
||||
run_id = int(cur.lastrowid)
|
||||
conn.commit()
|
||||
return run_id
|
||||
|
||||
def mark_run_started(self, run_id: int, *, started_ts: Optional[int] = None) -> None:
|
||||
started = started_ts or _now_ts()
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"UPDATE scheduled_job_runs SET started_ts=?, status='Running', updated_at=? WHERE id=?",
|
||||
(started, _now_ts(), run_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def mark_run_finished(
|
||||
self,
|
||||
run_id: int,
|
||||
*,
|
||||
status: str,
|
||||
error: Optional[str] = None,
|
||||
finished_ts: Optional[int] = None,
|
||||
) -> None:
|
||||
finished = finished_ts or _now_ts()
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"UPDATE scheduled_job_runs SET finished_ts=?, status=?, error=?, updated_at=? WHERE id=?",
|
||||
(finished, status, error, _now_ts(), run_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
return self._factory()
|
||||
|
||||
def _fetchall(self, query: str, params: Optional[Iterable[Any]] = None) -> list[sqlite3.Row]:
|
||||
with self._connect() as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cur = conn.cursor()
|
||||
cur.execute(query, tuple(params or ()))
|
||||
rows = cur.fetchall()
|
||||
return rows
|
||||
|
||||
def _row_to_job(self, row: sqlite3.Row) -> ScheduledJobRecord:
|
||||
components = _json_loads(row["components_json"])
|
||||
targets_raw = _json_loads(row["targets_json"])
|
||||
targets = [str(t) for t in targets_raw if isinstance(t, (str, int))]
|
||||
credential_id = row["credential_id"]
|
||||
return ScheduledJobRecord(
|
||||
id=int(row["id"]),
|
||||
name=str(row["name"] or ""),
|
||||
components=[c for c in components if isinstance(c, dict)],
|
||||
targets=targets,
|
||||
schedule_type=str(row["schedule_type"] or "immediately"),
|
||||
start_ts=int(row["start_ts"]) if row["start_ts"] is not None else None,
|
||||
duration_stop_enabled=bool(row["duration_stop_enabled"]),
|
||||
expiration=str(row["expiration"]) if row["expiration"] else None,
|
||||
execution_context=str(row["execution_context"] or "system"),
|
||||
credential_id=int(credential_id) if credential_id is not None else None,
|
||||
use_service_account=bool(row["use_service_account"]),
|
||||
enabled=bool(row["enabled"]),
|
||||
created_at=int(row["created_at"]) if row["created_at"] is not None else None,
|
||||
updated_at=int(row["updated_at"]) if row["updated_at"] is not None else None,
|
||||
)
|
||||
|
||||
def _row_to_run(self, row: sqlite3.Row) -> ScheduledJobRunRecord:
|
||||
return ScheduledJobRunRecord(
|
||||
id=int(row["id"]),
|
||||
job_id=int(row["job_id"]),
|
||||
scheduled_ts=int(row["scheduled_ts"]) if row["scheduled_ts"] is not None else None,
|
||||
started_ts=int(row["started_ts"]) if row["started_ts"] is not None else None,
|
||||
finished_ts=int(row["finished_ts"]) if row["finished_ts"] is not None else None,
|
||||
status=str(row["status"]) if row["status"] else None,
|
||||
error=str(row["error"]) if row["error"] else None,
|
||||
target_hostname=str(row["target_hostname"]) if row["target_hostname"] else None,
|
||||
created_at=int(row["created_at"]) if row["created_at"] is not None else None,
|
||||
updated_at=int(row["updated_at"]) if row["updated_at"] is not None else None,
|
||||
)
|
||||
507
Data/Engine/repositories/sqlite/migrations.py
Normal file
507
Data/Engine/repositories/sqlite/migrations.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""SQLite schema migrations for the Borealis Engine.
|
||||
|
||||
This module centralises schema evolution so the Engine and its interfaces can stay
|
||||
focused on request handling. The migration functions are intentionally
|
||||
idempotent — they can run repeatedly without changing state once the schema
|
||||
matches the desired shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
|
||||
DEVICE_TABLE = "devices"
|
||||
|
||||
|
||||
def apply_all(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Run all known schema migrations against the provided sqlite3 connection.
|
||||
"""
|
||||
|
||||
_ensure_devices_table(conn)
|
||||
_ensure_device_aux_tables(conn)
|
||||
_ensure_refresh_token_table(conn)
|
||||
_ensure_install_code_table(conn)
|
||||
_ensure_device_approval_table(conn)
|
||||
_ensure_github_token_table(conn)
|
||||
_ensure_scheduled_jobs_table(conn)
|
||||
_ensure_scheduled_job_run_tables(conn)
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _ensure_devices_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
if not _table_exists(cur, DEVICE_TABLE):
|
||||
_create_devices_table(cur)
|
||||
return
|
||||
|
||||
column_info = _table_info(cur, DEVICE_TABLE)
|
||||
col_names = [c[1] for c in column_info]
|
||||
pk_cols = [c[1] for c in column_info if c[5]]
|
||||
|
||||
needs_rebuild = pk_cols != ["guid"]
|
||||
required_columns = {
|
||||
"guid": "TEXT",
|
||||
"hostname": "TEXT",
|
||||
"description": "TEXT",
|
||||
"created_at": "INTEGER",
|
||||
"agent_hash": "TEXT",
|
||||
"memory": "TEXT",
|
||||
"network": "TEXT",
|
||||
"software": "TEXT",
|
||||
"storage": "TEXT",
|
||||
"cpu": "TEXT",
|
||||
"device_type": "TEXT",
|
||||
"domain": "TEXT",
|
||||
"external_ip": "TEXT",
|
||||
"internal_ip": "TEXT",
|
||||
"last_reboot": "TEXT",
|
||||
"last_seen": "INTEGER",
|
||||
"last_user": "TEXT",
|
||||
"operating_system": "TEXT",
|
||||
"uptime": "INTEGER",
|
||||
"agent_id": "TEXT",
|
||||
"ansible_ee_ver": "TEXT",
|
||||
"connection_type": "TEXT",
|
||||
"connection_endpoint": "TEXT",
|
||||
"ssl_key_fingerprint": "TEXT",
|
||||
"token_version": "INTEGER",
|
||||
"status": "TEXT",
|
||||
"key_added_at": "TEXT",
|
||||
}
|
||||
|
||||
missing_columns = [col for col in required_columns if col not in col_names]
|
||||
if missing_columns:
|
||||
needs_rebuild = True
|
||||
|
||||
if needs_rebuild:
|
||||
_rebuild_devices_table(conn, column_info)
|
||||
else:
|
||||
_ensure_column_defaults(cur)
|
||||
|
||||
_ensure_device_indexes(cur)
|
||||
|
||||
|
||||
def _ensure_device_aux_tables(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS device_keys (
|
||||
id TEXT PRIMARY KEY,
|
||||
guid TEXT NOT NULL,
|
||||
ssl_key_fingerprint TEXT NOT NULL,
|
||||
added_at TEXT NOT NULL,
|
||||
retired_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_device_keys_guid_fingerprint
|
||||
ON device_keys(guid, ssl_key_fingerprint)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_device_keys_guid
|
||||
ON device_keys(guid)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_refresh_token_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
id TEXT PRIMARY KEY,
|
||||
guid TEXT NOT NULL,
|
||||
token_hash TEXT NOT NULL,
|
||||
dpop_jkt TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
revoked_at TEXT,
|
||||
last_used_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_guid
|
||||
ON refresh_tokens(guid)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at
|
||||
ON refresh_tokens(expires_at)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS enrollment_install_codes (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL UNIQUE,
|
||||
expires_at TEXT NOT NULL,
|
||||
created_by_user_id TEXT,
|
||||
used_at TEXT,
|
||||
used_by_guid TEXT,
|
||||
max_uses INTEGER NOT NULL DEFAULT 1,
|
||||
use_count INTEGER NOT NULL DEFAULT 0,
|
||||
last_used_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_eic_expires_at
|
||||
ON enrollment_install_codes(expires_at)
|
||||
"""
|
||||
)
|
||||
|
||||
columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")}
|
||||
if "max_uses" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes
|
||||
ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1
|
||||
"""
|
||||
)
|
||||
if "use_count" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes
|
||||
ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0
|
||||
"""
|
||||
)
|
||||
if "last_used_at" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes
|
||||
ADD COLUMN last_used_at TEXT
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS device_approvals (
|
||||
id TEXT PRIMARY KEY,
|
||||
approval_reference TEXT NOT NULL UNIQUE,
|
||||
guid TEXT,
|
||||
hostname_claimed TEXT NOT NULL,
|
||||
ssl_key_fingerprint_claimed TEXT NOT NULL,
|
||||
enrollment_code_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
client_nonce TEXT NOT NULL,
|
||||
server_nonce TEXT NOT NULL,
|
||||
agent_pubkey_der BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
approved_by_user_id TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_da_status
|
||||
ON device_approvals(status)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_da_fp_status
|
||||
ON device_approvals(ssl_key_fingerprint_claimed, status)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_github_token_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS github_token (
|
||||
token TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_scheduled_jobs_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS scheduled_jobs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
components_json TEXT NOT NULL,
|
||||
targets_json TEXT NOT NULL,
|
||||
schedule_type TEXT NOT NULL,
|
||||
start_ts INTEGER,
|
||||
duration_stop_enabled INTEGER DEFAULT 0,
|
||||
expiration TEXT,
|
||||
execution_context TEXT NOT NULL,
|
||||
credential_id INTEGER,
|
||||
use_service_account INTEGER NOT NULL DEFAULT 1,
|
||||
enabled INTEGER DEFAULT 1,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
try:
|
||||
columns = {row[1] for row in _table_info(cur, "scheduled_jobs")}
|
||||
if "credential_id" not in columns:
|
||||
cur.execute("ALTER TABLE scheduled_jobs ADD COLUMN credential_id INTEGER")
|
||||
if "use_service_account" not in columns:
|
||||
cur.execute(
|
||||
"ALTER TABLE scheduled_jobs ADD COLUMN use_service_account INTEGER NOT NULL DEFAULT 1"
|
||||
)
|
||||
except Exception:
|
||||
# Legacy deployments may fail the ALTER TABLE calls; ignore silently.
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_scheduled_job_run_tables(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS scheduled_job_runs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id INTEGER NOT NULL,
|
||||
scheduled_ts INTEGER,
|
||||
started_ts INTEGER,
|
||||
finished_ts INTEGER,
|
||||
status TEXT,
|
||||
error TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER,
|
||||
target_hostname TEXT,
|
||||
FOREIGN KEY(job_id) REFERENCES scheduled_jobs(id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
try:
|
||||
cur.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_runs_job_sched_target ON scheduled_job_runs(job_id, scheduled_ts, target_hostname)"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS scheduled_job_run_activity (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id INTEGER NOT NULL,
|
||||
activity_id INTEGER NOT NULL,
|
||||
component_kind TEXT,
|
||||
script_type TEXT,
|
||||
component_path TEXT,
|
||||
component_name TEXT,
|
||||
created_at INTEGER,
|
||||
FOREIGN KEY(run_id) REFERENCES scheduled_job_runs(id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
try:
|
||||
cur.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_run_activity_run ON scheduled_job_run_activity(run_id)"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
cur.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_run_activity_activity ON scheduled_job_run_activity(activity_id)"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _create_devices_table(cur: sqlite3.Cursor) -> None:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE devices (
|
||||
guid TEXT PRIMARY KEY,
|
||||
hostname TEXT,
|
||||
description TEXT,
|
||||
created_at INTEGER,
|
||||
agent_hash TEXT,
|
||||
memory TEXT,
|
||||
network TEXT,
|
||||
software TEXT,
|
||||
storage TEXT,
|
||||
cpu TEXT,
|
||||
device_type TEXT,
|
||||
domain TEXT,
|
||||
external_ip TEXT,
|
||||
internal_ip TEXT,
|
||||
last_reboot TEXT,
|
||||
last_seen INTEGER,
|
||||
last_user TEXT,
|
||||
operating_system TEXT,
|
||||
uptime INTEGER,
|
||||
agent_id TEXT,
|
||||
ansible_ee_ver TEXT,
|
||||
connection_type TEXT,
|
||||
connection_endpoint TEXT,
|
||||
ssl_key_fingerprint TEXT,
|
||||
token_version INTEGER DEFAULT 1,
|
||||
status TEXT DEFAULT 'active',
|
||||
key_added_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
_ensure_device_indexes(cur)
|
||||
|
||||
|
||||
def _ensure_device_indexes(cur: sqlite3.Cursor) -> None:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_devices_hostname
|
||||
ON devices(hostname)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_devices_ssl_key
|
||||
ON devices(ssl_key_fingerprint)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_devices_status
|
||||
ON devices(status)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_column_defaults(cur: sqlite3.Cursor) -> None:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET token_version = COALESCE(token_version, 1)
|
||||
WHERE token_version IS NULL
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET status = COALESCE(status, 'active')
|
||||
WHERE status IS NULL OR status = ''
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _rebuild_devices_table(conn: sqlite3.Connection, column_info: Sequence[Tuple]) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute("PRAGMA foreign_keys=OFF")
|
||||
cur.execute("BEGIN IMMEDIATE")
|
||||
|
||||
cur.execute("ALTER TABLE devices RENAME TO devices_legacy")
|
||||
_create_devices_table(cur)
|
||||
|
||||
legacy_columns = [c[1] for c in column_info]
|
||||
cur.execute(f"SELECT {', '.join(legacy_columns)} FROM devices_legacy")
|
||||
rows = cur.fetchall()
|
||||
|
||||
insert_sql = (
|
||||
"""
|
||||
INSERT OR REPLACE INTO devices (
|
||||
guid, hostname, description, created_at, agent_hash, memory,
|
||||
network, software, storage, cpu, device_type, domain, external_ip,
|
||||
internal_ip, last_reboot, last_seen, last_user, operating_system,
|
||||
uptime, agent_id, ansible_ee_ver, connection_type, connection_endpoint,
|
||||
ssl_key_fingerprint, token_version, status, key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
record = dict(zip(legacy_columns, row))
|
||||
guid = _normalized_guid(record.get("guid"))
|
||||
if not guid:
|
||||
guid = str(uuid.uuid4())
|
||||
hostname = record.get("hostname")
|
||||
created_at = record.get("created_at")
|
||||
key_added_at = record.get("key_added_at")
|
||||
if key_added_at is None:
|
||||
key_added_at = _default_key_added_at(created_at)
|
||||
|
||||
params: Tuple = (
|
||||
guid,
|
||||
hostname,
|
||||
record.get("description"),
|
||||
created_at,
|
||||
record.get("agent_hash"),
|
||||
record.get("memory"),
|
||||
record.get("network"),
|
||||
record.get("software"),
|
||||
record.get("storage"),
|
||||
record.get("cpu"),
|
||||
record.get("device_type"),
|
||||
record.get("domain"),
|
||||
record.get("external_ip"),
|
||||
record.get("internal_ip"),
|
||||
record.get("last_reboot"),
|
||||
record.get("last_seen"),
|
||||
record.get("last_user"),
|
||||
record.get("operating_system"),
|
||||
record.get("uptime"),
|
||||
record.get("agent_id"),
|
||||
record.get("ansible_ee_ver"),
|
||||
record.get("connection_type"),
|
||||
record.get("connection_endpoint"),
|
||||
record.get("ssl_key_fingerprint"),
|
||||
record.get("token_version") or 1,
|
||||
record.get("status") or "active",
|
||||
key_added_at,
|
||||
)
|
||||
cur.execute(insert_sql, params)
|
||||
|
||||
cur.execute("DROP TABLE devices_legacy")
|
||||
cur.execute("COMMIT")
|
||||
cur.execute("PRAGMA foreign_keys=ON")
|
||||
|
||||
|
||||
def _default_key_added_at(created_at: Optional[int]) -> Optional[str]:
|
||||
if created_at:
|
||||
try:
|
||||
dt = datetime.fromtimestamp(int(created_at), tz=timezone.utc)
|
||||
return dt.isoformat()
|
||||
except Exception:
|
||||
pass
|
||||
return datetime.now(tz=timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _table_exists(cur: sqlite3.Cursor, name: str) -> bool:
|
||||
cur.execute(
|
||||
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(name,),
|
||||
)
|
||||
return cur.fetchone() is not None
|
||||
|
||||
|
||||
def _table_info(cur: sqlite3.Cursor, name: str) -> List[Tuple]:
|
||||
cur.execute(f"PRAGMA table_info({name})")
|
||||
return cur.fetchall()
|
||||
|
||||
|
||||
def _normalized_guid(value: Optional[str]) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
return str(value).strip()
|
||||
|
||||
__all__ = ["apply_all"]
|
||||
153
Data/Engine/repositories/sqlite/token_repository.py
Normal file
153
Data/Engine/repositories/sqlite/token_repository.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""SQLite-backed refresh token repository for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import closing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from Data.Engine.domain.device_auth import DeviceGuid
|
||||
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||
from Data.Engine.services.auth.token_service import RefreshTokenRecord
|
||||
|
||||
__all__ = ["SQLiteRefreshTokenRepository"]
|
||||
|
||||
|
||||
class SQLiteRefreshTokenRepository:
|
||||
"""Persistence adapter for refresh token records."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connections = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.tokens")
|
||||
|
||||
def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at
|
||||
FROM refresh_tokens
|
||||
WHERE guid = ?
|
||||
AND token_hash = ?
|
||||
""",
|
||||
(guid.value, token_hash),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return self._row_to_record(row)
|
||||
|
||||
def clear_dpop_binding(self, record_id: str) -> None:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?",
|
||||
(record_id,),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def touch(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
last_used_at: datetime,
|
||||
dpop_jkt: Optional[str],
|
||||
) -> None:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET last_used_at = ?,
|
||||
dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt)
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
self._isoformat(last_used_at),
|
||||
(dpop_jkt or "").strip(),
|
||||
record_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
guid: DeviceGuid,
|
||||
token_hash: str,
|
||||
created_at: datetime,
|
||||
expires_at: Optional[datetime],
|
||||
) -> None:
|
||||
created_iso = self._isoformat(created_at)
|
||||
expires_iso = self._isoformat(expires_at) if expires_at else None
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO refresh_tokens (
|
||||
id,
|
||||
guid,
|
||||
token_hash,
|
||||
created_at,
|
||||
expires_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(record_id, guid.value, token_hash, created_iso, expires_iso),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _row_to_record(self, row: tuple) -> Optional[RefreshTokenRecord]:
|
||||
try:
|
||||
guid = DeviceGuid(row[1])
|
||||
except Exception as exc:
|
||||
self._log.warning("invalid refresh token row guid=%s: %s", row[1], exc)
|
||||
return None
|
||||
|
||||
created_at = self._parse_iso(row[4])
|
||||
expires_at = self._parse_iso(row[5])
|
||||
revoked_at = self._parse_iso(row[6])
|
||||
|
||||
if created_at is None:
|
||||
created_at = datetime.now(tz=timezone.utc)
|
||||
|
||||
return RefreshTokenRecord.from_row(
|
||||
record_id=str(row[0]),
|
||||
guid=guid,
|
||||
token_hash=str(row[2]),
|
||||
dpop_jkt=str(row[3]) if row[3] is not None else None,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
revoked_at=revoked_at,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_iso(value: Optional[str]) -> Optional[datetime]:
|
||||
if not value:
|
||||
return None
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
parsed = datetime.fromisoformat(raw)
|
||||
except Exception:
|
||||
return None
|
||||
if parsed.tzinfo is None:
|
||||
return parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed
|
||||
|
||||
@staticmethod
|
||||
def _isoformat(value: datetime) -> str:
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc).isoformat()
|
||||
11
Data/Engine/requirements.txt
Normal file
11
Data/Engine/requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
#////////// PROJECT FILE SEPARATION LINE ////////// CODE AFTER THIS LINE ARE FROM: <ProjectRoot>/Data/Engine/requirements.txt
|
||||
# Core web stack
|
||||
Flask
|
||||
flask_socketio
|
||||
flask-cors
|
||||
eventlet
|
||||
requests
|
||||
|
||||
# Auth & security
|
||||
PyJWT[crypto]
|
||||
cryptography
|
||||
139
Data/Engine/runtime.py
Normal file
139
Data/Engine/runtime.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Runtime filesystem helpers for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
__all__ = [
|
||||
"agent_certificates_path",
|
||||
"ensure_agent_certificates_dir",
|
||||
"ensure_certificates_dir",
|
||||
"ensure_runtime_dir",
|
||||
"ensure_server_certificates_dir",
|
||||
"project_root",
|
||||
"runtime_path",
|
||||
"server_certificates_path",
|
||||
]
|
||||
|
||||
|
||||
def _env_path(name: str) -> Optional[Path]:
|
||||
value = os.environ.get(name)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return Path(value).expanduser().resolve()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def project_root() -> Path:
|
||||
env = _env_path("BOREALIS_PROJECT_ROOT")
|
||||
if env:
|
||||
return env
|
||||
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "Borealis.ps1").exists() or (parent / ".git").is_dir():
|
||||
return parent
|
||||
|
||||
try:
|
||||
return current.parents[1]
|
||||
except IndexError:
|
||||
return current.parent
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def server_runtime_root() -> Path:
|
||||
env = _env_path("BOREALIS_ENGINE_ROOT") or _env_path("BOREALIS_SERVER_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = project_root() / "Engine"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def runtime_path(*parts: str) -> Path:
|
||||
return server_runtime_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_runtime_dir(*parts: str) -> Path:
|
||||
path = runtime_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def certificates_root() -> Path:
|
||||
env = _env_path("BOREALIS_CERTIFICATES_ROOT") or _env_path("BOREALIS_CERT_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = project_root() / "Certificates"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
for name in ("Server", "Agent"):
|
||||
try:
|
||||
(root / name).mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
return root
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def server_certificates_root() -> Path:
|
||||
env = _env_path("BOREALIS_SERVER_CERT_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = certificates_root() / "Server"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def agent_certificates_root() -> Path:
|
||||
env = _env_path("BOREALIS_AGENT_CERT_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = certificates_root() / "Agent"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def certificates_path(*parts: str) -> Path:
|
||||
return certificates_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_certificates_dir(*parts: str) -> Path:
|
||||
path = certificates_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def server_certificates_path(*parts: str) -> Path:
|
||||
return server_certificates_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_server_certificates_dir(*parts: str) -> Path:
|
||||
path = server_certificates_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def agent_certificates_path(*parts: str) -> Path:
|
||||
return agent_certificates_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_agent_certificates_dir(*parts: str) -> Path:
|
||||
path = agent_certificates_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
103
Data/Engine/server.py
Normal file
103
Data/Engine/server.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Flask application factory for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, request, send_from_directory
|
||||
from flask_cors import CORS
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||
|
||||
from .config import EngineSettings
|
||||
|
||||
|
||||
from .repositories.sqlite.connection import (
|
||||
SQLiteConnectionFactory,
|
||||
connection_factory as create_sqlite_connection_factory,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_static_folder(static_root: Path) -> tuple[str, str]:
|
||||
return str(static_root), ""
|
||||
|
||||
|
||||
def _register_spa_routes(app: Flask, assets_root: Path) -> None:
|
||||
"""Serve the Borealis single-page application from *assets_root*.
|
||||
|
||||
The logic mirrors the legacy server by routing any unknown front-end paths
|
||||
back to ``index.html`` so the React router can take over.
|
||||
"""
|
||||
|
||||
static_folder = assets_root
|
||||
|
||||
@app.route("/", defaults={"path": ""})
|
||||
@app.route("/<path:path>")
|
||||
def serve_frontend(path: str) -> object:
|
||||
candidate = (static_folder / path).resolve()
|
||||
if path and candidate.is_file():
|
||||
return send_from_directory(str(static_folder), path)
|
||||
try:
|
||||
return send_from_directory(str(static_folder), "index.html")
|
||||
except Exception as exc: # pragma: no cover - passthrough
|
||||
raise NotFound() from exc
|
||||
|
||||
@app.errorhandler(404)
|
||||
def spa_fallback(error: Exception) -> object: # pragma: no cover - routing
|
||||
request_path = (request.path or "").strip()
|
||||
if request_path.startswith("/api") or request_path.startswith("/socket.io"):
|
||||
return error
|
||||
if "." in Path(request_path).name:
|
||||
return error
|
||||
if request.method not in {"GET", "HEAD"}:
|
||||
return error
|
||||
try:
|
||||
return send_from_directory(str(static_folder), "index.html")
|
||||
except Exception:
|
||||
return error
|
||||
|
||||
|
||||
def create_app(
|
||||
settings: EngineSettings,
|
||||
*,
|
||||
db_factory: Optional[SQLiteConnectionFactory] = None,
|
||||
) -> Flask:
|
||||
"""Create the Flask application instance for the Engine."""
|
||||
|
||||
if db_factory is None:
|
||||
db_factory = create_sqlite_connection_factory(settings.database_path)
|
||||
|
||||
static_folder, static_url_path = _resolve_static_folder(settings.flask.static_root)
|
||||
app = Flask(
|
||||
__name__,
|
||||
static_folder=static_folder,
|
||||
static_url_path=static_url_path,
|
||||
)
|
||||
|
||||
app.config.update(
|
||||
SECRET_KEY=settings.flask.secret_key,
|
||||
JSON_SORT_KEYS=False,
|
||||
SESSION_COOKIE_HTTPONLY=True,
|
||||
SESSION_COOKIE_SECURE=not settings.debug,
|
||||
SESSION_COOKIE_SAMESITE="Lax",
|
||||
ENGINE_DATABASE_PATH=str(settings.database_path),
|
||||
ENGINE_DB_CONN_FACTORY=db_factory,
|
||||
)
|
||||
app.config.setdefault("PREFERRED_URL_SCHEME", "https")
|
||||
|
||||
# Respect upstream proxy headers when Borealis is hosted behind a TLS terminator.
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1) # type: ignore[assignment]
|
||||
|
||||
CORS(
|
||||
app,
|
||||
resources={r"/*": {"origins": list(settings.flask.cors_allowed_origins)}},
|
||||
supports_credentials=True,
|
||||
)
|
||||
|
||||
_register_spa_routes(app, Path(static_folder))
|
||||
|
||||
return app
|
||||
|
||||
|
||||
__all__ = ["create_app"]
|
||||
62
Data/Engine/services/__init__.py
Normal file
62
Data/Engine/services/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Application services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
"DeviceRecord",
|
||||
"RefreshTokenRecord",
|
||||
"TokenRefreshError",
|
||||
"TokenRefreshErrorCode",
|
||||
"TokenService",
|
||||
"EnrollmentService",
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"EnrollmentValidationError",
|
||||
"PollingResult",
|
||||
"AgentRealtimeService",
|
||||
"AgentRecord",
|
||||
"SchedulerService",
|
||||
"GitHubService",
|
||||
"GitHubTokenPayload",
|
||||
]
|
||||
|
||||
_LAZY_TARGETS: Dict[str, Tuple[str, str]] = {
|
||||
"DeviceAuthService": ("Data.Engine.services.auth.device_auth_service", "DeviceAuthService"),
|
||||
"DeviceRecord": ("Data.Engine.services.auth.device_auth_service", "DeviceRecord"),
|
||||
"RefreshTokenRecord": ("Data.Engine.services.auth.device_auth_service", "RefreshTokenRecord"),
|
||||
"TokenService": ("Data.Engine.services.auth.token_service", "TokenService"),
|
||||
"TokenRefreshError": ("Data.Engine.services.auth.token_service", "TokenRefreshError"),
|
||||
"TokenRefreshErrorCode": ("Data.Engine.services.auth.token_service", "TokenRefreshErrorCode"),
|
||||
"EnrollmentService": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentService"),
|
||||
"EnrollmentRequestResult": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentRequestResult"),
|
||||
"EnrollmentStatus": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentStatus"),
|
||||
"EnrollmentTokenBundle": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentTokenBundle"),
|
||||
"PollingResult": ("Data.Engine.services.enrollment.enrollment_service", "PollingResult"),
|
||||
"EnrollmentValidationError": ("Data.Engine.domain.device_enrollment", "EnrollmentValidationError"),
|
||||
"AgentRealtimeService": ("Data.Engine.services.realtime.agent_registry", "AgentRealtimeService"),
|
||||
"AgentRecord": ("Data.Engine.services.realtime.agent_registry", "AgentRecord"),
|
||||
"SchedulerService": ("Data.Engine.services.jobs.scheduler_service", "SchedulerService"),
|
||||
"GitHubService": ("Data.Engine.services.github.github_service", "GitHubService"),
|
||||
"GitHubTokenPayload": ("Data.Engine.services.github.github_service", "GitHubTokenPayload"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
try:
|
||||
module_name, attribute = _LAZY_TARGETS[name]
|
||||
except KeyError as exc:
|
||||
raise AttributeError(name) from exc
|
||||
|
||||
module = import_module(module_name)
|
||||
value = getattr(module, attribute)
|
||||
globals()[name] = value
|
||||
return value
|
||||
|
||||
|
||||
def __dir__() -> Any: # pragma: no cover - interactive helper
|
||||
return sorted(set(__all__))
|
||||
27
Data/Engine/services/auth/__init__.py
Normal file
27
Data/Engine/services/auth/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Authentication services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .device_auth_service import DeviceAuthService, DeviceRecord
|
||||
from .dpop import DPoPReplayError, DPoPVerificationError, DPoPValidator
|
||||
from .jwt_service import JWTService, load_service as load_jwt_service
|
||||
from .token_service import (
|
||||
RefreshTokenRecord,
|
||||
TokenRefreshError,
|
||||
TokenRefreshErrorCode,
|
||||
TokenService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
"DeviceRecord",
|
||||
"DPoPReplayError",
|
||||
"DPoPVerificationError",
|
||||
"DPoPValidator",
|
||||
"JWTService",
|
||||
"load_jwt_service",
|
||||
"RefreshTokenRecord",
|
||||
"TokenRefreshError",
|
||||
"TokenRefreshErrorCode",
|
||||
"TokenService",
|
||||
]
|
||||
237
Data/Engine/services/auth/device_auth_service.py
Normal file
237
Data/Engine/services/auth/device_auth_service.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Device authentication service copied from the legacy server stack."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Mapping, Optional, Protocol
|
||||
import logging
|
||||
|
||||
from Data.Engine.builders.device_auth import DeviceAuthRequest
|
||||
from Data.Engine.domain.device_auth import (
|
||||
AccessTokenClaims,
|
||||
DeviceAuthContext,
|
||||
DeviceAuthErrorCode,
|
||||
DeviceAuthFailure,
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
DeviceIdentity,
|
||||
DeviceStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
"DeviceRecord",
|
||||
"DPoPValidator",
|
||||
"DPoPVerificationError",
|
||||
"DPoPReplayError",
|
||||
"RateLimiter",
|
||||
"RateLimitDecision",
|
||||
"DeviceRepository",
|
||||
]
|
||||
|
||||
|
||||
class RateLimitDecision(Protocol):
|
||||
allowed: bool
|
||||
retry_after: Optional[float]
|
||||
|
||||
|
||||
class RateLimiter(Protocol):
|
||||
def check(self, key: str, max_requests: int, window_seconds: float) -> RateLimitDecision: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class JWTDecoder(Protocol):
|
||||
def decode(self, token: str) -> Mapping[str, object]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class DPoPValidator(Protocol):
|
||||
def verify(
|
||||
self,
|
||||
method: str,
|
||||
htu: str,
|
||||
proof: str,
|
||||
access_token: Optional[str] = None,
|
||||
) -> str: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class DPoPVerificationError(Exception):
|
||||
"""Raised when a DPoP proof fails validation."""
|
||||
|
||||
|
||||
class DPoPReplayError(DPoPVerificationError):
|
||||
"""Raised when a DPoP proof is replayed."""
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceRecord:
|
||||
"""Snapshot of a device record required for authentication."""
|
||||
|
||||
identity: DeviceIdentity
|
||||
token_version: int
|
||||
status: DeviceStatus
|
||||
|
||||
|
||||
class DeviceRepository(Protocol):
|
||||
"""Port that exposes the minimal device persistence operations."""
|
||||
|
||||
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def recover_missing(
|
||||
self,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
token_version: int,
|
||||
service_context: Optional[str],
|
||||
) -> Optional[DeviceRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class DeviceAuthService:
|
||||
"""Authenticate devices using access tokens, repositories, and DPoP proofs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_repository: DeviceRepository,
|
||||
jwt_service: JWTDecoder,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
rate_limiter: Optional[RateLimiter] = None,
|
||||
dpop_validator: Optional[DPoPValidator] = None,
|
||||
) -> None:
|
||||
self._repository = device_repository
|
||||
self._jwt = jwt_service
|
||||
self._log = logger or logging.getLogger("borealis.engine.auth")
|
||||
self._rate_limiter = rate_limiter
|
||||
self._dpop_validator = dpop_validator
|
||||
|
||||
def authenticate(self, request: DeviceAuthRequest, *, path: str) -> DeviceAuthContext:
|
||||
"""Authenticate an access token and return the resulting context."""
|
||||
|
||||
claims = self._decode_claims(request.access_token)
|
||||
rate_limit_key = f"fp:{claims.fingerprint.value}"
|
||||
if self._rate_limiter is not None:
|
||||
decision = self._rate_limiter.check(rate_limit_key, 60, 60.0)
|
||||
if not decision.allowed:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.RATE_LIMITED,
|
||||
http_status=429,
|
||||
retry_after=decision.retry_after,
|
||||
)
|
||||
|
||||
record = self._repository.fetch_by_guid(claims.guid)
|
||||
if record is None:
|
||||
record = self._repository.recover_missing(
|
||||
claims.guid,
|
||||
claims.fingerprint,
|
||||
claims.token_version,
|
||||
request.service_context,
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DEVICE_NOT_FOUND,
|
||||
http_status=403,
|
||||
)
|
||||
|
||||
self._validate_identity(record, claims)
|
||||
|
||||
dpop_jkt = self._validate_dpop(request, record, claims)
|
||||
|
||||
context = DeviceAuthContext(
|
||||
identity=record.identity,
|
||||
access_token=request.access_token,
|
||||
claims=claims,
|
||||
status=record.status,
|
||||
service_context=request.service_context,
|
||||
dpop_jkt=dpop_jkt,
|
||||
)
|
||||
|
||||
if context.is_quarantined:
|
||||
self._log.warning(
|
||||
"device %s is quarantined; limited access for %s",
|
||||
record.identity.guid,
|
||||
path,
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
def _decode_claims(self, token: str) -> AccessTokenClaims:
|
||||
try:
|
||||
raw_claims = self._jwt.decode(token)
|
||||
except Exception as exc: # pragma: no cover - defensive fallback
|
||||
if self._is_expired_signature(exc):
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.TOKEN_EXPIRED) from exc
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN) from exc
|
||||
|
||||
try:
|
||||
return AccessTokenClaims.from_mapping(raw_claims)
|
||||
except Exception as exc:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_CLAIMS) from exc
|
||||
|
||||
@staticmethod
|
||||
def _is_expired_signature(exc: Exception) -> bool:
|
||||
name = exc.__class__.__name__
|
||||
return name == "ExpiredSignatureError"
|
||||
|
||||
def _validate_identity(
|
||||
self,
|
||||
record: DeviceRecord,
|
||||
claims: AccessTokenClaims,
|
||||
) -> None:
|
||||
if record.identity.guid.value != claims.guid.value:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DEVICE_GUID_MISMATCH,
|
||||
http_status=403,
|
||||
)
|
||||
|
||||
if record.identity.fingerprint.value:
|
||||
if record.identity.fingerprint.value != claims.fingerprint.value:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.FINGERPRINT_MISMATCH,
|
||||
http_status=403,
|
||||
)
|
||||
|
||||
if record.token_version > claims.token_version:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.TOKEN_VERSION_REVOKED)
|
||||
|
||||
if not record.status.allows_access:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DEVICE_REVOKED,
|
||||
http_status=403,
|
||||
)
|
||||
|
||||
def _validate_dpop(
|
||||
self,
|
||||
request: DeviceAuthRequest,
|
||||
record: DeviceRecord,
|
||||
claims: AccessTokenClaims,
|
||||
) -> Optional[str]:
|
||||
if not request.dpop_proof:
|
||||
return None
|
||||
|
||||
if self._dpop_validator is None:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DPOP_NOT_SUPPORTED,
|
||||
http_status=400,
|
||||
)
|
||||
|
||||
try:
|
||||
return self._dpop_validator.verify(
|
||||
request.http_method,
|
||||
request.htu,
|
||||
request.dpop_proof,
|
||||
request.access_token,
|
||||
)
|
||||
except DPoPReplayError as exc:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DPOP_REPLAYED,
|
||||
http_status=400,
|
||||
) from exc
|
||||
except DPoPVerificationError as exc:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DPOP_INVALID,
|
||||
http_status=400,
|
||||
) from exc
|
||||
105
Data/Engine/services/auth/dpop.py
Normal file
105
Data/Engine/services/auth/dpop.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""DPoP proof validation for Engine services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict, Optional
|
||||
|
||||
import jwt
|
||||
|
||||
__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"]
|
||||
|
||||
|
||||
_DP0P_MAX_SKEW = 300.0
|
||||
|
||||
|
||||
class DPoPVerificationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DPoPReplayError(DPoPVerificationError):
|
||||
pass
|
||||
|
||||
|
||||
class DPoPValidator:
|
||||
def __init__(self) -> None:
|
||||
self._observed_jti: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def verify(
|
||||
self,
|
||||
method: str,
|
||||
htu: str,
|
||||
proof: str,
|
||||
access_token: Optional[str] = None,
|
||||
) -> str:
|
||||
if not proof:
|
||||
raise DPoPVerificationError("DPoP proof missing")
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(proof)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP header") from exc
|
||||
|
||||
jwk = header.get("jwk")
|
||||
alg = header.get("alg")
|
||||
if not jwk or not isinstance(jwk, dict):
|
||||
raise DPoPVerificationError("missing jwk in DPoP header")
|
||||
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
|
||||
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
|
||||
|
||||
try:
|
||||
key = jwt.PyJWK(jwk)
|
||||
public_key = key.key
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
|
||||
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
proof,
|
||||
public_key,
|
||||
algorithms=[alg],
|
||||
options={"require": ["htm", "htu", "jti", "iat"]},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP signature") from exc
|
||||
|
||||
htm = claims.get("htm")
|
||||
proof_htu = claims.get("htu")
|
||||
jti = claims.get("jti")
|
||||
iat = claims.get("iat")
|
||||
ath = claims.get("ath")
|
||||
|
||||
if not isinstance(htm, str) or htm.lower() != method.lower():
|
||||
raise DPoPVerificationError("DPoP htm mismatch")
|
||||
if not isinstance(proof_htu, str) or proof_htu != htu:
|
||||
raise DPoPVerificationError("DPoP htu mismatch")
|
||||
if not isinstance(jti, str):
|
||||
raise DPoPVerificationError("DPoP jti missing")
|
||||
if not isinstance(iat, (int, float)):
|
||||
raise DPoPVerificationError("DPoP iat missing")
|
||||
|
||||
now = time.time()
|
||||
if abs(now - float(iat)) > _DP0P_MAX_SKEW:
|
||||
raise DPoPVerificationError("DPoP proof outside allowed skew")
|
||||
|
||||
if ath and access_token:
|
||||
expected_ath = jwt.utils.base64url_encode(
|
||||
hashlib.sha256(access_token.encode("utf-8")).digest()
|
||||
).decode("ascii")
|
||||
if expected_ath != ath:
|
||||
raise DPoPVerificationError("DPoP ath mismatch")
|
||||
|
||||
with self._lock:
|
||||
expiry = self._observed_jti.get(jti)
|
||||
if expiry and expiry > now:
|
||||
raise DPoPReplayError("DPoP proof replay detected")
|
||||
self._observed_jti[jti] = now + _DP0P_MAX_SKEW
|
||||
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
|
||||
for key in stale:
|
||||
self._observed_jti.pop(key, None)
|
||||
|
||||
thumbprint = jwt.PyJWK(jwk).thumbprint()
|
||||
return thumbprint.decode("ascii")
|
||||
124
Data/Engine/services/auth/jwt_service.py
Normal file
124
Data/Engine/services/auth/jwt_service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""JWT issuance utilities for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from Data.Engine.runtime import ensure_runtime_dir, runtime_path
|
||||
|
||||
__all__ = ["JWTService", "load_service"]
|
||||
|
||||
|
||||
_KEY_DIR = runtime_path("auth_keys")
|
||||
_KEY_FILE = _KEY_DIR / "engine-jwt-ed25519.key"
|
||||
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-jwt-ed25519.key"
|
||||
|
||||
|
||||
class JWTService:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str) -> None:
|
||||
self._private_key = private_key
|
||||
self._public_key = private_key.public_key()
|
||||
self._key_id = key_id
|
||||
|
||||
@property
|
||||
def key_id(self) -> str:
|
||||
return self._key_id
|
||||
|
||||
def issue_access_token(
|
||||
self,
|
||||
guid: str,
|
||||
ssl_key_fingerprint: str,
|
||||
token_version: int,
|
||||
expires_in: int = 900,
|
||||
extra_claims: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
now = int(time.time())
|
||||
payload: Dict[str, Any] = {
|
||||
"sub": f"device:{guid}",
|
||||
"guid": guid,
|
||||
"ssl_key_fingerprint": ssl_key_fingerprint,
|
||||
"token_version": int(token_version),
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": now + int(expires_in),
|
||||
}
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
self._private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
),
|
||||
algorithm="EdDSA",
|
||||
headers={"kid": self._key_id},
|
||||
)
|
||||
return token
|
||||
|
||||
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
|
||||
options = {"require": ["exp", "iat", "sub"]}
|
||||
public_pem = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=["EdDSA"],
|
||||
audience=audience,
|
||||
options=options,
|
||||
)
|
||||
|
||||
def public_jwk(self) -> Dict[str, Any]:
|
||||
public_bytes = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
|
||||
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
|
||||
|
||||
|
||||
def load_service() -> JWTService:
|
||||
private_key = _load_or_create_private_key()
|
||||
public_bytes = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
|
||||
return JWTService(private_key, key_id)
|
||||
|
||||
|
||||
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
|
||||
ensure_runtime_dir("auth_keys")
|
||||
|
||||
if _KEY_FILE.exists():
|
||||
with _KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
if _LEGACY_KEY_FILE.exists():
|
||||
with _LEGACY_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
_KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with _KEY_FILE.open("wb") as fh:
|
||||
fh.write(pem)
|
||||
try:
|
||||
if hasattr(_KEY_FILE, "chmod"):
|
||||
_KEY_FILE.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
return private_key
|
||||
190
Data/Engine/services/auth/token_service.py
Normal file
190
Data/Engine/services/auth/token_service.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Token refresh service extracted from the legacy blueprint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Protocol
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from Data.Engine.builders.device_auth import RefreshTokenRequest
|
||||
from Data.Engine.domain.device_auth import DeviceGuid
|
||||
|
||||
from .device_auth_service import (
|
||||
DeviceRecord,
|
||||
DeviceRepository,
|
||||
DPoPReplayError,
|
||||
DPoPVerificationError,
|
||||
DPoPValidator,
|
||||
)
|
||||
|
||||
__all__ = ["RefreshTokenRecord", "TokenService", "TokenRefreshError", "TokenRefreshErrorCode"]
|
||||
|
||||
|
||||
class JWTIssuer(Protocol):
|
||||
def issue_access_token(self, guid: str, fingerprint: str, token_version: int) -> str: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class TokenRefreshErrorCode(str):
|
||||
INVALID_REFRESH_TOKEN = "invalid_refresh_token"
|
||||
REFRESH_TOKEN_REVOKED = "refresh_token_revoked"
|
||||
REFRESH_TOKEN_EXPIRED = "refresh_token_expired"
|
||||
DEVICE_NOT_FOUND = "device_not_found"
|
||||
DEVICE_REVOKED = "device_revoked"
|
||||
DPOP_REPLAYED = "dpop_replayed"
|
||||
DPOP_INVALID = "dpop_invalid"
|
||||
|
||||
|
||||
class TokenRefreshError(Exception):
|
||||
def __init__(self, code: str, *, http_status: int = 400) -> None:
|
||||
self.code = code
|
||||
self.http_status = http_status
|
||||
super().__init__(code)
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
return {"error": self.code}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RefreshTokenRecord:
|
||||
record_id: str
|
||||
guid: DeviceGuid
|
||||
token_hash: str
|
||||
dpop_jkt: Optional[str]
|
||||
created_at: datetime
|
||||
expires_at: Optional[datetime]
|
||||
revoked_at: Optional[datetime]
|
||||
|
||||
@classmethod
|
||||
def from_row(
|
||||
cls,
|
||||
*,
|
||||
record_id: str,
|
||||
guid: DeviceGuid,
|
||||
token_hash: str,
|
||||
dpop_jkt: Optional[str],
|
||||
created_at: datetime,
|
||||
expires_at: Optional[datetime],
|
||||
revoked_at: Optional[datetime],
|
||||
) -> "RefreshTokenRecord":
|
||||
return cls(
|
||||
record_id=record_id,
|
||||
guid=guid,
|
||||
token_hash=token_hash,
|
||||
dpop_jkt=dpop_jkt,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
revoked_at=revoked_at,
|
||||
)
|
||||
|
||||
|
||||
class RefreshTokenRepository(Protocol):
|
||||
def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def clear_dpop_binding(self, record_id: str) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def touch(self, record_id: str, *, last_used_at: datetime, dpop_jkt: Optional[str]) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AccessTokenResponse:
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
refresh_token_repository: RefreshTokenRepository,
|
||||
device_repository: DeviceRepository,
|
||||
jwt_service: JWTIssuer,
|
||||
dpop_validator: Optional[DPoPValidator] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._refresh_tokens = refresh_token_repository
|
||||
self._devices = device_repository
|
||||
self._jwt = jwt_service
|
||||
self._dpop_validator = dpop_validator
|
||||
self._log = logger or logging.getLogger("borealis.engine.auth")
|
||||
|
||||
def refresh_access_token(
|
||||
self,
|
||||
request: RefreshTokenRequest,
|
||||
) -> AccessTokenResponse:
|
||||
record = self._refresh_tokens.fetch(
|
||||
request.guid,
|
||||
self._hash_token(request.refresh_token),
|
||||
)
|
||||
if record is None:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.INVALID_REFRESH_TOKEN, http_status=401)
|
||||
|
||||
if record.guid.value != request.guid.value:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.INVALID_REFRESH_TOKEN, http_status=401)
|
||||
|
||||
if record.revoked_at is not None:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.REFRESH_TOKEN_REVOKED, http_status=401)
|
||||
|
||||
if record.expires_at is not None and record.expires_at <= self._now():
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.REFRESH_TOKEN_EXPIRED, http_status=401)
|
||||
|
||||
device = self._devices.fetch_by_guid(request.guid)
|
||||
if device is None:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DEVICE_NOT_FOUND, http_status=404)
|
||||
|
||||
if not device.status.allows_access:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DEVICE_REVOKED, http_status=403)
|
||||
|
||||
dpop_jkt = record.dpop_jkt or ""
|
||||
if request.dpop_proof:
|
||||
if self._dpop_validator is None:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DPOP_INVALID)
|
||||
try:
|
||||
dpop_jkt = self._dpop_validator.verify(
|
||||
request.http_method,
|
||||
request.htu,
|
||||
request.dpop_proof,
|
||||
None,
|
||||
)
|
||||
except DPoPReplayError as exc:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DPOP_REPLAYED) from exc
|
||||
except DPoPVerificationError as exc:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DPOP_INVALID) from exc
|
||||
elif record.dpop_jkt:
|
||||
self._log.warning(
|
||||
"Clearing stored DPoP binding for guid=%s due to missing proof",
|
||||
request.guid.value,
|
||||
)
|
||||
self._refresh_tokens.clear_dpop_binding(record.record_id)
|
||||
|
||||
access_token = self._jwt.issue_access_token(
|
||||
request.guid.value,
|
||||
device.identity.fingerprint.value,
|
||||
max(device.token_version, 1),
|
||||
)
|
||||
|
||||
self._refresh_tokens.touch(
|
||||
record.record_id,
|
||||
last_used_at=self._now(),
|
||||
dpop_jkt=dpop_jkt or None,
|
||||
)
|
||||
|
||||
return AccessTokenResponse(
|
||||
access_token=access_token,
|
||||
expires_in=900,
|
||||
token_type="Bearer",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
158
Data/Engine/services/container.py
Normal file
158
Data/Engine/services/container.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Service container assembly for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from Data.Engine.config import EngineSettings
|
||||
from Data.Engine.integrations.github import GitHubArtifactProvider
|
||||
from Data.Engine.repositories.sqlite import (
|
||||
SQLiteConnectionFactory,
|
||||
SQLiteDeviceRepository,
|
||||
SQLiteEnrollmentRepository,
|
||||
SQLiteGitHubRepository,
|
||||
SQLiteJobRepository,
|
||||
SQLiteRefreshTokenRepository,
|
||||
)
|
||||
from Data.Engine.services.auth import (
|
||||
DeviceAuthService,
|
||||
DPoPValidator,
|
||||
JWTService,
|
||||
TokenService,
|
||||
load_jwt_service,
|
||||
)
|
||||
from Data.Engine.services.crypto.signing import ScriptSigner, load_signer
|
||||
from Data.Engine.services.enrollment import EnrollmentService
|
||||
from Data.Engine.services.enrollment.nonce_cache import NonceCache
|
||||
from Data.Engine.services.github import GitHubService
|
||||
from Data.Engine.services.jobs import SchedulerService
|
||||
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
|
||||
from Data.Engine.services.realtime import AgentRealtimeService
|
||||
|
||||
__all__ = ["EngineServiceContainer", "build_service_container"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EngineServiceContainer:
|
||||
device_auth: DeviceAuthService
|
||||
token_service: TokenService
|
||||
enrollment_service: EnrollmentService
|
||||
jwt_service: JWTService
|
||||
dpop_validator: DPoPValidator
|
||||
agent_realtime: AgentRealtimeService
|
||||
scheduler_service: SchedulerService
|
||||
github_service: GitHubService
|
||||
|
||||
|
||||
def build_service_container(
|
||||
settings: EngineSettings,
|
||||
*,
|
||||
db_factory: SQLiteConnectionFactory,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> EngineServiceContainer:
|
||||
log = logger or logging.getLogger("borealis.engine.services")
|
||||
|
||||
device_repo = SQLiteDeviceRepository(db_factory, logger=log.getChild("devices"))
|
||||
token_repo = SQLiteRefreshTokenRepository(db_factory, logger=log.getChild("tokens"))
|
||||
enrollment_repo = SQLiteEnrollmentRepository(db_factory, logger=log.getChild("enrollment"))
|
||||
job_repo = SQLiteJobRepository(db_factory, logger=log.getChild("jobs"))
|
||||
github_repo = SQLiteGitHubRepository(db_factory, logger=log.getChild("github_repo"))
|
||||
|
||||
jwt_service = load_jwt_service()
|
||||
dpop_validator = DPoPValidator()
|
||||
rate_limiter = SlidingWindowRateLimiter()
|
||||
|
||||
token_service = TokenService(
|
||||
refresh_token_repository=token_repo,
|
||||
device_repository=device_repo,
|
||||
jwt_service=jwt_service,
|
||||
dpop_validator=dpop_validator,
|
||||
logger=log.getChild("token_service"),
|
||||
)
|
||||
|
||||
enrollment_service = EnrollmentService(
|
||||
device_repository=device_repo,
|
||||
enrollment_repository=enrollment_repo,
|
||||
token_repository=token_repo,
|
||||
jwt_service=jwt_service,
|
||||
tls_bundle_loader=_tls_bundle_loader(settings),
|
||||
ip_rate_limiter=SlidingWindowRateLimiter(),
|
||||
fingerprint_rate_limiter=SlidingWindowRateLimiter(),
|
||||
nonce_cache=NonceCache(),
|
||||
script_signer=_load_script_signer(log),
|
||||
logger=log.getChild("enrollment"),
|
||||
)
|
||||
|
||||
device_auth = DeviceAuthService(
|
||||
device_repository=device_repo,
|
||||
jwt_service=jwt_service,
|
||||
logger=log.getChild("device_auth"),
|
||||
rate_limiter=rate_limiter,
|
||||
dpop_validator=dpop_validator,
|
||||
)
|
||||
|
||||
agent_realtime = AgentRealtimeService(
|
||||
device_repository=device_repo,
|
||||
logger=log.getChild("agent_realtime"),
|
||||
)
|
||||
|
||||
scheduler_service = SchedulerService(
|
||||
job_repository=job_repo,
|
||||
assemblies_root=settings.project_root / "Assemblies",
|
||||
logger=log.getChild("scheduler"),
|
||||
)
|
||||
|
||||
github_provider = GitHubArtifactProvider(
|
||||
cache_file=settings.github.cache_file,
|
||||
default_repo=settings.github.default_repo,
|
||||
default_branch=settings.github.default_branch,
|
||||
refresh_interval=settings.github.refresh_interval_seconds,
|
||||
logger=log.getChild("github.provider"),
|
||||
)
|
||||
github_service = GitHubService(
|
||||
repository=github_repo,
|
||||
provider=github_provider,
|
||||
logger=log.getChild("github"),
|
||||
)
|
||||
github_service.start_background_refresh()
|
||||
|
||||
return EngineServiceContainer(
|
||||
device_auth=device_auth,
|
||||
token_service=token_service,
|
||||
enrollment_service=enrollment_service,
|
||||
jwt_service=jwt_service,
|
||||
dpop_validator=dpop_validator,
|
||||
agent_realtime=agent_realtime,
|
||||
scheduler_service=scheduler_service,
|
||||
github_service=github_service,
|
||||
)
|
||||
|
||||
|
||||
def _tls_bundle_loader(settings: EngineSettings) -> Callable[[], str]:
|
||||
candidates = [
|
||||
Path(os.getenv("BOREALIS_TLS_BUNDLE", "")),
|
||||
settings.project_root / "Certificates" / "Server" / "borealis-server-bundle.pem",
|
||||
]
|
||||
|
||||
def loader() -> str:
|
||||
for candidate in candidates:
|
||||
if candidate and candidate.is_file():
|
||||
try:
|
||||
return candidate.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
return ""
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def _load_script_signer(logger: logging.Logger) -> Optional[ScriptSigner]:
|
||||
try:
|
||||
return load_signer()
|
||||
except Exception as exc:
|
||||
logger.warning("script signer unavailable: %s", exc)
|
||||
return None
|
||||
366
Data/Engine/services/crypto/certificates.py
Normal file
366
Data/Engine/services/crypto/certificates.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""Server TLS certificate management for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import ssl
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
|
||||
|
||||
from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path
|
||||
|
||||
__all__ = [
|
||||
"build_ssl_context",
|
||||
"certificate_paths",
|
||||
"ensure_certificate",
|
||||
]
|
||||
|
||||
_CERT_DIR = server_certificates_path()
|
||||
_CERT_FILE = _CERT_DIR / "borealis-server-cert.pem"
|
||||
_KEY_FILE = _CERT_DIR / "borealis-server-key.pem"
|
||||
_BUNDLE_FILE = _CERT_DIR / "borealis-server-bundle.pem"
|
||||
_CA_KEY_FILE = _CERT_DIR / "borealis-root-ca-key.pem"
|
||||
_CA_CERT_FILE = _CERT_DIR / "borealis-root-ca.pem"
|
||||
|
||||
_LEGACY_CERT_DIR = runtime_path("certs")
|
||||
_LEGACY_CERT_FILE = _LEGACY_CERT_DIR / "borealis-server-cert.pem"
|
||||
_LEGACY_KEY_FILE = _LEGACY_CERT_DIR / "borealis-server-key.pem"
|
||||
_LEGACY_BUNDLE_FILE = _LEGACY_CERT_DIR / "borealis-server-bundle.pem"
|
||||
|
||||
_ROOT_COMMON_NAME = "Borealis Root CA"
|
||||
_ORG_NAME = "Borealis"
|
||||
_ROOT_VALIDITY = timedelta(days=365 * 100)
|
||||
_SERVER_VALIDITY = timedelta(days=365 * 5)
|
||||
|
||||
|
||||
def ensure_certificate(common_name: str = "Borealis Engine") -> Tuple[Path, Path, Path]:
|
||||
"""Ensure the root CA, server certificate, and bundle exist on disk."""
|
||||
|
||||
ensure_server_certificates_dir()
|
||||
_migrate_legacy_material_if_present()
|
||||
|
||||
ca_key, ca_cert, ca_regenerated = _ensure_root_ca()
|
||||
|
||||
server_cert = _load_certificate(_CERT_FILE)
|
||||
needs_regen = ca_regenerated or _server_certificate_needs_regeneration(server_cert, ca_cert)
|
||||
if needs_regen:
|
||||
server_cert = _generate_server_certificate(common_name, ca_key, ca_cert)
|
||||
|
||||
if server_cert is None:
|
||||
server_cert = _generate_server_certificate(common_name, ca_key, ca_cert)
|
||||
|
||||
_write_bundle(server_cert, ca_cert)
|
||||
|
||||
return _CERT_FILE, _KEY_FILE, _BUNDLE_FILE
|
||||
|
||||
|
||||
def _migrate_legacy_material_if_present() -> None:
|
||||
if not _CERT_FILE.exists() or not _KEY_FILE.exists():
|
||||
legacy_cert = _LEGACY_CERT_FILE
|
||||
legacy_key = _LEGACY_KEY_FILE
|
||||
if legacy_cert.exists() and legacy_key.exists():
|
||||
try:
|
||||
ensure_server_certificates_dir()
|
||||
if not _CERT_FILE.exists():
|
||||
_safe_copy(legacy_cert, _CERT_FILE)
|
||||
if not _KEY_FILE.exists():
|
||||
_safe_copy(legacy_key, _KEY_FILE)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_root_ca() -> Tuple[ec.EllipticCurvePrivateKey, x509.Certificate, bool]:
|
||||
regenerated = False
|
||||
|
||||
ca_key: Optional[ec.EllipticCurvePrivateKey] = None
|
||||
ca_cert: Optional[x509.Certificate] = None
|
||||
|
||||
if _CA_KEY_FILE.exists() and _CA_CERT_FILE.exists():
|
||||
try:
|
||||
ca_key = _load_private_key(_CA_KEY_FILE)
|
||||
ca_cert = _load_certificate(_CA_CERT_FILE)
|
||||
if ca_cert is not None and ca_key is not None:
|
||||
expiry = _cert_not_after(ca_cert)
|
||||
subject = ca_cert.subject
|
||||
subject_cn = ""
|
||||
try:
|
||||
subject_cn = subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value # type: ignore[index]
|
||||
except Exception:
|
||||
subject_cn = ""
|
||||
try:
|
||||
basic = ca_cert.extensions.get_extension_for_class(x509.BasicConstraints).value # type: ignore[attr-defined]
|
||||
is_ca = bool(basic.ca)
|
||||
except Exception:
|
||||
is_ca = False
|
||||
if (
|
||||
expiry <= datetime.now(tz=timezone.utc)
|
||||
or not is_ca
|
||||
or subject_cn != _ROOT_COMMON_NAME
|
||||
):
|
||||
regenerated = True
|
||||
else:
|
||||
regenerated = True
|
||||
except Exception:
|
||||
regenerated = True
|
||||
else:
|
||||
regenerated = True
|
||||
|
||||
if regenerated or ca_key is None or ca_cert is None:
|
||||
ca_key = ec.generate_private_key(ec.SECP384R1())
|
||||
public_key = ca_key.public_key()
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
builder = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, _ROOT_COMMON_NAME),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
|
||||
]
|
||||
)
|
||||
)
|
||||
.issuer_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, _ROOT_COMMON_NAME),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
|
||||
]
|
||||
)
|
||||
)
|
||||
.public_key(public_key)
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now - timedelta(minutes=5))
|
||||
.not_valid_after(now + _ROOT_VALIDITY)
|
||||
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=True,
|
||||
crl_sign=True,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
)
|
||||
|
||||
builder = builder.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
|
||||
ca_cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA384())
|
||||
|
||||
_CA_KEY_FILE.write_bytes(
|
||||
ca_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
_CA_CERT_FILE.write_bytes(ca_cert.public_bytes(serialization.Encoding.PEM))
|
||||
|
||||
_tighten_permissions(_CA_KEY_FILE)
|
||||
_tighten_permissions(_CA_CERT_FILE)
|
||||
else:
|
||||
regenerated = False
|
||||
|
||||
return ca_key, ca_cert, regenerated
|
||||
|
||||
|
||||
def _server_certificate_needs_regeneration(
|
||||
server_cert: Optional[x509.Certificate],
|
||||
ca_cert: x509.Certificate,
|
||||
) -> bool:
|
||||
if server_cert is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
if server_cert.issuer != ca_cert.subject:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
try:
|
||||
expiry = _cert_not_after(server_cert)
|
||||
if expiry <= datetime.now(tz=timezone.utc):
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
try:
|
||||
basic = server_cert.extensions.get_extension_for_class(x509.BasicConstraints).value # type: ignore[attr-defined]
|
||||
if basic.ca:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
try:
|
||||
eku = server_cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value # type: ignore[attr-defined]
|
||||
if ExtendedKeyUsageOID.SERVER_AUTH not in eku:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _generate_server_certificate(
|
||||
common_name: str,
|
||||
ca_key: ec.EllipticCurvePrivateKey,
|
||||
ca_cert: x509.Certificate,
|
||||
) -> x509.Certificate:
|
||||
private_key = ec.generate_private_key(ec.SECP384R1())
|
||||
public_key = private_key.public_key()
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
ca_expiry = _cert_not_after(ca_cert)
|
||||
candidate_expiry = now + _SERVER_VALIDITY
|
||||
not_after = min(ca_expiry - timedelta(days=1), candidate_expiry)
|
||||
|
||||
builder = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
|
||||
]
|
||||
)
|
||||
)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(public_key)
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now - timedelta(minutes=5))
|
||||
.not_valid_after(not_after)
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName(
|
||||
[
|
||||
x509.DNSName("localhost"),
|
||||
x509.DNSName("127.0.0.1"),
|
||||
x509.DNSName("::1"),
|
||||
]
|
||||
),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=False,
|
||||
crl_sign=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_key.public_key()),
|
||||
critical=False,
|
||||
)
|
||||
)
|
||||
|
||||
certificate = builder.sign(private_key=ca_key, algorithm=hashes.SHA384())
|
||||
|
||||
_KEY_FILE.write_bytes(
|
||||
private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
_CERT_FILE.write_bytes(certificate.public_bytes(serialization.Encoding.PEM))
|
||||
|
||||
_tighten_permissions(_KEY_FILE)
|
||||
_tighten_permissions(_CERT_FILE)
|
||||
|
||||
return certificate
|
||||
|
||||
|
||||
def _write_bundle(server_cert: x509.Certificate, ca_cert: x509.Certificate) -> None:
|
||||
try:
|
||||
server_pem = server_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip()
|
||||
ca_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
bundle = f"{server_pem}\n{ca_pem}\n"
|
||||
_BUNDLE_FILE.write_text(bundle, encoding="utf-8")
|
||||
_tighten_permissions(_BUNDLE_FILE)
|
||||
|
||||
|
||||
def _safe_copy(src: Path, dst: Path) -> None:
|
||||
try:
|
||||
dst.write_bytes(src.read_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _tighten_permissions(path: Path) -> None:
|
||||
try:
|
||||
if os.name == "posix":
|
||||
path.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _load_private_key(path: Path) -> ec.EllipticCurvePrivateKey:
|
||||
with path.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
|
||||
def _load_certificate(path: Path) -> Optional[x509.Certificate]:
|
||||
try:
|
||||
return x509.load_pem_x509_certificate(path.read_bytes())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _cert_not_after(cert: x509.Certificate) -> datetime:
|
||||
try:
|
||||
return cert.not_valid_after_utc # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
value = cert.not_valid_after
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value
|
||||
|
||||
|
||||
def build_ssl_context() -> ssl.SSLContext:
|
||||
cert_path, key_path, bundle_path = ensure_certificate()
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||
context.load_cert_chain(certfile=str(bundle_path), keyfile=str(key_path))
|
||||
return context
|
||||
|
||||
|
||||
def certificate_paths() -> Tuple[str, str, str]:
|
||||
cert_path, key_path, bundle_path = ensure_certificate()
|
||||
return str(cert_path), str(key_path), str(bundle_path)
|
||||
75
Data/Engine/services/crypto/signing.py
Normal file
75
Data/Engine/services/crypto/signing.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Script signing utilities for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from Data.Engine.integrations.crypto.keys import base64_from_spki_der
|
||||
from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path
|
||||
|
||||
__all__ = ["ScriptSigner", "load_signer"]
|
||||
|
||||
|
||||
_KEY_DIR = server_certificates_path("Code-Signing")
|
||||
_SIGNING_KEY_FILE = _KEY_DIR / "engine-script-ed25519.key"
|
||||
_SIGNING_PUB_FILE = _KEY_DIR / "engine-script-ed25519.pub"
|
||||
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-script-ed25519.key"
|
||||
_LEGACY_PUB_FILE = runtime_path("keys") / "borealis-script-ed25519.pub"
|
||||
|
||||
|
||||
class ScriptSigner:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey) -> None:
|
||||
self._private = private_key
|
||||
self._public = private_key.public_key()
|
||||
|
||||
def sign(self, payload: bytes) -> bytes:
|
||||
return self._private.sign(payload)
|
||||
|
||||
def public_spki_der(self) -> bytes:
|
||||
return self._public.public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def public_base64_spki(self) -> str:
|
||||
return base64_from_spki_der(self.public_spki_der())
|
||||
|
||||
|
||||
def load_signer() -> ScriptSigner:
|
||||
private_key = _load_or_create()
|
||||
return ScriptSigner(private_key)
|
||||
|
||||
|
||||
def _load_or_create() -> ed25519.Ed25519PrivateKey:
|
||||
ensure_server_certificates_dir("Code-Signing")
|
||||
|
||||
if _SIGNING_KEY_FILE.exists():
|
||||
with _SIGNING_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
if _LEGACY_KEY_FILE.exists():
|
||||
with _LEGACY_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
_KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_SIGNING_KEY_FILE.write_bytes(pem)
|
||||
try:
|
||||
if hasattr(_SIGNING_KEY_FILE, "chmod"):
|
||||
_SIGNING_KEY_FILE.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
pub_der = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
_SIGNING_PUB_FILE.write_bytes(pub_der)
|
||||
|
||||
return private_key
|
||||
21
Data/Engine/services/enrollment/__init__.py
Normal file
21
Data/Engine/services/enrollment/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Enrollment services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .enrollment_service import (
|
||||
EnrollmentRequestResult,
|
||||
EnrollmentService,
|
||||
EnrollmentStatus,
|
||||
EnrollmentTokenBundle,
|
||||
PollingResult,
|
||||
)
|
||||
from Data.Engine.domain.device_enrollment import EnrollmentValidationError
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentService",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"EnrollmentValidationError",
|
||||
"PollingResult",
|
||||
]
|
||||
487
Data/Engine/services/enrollment/enrollment_service.py
Normal file
487
Data/Engine/services/enrollment/enrollment_service.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""Enrollment workflow orchestration for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable, Optional, Protocol
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
from Data.Engine.builders.device_enrollment import EnrollmentRequestInput
|
||||
from Data.Engine.domain.device_auth import (
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
sanitize_service_context,
|
||||
)
|
||||
from Data.Engine.domain.device_enrollment import (
|
||||
EnrollmentApproval,
|
||||
EnrollmentApprovalStatus,
|
||||
EnrollmentCode,
|
||||
EnrollmentValidationError,
|
||||
)
|
||||
from Data.Engine.services.auth.device_auth_service import DeviceRecord
|
||||
from Data.Engine.services.auth.token_service import JWTIssuer
|
||||
from Data.Engine.services.enrollment.nonce_cache import NonceCache
|
||||
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentService",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"PollingResult",
|
||||
]
|
||||
|
||||
|
||||
class DeviceRepository(Protocol):
|
||||
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def ensure_device_record(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
hostname: str,
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> DeviceRecord: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def record_device_key(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
added_at: datetime,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class EnrollmentRepository(Protocol):
|
||||
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_install_code_usage(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
use_count_increment: int,
|
||||
last_used_at: datetime,
|
||||
used_by_guid: Optional[DeviceGuid] = None,
|
||||
mark_first_use: bool = False,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_pending_approval_by_fingerprint(
|
||||
self, fingerprint: DeviceFingerprint
|
||||
) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_pending_approval(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
hostname: str,
|
||||
guid: Optional[DeviceGuid],
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
updated_at: datetime,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def create_device_approval(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
reference: str,
|
||||
claimed_hostname: str,
|
||||
claimed_fingerprint: DeviceFingerprint,
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
created_at: datetime,
|
||||
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> EnrollmentApproval: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_device_approval_status(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
status: EnrollmentApprovalStatus,
|
||||
updated_at: datetime,
|
||||
approved_by: Optional[str] = None,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class RefreshTokenRepository(Protocol):
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
guid: DeviceGuid,
|
||||
token_hash: str,
|
||||
created_at: datetime,
|
||||
expires_at: Optional[datetime],
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class ScriptSigner(Protocol):
|
||||
def public_base64_spki(self) -> str: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class EnrollmentStatus(str):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
EXPIRED = "expired"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentTokenBundle:
|
||||
guid: DeviceGuid
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
token_type: str = "Bearer"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentRequestResult:
|
||||
status: EnrollmentStatus
|
||||
server_certificate: str
|
||||
signing_key: str
|
||||
approval_reference: Optional[str] = None
|
||||
server_nonce: Optional[str] = None
|
||||
poll_after_ms: Optional[int] = None
|
||||
http_status: int = 200
|
||||
retry_after: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PollingResult:
|
||||
status: EnrollmentStatus
|
||||
http_status: int
|
||||
poll_after_ms: Optional[int] = None
|
||||
reason: Optional[str] = None
|
||||
detail: Optional[str] = None
|
||||
tokens: Optional[EnrollmentTokenBundle] = None
|
||||
server_certificate: Optional[str] = None
|
||||
signing_key: Optional[str] = None
|
||||
|
||||
|
||||
class EnrollmentService:
|
||||
"""Coordinate the Borealis device enrollment handshake."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_repository: DeviceRepository,
|
||||
enrollment_repository: EnrollmentRepository,
|
||||
token_repository: RefreshTokenRepository,
|
||||
jwt_service: JWTIssuer,
|
||||
tls_bundle_loader: Callable[[], str],
|
||||
ip_rate_limiter: SlidingWindowRateLimiter,
|
||||
fingerprint_rate_limiter: SlidingWindowRateLimiter,
|
||||
nonce_cache: NonceCache,
|
||||
script_signer: Optional[ScriptSigner] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._devices = device_repository
|
||||
self._enrollment = enrollment_repository
|
||||
self._tokens = token_repository
|
||||
self._jwt = jwt_service
|
||||
self._load_tls_bundle = tls_bundle_loader
|
||||
self._ip_rate_limiter = ip_rate_limiter
|
||||
self._fp_rate_limiter = fingerprint_rate_limiter
|
||||
self._nonce_cache = nonce_cache
|
||||
self._signer = script_signer
|
||||
self._log = logger or logging.getLogger("borealis.engine.enrollment")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
def request_enrollment(
|
||||
self,
|
||||
payload: EnrollmentRequestInput,
|
||||
*,
|
||||
remote_addr: str,
|
||||
) -> EnrollmentRequestResult:
|
||||
context_hint = sanitize_service_context(payload.service_context)
|
||||
self._log.info(
|
||||
"enrollment-request ip=%s host=%s code_mask=%s", remote_addr, payload.hostname, self._mask_code(payload.enrollment_code)
|
||||
)
|
||||
|
||||
self._enforce_rate_limit(self._ip_rate_limiter, f"ip:{remote_addr}")
|
||||
self._enforce_rate_limit(self._fp_rate_limiter, f"fp:{payload.fingerprint.value}")
|
||||
|
||||
install_code = self._enrollment.fetch_install_code(payload.enrollment_code)
|
||||
reuse_guid = self._determine_reuse_guid(install_code, payload.fingerprint)
|
||||
|
||||
server_nonce_bytes = secrets.token_bytes(32)
|
||||
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
|
||||
|
||||
now = self._now()
|
||||
approval = self._enrollment.fetch_pending_approval_by_fingerprint(payload.fingerprint)
|
||||
if approval:
|
||||
self._enrollment.update_pending_approval(
|
||||
approval.record_id,
|
||||
hostname=payload.hostname,
|
||||
guid=reuse_guid,
|
||||
enrollment_code_id=install_code.identifier if install_code else None,
|
||||
client_nonce_b64=payload.client_nonce_b64,
|
||||
server_nonce_b64=server_nonce_b64,
|
||||
agent_pubkey_der=payload.agent_public_key_der,
|
||||
updated_at=now,
|
||||
)
|
||||
approval_reference = approval.reference
|
||||
else:
|
||||
record_id = str(uuid.uuid4())
|
||||
approval_reference = str(uuid.uuid4())
|
||||
approval = self._enrollment.create_device_approval(
|
||||
record_id=record_id,
|
||||
reference=approval_reference,
|
||||
claimed_hostname=payload.hostname,
|
||||
claimed_fingerprint=payload.fingerprint,
|
||||
enrollment_code_id=install_code.identifier if install_code else None,
|
||||
client_nonce_b64=payload.client_nonce_b64,
|
||||
server_nonce_b64=server_nonce_b64,
|
||||
agent_pubkey_der=payload.agent_public_key_der,
|
||||
created_at=now,
|
||||
guid=reuse_guid,
|
||||
)
|
||||
|
||||
signing_key = self._signer.public_base64_spki() if self._signer else ""
|
||||
certificate = self._load_tls_bundle()
|
||||
|
||||
return EnrollmentRequestResult(
|
||||
status=EnrollmentStatus.PENDING,
|
||||
approval_reference=approval.reference,
|
||||
server_nonce=server_nonce_b64,
|
||||
poll_after_ms=3000,
|
||||
server_certificate=certificate,
|
||||
signing_key=signing_key,
|
||||
)
|
||||
|
||||
def poll_enrollment(
|
||||
self,
|
||||
*,
|
||||
approval_reference: str,
|
||||
client_nonce_b64: str,
|
||||
proof_signature_b64: str,
|
||||
) -> PollingResult:
|
||||
if not approval_reference:
|
||||
raise EnrollmentValidationError("approval_reference_required")
|
||||
if not client_nonce_b64:
|
||||
raise EnrollmentValidationError("client_nonce_required")
|
||||
if not proof_signature_b64:
|
||||
raise EnrollmentValidationError("proof_sig_required")
|
||||
|
||||
approval = self._enrollment.fetch_device_approval_by_reference(approval_reference)
|
||||
if approval is None:
|
||||
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=404)
|
||||
|
||||
client_nonce = self._decode_base64(client_nonce_b64, "invalid_client_nonce")
|
||||
server_nonce = self._decode_base64(approval.server_nonce_b64, "server_nonce_invalid")
|
||||
proof_sig = self._decode_base64(proof_signature_b64, "invalid_proof_sig")
|
||||
|
||||
if approval.client_nonce_b64 != client_nonce_b64:
|
||||
raise EnrollmentValidationError("nonce_mismatch")
|
||||
|
||||
self._verify_proof_signature(
|
||||
approval=approval,
|
||||
client_nonce=client_nonce,
|
||||
server_nonce=server_nonce,
|
||||
signature=proof_sig,
|
||||
)
|
||||
|
||||
status = approval.status
|
||||
if status is EnrollmentApprovalStatus.PENDING:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.PENDING,
|
||||
http_status=200,
|
||||
poll_after_ms=5000,
|
||||
)
|
||||
if status is EnrollmentApprovalStatus.DENIED:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.DENIED,
|
||||
http_status=200,
|
||||
reason="operator_denied",
|
||||
)
|
||||
if status is EnrollmentApprovalStatus.EXPIRED:
|
||||
return PollingResult(status=EnrollmentStatus.EXPIRED, http_status=200)
|
||||
if status is EnrollmentApprovalStatus.COMPLETED:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.APPROVED,
|
||||
http_status=200,
|
||||
detail="finalized",
|
||||
)
|
||||
if status is not EnrollmentApprovalStatus.APPROVED:
|
||||
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=400)
|
||||
|
||||
nonce_key = f"{approval.reference}:{proof_signature_b64}"
|
||||
if not self._nonce_cache.consume(nonce_key):
|
||||
raise EnrollmentValidationError("proof_replayed", http_status=409)
|
||||
|
||||
token_bundle = self._finalize_approval(approval)
|
||||
signing_key = self._signer.public_base64_spki() if self._signer else ""
|
||||
certificate = self._load_tls_bundle()
|
||||
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.APPROVED,
|
||||
http_status=200,
|
||||
tokens=token_bundle,
|
||||
server_certificate=certificate,
|
||||
signing_key=signing_key,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
def _enforce_rate_limit(
|
||||
self,
|
||||
limiter: SlidingWindowRateLimiter,
|
||||
key: str,
|
||||
*,
|
||||
limit: int = 60,
|
||||
window_seconds: float = 60.0,
|
||||
) -> None:
|
||||
decision = limiter.check(key, limit, window_seconds)
|
||||
if not decision.allowed:
|
||||
raise EnrollmentValidationError(
|
||||
"rate_limited", http_status=429, retry_after=max(decision.retry_after, 1.0)
|
||||
)
|
||||
|
||||
def _determine_reuse_guid(
|
||||
self,
|
||||
install_code: Optional[EnrollmentCode],
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> Optional[DeviceGuid]:
|
||||
if install_code is None:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
if install_code.is_expired:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
if not install_code.is_exhausted:
|
||||
return None
|
||||
if not install_code.used_by_guid:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
|
||||
existing = self._devices.fetch_by_guid(install_code.used_by_guid)
|
||||
if existing and existing.identity.fingerprint.value == fingerprint.value:
|
||||
return install_code.used_by_guid
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
|
||||
def _finalize_approval(self, approval: EnrollmentApproval) -> EnrollmentTokenBundle:
|
||||
now = self._now()
|
||||
effective_guid = approval.guid or DeviceGuid(str(uuid.uuid4()))
|
||||
device_record = self._devices.ensure_device_record(
|
||||
guid=effective_guid,
|
||||
hostname=approval.claimed_hostname,
|
||||
fingerprint=approval.claimed_fingerprint,
|
||||
)
|
||||
self._devices.record_device_key(
|
||||
guid=effective_guid,
|
||||
fingerprint=approval.claimed_fingerprint,
|
||||
added_at=now,
|
||||
)
|
||||
|
||||
if approval.enrollment_code_id:
|
||||
code = self._enrollment.fetch_install_code_by_id(approval.enrollment_code_id)
|
||||
if code is not None:
|
||||
mark_first = code.used_at is None
|
||||
self._enrollment.update_install_code_usage(
|
||||
approval.enrollment_code_id,
|
||||
use_count_increment=1,
|
||||
last_used_at=now,
|
||||
used_by_guid=effective_guid,
|
||||
mark_first_use=mark_first,
|
||||
)
|
||||
|
||||
refresh_token = secrets.token_urlsafe(48)
|
||||
refresh_id = str(uuid.uuid4())
|
||||
expires_at = now + timedelta(days=30)
|
||||
token_hash = hashlib.sha256(refresh_token.encode("utf-8")).hexdigest()
|
||||
self._tokens.create(
|
||||
record_id=refresh_id,
|
||||
guid=effective_guid,
|
||||
token_hash=token_hash,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
access_token = self._jwt.issue_access_token(
|
||||
effective_guid.value,
|
||||
device_record.identity.fingerprint.value,
|
||||
max(device_record.token_version, 1),
|
||||
)
|
||||
|
||||
self._enrollment.update_device_approval_status(
|
||||
approval.record_id,
|
||||
status=EnrollmentApprovalStatus.COMPLETED,
|
||||
updated_at=now,
|
||||
guid=effective_guid,
|
||||
)
|
||||
|
||||
return EnrollmentTokenBundle(
|
||||
guid=effective_guid,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=900,
|
||||
)
|
||||
|
||||
def _verify_proof_signature(
|
||||
self,
|
||||
*,
|
||||
approval: EnrollmentApproval,
|
||||
client_nonce: bytes,
|
||||
server_nonce: bytes,
|
||||
signature: bytes,
|
||||
) -> None:
|
||||
message = server_nonce + approval.reference.encode("utf-8") + client_nonce
|
||||
try:
|
||||
public_key = serialization.load_der_public_key(approval.agent_pubkey_der)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError("agent_pubkey_invalid") from exc
|
||||
|
||||
try:
|
||||
public_key.verify(signature, message)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError("invalid_proof") from exc
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64(value: str, error_code: str) -> bytes:
|
||||
try:
|
||||
return base64.b64decode(value, validate=True)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError(error_code) from exc
|
||||
|
||||
@staticmethod
|
||||
def _mask_code(code: str) -> str:
|
||||
trimmed = (code or "").strip()
|
||||
if len(trimmed) <= 6:
|
||||
return "***"
|
||||
return f"{trimmed[:3]}***{trimmed[-3:]}"
|
||||
|
||||
@staticmethod
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
32
Data/Engine/services/enrollment/nonce_cache.py
Normal file
32
Data/Engine/services/enrollment/nonce_cache.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Nonce replay protection for enrollment workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict
|
||||
|
||||
__all__ = ["NonceCache"]
|
||||
|
||||
|
||||
class NonceCache:
|
||||
"""Track recently observed nonces to prevent replay."""
|
||||
|
||||
def __init__(self, ttl_seconds: float = 300.0) -> None:
|
||||
self._ttl = ttl_seconds
|
||||
self._entries: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def consume(self, key: str) -> bool:
|
||||
"""Consume *key* if it has not been seen recently."""
|
||||
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
expiry = self._entries.get(key)
|
||||
if expiry and expiry > now:
|
||||
return False
|
||||
self._entries[key] = now + self._ttl
|
||||
stale = [nonce for nonce, ttl in self._entries.items() if ttl <= now]
|
||||
for nonce in stale:
|
||||
self._entries.pop(nonce, None)
|
||||
return True
|
||||
8
Data/Engine/services/github/__init__.py
Normal file
8
Data/Engine/services/github/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""GitHub-oriented services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .github_service import GitHubService, GitHubTokenPayload
|
||||
|
||||
__all__ = ["GitHubService", "GitHubTokenPayload"]
|
||||
|
||||
106
Data/Engine/services/github/github_service.py
Normal file
106
Data/Engine/services/github/github_service.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""GitHub service layer bridging repositories and integrations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
from Data.Engine.domain.github import GitHubRepoRef, GitHubTokenStatus, RepoHeadSnapshot
|
||||
from Data.Engine.integrations.github import GitHubArtifactProvider
|
||||
from Data.Engine.repositories.sqlite.github_repository import SQLiteGitHubRepository
|
||||
|
||||
__all__ = ["GitHubService", "GitHubTokenPayload"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GitHubTokenPayload:
|
||||
token: Optional[str]
|
||||
status: GitHubTokenStatus
|
||||
checked_at: int
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
payload = self.status.to_dict()
|
||||
payload.update(
|
||||
{
|
||||
"token": self.token or "",
|
||||
"checked_at": self.checked_at,
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
class GitHubService:
|
||||
"""Coordinate GitHub caching, verification, and persistence."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: SQLiteGitHubRepository,
|
||||
provider: GitHubArtifactProvider,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
clock: Optional[Callable[[], float]] = None,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._provider = provider
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.github")
|
||||
self._clock = clock or time.time
|
||||
self._token_cache: Optional[str] = None
|
||||
self._token_loaded_at: float = 0.0
|
||||
|
||||
initial_token = self._repository.load_token()
|
||||
self._apply_token(initial_token)
|
||||
|
||||
def get_repo_head(
|
||||
self,
|
||||
owner_repo: Optional[str],
|
||||
branch: Optional[str],
|
||||
*,
|
||||
ttl_seconds: int,
|
||||
force_refresh: bool = False,
|
||||
) -> RepoHeadSnapshot:
|
||||
repo_str = (owner_repo or self._provider.default_repo).strip()
|
||||
branch_name = (branch or self._provider.default_branch).strip()
|
||||
repo = GitHubRepoRef.parse(repo_str, branch_name)
|
||||
ttl = max(30, min(ttl_seconds, 3600))
|
||||
return self._provider.fetch_repo_head(repo, ttl_seconds=ttl, force_refresh=force_refresh)
|
||||
|
||||
def refresh_default_repo(self, *, force: bool = False) -> RepoHeadSnapshot:
|
||||
return self._provider.refresh_default_repo_head(force=force)
|
||||
|
||||
def get_token_status(self, *, force_refresh: bool = False) -> GitHubTokenPayload:
|
||||
token = self._load_token(force_refresh=force_refresh)
|
||||
status = self._provider.verify_token(token)
|
||||
return GitHubTokenPayload(token=token, status=status, checked_at=int(self._clock()))
|
||||
|
||||
def update_token(self, token: Optional[str]) -> GitHubTokenPayload:
|
||||
normalized = (token or "").strip()
|
||||
self._repository.store_token(normalized)
|
||||
self._apply_token(normalized)
|
||||
status = self._provider.verify_token(normalized)
|
||||
self._provider.start_background_refresh()
|
||||
self._log.info("github-token updated valid=%s", status.valid)
|
||||
return GitHubTokenPayload(token=normalized or None, status=status, checked_at=int(self._clock()))
|
||||
|
||||
def start_background_refresh(self) -> None:
|
||||
self._provider.start_background_refresh()
|
||||
|
||||
@property
|
||||
def default_refresh_interval(self) -> int:
|
||||
return self._provider.refresh_interval
|
||||
|
||||
def _load_token(self, *, force_refresh: bool = False) -> Optional[str]:
|
||||
now = self._clock()
|
||||
if not force_refresh and self._token_cache is not None and (now - self._token_loaded_at) < 15.0:
|
||||
return self._token_cache
|
||||
|
||||
token = self._repository.load_token()
|
||||
self._apply_token(token)
|
||||
return token
|
||||
|
||||
def _apply_token(self, token: Optional[str]) -> None:
|
||||
self._token_cache = (token or "").strip() or None
|
||||
self._token_loaded_at = self._clock()
|
||||
self._provider.set_token(self._token_cache)
|
||||
|
||||
5
Data/Engine/services/jobs/__init__.py
Normal file
5
Data/Engine/services/jobs/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Job-related services for the Borealis Engine."""
|
||||
|
||||
from .scheduler_service import SchedulerService
|
||||
|
||||
__all__ = ["SchedulerService"]
|
||||
373
Data/Engine/services/jobs/scheduler_service.py
Normal file
373
Data/Engine/services/jobs/scheduler_service.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""Background scheduler service for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import calendar
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Mapping, Optional
|
||||
|
||||
from Data.Engine.builders.job_fabricator import JobFabricator, JobManifest
|
||||
from Data.Engine.repositories.sqlite.job_repository import (
|
||||
ScheduledJobRecord,
|
||||
ScheduledJobRunRecord,
|
||||
SQLiteJobRepository,
|
||||
)
|
||||
|
||||
__all__ = ["SchedulerService", "SchedulerRuntime"]
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def _floor_minute(ts: int) -> int:
|
||||
ts = int(ts or 0)
|
||||
return ts - (ts % 60)
|
||||
|
||||
|
||||
def _parse_ts(val: Any) -> Optional[int]:
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, (int, float)):
|
||||
return int(val)
|
||||
try:
|
||||
s = str(val).strip().replace("Z", "+00:00")
|
||||
return int(datetime.fromisoformat(s).timestamp())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_expiration(raw: Optional[str]) -> Optional[int]:
|
||||
if not raw or raw == "no_expire":
|
||||
return None
|
||||
try:
|
||||
s = raw.strip().lower()
|
||||
unit = s[-1]
|
||||
value = int(s[:-1])
|
||||
if unit == "m":
|
||||
return value * 60
|
||||
if unit == "h":
|
||||
return value * 3600
|
||||
if unit == "d":
|
||||
return value * 86400
|
||||
return int(s) * 60
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _add_months(dt: datetime, months: int) -> datetime:
|
||||
year = dt.year + (dt.month - 1 + months) // 12
|
||||
month = ((dt.month - 1 + months) % 12) + 1
|
||||
last_day = calendar.monthrange(year, month)[1]
|
||||
day = min(dt.day, last_day)
|
||||
return dt.replace(year=year, month=month, day=day)
|
||||
|
||||
|
||||
def _add_years(dt: datetime, years: int) -> datetime:
|
||||
year = dt.year + years
|
||||
month = dt.month
|
||||
last_day = calendar.monthrange(year, month)[1]
|
||||
day = min(dt.day, last_day)
|
||||
return dt.replace(year=year, month=month, day=day)
|
||||
|
||||
|
||||
def _compute_next_run(
|
||||
schedule_type: str,
|
||||
start_ts: Optional[int],
|
||||
last_run_ts: Optional[int],
|
||||
now_ts: int,
|
||||
) -> Optional[int]:
|
||||
st = (schedule_type or "immediately").strip().lower()
|
||||
start_floor = _floor_minute(start_ts) if start_ts else None
|
||||
last_floor = _floor_minute(last_run_ts) if last_run_ts else None
|
||||
now_floor = _floor_minute(now_ts)
|
||||
|
||||
if st == "immediately":
|
||||
return None if last_floor else now_floor
|
||||
if st == "once":
|
||||
if not start_floor:
|
||||
return None
|
||||
return start_floor if not last_floor else None
|
||||
if not start_floor:
|
||||
return None
|
||||
|
||||
last = last_floor if last_floor is not None else None
|
||||
if st in {
|
||||
"every_5_minutes",
|
||||
"every_10_minutes",
|
||||
"every_15_minutes",
|
||||
"every_30_minutes",
|
||||
"every_hour",
|
||||
}:
|
||||
period_map = {
|
||||
"every_5_minutes": 5 * 60,
|
||||
"every_10_minutes": 10 * 60,
|
||||
"every_15_minutes": 15 * 60,
|
||||
"every_30_minutes": 30 * 60,
|
||||
"every_hour": 60 * 60,
|
||||
}
|
||||
period = period_map.get(st)
|
||||
candidate = (last + period) if last else start_floor
|
||||
while candidate is not None and candidate <= now_floor - 1:
|
||||
candidate += period
|
||||
return candidate
|
||||
if st == "daily":
|
||||
period = 86400
|
||||
candidate = (last + period) if last else start_floor
|
||||
while candidate is not None and candidate <= now_floor - 1:
|
||||
candidate += period
|
||||
return candidate
|
||||
if st == "weekly":
|
||||
period = 7 * 86400
|
||||
candidate = (last + period) if last else start_floor
|
||||
while candidate is not None and candidate <= now_floor - 1:
|
||||
candidate += period
|
||||
return candidate
|
||||
if st == "monthly":
|
||||
base = datetime.utcfromtimestamp(last) if last else datetime.utcfromtimestamp(start_floor)
|
||||
candidate = _add_months(base, 1 if last else 0)
|
||||
while int(candidate.timestamp()) <= now_floor - 1:
|
||||
candidate = _add_months(candidate, 1)
|
||||
return int(candidate.timestamp())
|
||||
if st == "yearly":
|
||||
base = datetime.utcfromtimestamp(last) if last else datetime.utcfromtimestamp(start_floor)
|
||||
candidate = _add_years(base, 1 if last else 0)
|
||||
while int(candidate.timestamp()) <= now_floor - 1:
|
||||
candidate = _add_years(candidate, 1)
|
||||
return int(candidate.timestamp())
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerRuntime:
|
||||
thread: Optional[threading.Thread]
|
||||
stop_event: threading.Event
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""Evaluate and dispatch scheduled jobs using Engine repositories."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
job_repository: SQLiteJobRepository,
|
||||
assemblies_root: Path,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
poll_interval: int = 30,
|
||||
) -> None:
|
||||
self._jobs = job_repository
|
||||
self._fabricator = JobFabricator(assemblies_root=assemblies_root, logger=logger)
|
||||
self._log = logger or logging.getLogger("borealis.engine.scheduler")
|
||||
self._poll_interval = max(5, poll_interval)
|
||||
self._socketio: Optional[Any] = None
|
||||
self._runtime = SchedulerRuntime(thread=None, stop_event=threading.Event())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
def start(self, socketio: Optional[Any] = None) -> None:
|
||||
self._socketio = socketio
|
||||
if self._runtime.thread and self._runtime.thread.is_alive():
|
||||
return
|
||||
self._runtime.stop_event.clear()
|
||||
thread = threading.Thread(target=self._run_loop, name="borealis-engine-scheduler", daemon=True)
|
||||
thread.start()
|
||||
self._runtime.thread = thread
|
||||
self._log.info("scheduler-started")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._runtime.stop_event.set()
|
||||
thread = self._runtime.thread
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=5)
|
||||
self._log.info("scheduler-stopped")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP orchestration helpers
|
||||
# ------------------------------------------------------------------
|
||||
def list_jobs(self) -> list[dict[str, Any]]:
|
||||
return [self._serialize_job(job) for job in self._jobs.list_jobs()]
|
||||
|
||||
def get_job(self, job_id: int) -> Optional[dict[str, Any]]:
|
||||
record = self._jobs.fetch_job(job_id)
|
||||
return self._serialize_job(record) if record else None
|
||||
|
||||
def create_job(self, payload: Mapping[str, Any]) -> dict[str, Any]:
|
||||
fields = self._normalize_payload(payload)
|
||||
record = self._jobs.create_job(**fields)
|
||||
return self._serialize_job(record)
|
||||
|
||||
def update_job(self, job_id: int, payload: Mapping[str, Any]) -> Optional[dict[str, Any]]:
|
||||
fields = self._normalize_payload(payload)
|
||||
record = self._jobs.update_job(job_id, **fields)
|
||||
return self._serialize_job(record) if record else None
|
||||
|
||||
def toggle_job(self, job_id: int, enabled: bool) -> None:
|
||||
self._jobs.set_enabled(job_id, enabled)
|
||||
|
||||
def delete_job(self, job_id: int) -> None:
|
||||
self._jobs.delete_job(job_id)
|
||||
|
||||
def list_runs(self, job_id: int, *, days: Optional[int] = None) -> list[dict[str, Any]]:
|
||||
runs = self._jobs.list_runs(job_id, days=days)
|
||||
return [self._serialize_run(run) for run in runs]
|
||||
|
||||
def purge_runs(self, job_id: int) -> None:
|
||||
self._jobs.purge_runs(job_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scheduling loop
|
||||
# ------------------------------------------------------------------
|
||||
def tick(self, *, now_ts: Optional[int] = None) -> None:
|
||||
self._evaluate_jobs(now_ts=now_ts or _now_ts())
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
stop_event = self._runtime.stop_event
|
||||
while not stop_event.wait(timeout=self._poll_interval):
|
||||
try:
|
||||
self._evaluate_jobs(now_ts=_now_ts())
|
||||
except Exception as exc: # pragma: no cover - safety net
|
||||
self._log.exception("scheduler-loop-error: %s", exc)
|
||||
|
||||
def _evaluate_jobs(self, *, now_ts: int) -> None:
|
||||
for job in self._jobs.list_enabled_jobs():
|
||||
try:
|
||||
self._evaluate_job(job, now_ts=now_ts)
|
||||
except Exception as exc:
|
||||
self._log.exception("job-evaluation-error job_id=%s error=%s", job.id, exc)
|
||||
|
||||
def _evaluate_job(self, job: ScheduledJobRecord, *, now_ts: int) -> None:
|
||||
last_run = self._jobs.fetch_last_run(job.id)
|
||||
last_ts = None
|
||||
if last_run:
|
||||
last_ts = last_run.scheduled_ts or last_run.started_ts or last_run.created_at
|
||||
next_run = _compute_next_run(job.schedule_type, job.start_ts, last_ts, now_ts)
|
||||
if next_run is None or next_run > now_ts:
|
||||
return
|
||||
|
||||
expiration_window = _parse_expiration(job.expiration)
|
||||
if expiration_window and job.start_ts:
|
||||
if job.start_ts + expiration_window <= now_ts:
|
||||
self._log.info(
|
||||
"job-expired",
|
||||
extra={"job_id": job.id, "start_ts": job.start_ts, "expiration": job.expiration},
|
||||
)
|
||||
return
|
||||
|
||||
manifest = self._fabricator.build(job, occurrence_ts=next_run)
|
||||
targets = manifest.targets or ("<unassigned>",)
|
||||
for target in targets:
|
||||
run_id = self._jobs.create_run(job.id, next_run, target_hostname=None if target == "<unassigned>" else target)
|
||||
self._jobs.mark_run_started(run_id, started_ts=now_ts)
|
||||
self._emit_run_event("job_run_started", job, run_id, target, manifest)
|
||||
self._jobs.mark_run_finished(run_id, status="Success", finished_ts=now_ts)
|
||||
self._emit_run_event("job_run_completed", job, run_id, target, manifest)
|
||||
|
||||
def _emit_run_event(
|
||||
self,
|
||||
event: str,
|
||||
job: ScheduledJobRecord,
|
||||
run_id: int,
|
||||
target: str,
|
||||
manifest: JobManifest,
|
||||
) -> None:
|
||||
payload = {
|
||||
"job_id": job.id,
|
||||
"run_id": run_id,
|
||||
"target": target,
|
||||
"schedule_type": job.schedule_type,
|
||||
"occurrence_ts": manifest.occurrence_ts,
|
||||
}
|
||||
if self._socketio is not None:
|
||||
try:
|
||||
self._socketio.emit(event, payload) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
self._log.debug("socketio-emit-failed event=%s payload=%s", event, payload)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _serialize_job(self, job: ScheduledJobRecord) -> dict[str, Any]:
|
||||
return {
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"components": job.components,
|
||||
"targets": job.targets,
|
||||
"schedule": {
|
||||
"type": job.schedule_type,
|
||||
"start": job.start_ts,
|
||||
},
|
||||
"schedule_type": job.schedule_type,
|
||||
"start_ts": job.start_ts,
|
||||
"duration_stop_enabled": job.duration_stop_enabled,
|
||||
"expiration": job.expiration or "no_expire",
|
||||
"execution_context": job.execution_context,
|
||||
"credential_id": job.credential_id,
|
||||
"use_service_account": job.use_service_account,
|
||||
"enabled": job.enabled,
|
||||
"created_at": job.created_at,
|
||||
"updated_at": job.updated_at,
|
||||
}
|
||||
|
||||
def _serialize_run(self, run: ScheduledJobRunRecord) -> dict[str, Any]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"job_id": run.job_id,
|
||||
"scheduled_ts": run.scheduled_ts,
|
||||
"started_ts": run.started_ts,
|
||||
"finished_ts": run.finished_ts,
|
||||
"status": run.status,
|
||||
"error": run.error,
|
||||
"target_hostname": run.target_hostname,
|
||||
"created_at": run.created_at,
|
||||
"updated_at": run.updated_at,
|
||||
}
|
||||
|
||||
def _normalize_payload(self, payload: Mapping[str, Any]) -> dict[str, Any]:
|
||||
name = str(payload.get("name") or "").strip()
|
||||
components = payload.get("components") or []
|
||||
targets = payload.get("targets") or []
|
||||
schedule_block = payload.get("schedule") if isinstance(payload.get("schedule"), Mapping) else {}
|
||||
schedule_type = str(schedule_block.get("type") or payload.get("schedule_type") or "immediately").strip().lower()
|
||||
start_value = schedule_block.get("start") if isinstance(schedule_block, Mapping) else None
|
||||
if start_value is None:
|
||||
start_value = payload.get("start")
|
||||
start_ts = _parse_ts(start_value)
|
||||
duration_block = payload.get("duration") if isinstance(payload.get("duration"), Mapping) else {}
|
||||
duration_stop = bool(duration_block.get("stopAfterEnabled") or payload.get("duration_stop_enabled"))
|
||||
expiration = str(duration_block.get("expiration") or payload.get("expiration") or "no_expire").strip()
|
||||
execution_context = str(payload.get("execution_context") or "system").strip().lower()
|
||||
credential_id = payload.get("credential_id")
|
||||
try:
|
||||
credential_id = int(credential_id) if credential_id is not None else None
|
||||
except Exception:
|
||||
credential_id = None
|
||||
use_service_account_raw = payload.get("use_service_account")
|
||||
use_service_account = bool(use_service_account_raw) if execution_context == "winrm" else False
|
||||
enabled = bool(payload.get("enabled", True))
|
||||
|
||||
if not name:
|
||||
raise ValueError("job name is required")
|
||||
if not isinstance(components, Iterable) or not list(components):
|
||||
raise ValueError("at least one component is required")
|
||||
if not isinstance(targets, Iterable) or not list(targets):
|
||||
raise ValueError("at least one target is required")
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"components": list(components),
|
||||
"targets": list(targets),
|
||||
"schedule_type": schedule_type,
|
||||
"start_ts": start_ts,
|
||||
"duration_stop_enabled": duration_stop,
|
||||
"expiration": expiration,
|
||||
"execution_context": execution_context,
|
||||
"credential_id": credential_id,
|
||||
"use_service_account": use_service_account,
|
||||
"enabled": enabled,
|
||||
}
|
||||
45
Data/Engine/services/rate_limit.py
Normal file
45
Data/Engine/services/rate_limit.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""In-process rate limiting utilities for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Deque, Dict
|
||||
|
||||
__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RateLimitDecision:
|
||||
"""Result of a rate limit check."""
|
||||
|
||||
allowed: bool
|
||||
retry_after: float
|
||||
|
||||
|
||||
class SlidingWindowRateLimiter:
|
||||
"""Tiny in-memory sliding window limiter suitable for single-process use."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buckets: Dict[str, Deque[float]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
bucket = self._buckets.get(key)
|
||||
if bucket is None:
|
||||
bucket = deque()
|
||||
self._buckets[key] = bucket
|
||||
|
||||
while bucket and now - bucket[0] > window_seconds:
|
||||
bucket.popleft()
|
||||
|
||||
if len(bucket) >= limit:
|
||||
retry_after = max(0.0, window_seconds - (now - bucket[0]))
|
||||
return RateLimitDecision(False, retry_after)
|
||||
|
||||
bucket.append(now)
|
||||
return RateLimitDecision(True, 0.0)
|
||||
10
Data/Engine/services/realtime/__init__.py
Normal file
10
Data/Engine/services/realtime/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Realtime coordination services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .agent_registry import AgentRealtimeService, AgentRecord
|
||||
|
||||
__all__ = [
|
||||
"AgentRealtimeService",
|
||||
"AgentRecord",
|
||||
]
|
||||
301
Data/Engine/services/realtime/agent_registry.py
Normal file
301
Data/Engine/services/realtime/agent_registry.py
Normal file
@@ -0,0 +1,301 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple
|
||||
|
||||
from Data.Engine.repositories.sqlite import SQLiteDeviceRepository
|
||||
|
||||
__all__ = ["AgentRealtimeService", "AgentRecord"]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentRecord:
|
||||
"""In-memory representation of a connected agent."""
|
||||
|
||||
agent_id: str
|
||||
hostname: str = "unknown"
|
||||
agent_operating_system: str = "-"
|
||||
last_seen: int = 0
|
||||
status: str = "orphaned"
|
||||
service_mode: str = "currentuser"
|
||||
is_script_agent: bool = False
|
||||
collector_active_ts: Optional[float] = None
|
||||
|
||||
|
||||
class AgentRealtimeService:
|
||||
"""Track realtime agent presence and provide persistence hooks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_repository: SQLiteDeviceRepository,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._device_repository = device_repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.realtime.agents")
|
||||
self._agents: Dict[str, AgentRecord] = {}
|
||||
self._configs: Dict[str, Dict[str, Any]] = {}
|
||||
self._screenshots: Dict[str, Dict[str, Any]] = {}
|
||||
self._task_screenshots: Dict[Tuple[str, str], Dict[str, Any]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent presence management
|
||||
# ------------------------------------------------------------------
|
||||
def register_connection(self, agent_id: str, service_mode: Optional[str]) -> AgentRecord:
|
||||
record = self._agents.get(agent_id) or AgentRecord(agent_id=agent_id)
|
||||
mode = self.normalize_service_mode(service_mode, agent_id)
|
||||
now = int(time.time())
|
||||
|
||||
record.service_mode = mode
|
||||
record.is_script_agent = self._is_script_agent(agent_id)
|
||||
record.last_seen = now
|
||||
record.status = "provisioned" if agent_id in self._configs else "orphaned"
|
||||
self._agents[agent_id] = record
|
||||
|
||||
self._persist_activity(
|
||||
hostname=record.hostname,
|
||||
last_seen=record.last_seen,
|
||||
agent_id=agent_id,
|
||||
operating_system=record.agent_operating_system,
|
||||
)
|
||||
return record
|
||||
|
||||
def heartbeat(self, payload: Mapping[str, Any]) -> Optional[AgentRecord]:
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
agent_id = payload.get("agent_id")
|
||||
if not agent_id:
|
||||
return None
|
||||
|
||||
hostname = payload.get("hostname") or ""
|
||||
mode = self.normalize_service_mode(payload.get("service_mode"), agent_id)
|
||||
is_script_agent = self._is_script_agent(agent_id)
|
||||
last_seen = self._coerce_int(payload.get("last_seen"), default=int(time.time()))
|
||||
operating_system = (payload.get("agent_operating_system") or "-").strip() or "-"
|
||||
|
||||
if hostname:
|
||||
self._reconcile_hostname_collisions(
|
||||
hostname=hostname,
|
||||
agent_id=agent_id,
|
||||
incoming_mode=mode,
|
||||
is_script_agent=is_script_agent,
|
||||
last_seen=last_seen,
|
||||
)
|
||||
|
||||
record = self._agents.get(agent_id) or AgentRecord(agent_id=agent_id)
|
||||
if hostname:
|
||||
record.hostname = hostname
|
||||
record.agent_operating_system = operating_system
|
||||
record.last_seen = last_seen
|
||||
record.service_mode = mode
|
||||
record.is_script_agent = is_script_agent
|
||||
record.status = "provisioned" if agent_id in self._configs else record.status or "orphaned"
|
||||
self._agents[agent_id] = record
|
||||
|
||||
self._persist_activity(
|
||||
hostname=record.hostname or hostname,
|
||||
last_seen=record.last_seen,
|
||||
agent_id=agent_id,
|
||||
operating_system=record.agent_operating_system,
|
||||
)
|
||||
return record
|
||||
|
||||
def collector_status(self, payload: Mapping[str, Any]) -> None:
|
||||
if not payload:
|
||||
return
|
||||
|
||||
agent_id = payload.get("agent_id")
|
||||
if not agent_id:
|
||||
return
|
||||
|
||||
hostname = payload.get("hostname") or ""
|
||||
mode = self.normalize_service_mode(payload.get("service_mode"), agent_id)
|
||||
active = bool(payload.get("active"))
|
||||
last_user = (payload.get("last_user") or "").strip()
|
||||
|
||||
record = self._agents.get(agent_id) or AgentRecord(agent_id=agent_id)
|
||||
if hostname:
|
||||
record.hostname = hostname
|
||||
if mode:
|
||||
record.service_mode = mode
|
||||
record.is_script_agent = self._is_script_agent(agent_id) or record.is_script_agent
|
||||
if active:
|
||||
record.collector_active_ts = time.time()
|
||||
self._agents[agent_id] = record
|
||||
|
||||
if (
|
||||
last_user
|
||||
and hostname
|
||||
and self._is_valid_interactive_user(last_user)
|
||||
and not self._is_system_service_agent(agent_id, record.is_script_agent)
|
||||
):
|
||||
self._persist_activity(
|
||||
hostname=hostname,
|
||||
last_seen=int(time.time()),
|
||||
agent_id=agent_id,
|
||||
operating_system=record.agent_operating_system,
|
||||
last_user=last_user,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Configuration management
|
||||
# ------------------------------------------------------------------
|
||||
def set_agent_config(self, agent_id: str, config: Mapping[str, Any]) -> None:
|
||||
self._configs[agent_id] = dict(config)
|
||||
record = self._agents.get(agent_id)
|
||||
if record:
|
||||
record.status = "provisioned"
|
||||
|
||||
def get_agent_config(self, agent_id: str) -> Optional[Dict[str, Any]]:
|
||||
config = self._configs.get(agent_id)
|
||||
if config is None:
|
||||
return None
|
||||
return dict(config)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Screenshot caches
|
||||
# ------------------------------------------------------------------
|
||||
def store_agent_screenshot(self, agent_id: str, image_base64: str) -> None:
|
||||
self._screenshots[agent_id] = {
|
||||
"image_base64": image_base64,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
def store_task_screenshot(self, agent_id: str, node_id: str, image_base64: str) -> None:
|
||||
self._task_screenshots[(agent_id, node_id)] = {
|
||||
"image_base64": image_base64,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def normalize_service_mode(value: Optional[str], agent_id: Optional[str] = None) -> str:
|
||||
text = ""
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
text = value.strip().lower()
|
||||
except Exception:
|
||||
text = ""
|
||||
|
||||
if not text and agent_id:
|
||||
try:
|
||||
lowered = agent_id.lower()
|
||||
if "-svc-" in lowered or lowered.endswith("-svc"):
|
||||
return "system"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if text in {"system", "svc", "service", "system_service"}:
|
||||
return "system"
|
||||
if text in {"interactive", "currentuser", "user", "current_user"}:
|
||||
return "currentuser"
|
||||
return "currentuser"
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, *, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _is_script_agent(agent_id: Optional[str]) -> bool:
|
||||
try:
|
||||
return bool(isinstance(agent_id, str) and agent_id.lower().endswith("-script"))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_interactive_user(candidate: Optional[str]) -> bool:
|
||||
if not candidate:
|
||||
return False
|
||||
try:
|
||||
text = str(candidate).strip()
|
||||
except Exception:
|
||||
return False
|
||||
if not text:
|
||||
return False
|
||||
upper = text.upper()
|
||||
if text.endswith("$"):
|
||||
return False
|
||||
if "NT AUTHORITY\\" in upper or "NT SERVICE\\" in upper:
|
||||
return False
|
||||
if upper.endswith("\\SYSTEM"):
|
||||
return False
|
||||
if upper.endswith("\\LOCAL SERVICE"):
|
||||
return False
|
||||
if upper.endswith("\\NETWORK SERVICE"):
|
||||
return False
|
||||
if upper == "ANONYMOUS LOGON":
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _is_system_service_agent(agent_id: str, is_script_agent: bool) -> bool:
|
||||
try:
|
||||
lowered = agent_id.lower()
|
||||
except Exception:
|
||||
lowered = ""
|
||||
if is_script_agent:
|
||||
return False
|
||||
return "-svc-" in lowered or lowered.endswith("-svc")
|
||||
|
||||
def _reconcile_hostname_collisions(
|
||||
self,
|
||||
*,
|
||||
hostname: str,
|
||||
agent_id: str,
|
||||
incoming_mode: str,
|
||||
is_script_agent: bool,
|
||||
last_seen: int,
|
||||
) -> None:
|
||||
transferred_config = False
|
||||
for existing_id, info in list(self._agents.items()):
|
||||
if existing_id == agent_id:
|
||||
continue
|
||||
if info.hostname != hostname:
|
||||
continue
|
||||
existing_mode = self.normalize_service_mode(info.service_mode, existing_id)
|
||||
if existing_mode != incoming_mode:
|
||||
continue
|
||||
if is_script_agent and not info.is_script_agent:
|
||||
self._persist_activity(
|
||||
hostname=hostname,
|
||||
last_seen=last_seen,
|
||||
agent_id=existing_id,
|
||||
operating_system=info.agent_operating_system,
|
||||
)
|
||||
return
|
||||
if not transferred_config and existing_id in self._configs and agent_id not in self._configs:
|
||||
self._configs[agent_id] = dict(self._configs[existing_id])
|
||||
transferred_config = True
|
||||
self._agents.pop(existing_id, None)
|
||||
if existing_id != agent_id:
|
||||
self._configs.pop(existing_id, None)
|
||||
|
||||
def _persist_activity(
|
||||
self,
|
||||
*,
|
||||
hostname: Optional[str],
|
||||
last_seen: Optional[int],
|
||||
agent_id: Optional[str],
|
||||
operating_system: Optional[str],
|
||||
last_user: Optional[str] = None,
|
||||
) -> None:
|
||||
if not hostname:
|
||||
return
|
||||
try:
|
||||
self._device_repository.update_device_summary(
|
||||
hostname=hostname,
|
||||
last_seen=last_seen,
|
||||
agent_id=agent_id,
|
||||
operating_system=operating_system,
|
||||
last_user=last_user,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.debug("failed to persist device activity for %s: %s", hostname, exc)
|
||||
1
Data/Engine/tests/__init__.py
Normal file
1
Data/Engine/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test suite for the Borealis Engine."""
|
||||
74
Data/Engine/tests/test_builders_device_auth.py
Normal file
74
Data/Engine/tests/test_builders_device_auth.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import unittest
|
||||
|
||||
from Data.Engine.builders.device_auth import (
|
||||
DeviceAuthRequestBuilder,
|
||||
RefreshTokenRequestBuilder,
|
||||
)
|
||||
from Data.Engine.domain.device_auth import DeviceAuthErrorCode, DeviceAuthFailure
|
||||
|
||||
|
||||
class DeviceAuthRequestBuilderTests(unittest.TestCase):
|
||||
def test_build_successful_request(self) -> None:
|
||||
request = (
|
||||
DeviceAuthRequestBuilder()
|
||||
.with_authorization("Bearer abc123")
|
||||
.with_http_method("post")
|
||||
.with_htu("https://example.test/api")
|
||||
.with_service_context("currentUser")
|
||||
.with_dpop_proof("proof")
|
||||
.build()
|
||||
)
|
||||
|
||||
self.assertEqual(request.access_token, "abc123")
|
||||
self.assertEqual(request.http_method, "POST")
|
||||
self.assertEqual(request.htu, "https://example.test/api")
|
||||
self.assertEqual(request.service_context, "CURRENTUSER")
|
||||
self.assertEqual(request.dpop_proof, "proof")
|
||||
|
||||
def test_missing_authorization_raises_failure(self) -> None:
|
||||
builder = (
|
||||
DeviceAuthRequestBuilder()
|
||||
.with_http_method("GET")
|
||||
.with_htu("/health")
|
||||
)
|
||||
|
||||
with self.assertRaises(DeviceAuthFailure) as ctx:
|
||||
builder.build()
|
||||
|
||||
self.assertEqual(ctx.exception.code, DeviceAuthErrorCode.MISSING_AUTHORIZATION)
|
||||
|
||||
|
||||
class RefreshTokenRequestBuilderTests(unittest.TestCase):
|
||||
def test_refresh_request_requires_all_fields(self) -> None:
|
||||
request = (
|
||||
RefreshTokenRequestBuilder()
|
||||
.with_payload({"guid": "de305d54-75b4-431b-adb2-eb6b9e546014", "refresh_token": "tok"})
|
||||
.with_http_method("post")
|
||||
.with_htu("https://example.test/api")
|
||||
.with_dpop_proof("proof")
|
||||
.build()
|
||||
)
|
||||
|
||||
self.assertEqual(request.guid.value, "DE305D54-75B4-431B-ADB2-EB6B9E546014")
|
||||
self.assertEqual(request.refresh_token, "tok")
|
||||
self.assertEqual(request.http_method, "POST")
|
||||
self.assertEqual(request.htu, "https://example.test/api")
|
||||
self.assertEqual(request.dpop_proof, "proof")
|
||||
|
||||
def test_refresh_request_missing_guid_raises_failure(self) -> None:
|
||||
builder = (
|
||||
RefreshTokenRequestBuilder()
|
||||
.with_payload({"refresh_token": "tok"})
|
||||
.with_http_method("POST")
|
||||
.with_htu("https://example.test/api")
|
||||
)
|
||||
|
||||
with self.assertRaises(DeviceAuthFailure) as ctx:
|
||||
builder.build()
|
||||
|
||||
self.assertEqual(ctx.exception.code, DeviceAuthErrorCode.INVALID_CLAIMS)
|
||||
self.assertIn("missing guid", ctx.exception.detail)
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover - convenience for local runs
|
||||
unittest.main()
|
||||
44
Data/Engine/tests/test_config_environment.py
Normal file
44
Data/Engine/tests/test_config_environment.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Tests for environment configuration helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from Data.Engine.config.environment import load_environment
|
||||
|
||||
|
||||
def test_static_root_prefers_engine_runtime(tmp_path, monkeypatch):
|
||||
"""Engine static root should prefer the staged web-interface build."""
|
||||
|
||||
engine_build = tmp_path / "Engine" / "web-interface" / "build"
|
||||
engine_build.mkdir(parents=True)
|
||||
(engine_build / "index.html").write_text("<html></html>", encoding="utf-8")
|
||||
|
||||
# Ensure other fallbacks exist but should not be selected while the Engine
|
||||
# runtime assets are present.
|
||||
legacy_build = tmp_path / "Data" / "Server" / "WebUI" / "build"
|
||||
legacy_build.mkdir(parents=True)
|
||||
(legacy_build / "index.html").write_text("legacy", encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("BOREALIS_ROOT", str(tmp_path))
|
||||
monkeypatch.delenv("BOREALIS_STATIC_ROOT", raising=False)
|
||||
|
||||
settings = load_environment()
|
||||
|
||||
assert settings.flask.static_root == engine_build.resolve()
|
||||
|
||||
|
||||
def test_static_root_env_override(tmp_path, monkeypatch):
|
||||
"""Explicit overrides should win over filesystem detection."""
|
||||
|
||||
override = tmp_path / "custom" / "build"
|
||||
override.mkdir(parents=True)
|
||||
(override / "index.html").write_text("override", encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("BOREALIS_ROOT", str(tmp_path))
|
||||
monkeypatch.setenv("BOREALIS_STATIC_ROOT", str(override))
|
||||
|
||||
settings = load_environment()
|
||||
|
||||
assert settings.flask.static_root == override.resolve()
|
||||
|
||||
monkeypatch.delenv("BOREALIS_STATIC_ROOT", raising=False)
|
||||
monkeypatch.delenv("BOREALIS_ROOT", raising=False)
|
||||
65
Data/Engine/tests/test_crypto_certificates.py
Normal file
65
Data/Engine/tests/test_crypto_certificates.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import shutil
|
||||
import ssl
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from Data.Engine import runtime
|
||||
|
||||
|
||||
class CertificateGenerationTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self._tmpdir = Path(tempfile.mkdtemp(prefix="engine-cert-tests-"))
|
||||
self.addCleanup(lambda: shutil.rmtree(self._tmpdir, ignore_errors=True))
|
||||
|
||||
self._previous_env: dict[str, str | None] = {}
|
||||
for name in ("BOREALIS_CERTIFICATES_ROOT", "BOREALIS_SERVER_CERT_ROOT"):
|
||||
self._previous_env[name] = os.environ.get(name)
|
||||
os.environ[name] = str(self._tmpdir / name.lower())
|
||||
|
||||
runtime.certificates_root.cache_clear()
|
||||
runtime.server_certificates_root.cache_clear()
|
||||
|
||||
module_name = "Data.Engine.services.crypto.certificates"
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
try:
|
||||
self.certificates = importlib.import_module(module_name)
|
||||
except ModuleNotFoundError as exc: # pragma: no cover - optional deps absent
|
||||
self.skipTest(f"cryptography dependency unavailable: {exc}")
|
||||
|
||||
def tearDown(self) -> None: # pragma: no cover - environment cleanup
|
||||
for name, value in self._previous_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(name, None)
|
||||
else:
|
||||
os.environ[name] = value
|
||||
runtime.certificates_root.cache_clear()
|
||||
runtime.server_certificates_root.cache_clear()
|
||||
|
||||
def test_ensure_certificate_creates_material(self) -> None:
|
||||
cert_path, key_path, bundle_path = self.certificates.ensure_certificate()
|
||||
|
||||
self.assertTrue(cert_path.exists(), "certificate was not generated")
|
||||
self.assertTrue(key_path.exists(), "private key was not generated")
|
||||
self.assertTrue(bundle_path.exists(), "bundle was not generated")
|
||||
|
||||
context = self.certificates.build_ssl_context()
|
||||
self.assertIsInstance(context, ssl.SSLContext)
|
||||
self.assertEqual(context.minimum_version, ssl.TLSVersion.TLSv1_3)
|
||||
|
||||
def test_certificate_paths_returns_strings(self) -> None:
|
||||
cert_path, key_path, bundle_path = self.certificates.certificate_paths()
|
||||
self.assertIsInstance(cert_path, str)
|
||||
self.assertIsInstance(key_path, str)
|
||||
self.assertIsInstance(bundle_path, str)
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover - convenience
|
||||
unittest.main()
|
||||
59
Data/Engine/tests/test_domain_device_auth.py
Normal file
59
Data/Engine/tests/test_domain_device_auth.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import unittest
|
||||
|
||||
from Data.Engine.domain.device_auth import (
|
||||
DeviceAuthErrorCode,
|
||||
DeviceAuthFailure,
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
sanitize_service_context,
|
||||
)
|
||||
|
||||
|
||||
class DeviceGuidTests(unittest.TestCase):
|
||||
def test_guid_normalization_accepts_braces_and_lowercase(self) -> None:
|
||||
guid = DeviceGuid("{de305d54-75b4-431b-adb2-eb6b9e546014}")
|
||||
self.assertEqual(guid.value, "DE305D54-75B4-431B-ADB2-EB6B9E546014")
|
||||
|
||||
def test_guid_rejects_empty_string(self) -> None:
|
||||
with self.assertRaises(ValueError):
|
||||
DeviceGuid("")
|
||||
|
||||
|
||||
class DeviceFingerprintTests(unittest.TestCase):
|
||||
def test_fingerprint_normalization_trims_and_lowercases(self) -> None:
|
||||
fingerprint = DeviceFingerprint(" AA:BB:CC ")
|
||||
self.assertEqual(fingerprint.value, "aa:bb:cc")
|
||||
|
||||
def test_fingerprint_rejects_blank_input(self) -> None:
|
||||
with self.assertRaises(ValueError):
|
||||
DeviceFingerprint(" ")
|
||||
|
||||
|
||||
class ServiceContextTests(unittest.TestCase):
|
||||
def test_sanitize_service_context_returns_uppercase_only(self) -> None:
|
||||
self.assertEqual(sanitize_service_context("system"), "SYSTEM")
|
||||
|
||||
def test_sanitize_service_context_filters_invalid_chars(self) -> None:
|
||||
self.assertEqual(sanitize_service_context("sys tem!"), "SYSTEM")
|
||||
|
||||
def test_sanitize_service_context_returns_none_for_empty_result(self) -> None:
|
||||
self.assertIsNone(sanitize_service_context("@@@"))
|
||||
|
||||
|
||||
class DeviceAuthFailureTests(unittest.TestCase):
|
||||
def test_to_dict_includes_retry_after_and_detail(self) -> None:
|
||||
failure = DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.RATE_LIMITED,
|
||||
http_status=429,
|
||||
retry_after=30,
|
||||
detail="too many attempts",
|
||||
)
|
||||
payload = failure.to_dict()
|
||||
self.assertEqual(
|
||||
payload,
|
||||
{"error": "rate_limited", "retry_after": 30.0, "detail": "too many attempts"},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover - convenience for local runs
|
||||
unittest.main()
|
||||
32
Data/Engine/tests/test_sqlite_migrations.py
Normal file
32
Data/Engine/tests/test_sqlite_migrations.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import sqlite3
|
||||
import unittest
|
||||
|
||||
from Data.Engine.repositories.sqlite import migrations
|
||||
|
||||
|
||||
class MigrationTests(unittest.TestCase):
|
||||
def test_apply_all_creates_expected_tables(self) -> None:
|
||||
conn = sqlite3.connect(":memory:")
|
||||
try:
|
||||
migrations.apply_all(conn)
|
||||
cursor = conn.cursor()
|
||||
tables = {
|
||||
row[0]
|
||||
for row in cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
)
|
||||
}
|
||||
|
||||
self.assertIn("devices", tables)
|
||||
self.assertIn("refresh_tokens", tables)
|
||||
self.assertIn("enrollment_install_codes", tables)
|
||||
self.assertIn("device_approvals", tables)
|
||||
self.assertIn("scheduled_jobs", tables)
|
||||
self.assertIn("scheduled_job_runs", tables)
|
||||
self.assertIn("github_token", tables)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover - convenience for local runs
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user