diff --git a/Borealis.ps1 b/Borealis.ps1 index ef3cb12..e911197 100644 --- a/Borealis.ps1 +++ b/Borealis.ps1 @@ -977,6 +977,8 @@ if (-not $choice) { Write-Host "[Experimental]" -ForegroundColor Red Write-Host " 4) Package Self-Contained EXE of Server or Agent " -NoNewline -ForegroundColor DarkGray Write-Host "[Experimental]" -ForegroundColor Red + Write-Host " 5) Borealis Engine " -NoNewline -ForegroundColor DarkGray + Write-Host "[In Development]" -ForegroundColor Cyan # (Removed) AutoHotKey experimental testing Write-Host "Type a number and press " -NoNewLine Write-Host "" -ForegroundColor DarkCyan @@ -1214,6 +1216,162 @@ switch ($choice) { } } + "5" { + $host.UI.RawUI.WindowTitle = "Borealis Engine" + Write-Host "Ensuring Engine Dependencies Exist..." -ForegroundColor DarkCyan + + Install_Shared_Dependencies + Install_Server_Dependencies + + foreach ($tool in @($pythonExe, $nodeExe, $npmCmd, $npxCmd)) { + if (-not (Test-Path $tool)) { + Write-Host "`r$($symbols.Fail) Bundled executable not found at '$tool'." -ForegroundColor Red + exit 1 + } + } + $env:PATH = '{0};{1};{2}' -f (Split-Path $pythonExe), (Split-Path $nodeExe), $env:PATH + + Write-Host " " + Write-Host "Configure Borealis Engine Mode:" -ForegroundColor DarkYellow + Write-Host " 1) Build & Launch > Production Flask Server @ http://localhost:5001" -ForegroundColor DarkCyan + Write-Host " 2) [Skip Build] & Immediately Launch > Production Flask Server @ http://localhost:5001" -ForegroundColor DarkCyan + Write-Host " 3) Launch > [Hotload-Ready] Vite Dev Server @ http://localhost:5173" -ForegroundColor DarkCyan + $engineModeChoice = Read-Host "Enter choice [1/2/3]" + + $engineOperationMode = "production" + $engineImmediateLaunch = $false + switch ($engineModeChoice) { + "1" { $engineOperationMode = "production" } + "2" { $engineImmediateLaunch = $true } + "3" { $engineOperationMode = "developer" } + default { + Write-Host "Invalid mode choice: $engineModeChoice" -ForegroundColor Red + break + } + } + + if ($engineModeChoice -notin @('1','2','3')) { + break + } + + if ($engineImmediateLaunch) { + Run-Step "Borealis Engine: Launch Flask Server" { + Push-Location (Join-Path $scriptDir "Engine") + $py = Join-Path $scriptDir "Engine\Scripts\python.exe" + Write-Host "`nLaunching Borealis Engine..." -ForegroundColor Green + Write-Host "====================================================================================" + Write-Host "$($symbols.Running) Engine Socket Server Started..." + & $py -m Data.Engine.bootstrapper + Pop-Location + } + break + } + + Write-Host "Deploying Borealis Engine in '$engineOperationMode' mode" -ForegroundColor Blue + + $venvFolder = "Engine" + $dataSource = "Data" + $engineSource = "$dataSource\Engine" + $engineDataDestination = "$venvFolder\Data\Engine" + $webUIFallbackSource = "$dataSource\Server\WebUI" + $webUIDestination = "$venvFolder\web-interface" + $venvPython = Join-Path $venvFolder 'Scripts\python.exe' + $engineSourceAbsolute = Join-Path $scriptDir $engineSource + $webUIFallbackAbsolute = Join-Path $scriptDir $webUIFallbackSource + + Run-Step "Create Borealis Engine Virtual Python Environment" { + $venvActivate = Join-Path $venvFolder 'Scripts\Activate' + if (-not (Test-Path $venvActivate)) { + & $pythonExe -m venv $venvFolder | Out-Null + } + + $engineDataRoot = Join-Path $venvFolder 'Data' + if (-not (Test-Path $engineDataRoot)) { + New-Item -Path $engineDataRoot -ItemType Directory -Force | Out-Null + } + + if (Test-Path (Join-Path $scriptDir $engineDataDestination)) { + Remove-Item (Join-Path $scriptDir $engineDataDestination) -Recurse -Force -ErrorAction SilentlyContinue + } + New-Item -Path (Join-Path $scriptDir $engineDataDestination) -ItemType Directory -Force | Out-Null + + if (-not (Test-Path $engineSourceAbsolute)) { + throw "Engine source directory '$engineSourceAbsolute' not found." + } + Copy-Item (Join-Path $engineSourceAbsolute '*') (Join-Path $scriptDir $engineDataDestination) -Recurse -Force + + . (Join-Path $venvFolder 'Scripts\Activate') + } + + Run-Step "Install Engine Python Dependencies into Virtual Python Environment" { + $engineRequirements = @( + (Join-Path $engineSourceAbsolute 'engine-requirements.txt'), + (Join-Path $engineSourceAbsolute 'requirements.txt') + ) + $requirementsPath = $engineRequirements | Where-Object { Test-Path $_ } | Select-Object -First 1 + if ($requirementsPath) { + & $venvPython -m pip install --disable-pip-version-check -q -r $requirementsPath | Out-Null + } + } + + Run-Step "Copy Borealis Engine WebUI Files into: $webUIDestination" { + $engineWebUISource = Join-Path $engineSourceAbsolute 'WebUI' + if (Test-Path $engineWebUISource) { + $webUIDestinationAbsolute = Join-Path $scriptDir $webUIDestination + if (Test-Path $webUIDestinationAbsolute) { + Remove-Item (Join-Path $webUIDestinationAbsolute '*') -Recurse -Force -ErrorAction SilentlyContinue + } else { + New-Item -Path $webUIDestinationAbsolute -ItemType Directory -Force | Out-Null + } + Copy-Item (Join-Path $engineWebUISource '*') $webUIDestinationAbsolute -Recurse -Force + } elseif (-not (Test-Path (Join-Path $scriptDir $webUIDestination)) -or -not (Get-ChildItem -Path (Join-Path $scriptDir $webUIDestination) -ErrorAction SilentlyContinue | Select-Object -First 1)) { + if (Test-Path $webUIFallbackAbsolute) { + $webUIDestinationAbsolute = Join-Path $scriptDir $webUIDestination + if (-not (Test-Path $webUIDestinationAbsolute)) { + New-Item -Path $webUIDestinationAbsolute -ItemType Directory -Force | Out-Null + } + Copy-Item (Join-Path $webUIFallbackAbsolute '*') $webUIDestinationAbsolute -Recurse -Force + } else { + Write-Host "Fallback WebUI source not found at '$webUIFallbackAbsolute'." -ForegroundColor Yellow + } + } else { + Write-Host "Existing Engine web interface detected; skipping fallback copy." -ForegroundColor DarkYellow + } + } + + Run-Step "Vite Web Frontend: Install NPM Packages" { + $webUIDestinationAbsolute = Join-Path $scriptDir $webUIDestination + if (Test-Path $webUIDestinationAbsolute) { + Push-Location $webUIDestinationAbsolute + $env:npm_config_loglevel = "silent" + & $npmCmd install --silent --no-fund --audit=false | Out-Null + Pop-Location + } else { + Write-Host "Web interface destination '$webUIDestinationAbsolute' not found." -ForegroundColor Yellow + } + } + + Run-Step "Vite Web Frontend: Start ($engineOperationMode)" { + $webUIDestinationAbsolute = Join-Path $scriptDir $webUIDestination + if (Test-Path $webUIDestinationAbsolute) { + Push-Location $webUIDestinationAbsolute + if ($engineOperationMode -eq "developer") { $viteSubCommand = "dev" } else { $viteSubCommand = "build" } + Start-Process -NoNewWindow -FilePath $npmCmd -ArgumentList @("run", $viteSubCommand) + Pop-Location + } + } + + Run-Step "Borealis Engine: Launch Flask Server" { + Push-Location (Join-Path $scriptDir "Engine") + $py = Join-Path $scriptDir "Engine\Scripts\python.exe" + Write-Host "`nLaunching Borealis Engine..." -ForegroundColor Green + Write-Host "====================================================================================" + Write-Host "$($symbols.Running) Engine Socket Server Started..." + & $py -m Data.Engine.bootstrapper + Pop-Location + } + } + # (Removed) case "6" experimental AHK test default { Write-Host "Invalid selection. Exiting..." -ForegroundColor Red; exit 1 } diff --git a/Data/Engine/CURRENT_STAGE.md b/Data/Engine/CURRENT_STAGE.md new file mode 100644 index 0000000..043d18b --- /dev/null +++ b/Data/Engine/CURRENT_STAGE.md @@ -0,0 +1,67 @@ +# Borealis Engine Migration Progress + +[COMPLETED] 1. Stabilize Engine foundation (baseline already in repo) +- 1.1 Confirm `Data/Engine/bootstrapper.py` launches the placeholder Flask app without side effects. +- 1.2 Document environment variables/settings expected by the Engine to keep parity with legacy defaults. +- 1.3 Verify Engine logging produces `Logs/Server/engine.log` entries alongside the legacy server. + +[COMPLETED] 2. Introduce configuration & dependency wiring +- 2.1 Create `config/environment.py` loaders mirroring legacy defaults (TLS paths, feature flags). +- 2.2 Add settings dataclasses for Flask, Socket.IO, and DB paths; inject them via `server.py`. +- 2.3 Commit once the Engine can start with equivalent config but no real routes. + +[COMPLETED] 3. Copy Flask application scaffolding +- 3.1 Port proxy/CORS/static setup from `Data/Server/server.py` into Engine `server.py` using dependency injection. +- 3.2 Stub out blueprint/Socket.IO registration hooks that mirror names from legacy code (no logic yet). +- 3.3 Smoke-test app startup via `python Data/Engine/bootstrapper.py` (or Flask CLI) to ensure no regressions. + +[COMPLETED] 4. Establish SQLite infrastructure +- 4.1 Copy `_db_conn` logic into `repositories/sqlite/connection.py`, parameterized by database path (`/database.db`). +- 4.2 Port migration helpers into `repositories/sqlite/migrations.py`; expose an `apply_all()` callable. +- 4.3 Wire migrations to run during Engine bootstrap (behind a flag) and confirm tables initialize in a sandbox DB. +- 4.4 Commit once DB connection + migrations succeed independently of legacy server. + +[COMPLETED] 5. Extract authentication/enrollment domain surface +- 5.1 Define immutable dataclasses in `domain/device_auth.py`, `domain/device_enrollment.py` for tokens, GUIDs, approvals. +- 5.2 Map legacy error codes/enums into domain exceptions or enums in the same modules. +- 5.3 Commit after unit tests (or doctests) validate dataclass invariants. + +[COMPLETED] 6. Port authentication services +- 6.1 Copy `DeviceAuthManager` logic into `services/auth/device_auth_service.py`, refactoring to use new repositories and domain types. +- 6.2 Create `builders/device_auth.py` to assemble `DeviceAuthContext` from headers/DPoP proof. +- 6.3 Mirror refresh token issuance into `services/auth/token_service.py`; use `builders/device_enrollment.py` for payload assembly. +- 6.4 Commit once services pass targeted unit tests and integrate with placeholder repositories. + +[COMPLETED] 7. Implement SQLite repositories +- 7.1 Introduce `repositories/sqlite/device_repository.py`, `token_repository.py`, `enrollment_repository.py` using copied SQL. +- 7.2 Write integration tests exercising CRUD against a temporary SQLite file. +- 7.3 Commit when repositories provide the required ports used by services. + +[COMPLETED] 8. Recreate HTTP interfaces +- 8.1 Port health/enrollment/token blueprints into `interfaces/http//routes.py`, calling Engine services only. +- 8.2 Ensure request validation occurs via builders; response schemas stay aligned with legacy JSON. +- 8.3 Register blueprints through Engine `server.py`; confirm endpoints respond via manual or automated tests. +- 8.4 Commit after each major blueprint migration for clear milestones. + +[COMPLETED] 9. Rebuild WebSocket interfaces +- 9.1 Establish feature-scoped modules (e.g., `interfaces/ws/agents/events.py`) and copy event handlers. +- 9.2 Replace global state with repository/service calls where feasible; otherwise encapsulate in Engine-managed caches. +- 9.3 Validate namespace registration with Socket.IO test clients before committing. + +[COMPLETED] 10. Scheduler & job management +- 10.1 Port scheduler core into `services/jobs/scheduler_service.py`; wrap job state persistence via new repositories. +- 10.2 Implement `builders/job_fabricator.py` for manifest assembly; ensure immutability and validation. +- 10.3 Expose HTTP orchestration via `interfaces/http/job_management.py` and WS notifications via dedicated modules. +- 10.4 Commit after scheduler can run a no-op job loop independently. + +[COMPLETED] 11. GitHub integration +- 11.1 Copy GitHub helper logic into `integrations/github/artifact_provider.py` with proper configuration injection. +- 11.2 Provide repository/service hooks for fetching artifacts or repo heads; add resilience logging. +- 11.3 Commit after integration tests (or mocked unit tests) confirm API workflows. + +[IN PROGRESS] 12. Final parity verification +- 12.1 Follow the staging playbook in `Data/Engine/STAGING_GUIDE.md` to stand up the Engine end-to-end and exercise enrollment, token refresh, agent connections, GitHub integration, and scheduler flows. +- 12.2 Record any divergences in the staging guide’s table and address them with follow-up commits before cut-over. +- 12.3 Once parity is confirmed, coordinate entrypoint switching (point deployment at `Data/Engine/bootstrapper.py`) and plan the legacy server deprecation. +- Supporting documentation and unit tests live in `Data/Engine/README.md`, `Data/Engine/STAGING_GUIDE.md`, and `Data/Engine/tests/` to guide the remaining staging work. +- Engine deployment now installs dependencies via `Data/Engine/requirements.txt` so parity runs include Flask, Socket.IO, and security packages. diff --git a/Data/Engine/README.md b/Data/Engine/README.md new file mode 100644 index 0000000..b83cb85 --- /dev/null +++ b/Data/Engine/README.md @@ -0,0 +1,204 @@ +# Borealis Engine Overview + +The Engine is an additive server stack that will ultimately replace the legacy Flask app under `Data/Server`. It is safe to run the Engine entrypoint (`Data/Engine/bootstrapper.py`) side-by-side with the legacy server while we migrate functionality feature-by-feature. + +## Architectural roles + +The Engine is organized around explicit dependency layers so each concern stays +testable and replaceable: + +- **Configuration (`Data/Engine/config/`)** parses environment variables into + immutable settings objects that the bootstrapper hands to factories and + integrations. +- **Builders (`Data/Engine/builders/`)** transform external inputs (HTTP + headers, JSON payloads, scheduled job definitions) into validated immutable + records that services can trust. +- **Domain models (`Data/Engine/domain/`)** house pure value objects, enums, and + error types with no I/O so services can express intent without depending on + Flask or SQLite. +- **Repositories (`Data/Engine/repositories/`)** encapsulate all SQLite access + and expose protocol methods that return domain models. They are injected into + services through the container so persistence can be swapped or mocked. +- **Services (`Data/Engine/services/`)** host business logic such as device + authentication, enrollment, job scheduling, GitHub artifact lookups, and + real-time agent coordination. Services depend only on repositories, + integrations, and builders. +- **Integrations (`Data/Engine/integrations/`)** wrap external systems (GitHub + today) and keep HTTP/token handling outside the services that consume them. +- **Interfaces (`Data/Engine/interfaces/`)** provide thin HTTP/Socket.IO + adapters that translate requests to builder/service calls and serialize + responses. They contain no business rules of their own. + +The runtime factory (`Data/Engine/runtime.py`) wires these layers together and +attaches the resulting container to the Flask app created in +`Data/Engine/server.py`. + +## Environment configuration + +The Engine mirrors the legacy defaults so it can boot without additional configuration. These environment variables are read by `Data/Engine/config/environment.py`: + +| Variable | Purpose | Default | +| --- | --- | --- | +| `BOREALIS_ROOT` | Overrides automatic project root detection. Useful when running from a packaged location. | Directory two levels above `Data/Engine/` | +| `BOREALIS_DATABASE_PATH` | Path to the SQLite database. | `/database.db` | +| `BOREALIS_ENGINE_AUTO_MIGRATE` | Run Engine-managed schema migrations during bootstrap (`true`/`false`). | `true` | +| `BOREALIS_STATIC_ROOT` | Directory that serves static assets for the SPA. | First existing path among `Engine/web-interface/build`, `Engine/web-interface/dist`, `Data/Engine/WebUI/build`, `Data/Server/web-interface/build`, `Data/Server/WebUI/build`, `Data/WebUI/build` | +| `BOREALIS_CORS_ALLOWED_ORIGINS` | Comma-delimited list of origins granted CORS access. Use `*` for all origins. | `*` | +| `BOREALIS_FLASK_SECRET_KEY` | Secret key for Flask session signing. | `change-me` | +| `BOREALIS_DEBUG` | Enables debug logging, disables secure-cookie requirements, and allows Werkzeug debug mode. | `false` | +| `BOREALIS_HOST` | Bind address for the HTTP/Socket.IO server. | `127.0.0.1` | +| `BOREALIS_PORT` | Bind port for the HTTP/Socket.IO server. | `5000` | +| `BOREALIS_REPO` | Default GitHub repository (`owner/name`) for artifact lookups. | `bunny-lab-io/Borealis` | +| `BOREALIS_REPO_BRANCH` | Default branch tracked by the Engine GitHub integration. | `main` | +| `BOREALIS_REPO_HASH_REFRESH` | Seconds between default repository head refresh attempts (clamped 30-3600). | `60` | +| `BOREALIS_CACHE_DIR` | Directory used to persist Engine cache files (GitHub repo head cache). | `/Data/Engine/cache` | +| `BOREALIS_CERTIFICATES_ROOT` | Overrides where TLS certificates (root CA + leaf) are stored. | `/Certificates` | +| `BOREALIS_SERVER_CERT_ROOT` | Directly points to the Engine server certificate directory if certificates are staged elsewhere. | `/Certificates/Server` | + +## TLS and transport stack + +`Data/Engine/services/crypto/certificates.py` mirrors the legacy certificate +generator so the Engine always serves HTTPS with a self-managed root CA and +leaf certificate. During bootstrap the Engine: + +1. Runs the certificate helper to ensure the root CA, server key, and bundle + exist under `Certificates/Server/` (or the configured override path). +2. Exposes the resulting bundle via `BOREALIS_TLS_BUNDLE` so enrollment flows + can deliver the pinned certificate to agents. +3. Launches Socket.IO/Eventlet with the generated cert/key pair. A fallback to + Werkzeug’s TLS support keeps HTTPS available even if Socket.IO is disabled. + +`Data/Engine/interfaces/eventlet_compat.py` applies the same Eventlet monkey +patch as the legacy server so TLS handshakes presented to the HTTP listener are +handled quietly instead of surfacing `400 Bad Request` noise when non-TLS +clients connect. + +## Logging expectations + +`Data/Engine/config/logging.py` configures a timed rotating file handler that writes to `Logs/Server/engine.log`. Each entry follows the `-engine-` format required by the project logging policy. The handler is attached to both the Engine logger (`borealis.engine`) and the root logger so that third-party frameworks share the same log destination. + +## Bootstrapping flow + +1. `Data/Engine/bootstrapper.py` loads the environment, configures logging, prepares the SQLite connection factory, optionally applies schema migrations, and builds the Flask application via `Data/Engine/server.py`. +2. A service container is assembled (`Data/Engine/services/container.py`) that wires repositories, JWT/DPoP helpers, and Engine services (device auth, token refresh, enrollment). The container is stored on the Flask app for interface modules to consume. +3. HTTP and Socket.IO interfaces register against the new service container. The resulting runtime object exposes the Flask app, resolved settings, optional Socket.IO server, and the configured database connection factory. `bootstrapper.main()` runs the appropriate server based on whether Socket.IO is present. + +As migration continues, services, repositories, interfaces, and integrations will live under their respective subpackages while maintaining isolation from the legacy server. + +## Python dependencies + +`Data/Engine/requirements.txt` mirrors the minimal runtime stack (Flask, Flask-SocketIO, CORS, requests, PyJWT, and cryptography) needed by the Engine entrypoint. The PowerShell launcher consumes this file when preparing the `Engine/` virtual environment so parity tests always run against an environment with the expected web and security packages preinstalled. + +## HTTP interfaces + +The Engine now exposes working HTTP routes alongside the remaining scaffolding: + +- `Data/Engine/interfaces/http/health.py` implements `GET /health` for liveness probes. +- `Data/Engine/interfaces/http/tokens.py` ports the refresh-token endpoint (`POST /api/agent/token/refresh`) using the Engine `TokenService` and request builders. +- `Data/Engine/interfaces/http/enrollment.py` handles the enrollment handshake (`/api/agent/enroll/request` and `/api/agent/enroll/poll`) with rate limiting, nonce protection, and repository-backed approvals. +- The admin and agent blueprints remain placeholders until their services migrate. + +## WebSocket interfaces + +Step 9 introduces real-time handlers backed by the new service container: + +- `Data/Engine/services/realtime/agent_registry.py` manages connected-agent state, last-seen persistence, collector updates, and screenshot caches without sharing globals with the legacy server. +- `Data/Engine/interfaces/ws/agents/events.py` ports the agent namespace, handling connect/disconnect logging, heartbeat reconciliation, screenshot relays, macro status broadcasts, and provisioning lookups through the realtime service. +- `Data/Engine/interfaces/ws/job_management/events.py` now forwards scheduler updates and responds to job status requests, keeping WebSocket clients informed as new runs are simulated. + +The WebSocket factory (`Data/Engine/interfaces/ws/__init__.py`) now accepts the Engine service container so namespaces can resolve dependencies just like their HTTP counterparts. + +## Authentication services + +Step 6 introduces the first real Engine services: + +- `Data/Engine/builders/device_auth.py` normalizes headers for access-token authentication and token refresh payloads. +- `Data/Engine/builders/device_enrollment.py` prepares enrollment payloads and nonce proof challenges for future migration steps. +- `Data/Engine/services/auth/device_auth_service.py` ports the legacy `DeviceAuthManager` into a repository-driven service that emits `DeviceAuthContext` instances from the new domain layer. +- `Data/Engine/services/auth/token_service.py` issues refreshed access tokens while enforcing DPoP bindings and repository lookups. + +Interfaces now consume these services via the shared container, keeping business logic inside the Engine service layer while HTTP modules remain thin request/response translators. + +## SQLite repositories + +Step 7 ports the first persistence adapters into the Engine: + +- `Data/Engine/repositories/sqlite/device_repository.py` exposes `SQLiteDeviceRepository`, mirroring the legacy device lookups and automatic record recovery used during authentication. +- `Data/Engine/repositories/sqlite/token_repository.py` provides `SQLiteRefreshTokenRepository` for refresh-token validation, DPoP binding management, and usage timestamps. +- `Data/Engine/repositories/sqlite/enrollment_repository.py` surfaces enrollment install-code counters and device approval records so future services can operate without touching raw SQL. + +Each repository accepts the shared `SQLiteConnectionFactory`, keeping all SQL execution confined to the Engine layer while services depend only on protocol interfaces. + +## Job scheduling services + +Step 10 migrates the foundational job scheduler into the Engine: + +- `Data/Engine/builders/job_fabricator.py` transforms stored job definitions into immutable manifests, decoding scripts, resolving environment variables, and preparing execution metadata. +- `Data/Engine/repositories/sqlite/job_repository.py` encapsulates scheduled job persistence, run history, and status tracking in SQLite. +- `Data/Engine/services/jobs/scheduler_service.py` runs the background evaluation loop, emits Socket.IO lifecycle events, and exposes CRUD helpers for the HTTP and WebSocket interfaces. +- `Data/Engine/interfaces/http/job_management.py` mirrors the legacy REST surface for creating, updating, toggling, and inspecting scheduled jobs and their run history. + +The scheduler service starts automatically from `Data/Engine/bootstrapper.py` once the Engine runtime builds the service container, ensuring a no-op scheduling loop executes independently of the legacy server. + +## GitHub integration + +Step 11 migrates the GitHub artifact provider into the Engine: + +- `Data/Engine/integrations/github/artifact_provider.py` caches branch head lookups, verifies API tokens, and optionally refreshes the default repository in the background. +- `Data/Engine/repositories/sqlite/github_repository.py` persists the GitHub API token so HTTP handlers do not speak to SQLite directly. +- `Data/Engine/services/github/github_service.py` coordinates token caching, verification, and repo head lookups for both HTTP and background refresh flows. +- `Data/Engine/interfaces/http/github.py` exposes `/api/repo/current_hash` and `/api/github/token` through the Engine stack while keeping business logic in the service layer. + +The service container now wires `github_service`, giving other interfaces and background jobs a clean entry point for GitHub functionality. + +## Final parity checklist + +Step 12 tracks the final integration work required before switching over to the +Engine entrypoint. Use the detailed playbook in +[`Data/Engine/STAGING_GUIDE.md`](./STAGING_GUIDE.md) to coordinate each +staging run: + +1. Stand up the Engine in a staging environment and exercise enrollment, token + refresh, scheduler operations, and the agent real-time channel side-by-side + with the legacy server. +2. Capture any behavioural differences uncovered during staging using the + divergence table in the staging guide and file them for follow-up fixes + before the cut-over. +3. When satisfied with parity, coordinate the entrypoint swap (point production + tooling at `Data/Engine/bootstrapper.py`) and plan the deprecation of + `Data/Server`. + +## Performing unit tests + +Targeted unit tests cover the most important domain, builder, repository, and +migration behaviours without requiring Flask or external services. Run them +with the standard library test runner: + +```bash +python -m unittest discover Data/Engine/tests +``` + +The suite currently validates: + +- Domain normalization helpers for GUIDs, fingerprints, and authentication + failures. +- Device authentication and refresh-token builders, including error handling for + malformed requests. +- SQLite schema migrations to ensure the Engine can provision required tables in + a fresh database. +- TLS certificate provisioning helpers to guarantee HTTPS material exists before + the Engine starts serving requests. + +Successful execution prints a summary similar to: + +``` +............. +---------------------------------------------------------------------- +Ran 13 tests in .s + +OK +``` + +Additional tests should follow the same pattern and live under +`Data/Engine/tests/` so this command remains the single entry point for Engine +unit verification. diff --git a/Data/Engine/STAGING_GUIDE.md b/Data/Engine/STAGING_GUIDE.md new file mode 100644 index 0000000..60ae3b5 --- /dev/null +++ b/Data/Engine/STAGING_GUIDE.md @@ -0,0 +1,116 @@ +# Engine Staging & Parity Guide + +This guide supports Step 12 of the migration plan by walking operators through +standing up the Engine alongside the legacy server, validating core workflows, +and documenting any behavioural gaps before switching the production entrypoint +to `Data/Engine/bootstrapper.py`. + +## 1. Prerequisites + +- Python 3.11 or later available on the host. +- A clone of the Borealis repository with the Engine tree checked out. +- Access to the legacy runtime assets (certificates, TLS bundle, etc.). +- Optional: a staging agent install for end-to-end WebSocket validation. + +Ensure the SQLite database lives at `/database.db` and that the +Engine migrations have already run (they execute automatically when the +`BOREALIS_ENGINE_AUTO_MIGRATE` environment variable is left at its default +`true`). + +## 2. Launching the Engine in staging mode + +1. Open a terminal at the project root. +2. Set any environment overrides required for the test scenario (for example, + `BOREALIS_DEBUG=true` to surface verbose logging, or + `BOREALIS_CORS_ALLOWED_ORIGINS=https://localhost:3000` when pairing with the + React UI). +3. Run the Engine entrypoint: + + ```bash + python Data/Engine/bootstrapper.py + ``` + +4. Verify `Logs/Server/engine.log` is created and that the startup entries are + timestamped `-engine-`. + +Keep the legacy server running in a separate process if comparative testing is +required; they do not share global state. + +## 3. Feature validation checklist + +Work through the following areas and tick each box once verified. Capture any +issues in the log table in §4. + +### Authentication and tokens + +- [ ] `POST /api/agent/token/refresh` returns a new access token when supplied a + valid refresh token + DPoP proof. +- [ ] Invalid DPoP proofs or revoked refresh tokens yield the expected HTTP 401 + responses and structured error payloads. +- [ ] Device last-seen metadata updates inside the database after a successful + refresh. + +### Enrollment + +- [ ] `POST /api/agent/enroll/request` produces an enrollment ticket with the + correct expiration and retry counters. +- [ ] `POST /api/agent/enroll/poll` transitions an approved device into an + authenticated state and returns the TLS bundle. +- [ ] Audit logging for approvals lands in `Logs/Server/engine.log`. + +### Job management + +- [ ] `POST /api/jobs` (or UI equivalent) creates a scheduled job and returns a + manifest identifier. +- [ ] `GET /api/jobs/` surfaces the stored manifest with normalized + schedules and environment variables. +- [ ] Job lifecycle events arrive over the `job_management` Socket.IO namespace + when a job transitions between `pending`, `running`, and `completed`. + +### Real-time agents + +- [ ] Agents connecting to the `agents` namespace appear in the realtime roster + with accurate hostname, username, and fingerprint details. +- [ ] Screenshot broadcasts relay from agents to the UI without residual cache + bleed-through after disconnects. +- [ ] Macro execution responses round-trip through Socket.IO and reach the + initiating client. + +### GitHub integration + +- [ ] `GET /api/repo/current_hash` reflects the latest branch head and caches + repeated calls. +- [ ] `POST /api/github/token` persists a new token and survives Engine restarts + (confirm via database inspection). +- [ ] The background refresher logs rate-limit warnings instead of raising + uncaught exceptions when the GitHub API throttles requests. + +## 4. Recording divergences + +Use the table below to document behavioural differences or bugs uncovered during +staging. This artifact should accompany the staging run summary so follow-up +fixes can be triaged quickly. + +| Area | Legacy Behaviour | Engine Behaviour | Notes / Links | +| --- | --- | --- | --- | +| Authentication | | | | +| Enrollment | | | | +| Scheduler | | | | +| Realtime | | | | +| GitHub | | | | +| Other | | | | + +## 5. Cut-over readiness + +Once every checklist item passes and no critical divergences remain: + +1. Update `Data/Engine/CURRENT_STAGE.md` with the completion date for Step 12. +2. Coordinate with the operator to switch deployment scripts to + `Data/Engine/bootstrapper.py`. +3. Plan a rollback strategy (typically re-launching the legacy server) should + issues appear immediately after the cut-over. +4. Archive the filled divergence table alongside Engine logs for historical + traceability. + +Document the results in project tracking tools before moving on to deprecating +`Data/Server`. diff --git a/Data/Engine/__init__.py b/Data/Engine/__init__.py new file mode 100644 index 0000000..afc216c --- /dev/null +++ b/Data/Engine/__init__.py @@ -0,0 +1,11 @@ +"""Borealis Engine package. + +This namespace contains the next-generation server implementation. +""" + +from __future__ import annotations + +__all__ = [ + "bootstrapper", + "server", +] diff --git a/Data/Engine/bootstrapper.py b/Data/Engine/bootstrapper.py new file mode 100644 index 0000000..e16b272 --- /dev/null +++ b/Data/Engine/bootstrapper.py @@ -0,0 +1,113 @@ +"""Entrypoint for the Borealis Engine server.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from flask import Flask + +from .config import EngineSettings, configure_logging, load_environment +from .interfaces import ( + create_socket_server, + register_http_interfaces, + register_ws_interfaces, +) +from .interfaces.eventlet_compat import apply_eventlet_patches +from .repositories.sqlite import connection as sqlite_connection +from .repositories.sqlite import migrations as sqlite_migrations +from .server import create_app +from .services.container import build_service_container +from .services.crypto.certificates import ensure_certificate + + +apply_eventlet_patches() + + +@dataclass(frozen=True, slots=True) +class EngineRuntime: + """Aggregated runtime context produced by :func:`bootstrap`.""" + + app: Flask + settings: EngineSettings + socketio: Optional[object] + db_factory: sqlite_connection.SQLiteConnectionFactory + tls_certificate: Path + tls_key: Path + tls_bundle: Path + + +def bootstrap() -> EngineRuntime: + """Construct the Flask application and supporting infrastructure.""" + + settings = load_environment() + logger = configure_logging(settings) + logger.info("bootstrap-started") + + cert_path, key_path, bundle_path = ensure_certificate() + os.environ.setdefault("BOREALIS_TLS_BUNDLE", str(bundle_path)) + logger.info( + "tls-material-ready", + extra={ + "cert_path": str(cert_path), + "key_path": str(key_path), + "bundle_path": str(bundle_path), + }, + ) + + db_factory = sqlite_connection.connection_factory(settings.database_path) + if settings.apply_migrations: + logger.info("migrations-start") + with sqlite_connection.connection_scope(settings.database_path) as conn: + sqlite_migrations.apply_all(conn) + logger.info("migrations-complete") + else: + logger.info("migrations-skipped") + + app = create_app(settings, db_factory=db_factory) + services = build_service_container(settings, db_factory=db_factory, logger=logger.getChild("services")) + app.extensions["engine_services"] = services + register_http_interfaces(app, services) + socketio = create_socket_server(app, settings.socketio) + register_ws_interfaces(socketio, services) + services.scheduler_service.start(socketio) + logger.info("bootstrap-complete") + return EngineRuntime( + app=app, + settings=settings, + socketio=socketio, + db_factory=db_factory, + tls_certificate=cert_path, + tls_key=key_path, + tls_bundle=bundle_path, + ) + + +def main() -> None: + runtime = bootstrap() + socketio = runtime.socketio + certfile = str(runtime.tls_bundle) + keyfile = str(runtime.tls_key) + + if socketio is not None: + socketio.run( # type: ignore[call-arg] + runtime.app, + host=runtime.settings.server.host, + port=runtime.settings.server.port, + debug=runtime.settings.debug, + certfile=certfile, + keyfile=keyfile, + ) + else: + runtime.app.run( + host=runtime.settings.server.host, + port=runtime.settings.server.port, + debug=runtime.settings.debug, + ssl_context=(certfile, keyfile), + ) + + +if __name__ == "__main__": # pragma: no cover - manual execution + main() diff --git a/Data/Engine/builders/__init__.py b/Data/Engine/builders/__init__.py new file mode 100644 index 0000000..0f9b02a --- /dev/null +++ b/Data/Engine/builders/__init__.py @@ -0,0 +1,35 @@ +"""Builder utilities for constructing immutable Engine aggregates.""" + +from __future__ import annotations + +from .device_auth import ( + DeviceAuthRequest, + DeviceAuthRequestBuilder, + RefreshTokenRequest, + RefreshTokenRequestBuilder, +) + +__all__ = [ + "DeviceAuthRequest", + "DeviceAuthRequestBuilder", + "RefreshTokenRequest", + "RefreshTokenRequestBuilder", +] + +try: # pragma: no cover - optional dependency shim + from .device_enrollment import ( + EnrollmentRequestBuilder, + ProofChallengeBuilder, + ) +except ModuleNotFoundError as exc: # pragma: no cover - executed when crypto deps missing + _missing_reason = str(exc) + + def _missing_builder(*_args: object, **_kwargs: object) -> None: + raise ModuleNotFoundError( + "device enrollment builders require optional cryptography dependencies" + ) from exc + + EnrollmentRequestBuilder = _missing_builder # type: ignore[assignment] + ProofChallengeBuilder = _missing_builder # type: ignore[assignment] +else: + __all__ += ["EnrollmentRequestBuilder", "ProofChallengeBuilder"] diff --git a/Data/Engine/builders/device_auth.py b/Data/Engine/builders/device_auth.py new file mode 100644 index 0000000..abc09e8 --- /dev/null +++ b/Data/Engine/builders/device_auth.py @@ -0,0 +1,165 @@ +"""Builders for device authentication and token refresh inputs.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from Data.Engine.domain.device_auth import ( + DeviceAuthErrorCode, + DeviceAuthFailure, + DeviceGuid, + sanitize_service_context, +) + +__all__ = [ + "DeviceAuthRequest", + "DeviceAuthRequestBuilder", + "RefreshTokenRequest", + "RefreshTokenRequestBuilder", +] + + +@dataclass(frozen=True, slots=True) +class DeviceAuthRequest: + """Normalized authentication inputs derived from an HTTP request.""" + + access_token: str + http_method: str + htu: str + service_context: Optional[str] + dpop_proof: Optional[str] + + +class DeviceAuthRequestBuilder: + """Validate and normalize HTTP headers for device authentication.""" + + _authorization: Optional[str] + _http_method: Optional[str] + _htu: Optional[str] + _service_context: Optional[str] + _dpop_proof: Optional[str] + + def __init__(self) -> None: + self._authorization = None + self._http_method = None + self._htu = None + self._service_context = None + self._dpop_proof = None + + def with_authorization(self, header_value: Optional[str]) -> "DeviceAuthRequestBuilder": + if header_value is None: + self._authorization = None + else: + self._authorization = header_value.strip() + return self + + def with_http_method(self, method: Optional[str]) -> "DeviceAuthRequestBuilder": + self._http_method = (method or "").strip().upper() + return self + + def with_htu(self, url: Optional[str]) -> "DeviceAuthRequestBuilder": + self._htu = (url or "").strip() + return self + + def with_service_context(self, header_value: Optional[str]) -> "DeviceAuthRequestBuilder": + self._service_context = sanitize_service_context(header_value) + return self + + def with_dpop_proof(self, proof: Optional[str]) -> "DeviceAuthRequestBuilder": + self._dpop_proof = (proof or "").strip() or None + return self + + def build(self) -> DeviceAuthRequest: + token = self._parse_authorization(self._authorization) + method = (self._http_method or "").strip().upper() + if not method: + raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN, detail="missing HTTP method") + url = (self._htu or "").strip() + if not url: + raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN, detail="missing request URL") + return DeviceAuthRequest( + access_token=token, + http_method=method, + htu=url, + service_context=self._service_context, + dpop_proof=self._dpop_proof, + ) + + @staticmethod + def _parse_authorization(header_value: Optional[str]) -> str: + header = (header_value or "").strip() + if not header: + raise DeviceAuthFailure(DeviceAuthErrorCode.MISSING_AUTHORIZATION) + prefix = "Bearer " + if not header.startswith(prefix): + raise DeviceAuthFailure(DeviceAuthErrorCode.MISSING_AUTHORIZATION) + token = header[len(prefix) :].strip() + if not token: + raise DeviceAuthFailure(DeviceAuthErrorCode.MISSING_AUTHORIZATION) + return token + + +@dataclass(frozen=True, slots=True) +class RefreshTokenRequest: + """Validated refresh token payload supplied by an agent.""" + + guid: DeviceGuid + refresh_token: str + http_method: str + htu: str + dpop_proof: Optional[str] + + +class RefreshTokenRequestBuilder: + """Helper to normalize refresh token JSON payloads.""" + + _guid: Optional[str] + _refresh_token: Optional[str] + _http_method: Optional[str] + _htu: Optional[str] + _dpop_proof: Optional[str] + + def __init__(self) -> None: + self._guid = None + self._refresh_token = None + self._http_method = None + self._htu = None + self._dpop_proof = None + + def with_payload(self, payload: Optional[dict[str, object]]) -> "RefreshTokenRequestBuilder": + payload = payload or {} + self._guid = str(payload.get("guid") or "").strip() + self._refresh_token = str(payload.get("refresh_token") or "").strip() + return self + + def with_http_method(self, method: Optional[str]) -> "RefreshTokenRequestBuilder": + self._http_method = (method or "").strip().upper() + return self + + def with_htu(self, url: Optional[str]) -> "RefreshTokenRequestBuilder": + self._htu = (url or "").strip() + return self + + def with_dpop_proof(self, proof: Optional[str]) -> "RefreshTokenRequestBuilder": + self._dpop_proof = (proof or "").strip() or None + return self + + def build(self) -> RefreshTokenRequest: + if not self._guid: + raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_CLAIMS, detail="missing guid") + if not self._refresh_token: + raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_CLAIMS, detail="missing refresh token") + method = (self._http_method or "").strip().upper() + if not method: + raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN, detail="missing HTTP method") + url = (self._htu or "").strip() + if not url: + raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN, detail="missing request URL") + return RefreshTokenRequest( + guid=DeviceGuid(self._guid), + refresh_token=self._refresh_token, + http_method=method, + htu=url, + dpop_proof=self._dpop_proof, + ) diff --git a/Data/Engine/builders/device_enrollment.py b/Data/Engine/builders/device_enrollment.py new file mode 100644 index 0000000..92f1217 --- /dev/null +++ b/Data/Engine/builders/device_enrollment.py @@ -0,0 +1,131 @@ +"""Builder utilities for device enrollment payloads.""" + +from __future__ import annotations + +import base64 +from dataclasses import dataclass +from typing import Optional + +from Data.Engine.domain.device_auth import DeviceFingerprint, sanitize_service_context +from Data.Engine.domain.device_enrollment import ( + EnrollmentValidationError, + ProofChallenge, +) +from Data.Engine.integrations.crypto import keys as crypto_keys + +__all__ = [ + "EnrollmentRequestBuilder", + "EnrollmentRequestInput", + "ProofChallengeBuilder", +] + + +@dataclass(frozen=True, slots=True) +class EnrollmentRequestInput: + """Structured enrollment request payload ready for the service layer.""" + + hostname: str + enrollment_code: str + fingerprint: DeviceFingerprint + client_nonce: bytes + client_nonce_b64: str + agent_public_key_der: bytes + service_context: Optional[str] + + +class EnrollmentRequestBuilder: + """Normalize agent enrollment JSON payloads into domain objects.""" + + def __init__(self) -> None: + self._hostname: Optional[str] = None + self._enrollment_code: Optional[str] = None + self._agent_pubkey_b64: Optional[str] = None + self._client_nonce_b64: Optional[str] = None + self._service_context: Optional[str] = None + + def with_payload(self, payload: Optional[dict[str, object]]) -> "EnrollmentRequestBuilder": + payload = payload or {} + self._hostname = str(payload.get("hostname") or "").strip() + self._enrollment_code = str(payload.get("enrollment_code") or "").strip() + agent_pubkey = payload.get("agent_pubkey") + self._agent_pubkey_b64 = agent_pubkey if isinstance(agent_pubkey, str) else None + client_nonce = payload.get("client_nonce") + self._client_nonce_b64 = client_nonce if isinstance(client_nonce, str) else None + return self + + def with_service_context(self, value: Optional[str]) -> "EnrollmentRequestBuilder": + self._service_context = value + return self + + def build(self) -> EnrollmentRequestInput: + if not self._hostname: + raise EnrollmentValidationError("hostname_required") + if not self._enrollment_code: + raise EnrollmentValidationError("enrollment_code_required") + if not self._agent_pubkey_b64: + raise EnrollmentValidationError("agent_pubkey_required") + if not self._client_nonce_b64: + raise EnrollmentValidationError("client_nonce_required") + + try: + agent_pubkey_der = crypto_keys.spki_der_from_base64(self._agent_pubkey_b64) + except Exception as exc: # pragma: no cover - invalid input path + raise EnrollmentValidationError("invalid_agent_pubkey") from exc + + if len(agent_pubkey_der) < 10: + raise EnrollmentValidationError("invalid_agent_pubkey") + + try: + client_nonce_bytes = base64.b64decode(self._client_nonce_b64, validate=True) + except Exception as exc: # pragma: no cover - invalid input path + raise EnrollmentValidationError("invalid_client_nonce") from exc + + if len(client_nonce_bytes) < 16: + raise EnrollmentValidationError("invalid_client_nonce") + + fingerprint_value = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der) + + return EnrollmentRequestInput( + hostname=self._hostname, + enrollment_code=self._enrollment_code, + fingerprint=DeviceFingerprint(fingerprint_value), + client_nonce=client_nonce_bytes, + client_nonce_b64=self._client_nonce_b64, + agent_public_key_der=agent_pubkey_der, + service_context=sanitize_service_context(self._service_context), + ) + + +class ProofChallengeBuilder: + """Construct proof challenges during enrollment approval.""" + + def __init__(self) -> None: + self._server_nonce: Optional[bytes] = None + self._client_nonce: Optional[bytes] = None + self._fingerprint: Optional[DeviceFingerprint] = None + + def with_server_nonce(self, nonce: Optional[bytes]) -> "ProofChallengeBuilder": + self._server_nonce = bytes(nonce or b"") + return self + + def with_client_nonce(self, nonce: Optional[bytes]) -> "ProofChallengeBuilder": + self._client_nonce = bytes(nonce or b"") + return self + + def with_fingerprint(self, fingerprint: Optional[str]) -> "ProofChallengeBuilder": + if fingerprint: + self._fingerprint = DeviceFingerprint(fingerprint) + else: + self._fingerprint = None + return self + + def build(self) -> ProofChallenge: + if self._server_nonce is None or self._client_nonce is None: + raise ValueError("both server and client nonces are required") + if not self._fingerprint: + raise ValueError("fingerprint is required") + return ProofChallenge( + client_nonce=self._client_nonce, + server_nonce=self._server_nonce, + fingerprint=self._fingerprint, + ) diff --git a/Data/Engine/builders/job_fabricator.py b/Data/Engine/builders/job_fabricator.py new file mode 100644 index 0000000..9ca8eef --- /dev/null +++ b/Data/Engine/builders/job_fabricator.py @@ -0,0 +1,382 @@ +"""Builders for Engine job manifests.""" + +from __future__ import annotations + +import base64 +import json +import logging +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple + +from Data.Engine.repositories.sqlite.job_repository import ScheduledJobRecord + +__all__ = [ + "JobComponentManifest", + "JobManifest", + "JobFabricator", +] + + +_ENV_VAR_PATTERN = re.compile(r"(?i)\$env:(\{)?([A-Za-z0-9_\-]+)(?(1)\})") + + +@dataclass(frozen=True, slots=True) +class JobComponentManifest: + """Materialized job component ready for execution.""" + + name: str + path: str + script_type: str + script_content: str + encoded_content: str + environment: Dict[str, str] + literal_environment: Dict[str, str] + timeout_seconds: int + + +@dataclass(frozen=True, slots=True) +class JobManifest: + job_id: int + name: str + occurrence_ts: int + execution_context: str + targets: Tuple[str, ...] + components: Tuple[JobComponentManifest, ...] + + +class JobFabricator: + """Convert stored job records into immutable manifests.""" + + def __init__( + self, + *, + assemblies_root: Path, + logger: Optional[logging.Logger] = None, + ) -> None: + self._assemblies_root = assemblies_root + self._log = logger or logging.getLogger("borealis.engine.builders.jobs") + + def build( + self, + job: ScheduledJobRecord, + *, + occurrence_ts: int, + ) -> JobManifest: + components = tuple(self._materialize_component(job, component) for component in job.components) + targets = tuple(str(t) for t in job.targets) + return JobManifest( + job_id=job.id, + name=job.name, + occurrence_ts=occurrence_ts, + execution_context=job.execution_context, + targets=targets, + components=tuple(c for c in components if c is not None), + ) + + # ------------------------------------------------------------------ + # Component handling + # ------------------------------------------------------------------ + def _materialize_component( + self, + job: ScheduledJobRecord, + component: Mapping[str, Any], + ) -> Optional[JobComponentManifest]: + if not isinstance(component, Mapping): + return None + + component_type = str(component.get("type") or "").strip().lower() + if component_type not in {"script", "ansible"}: + return None + + path = str(component.get("path") or component.get("script_path") or "").strip() + if not path: + return None + + try: + abs_path = self._resolve_script_path(path) + except FileNotFoundError: + self._log.warning( + "job component path invalid", extra={"job_id": job.id, "path": path} + ) + return None + script_type = self._detect_script_type(abs_path, component_type) + script_content = self._load_script_content(abs_path, component) + + doc_variables: List[Dict[str, Any]] = [] + if isinstance(component.get("variables"), list): + doc_variables = [v for v in component["variables"] if isinstance(v, dict)] + overrides = self._collect_overrides(component) + env_map, _, literal_lookup = _prepare_variable_context(doc_variables, overrides) + + rewritten = _rewrite_powershell_script(script_content, literal_lookup) + encoded = _encode_script_content(rewritten) + + timeout_seconds = _coerce_int(component.get("timeout_seconds")) + if not timeout_seconds: + timeout_seconds = _coerce_int(component.get("timeout")) + + return JobComponentManifest( + name=self._component_name(abs_path, component), + path=path, + script_type=script_type, + script_content=rewritten, + encoded_content=encoded, + environment=env_map, + literal_environment=literal_lookup, + timeout_seconds=timeout_seconds, + ) + + def _component_name(self, abs_path: Path, component: Mapping[str, Any]) -> str: + if isinstance(component.get("name"), str) and component["name"].strip(): + return component["name"].strip() + return abs_path.stem + + def _resolve_script_path(self, rel_path: str) -> Path: + candidate = Path(rel_path.replace("\\", "/").lstrip("/")) + if candidate.parts and candidate.parts[0] != "Scripts": + candidate = Path("Scripts") / candidate + abs_path = (self._assemblies_root / candidate).resolve() + try: + abs_path.relative_to(self._assemblies_root) + except ValueError: + raise FileNotFoundError(rel_path) + if not abs_path.is_file(): + raise FileNotFoundError(rel_path) + return abs_path + + def _load_script_content(self, abs_path: Path, component: Mapping[str, Any]) -> str: + if isinstance(component.get("script"), str) and component["script"].strip(): + return _decode_script_content(component["script"], component.get("encoding") or "") + try: + return abs_path.read_text(encoding="utf-8") + except Exception as exc: + self._log.warning("unable to read script for job component: path=%s error=%s", abs_path, exc) + return "" + + def _detect_script_type(self, abs_path: Path, declared: str) -> str: + lower = declared.lower() + if lower in {"script", "powershell"}: + return "powershell" + suffix = abs_path.suffix.lower() + if suffix == ".ps1": + return "powershell" + if suffix == ".yml": + return "ansible" + if suffix == ".json": + try: + data = json.loads(abs_path.read_text(encoding="utf-8")) + if isinstance(data, dict): + t = str(data.get("type") or data.get("script_type") or "").strip().lower() + if t: + return t + except Exception: + pass + return lower or "powershell" + + def _collect_overrides(self, component: Mapping[str, Any]) -> Dict[str, Any]: + overrides: Dict[str, Any] = {} + values = component.get("variable_values") + if isinstance(values, Mapping): + for key, value in values.items(): + name = str(key or "").strip() + if name: + overrides[name] = value + vars_inline = component.get("variables") + if isinstance(vars_inline, Iterable): + for var in vars_inline: + if not isinstance(var, Mapping): + continue + name = str(var.get("name") or "").strip() + if not name: + continue + if "value" in var: + overrides[name] = var.get("value") + return overrides + + +def _coerce_int(value: Any) -> int: + try: + return int(value or 0) + except Exception: + return 0 + + +def _env_string(value: Any) -> str: + if isinstance(value, bool): + return "True" if value else "False" + if value is None: + return "" + return str(value) + + +def _decode_base64_text(value: Any) -> Optional[str]: + if not isinstance(value, str): + return None + stripped = value.strip() + if not stripped: + return "" + try: + cleaned = re.sub(r"\s+", "", stripped) + except Exception: + cleaned = stripped + try: + decoded = base64.b64decode(cleaned, validate=True) + except Exception: + return None + try: + return decoded.decode("utf-8") + except Exception: + return decoded.decode("utf-8", errors="replace") + + +def _decode_script_content(value: Any, encoding_hint: Any = "") -> str: + encoding = str(encoding_hint or "").strip().lower() + if isinstance(value, str): + if encoding in {"base64", "b64", "base-64"}: + decoded = _decode_base64_text(value) + if decoded is not None: + return decoded.replace("\r\n", "\n") + decoded = _decode_base64_text(value) + if decoded is not None: + return decoded.replace("\r\n", "\n") + return value.replace("\r\n", "\n") + return "" + + +def _encode_script_content(script_text: Any) -> str: + if not isinstance(script_text, str): + if script_text is None: + script_text = "" + else: + script_text = str(script_text) + normalized = script_text.replace("\r\n", "\n") + if not normalized: + return "" + encoded = base64.b64encode(normalized.encode("utf-8")) + return encoded.decode("ascii") + + +def _canonical_env_key(name: Any) -> str: + try: + return re.sub(r"[^A-Za-z0-9_]", "_", str(name or "").strip()).upper() + except Exception: + return "" + + +def _expand_env_aliases(env_map: Dict[str, str], variables: Iterable[Mapping[str, Any]]) -> Dict[str, str]: + expanded = dict(env_map or {}) + for var in variables: + if not isinstance(var, Mapping): + continue + name = str(var.get("name") or "").strip() + if not name: + continue + canonical = _canonical_env_key(name) + if not canonical or canonical not in expanded: + continue + value = expanded[canonical] + alias = re.sub(r"[^A-Za-z0-9_]", "_", name) + if alias and alias not in expanded: + expanded[alias] = value + if alias != name and re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", name) and name not in expanded: + expanded[name] = value + return expanded + + +def _powershell_literal(value: Any, var_type: str) -> str: + typ = str(var_type or "string").lower() + if typ == "boolean": + if isinstance(value, bool): + truthy = value + elif value is None: + truthy = False + elif isinstance(value, (int, float)): + truthy = value != 0 + else: + s = str(value).strip().lower() + if s in {"true", "1", "yes", "y", "on"}: + truthy = True + elif s in {"false", "0", "no", "n", "off", ""}: + truthy = False + else: + truthy = bool(s) + return "$true" if truthy else "$false" + if typ == "number": + if value is None or value == "": + return "0" + return str(value) + s = "" if value is None else str(value) + return "'" + s.replace("'", "''") + "'" + + +def _extract_variable_default(var: Mapping[str, Any]) -> Any: + for key in ("value", "default", "defaultValue", "default_value"): + if key in var: + val = var.get(key) + return "" if val is None else val + return "" + + +def _prepare_variable_context( + doc_variables: Iterable[Mapping[str, Any]], + overrides: Mapping[str, Any], +) -> Tuple[Dict[str, str], List[Dict[str, Any]], Dict[str, str]]: + env_map: Dict[str, str] = {} + variables: List[Dict[str, Any]] = [] + literal_lookup: Dict[str, str] = {} + doc_names: Dict[str, bool] = {} + + overrides = dict(overrides or {}) + + for var in doc_variables: + if not isinstance(var, Mapping): + continue + name = str(var.get("name") or "").strip() + if not name: + continue + doc_names[name] = True + canonical = _canonical_env_key(name) + var_type = str(var.get("type") or "string").lower() + default_val = _extract_variable_default(var) + final_val = overrides[name] if name in overrides else default_val + if canonical: + env_map[canonical] = _env_string(final_val) + literal_lookup[canonical] = _powershell_literal(final_val, var_type) + if name in overrides: + new_var = dict(var) + new_var["value"] = overrides[name] + variables.append(new_var) + else: + variables.append(dict(var)) + + for name, val in overrides.items(): + if name in doc_names: + continue + canonical = _canonical_env_key(name) + if canonical: + env_map[canonical] = _env_string(val) + literal_lookup[canonical] = _powershell_literal(val, "string") + variables.append({"name": name, "value": val, "type": "string"}) + + env_map = _expand_env_aliases(env_map, variables) + return env_map, variables, literal_lookup + + +def _rewrite_powershell_script(content: str, literal_lookup: Mapping[str, str]) -> str: + if not content or not literal_lookup: + return content + + def _replace(match: re.Match[str]) -> str: + name = match.group(2) + canonical = _canonical_env_key(name) + if not canonical: + return match.group(0) + literal = literal_lookup.get(canonical) + if literal is None: + return match.group(0) + return literal + + return _ENV_VAR_PATTERN.sub(_replace, content) diff --git a/Data/Engine/config/__init__.py b/Data/Engine/config/__init__.py new file mode 100644 index 0000000..8074ffa --- /dev/null +++ b/Data/Engine/config/__init__.py @@ -0,0 +1,25 @@ +"""Configuration primitives for the Borealis Engine.""" + +from __future__ import annotations + +from .environment import ( + DatabaseSettings, + EngineSettings, + FlaskSettings, + GitHubSettings, + ServerSettings, + SocketIOSettings, + load_environment, +) +from .logging import configure_logging + +__all__ = [ + "DatabaseSettings", + "EngineSettings", + "FlaskSettings", + "GitHubSettings", + "load_environment", + "ServerSettings", + "SocketIOSettings", + "configure_logging", +] diff --git a/Data/Engine/config/environment.py b/Data/Engine/config/environment.py new file mode 100644 index 0000000..58ccf66 --- /dev/null +++ b/Data/Engine/config/environment.py @@ -0,0 +1,206 @@ +"""Environment detection for the Borealis Engine.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, Tuple + + +@dataclass(frozen=True, slots=True) +class DatabaseSettings: + """SQLite database configuration for the Engine.""" + + path: Path + apply_migrations: bool + + +@dataclass(frozen=True, slots=True) +class FlaskSettings: + """Parameters that influence Flask application behavior.""" + + secret_key: str + static_root: Path + cors_allowed_origins: Tuple[str, ...] + + +@dataclass(frozen=True, slots=True) +class SocketIOSettings: + """Configuration for the optional Socket.IO server.""" + + cors_allowed_origins: Tuple[str, ...] + + +@dataclass(frozen=True, slots=True) +class ServerSettings: + """HTTP server binding configuration.""" + + host: str + port: int + + +@dataclass(frozen=True, slots=True) +class GitHubSettings: + """Configuration surface for GitHub repository interactions.""" + + default_repo: str + default_branch: str + refresh_interval_seconds: int + cache_root: Path + + @property + def cache_file(self) -> Path: + """Location of the persisted repository-head cache.""" + + return self.cache_root / "repo_hash_cache.json" + + +@dataclass(frozen=True, slots=True) +class EngineSettings: + """Immutable container describing the Engine runtime configuration.""" + + project_root: Path + debug: bool + database: DatabaseSettings + flask: FlaskSettings + socketio: SocketIOSettings + server: ServerSettings + github: GitHubSettings + + @property + def logs_root(self) -> Path: + """Return the directory where Engine-specific logs should live.""" + + return self.project_root / "Logs" / "Server" + + @property + def database_path(self) -> Path: + """Convenience accessor for the database file path.""" + + return self.database.path + + @property + def apply_migrations(self) -> bool: + """Return whether schema migrations should run at bootstrap.""" + + return self.database.apply_migrations + + +def _resolve_project_root() -> Path: + candidate = os.getenv("BOREALIS_ROOT") + if candidate: + return Path(candidate).expanduser().resolve() + return Path(__file__).resolve().parents[2] + + +def _resolve_database_path(project_root: Path) -> Path: + candidate = os.getenv("BOREALIS_DATABASE_PATH") + if candidate: + return Path(candidate).expanduser().resolve() + return (project_root / "database.db").resolve() + + +def _should_apply_migrations() -> bool: + raw = os.getenv("BOREALIS_ENGINE_AUTO_MIGRATE", "true") + return raw.lower() in {"1", "true", "yes", "on"} + + +def _resolve_static_root(project_root: Path) -> Path: + candidate = os.getenv("BOREALIS_STATIC_ROOT") + if candidate: + return Path(candidate).expanduser().resolve() + + candidates = ( + project_root / "Engine" / "web-interface" / "build", + project_root / "Engine" / "web-interface" / "dist", + project_root / "Data" / "Engine" / "WebUI" / "build", + project_root / "Data" / "Server" / "web-interface" / "build", + project_root / "Data" / "Server" / "WebUI" / "build", + project_root / "Data" / "WebUI" / "build", + ) + for path in candidates: + resolved = path.resolve() + if resolved.is_dir(): + return resolved + + # Fall back to the first candidate even if it does not yet exist so the + # Flask factory still initialises; individual requests will surface 404s + # until an asset build is available, matching the legacy behaviour. + return candidates[0].resolve() + + +def _resolve_github_cache_root(project_root: Path) -> Path: + candidate = os.getenv("BOREALIS_CACHE_DIR") + if candidate: + return Path(candidate).expanduser().resolve() + return (project_root / "Data" / "Engine" / "cache").resolve() + + +def _parse_refresh_interval(raw: str | None) -> int: + if not raw: + return 60 + try: + value = int(raw) + except ValueError: + value = 60 + return max(30, min(value, 3600)) + + +def _parse_origins(raw: str | None) -> Tuple[str, ...]: + if not raw: + return ("*",) + parts: Iterable[str] = (segment.strip() for segment in raw.split(",")) + filtered = tuple(part for part in parts if part) + return filtered or ("*",) + + +def load_environment() -> EngineSettings: + """Load Engine settings from environment variables and filesystem hints.""" + + project_root = _resolve_project_root() + database = DatabaseSettings( + path=_resolve_database_path(project_root), + apply_migrations=_should_apply_migrations(), + ) + cors_allowed_origins = _parse_origins(os.getenv("BOREALIS_CORS_ALLOWED_ORIGINS")) + flask_settings = FlaskSettings( + secret_key=os.getenv("BOREALIS_FLASK_SECRET_KEY", "change-me"), + static_root=_resolve_static_root(project_root), + cors_allowed_origins=cors_allowed_origins, + ) + socket_settings = SocketIOSettings(cors_allowed_origins=cors_allowed_origins) + debug = os.getenv("BOREALIS_DEBUG", "false").lower() in {"1", "true", "yes", "on"} + host = os.getenv("BOREALIS_HOST", "127.0.0.1") + try: + port = int(os.getenv("BOREALIS_PORT", "5000")) + except ValueError: + port = 5000 + server_settings = ServerSettings(host=host, port=port) + github_settings = GitHubSettings( + default_repo=os.getenv("BOREALIS_REPO", "bunny-lab-io/Borealis"), + default_branch=os.getenv("BOREALIS_REPO_BRANCH", "main"), + refresh_interval_seconds=_parse_refresh_interval(os.getenv("BOREALIS_REPO_HASH_REFRESH")), + cache_root=_resolve_github_cache_root(project_root), + ) + + return EngineSettings( + project_root=project_root, + debug=debug, + database=database, + flask=flask_settings, + socketio=socket_settings, + server=server_settings, + github=github_settings, + ) + + +__all__ = [ + "DatabaseSettings", + "EngineSettings", + "FlaskSettings", + "GitHubSettings", + "SocketIOSettings", + "ServerSettings", + "load_environment", +] diff --git a/Data/Engine/config/logging.py b/Data/Engine/config/logging.py new file mode 100644 index 0000000..f07b764 --- /dev/null +++ b/Data/Engine/config/logging.py @@ -0,0 +1,71 @@ +"""Logging bootstrap helpers for the Borealis Engine.""" + +from __future__ import annotations + +import logging +from logging.handlers import TimedRotatingFileHandler +from pathlib import Path + +from .environment import EngineSettings + + +_ENGINE_LOGGER_NAME = "borealis.engine" +_SERVICE_NAME = "engine" +_DEFAULT_FORMAT = "%(asctime)s-" + _SERVICE_NAME + "-%(message)s" + + +def _handler_already_attached(logger: logging.Logger, log_path: Path) -> bool: + for handler in logger.handlers: + if isinstance(handler, TimedRotatingFileHandler): + handler_path = Path(getattr(handler, "baseFilename", "")) + if handler_path == log_path: + return True + return False + + +def _build_handler(log_path: Path) -> TimedRotatingFileHandler: + handler = TimedRotatingFileHandler( + log_path, + when="midnight", + backupCount=30, + encoding="utf-8", + ) + handler.setLevel(logging.INFO) + handler.setFormatter(logging.Formatter(_DEFAULT_FORMAT)) + return handler + + +def configure_logging(settings: EngineSettings) -> logging.Logger: + """Configure a rotating log handler for the Engine.""" + + logs_root = settings.logs_root + logs_root.mkdir(parents=True, exist_ok=True) + log_path = logs_root / "engine.log" + + logger = logging.getLogger(_ENGINE_LOGGER_NAME) + logger.setLevel(logging.INFO if not settings.debug else logging.DEBUG) + + if not _handler_already_attached(logger, log_path): + handler = _build_handler(log_path) + logger.addHandler(handler) + logger.propagate = False + + # Also ensure the root logger follows suit so third-party modules inherit the handler. + root_logger = logging.getLogger() + if not _handler_already_attached(root_logger, log_path): + handler = _build_handler(log_path) + root_logger.addHandler(handler) + if root_logger.level == logging.WARNING: + # Default level is WARNING; lower it to INFO so our handler captures application messages. + root_logger.setLevel(logging.INFO if not settings.debug else logging.DEBUG) + + # Quieten overly chatty frameworks unless debugging is explicitly requested. + if not settings.debug: + logging.getLogger("werkzeug").setLevel(logging.WARNING) + logging.getLogger("engineio").setLevel(logging.WARNING) + logging.getLogger("socketio").setLevel(logging.WARNING) + + return logger + + +__all__ = ["configure_logging"] diff --git a/Data/Engine/domain/__init__.py b/Data/Engine/domain/__init__.py new file mode 100644 index 0000000..077ce2f --- /dev/null +++ b/Data/Engine/domain/__init__.py @@ -0,0 +1,49 @@ +"""Pure value objects and enums for the Borealis Engine.""" + +from __future__ import annotations + +from .device_auth import ( # noqa: F401 + AccessTokenClaims, + DeviceAuthContext, + DeviceAuthErrorCode, + DeviceAuthFailure, + DeviceFingerprint, + DeviceGuid, + DeviceIdentity, + DeviceStatus, + sanitize_service_context, +) +from .device_enrollment import ( # noqa: F401 + EnrollmentApproval, + EnrollmentApprovalStatus, + EnrollmentCode, + EnrollmentRequest, + ProofChallenge, +) +from .github import ( # noqa: F401 + GitHubRateLimit, + GitHubRepoRef, + GitHubTokenStatus, + RepoHeadSnapshot, +) + +__all__ = [ + "AccessTokenClaims", + "DeviceAuthContext", + "DeviceAuthErrorCode", + "DeviceAuthFailure", + "DeviceFingerprint", + "DeviceGuid", + "DeviceIdentity", + "DeviceStatus", + "EnrollmentApproval", + "EnrollmentApprovalStatus", + "EnrollmentCode", + "EnrollmentRequest", + "ProofChallenge", + "GitHubRateLimit", + "GitHubRepoRef", + "GitHubTokenStatus", + "RepoHeadSnapshot", + "sanitize_service_context", +] diff --git a/Data/Engine/domain/device_auth.py b/Data/Engine/domain/device_auth.py new file mode 100644 index 0000000..d377e52 --- /dev/null +++ b/Data/Engine/domain/device_auth.py @@ -0,0 +1,242 @@ +"""Domain primitives for device authentication and token validation.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Mapping, Optional +import string +import uuid + +__all__ = [ + "DeviceGuid", + "DeviceFingerprint", + "DeviceIdentity", + "DeviceStatus", + "DeviceAuthErrorCode", + "DeviceAuthFailure", + "AccessTokenClaims", + "DeviceAuthContext", + "sanitize_service_context", +] + + +def _normalize_guid(value: Optional[str]) -> str: + """Return a canonical GUID string or an empty string.""" + candidate = (value or "").strip() + if not candidate: + return "" + candidate = candidate.strip("{}") + try: + return str(uuid.UUID(candidate)).upper() + except Exception: + cleaned = "".join( + ch for ch in candidate if ch in string.hexdigits or ch == "-" + ).strip("-") + if cleaned: + try: + return str(uuid.UUID(cleaned)).upper() + except Exception: + pass + return candidate.upper() + + +def _normalize_fingerprint(value: Optional[str]) -> str: + return (value or "").strip().lower() + + +def sanitize_service_context(value: Optional[str]) -> Optional[str]: + """Normalize the optional agent service context header value.""" + if not value: + return None + cleaned = "".join( + ch for ch in str(value) if ch.isalnum() or ch in ("_", "-") + ) + if not cleaned: + return None + return cleaned.upper() + + +@dataclass(frozen=True, slots=True) +class DeviceGuid: + """Canonical GUID wrapper that enforces Borealis normalization.""" + + value: str + + def __post_init__(self) -> None: # pragma: no cover - simple data normalization + normalized = _normalize_guid(self.value) + if not normalized: + raise ValueError("device GUID is required") + object.__setattr__(self, "value", normalized) + + def __str__(self) -> str: + return self.value + + +@dataclass(frozen=True, slots=True) +class DeviceFingerprint: + """Normalized TLS key fingerprint associated with a device.""" + + value: str + + def __post_init__(self) -> None: # pragma: no cover - simple data normalization + normalized = _normalize_fingerprint(self.value) + if not normalized: + raise ValueError("device fingerprint is required") + object.__setattr__(self, "value", normalized) + + def __str__(self) -> str: + return self.value + + +@dataclass(frozen=True, slots=True) +class DeviceIdentity: + """Immutable pairing of device GUID and TLS key fingerprint.""" + + guid: DeviceGuid + fingerprint: DeviceFingerprint + + +class DeviceStatus(str, Enum): + """Lifecycle markers mirrored from the legacy devices table.""" + + ACTIVE = "active" + QUARANTINED = "quarantined" + REVOKED = "revoked" + DECOMMISSIONED = "decommissioned" + + @classmethod + def from_string(cls, value: Optional[str]) -> "DeviceStatus": + normalized = (value or "active").strip().lower() + try: + return cls(normalized) + except ValueError: + return cls.ACTIVE + + @property + def allows_access(self) -> bool: + return self in {self.ACTIVE, self.QUARANTINED} + + +class DeviceAuthErrorCode(str, Enum): + """Well-known authentication failure categories.""" + + MISSING_AUTHORIZATION = "missing_authorization" + TOKEN_EXPIRED = "token_expired" + INVALID_TOKEN = "invalid_token" + INVALID_CLAIMS = "invalid_claims" + RATE_LIMITED = "rate_limited" + DEVICE_NOT_FOUND = "device_not_found" + DEVICE_GUID_MISMATCH = "device_guid_mismatch" + FINGERPRINT_MISMATCH = "fingerprint_mismatch" + TOKEN_VERSION_REVOKED = "token_version_revoked" + DEVICE_REVOKED = "device_revoked" + DPOP_NOT_SUPPORTED = "dpop_not_supported" + DPOP_REPLAYED = "dpop_replayed" + DPOP_INVALID = "dpop_invalid" + + +class DeviceAuthFailure(Exception): + """Domain-level authentication error with HTTP metadata.""" + + def __init__( + self, + code: DeviceAuthErrorCode, + *, + http_status: int = 401, + retry_after: Optional[float] = None, + detail: Optional[str] = None, + ) -> None: + self.code = code + self.http_status = int(http_status) + self.retry_after = retry_after + self.detail = detail or code.value + super().__init__(self.detail) + + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = {"error": self.code.value} + if self.retry_after is not None: + payload["retry_after"] = float(self.retry_after) + if self.detail and self.detail != self.code.value: + payload["detail"] = self.detail + return payload + + +def _coerce_int(value: Any, *, minimum: Optional[int] = None) -> int: + try: + result = int(value) + except (TypeError, ValueError): + raise ValueError("expected integer value") from None + if minimum is not None and result < minimum: + raise ValueError("integer below minimum") + return result + + +@dataclass(frozen=True, slots=True) +class AccessTokenClaims: + """Validated subset of JWT claims issued to a device.""" + + subject: str + guid: DeviceGuid + fingerprint: DeviceFingerprint + token_version: int + issued_at: int + not_before: int + expires_at: int + raw: Mapping[str, Any] + + def __post_init__(self) -> None: + if self.token_version <= 0: + raise ValueError("token_version must be positive") + if self.issued_at <= 0 or self.not_before <= 0 or self.expires_at <= 0: + raise ValueError("temporal claims must be positive integers") + if self.expires_at <= self.not_before: + raise ValueError("token expiration must be after not-before") + + @classmethod + def from_mapping(cls, claims: Mapping[str, Any]) -> "AccessTokenClaims": + subject = str(claims.get("sub") or "").strip() + if not subject: + raise ValueError("missing token subject") + guid = DeviceGuid(str(claims.get("guid") or "")) + fingerprint = DeviceFingerprint(claims.get("ssl_key_fingerprint")) + token_version = _coerce_int(claims.get("token_version"), minimum=1) + issued_at = _coerce_int(claims.get("iat"), minimum=1) + not_before = _coerce_int(claims.get("nbf"), minimum=1) + expires_at = _coerce_int(claims.get("exp"), minimum=1) + return cls( + subject=subject, + guid=guid, + fingerprint=fingerprint, + token_version=token_version, + issued_at=issued_at, + not_before=not_before, + expires_at=expires_at, + raw=dict(claims), + ) + + +@dataclass(frozen=True, slots=True) +class DeviceAuthContext: + """Domain result emitted after successful authentication.""" + + identity: DeviceIdentity + access_token: str + claims: AccessTokenClaims + status: DeviceStatus + service_context: Optional[str] + dpop_jkt: Optional[str] = None + + def __post_init__(self) -> None: + if not self.access_token: + raise ValueError("access token is required") + service = sanitize_service_context(self.service_context) + object.__setattr__(self, "service_context", service) + + @property + def is_quarantined(self) -> bool: + return self.status is DeviceStatus.QUARANTINED + + @property + def allows_access(self) -> bool: + return self.status.allows_access diff --git a/Data/Engine/domain/device_enrollment.py b/Data/Engine/domain/device_enrollment.py new file mode 100644 index 0000000..85b680f --- /dev/null +++ b/Data/Engine/domain/device_enrollment.py @@ -0,0 +1,261 @@ +"""Domain types describing device enrollment flows.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +import base64 +from enum import Enum +from typing import Any, Mapping, Optional + +from .device_auth import DeviceFingerprint, DeviceGuid, sanitize_service_context + +__all__ = [ + "EnrollmentCode", + "EnrollmentApprovalStatus", + "EnrollmentApproval", + "EnrollmentRequest", + "EnrollmentValidationError", + "ProofChallenge", +] + + +def _parse_iso8601(value: Optional[str]) -> Optional[datetime]: + if not value: + return None + raw = str(value).strip() + if not raw: + return None + try: + dt = datetime.fromisoformat(raw) + except Exception as exc: # pragma: no cover - error path + raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt + + +def _require(value: Optional[str], field: str) -> str: + text = (value or "").strip() + if not text: + raise ValueError(f"{field} is required") + return text + + +@dataclass(frozen=True, slots=True) +class EnrollmentValidationError(Exception): + """Raised when enrollment input fails validation.""" + + code: str + http_status: int = 400 + retry_after: Optional[float] = None + + def to_response(self) -> dict[str, object]: + payload: dict[str, object] = {"error": self.code} + if self.retry_after is not None: + payload["retry_after"] = self.retry_after + return payload + + def __str__(self) -> str: # pragma: no cover - debug helper + return f"{self.code} (status={self.http_status})" + + +@dataclass(frozen=True, slots=True) +class EnrollmentCode: + """Installer code metadata loaded from the persistence layer.""" + + code: str + expires_at: datetime + max_uses: int + use_count: int + used_by_guid: Optional[DeviceGuid] + last_used_at: Optional[datetime] + used_at: Optional[datetime] + record_id: Optional[str] = None + + def __post_init__(self) -> None: + if not self.code: + raise ValueError("code is required") + if self.max_uses < 1: + raise ValueError("max_uses must be >= 1") + if self.use_count < 0: + raise ValueError("use_count cannot be negative") + if self.use_count > self.max_uses: + raise ValueError("use_count cannot exceed max_uses") + + @classmethod + def from_mapping(cls, record: Mapping[str, Any]) -> "EnrollmentCode": + used_by = record.get("used_by_guid") + used_by_guid = DeviceGuid(used_by) if used_by else None + return cls( + code=_require(record.get("code"), "code"), + expires_at=_parse_iso8601(record.get("expires_at")) or datetime.now(tz=timezone.utc), + max_uses=int(record.get("max_uses") or 1), + use_count=int(record.get("use_count") or 0), + used_by_guid=used_by_guid, + last_used_at=_parse_iso8601(record.get("last_used_at")), + used_at=_parse_iso8601(record.get("used_at")), + record_id=str(record.get("id") or "") or None, + ) + + @property + def remaining_uses(self) -> int: + return max(self.max_uses - self.use_count, 0) + + @property + def is_exhausted(self) -> bool: + return self.remaining_uses == 0 + + @property + def is_expired(self) -> bool: + return self.expires_at <= datetime.now(tz=timezone.utc) + + @property + def identifier(self) -> Optional[str]: + return self.record_id + + +class EnrollmentApprovalStatus(str, Enum): + """Possible states for a device approval entry.""" + + PENDING = "pending" + APPROVED = "approved" + DENIED = "denied" + COMPLETED = "completed" + EXPIRED = "expired" + + @classmethod + def from_string(cls, value: Optional[str]) -> "EnrollmentApprovalStatus": + normalized = (value or "pending").strip().lower() + try: + return cls(normalized) + except ValueError: + return cls.PENDING + + @property + def is_terminal(self) -> bool: + return self in {self.APPROVED, self.DENIED, self.COMPLETED, self.EXPIRED} + + +@dataclass(frozen=True, slots=True) +class ProofChallenge: + """Client/server nonce pair distributed during enrollment.""" + + client_nonce: bytes + server_nonce: bytes + + def __post_init__(self) -> None: + if not self.client_nonce or not self.server_nonce: + raise ValueError("nonce payloads must be non-empty") + + @classmethod + def from_base64(cls, *, client: bytes, server: bytes) -> "ProofChallenge": + return cls(client_nonce=client, server_nonce=server) + + +@dataclass(frozen=True, slots=True) +class EnrollmentRequest: + """Validated payload submitted by an agent during enrollment.""" + + hostname: str + enrollment_code: str + fingerprint: DeviceFingerprint + proof: ProofChallenge + service_context: Optional[str] + + def __post_init__(self) -> None: + if not self.hostname: + raise ValueError("hostname is required") + if not self.enrollment_code: + raise ValueError("enrollment code is required") + object.__setattr__( + self, + "service_context", + sanitize_service_context(self.service_context), + ) + + @classmethod + def from_payload( + cls, + *, + hostname: str, + enrollment_code: str, + fingerprint: str, + client_nonce: bytes, + server_nonce: bytes, + service_context: Optional[str] = None, + ) -> "EnrollmentRequest": + proof = ProofChallenge(client_nonce=client_nonce, server_nonce=server_nonce) + return cls( + hostname=_require(hostname, "hostname"), + enrollment_code=_require(enrollment_code, "enrollment_code"), + fingerprint=DeviceFingerprint(fingerprint), + proof=proof, + service_context=service_context, + ) + + +@dataclass(frozen=True, slots=True) +class EnrollmentApproval: + """Pending or resolved approval tracked by operators.""" + + record_id: str + reference: str + status: EnrollmentApprovalStatus + claimed_hostname: str + claimed_fingerprint: DeviceFingerprint + enrollment_code_id: Optional[str] + created_at: datetime + updated_at: datetime + client_nonce_b64: str + server_nonce_b64: str + agent_pubkey_der: bytes + guid: Optional[DeviceGuid] = None + approved_by: Optional[str] = None + + def __post_init__(self) -> None: + if not self.record_id: + raise ValueError("record identifier is required") + if not self.reference: + raise ValueError("approval reference is required") + if not self.claimed_hostname: + raise ValueError("claimed hostname is required") + + @classmethod + def from_mapping(cls, record: Mapping[str, Any]) -> "EnrollmentApproval": + guid_raw = record.get("guid") + approved_raw = record.get("approved_by_user_id") + return cls( + record_id=_require(record.get("id"), "id"), + reference=_require(record.get("approval_reference"), "approval_reference"), + status=EnrollmentApprovalStatus.from_string(record.get("status")), + claimed_hostname=_require(record.get("hostname_claimed"), "hostname_claimed"), + claimed_fingerprint=DeviceFingerprint(record.get("ssl_key_fingerprint_claimed")), + enrollment_code_id=record.get("enrollment_code_id"), + created_at=_parse_iso8601(record.get("created_at")) or datetime.now(tz=timezone.utc), + updated_at=_parse_iso8601(record.get("updated_at")) or datetime.now(tz=timezone.utc), + guid=DeviceGuid(guid_raw) if guid_raw else None, + approved_by=(approved_raw or None), + client_nonce_b64=_require(record.get("client_nonce"), "client_nonce"), + server_nonce_b64=_require(record.get("server_nonce"), "server_nonce"), + agent_pubkey_der=bytes(record.get("agent_pubkey_der") or b""), + ) + + @property + def is_pending(self) -> bool: + return self.status is EnrollmentApprovalStatus.PENDING + + @property + def is_completed(self) -> bool: + return self.status in { + EnrollmentApprovalStatus.APPROVED, + EnrollmentApprovalStatus.COMPLETED, + } + + @property + def client_nonce_bytes(self) -> bytes: + return base64.b64decode(self.client_nonce_b64.encode("ascii"), validate=True) + + @property + def server_nonce_bytes(self) -> bytes: + return base64.b64decode(self.server_nonce_b64.encode("ascii"), validate=True) diff --git a/Data/Engine/domain/github.py b/Data/Engine/domain/github.py new file mode 100644 index 0000000..f2e60b1 --- /dev/null +++ b/Data/Engine/domain/github.py @@ -0,0 +1,103 @@ +"""Domain types for GitHub integrations.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Optional + + +@dataclass(frozen=True, slots=True) +class GitHubRepoRef: + """Identify a GitHub repository and branch.""" + + owner: str + name: str + branch: str + + @property + def full_name(self) -> str: + return f"{self.owner}/{self.name}".strip("/") + + @classmethod + def parse(cls, owner_repo: str, branch: str) -> "GitHubRepoRef": + owner_repo = (owner_repo or "").strip() + if "/" not in owner_repo: + raise ValueError("repo must be in the form owner/name") + owner, repo = owner_repo.split("/", 1) + return cls(owner=owner.strip(), name=repo.strip(), branch=(branch or "main").strip()) + + +@dataclass(frozen=True, slots=True) +class RepoHeadSnapshot: + """Snapshot describing the current head of a repository branch.""" + + repository: GitHubRepoRef + sha: Optional[str] + cached: bool + age_seconds: Optional[float] + source: str + error: Optional[str] + + def to_dict(self) -> Dict[str, object]: + return { + "repo": self.repository.full_name, + "branch": self.repository.branch, + "sha": self.sha, + "cached": self.cached, + "age_seconds": self.age_seconds, + "source": self.source, + "error": self.error, + } + + +@dataclass(frozen=True, slots=True) +class GitHubRateLimit: + """Subset of rate limit details returned by the GitHub API.""" + + limit: Optional[int] + remaining: Optional[int] + reset_epoch: Optional[int] + used: Optional[int] + + def to_dict(self) -> Dict[str, Optional[int]]: + return { + "limit": self.limit, + "remaining": self.remaining, + "reset": self.reset_epoch, + "used": self.used, + } + + +@dataclass(frozen=True, slots=True) +class GitHubTokenStatus: + """Describe the verification result for a GitHub access token.""" + + has_token: bool + valid: bool + status: str + message: str + rate_limit: Optional[GitHubRateLimit] + error: Optional[str] = None + + def to_dict(self) -> Dict[str, object]: + payload: Dict[str, object] = { + "has_token": self.has_token, + "valid": self.valid, + "status": self.status, + "message": self.message, + "error": self.error, + } + if self.rate_limit is not None: + payload["rate_limit"] = self.rate_limit.to_dict() + else: + payload["rate_limit"] = None + return payload + + +__all__ = [ + "GitHubRateLimit", + "GitHubRepoRef", + "GitHubTokenStatus", + "RepoHeadSnapshot", +] + diff --git a/Data/Engine/integrations/__init__.py b/Data/Engine/integrations/__init__.py new file mode 100644 index 0000000..d2fc960 --- /dev/null +++ b/Data/Engine/integrations/__init__.py @@ -0,0 +1,7 @@ +"""External system adapters for the Borealis Engine.""" + +from __future__ import annotations + +from .github.artifact_provider import GitHubArtifactProvider + +__all__ = ["GitHubArtifactProvider"] diff --git a/Data/Engine/integrations/crypto/__init__.py b/Data/Engine/integrations/crypto/__init__.py new file mode 100644 index 0000000..0d790bf --- /dev/null +++ b/Data/Engine/integrations/crypto/__init__.py @@ -0,0 +1,25 @@ +"""Crypto integration helpers for the Engine.""" + +from __future__ import annotations + +from .keys import ( + base64_from_spki_der, + fingerprint_from_base64_spki, + fingerprint_from_spki_der, + generate_ed25519_keypair, + normalize_base64, + private_key_to_pem, + public_key_to_pem, + spki_der_from_base64, +) + +__all__ = [ + "base64_from_spki_der", + "fingerprint_from_base64_spki", + "fingerprint_from_spki_der", + "generate_ed25519_keypair", + "normalize_base64", + "private_key_to_pem", + "public_key_to_pem", + "spki_der_from_base64", +] diff --git a/Data/Engine/integrations/crypto/keys.py b/Data/Engine/integrations/crypto/keys.py new file mode 100644 index 0000000..076a520 --- /dev/null +++ b/Data/Engine/integrations/crypto/keys.py @@ -0,0 +1,70 @@ +"""Key utilities mirrored from the legacy crypto helpers.""" + +from __future__ import annotations + +import base64 +import hashlib +import re +from typing import Tuple + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 +from cryptography.hazmat.primitives.serialization import load_der_public_key + +__all__ = [ + "base64_from_spki_der", + "fingerprint_from_base64_spki", + "fingerprint_from_spki_der", + "generate_ed25519_keypair", + "normalize_base64", + "private_key_to_pem", + "public_key_to_pem", + "spki_der_from_base64", +] + + +def generate_ed25519_keypair() -> Tuple[ed25519.Ed25519PrivateKey, bytes]: + private_key = ed25519.Ed25519PrivateKey.generate() + public_key = private_key.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return private_key, public_key + + +def normalize_base64(data: str) -> str: + cleaned = re.sub(r"\s+", "", data or "") + return cleaned.replace("-", "+").replace("_", "/") + + +def spki_der_from_base64(spki_b64: str) -> bytes: + return base64.b64decode(normalize_base64(spki_b64), validate=True) + + +def base64_from_spki_der(spki_der: bytes) -> str: + return base64.b64encode(spki_der).decode("ascii") + + +def fingerprint_from_spki_der(spki_der: bytes) -> str: + digest = hashlib.sha256(spki_der).hexdigest() + return digest.lower() + + +def fingerprint_from_base64_spki(spki_b64: str) -> str: + return fingerprint_from_spki_der(spki_der_from_base64(spki_b64)) + + +def private_key_to_pem(private_key: ed25519.Ed25519PrivateKey) -> bytes: + return private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + +def public_key_to_pem(public_spki_der: bytes) -> bytes: + public_key = load_der_public_key(public_spki_der) + return public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) diff --git a/Data/Engine/integrations/github/__init__.py b/Data/Engine/integrations/github/__init__.py new file mode 100644 index 0000000..00facf8 --- /dev/null +++ b/Data/Engine/integrations/github/__init__.py @@ -0,0 +1,8 @@ +"""GitHub integration surface for the Borealis Engine.""" + +from __future__ import annotations + +from .artifact_provider import GitHubArtifactProvider + +__all__ = ["GitHubArtifactProvider"] + diff --git a/Data/Engine/integrations/github/artifact_provider.py b/Data/Engine/integrations/github/artifact_provider.py new file mode 100644 index 0000000..91d1a3b --- /dev/null +++ b/Data/Engine/integrations/github/artifact_provider.py @@ -0,0 +1,275 @@ +"""GitHub REST API integration with caching support.""" + +from __future__ import annotations + +import json +import logging +import threading +import time +from pathlib import Path +from typing import Dict, Optional + +from Data.Engine.domain.github import GitHubRepoRef, GitHubTokenStatus, RepoHeadSnapshot, GitHubRateLimit + +try: # pragma: no cover - optional dependency guard + import requests + from requests import Response +except Exception: # pragma: no cover - fallback when requests is unavailable + requests = None # type: ignore[assignment] + Response = object # type: ignore[misc,assignment] + +__all__ = ["GitHubArtifactProvider"] + + +class GitHubArtifactProvider: + """Resolve repository heads and token metadata from the GitHub API.""" + + def __init__( + self, + *, + cache_file: Path, + default_repo: str, + default_branch: str, + refresh_interval: int, + logger: Optional[logging.Logger] = None, + ) -> None: + self._cache_file = cache_file + self._default_repo = default_repo + self._default_branch = default_branch + self._refresh_interval = max(30, min(refresh_interval, 3600)) + self._log = logger or logging.getLogger("borealis.engine.integrations.github") + self._token: Optional[str] = None + self._cache_lock = threading.Lock() + self._cache: Dict[str, Dict[str, float | str]] = {} + self._worker: Optional[threading.Thread] = None + self._hydrate_cache_from_disk() + + def set_token(self, token: Optional[str]) -> None: + self._token = (token or "").strip() or None + + @property + def default_repo(self) -> str: + return self._default_repo + + @property + def default_branch(self) -> str: + return self._default_branch + + @property + def refresh_interval(self) -> int: + return self._refresh_interval + + def fetch_repo_head( + self, + repo: GitHubRepoRef, + *, + ttl_seconds: int, + force_refresh: bool = False, + ) -> RepoHeadSnapshot: + key = f"{repo.full_name}:{repo.branch}" + now = time.time() + + cached_entry = None + with self._cache_lock: + cached_entry = self._cache.get(key, {}).copy() + + cached_sha = (cached_entry.get("sha") if cached_entry else None) # type: ignore[assignment] + cached_ts = cached_entry.get("timestamp") if cached_entry else None # type: ignore[assignment] + cached_age = None + if isinstance(cached_ts, (int, float)): + cached_age = max(0.0, now - float(cached_ts)) + + ttl = max(30, min(ttl_seconds, 3600)) + if cached_sha and not force_refresh and cached_age is not None and cached_age < ttl: + return RepoHeadSnapshot( + repository=repo, + sha=str(cached_sha), + cached=True, + age_seconds=cached_age, + source="cache", + error=None, + ) + + if requests is None: + return RepoHeadSnapshot( + repository=repo, + sha=str(cached_sha) if cached_sha else None, + cached=bool(cached_sha), + age_seconds=cached_age, + source="unavailable", + error="requests library not available", + ) + + headers = { + "Accept": "application/vnd.github+json", + "User-Agent": "Borealis-Engine", + } + if self._token: + headers["Authorization"] = f"Bearer {self._token}" + + url = f"https://api.github.com/repos/{repo.full_name}/branches/{repo.branch}" + error: Optional[str] = None + sha: Optional[str] = None + + try: + response: Response = requests.get(url, headers=headers, timeout=20) + if response.status_code == 200: + data = response.json() + sha = (data.get("commit", {}).get("sha") or "").strip() # type: ignore[assignment] + else: + error = f"GitHub REST API repo head lookup failed: HTTP {response.status_code} {response.text[:200]}" + except Exception as exc: # pragma: no cover - defensive logging + error = f"GitHub REST API repo head lookup raised: {exc}" + + if sha: + payload = {"sha": sha, "timestamp": now} + with self._cache_lock: + self._cache[key] = payload + self._persist_cache() + return RepoHeadSnapshot( + repository=repo, + sha=sha, + cached=False, + age_seconds=0.0, + source="github", + error=None, + ) + + if error: + self._log.warning("repo-head-lookup failure repo=%s branch=%s error=%s", repo.full_name, repo.branch, error) + + return RepoHeadSnapshot( + repository=repo, + sha=str(cached_sha) if cached_sha else None, + cached=bool(cached_sha), + age_seconds=cached_age, + source="cache-stale" if cached_sha else "github", + error=error or ("using cached value" if cached_sha else "unable to resolve repository head"), + ) + + def refresh_default_repo_head(self, *, force: bool = False) -> RepoHeadSnapshot: + repo = GitHubRepoRef.parse(self._default_repo, self._default_branch) + return self.fetch_repo_head(repo, ttl_seconds=self._refresh_interval, force_refresh=force) + + def verify_token(self, token: Optional[str]) -> GitHubTokenStatus: + token = (token or "").strip() + if not token: + return GitHubTokenStatus( + has_token=False, + valid=False, + status="missing", + message="API Token Not Configured", + rate_limit=None, + error=None, + ) + + if requests is None: + return GitHubTokenStatus( + has_token=True, + valid=False, + status="unknown", + message="requests library not available", + rate_limit=None, + error="requests library not available", + ) + + headers = { + "Accept": "application/vnd.github+json", + "Authorization": f"Bearer {token}", + "User-Agent": "Borealis-Engine", + } + try: + response: Response = requests.get("https://api.github.com/rate_limit", headers=headers, timeout=10) + except Exception as exc: # pragma: no cover - defensive logging + message = f"GitHub token verification raised: {exc}" + self._log.warning("github-token-verify error=%s", message) + return GitHubTokenStatus( + has_token=True, + valid=False, + status="error", + message="API Token Invalid", + rate_limit=None, + error=message, + ) + + if response.status_code != 200: + message = f"GitHub API error (HTTP {response.status_code})" + self._log.warning("github-token-verify http_status=%s", response.status_code) + return GitHubTokenStatus( + has_token=True, + valid=False, + status="error", + message="API Token Invalid", + rate_limit=None, + error=message, + ) + + data = response.json() + core = (data.get("resources", {}).get("core", {}) if isinstance(data, dict) else {}) + rate_limit = GitHubRateLimit( + limit=_safe_int(core.get("limit")), + remaining=_safe_int(core.get("remaining")), + reset_epoch=_safe_int(core.get("reset")), + used=_safe_int(core.get("used")), + ) + + message = "API Token Valid" if rate_limit.remaining is not None else "API Token Verified" + return GitHubTokenStatus( + has_token=True, + valid=True, + status="valid", + message=message, + rate_limit=rate_limit, + error=None, + ) + + def start_background_refresh(self) -> None: + if self._worker and self._worker.is_alive(): # pragma: no cover - guard + return + + def _loop() -> None: + interval = max(30, self._refresh_interval) + while True: + try: + self.refresh_default_repo_head(force=True) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("default-repo-refresh failure: %s", exc) + time.sleep(interval) + + self._worker = threading.Thread(target=_loop, name="github-repo-refresh", daemon=True) + self._worker.start() + + def _hydrate_cache_from_disk(self) -> None: + path = self._cache_file + try: + if not path.exists(): + return + data = json.loads(path.read_text(encoding="utf-8")) + if isinstance(data, dict): + with self._cache_lock: + self._cache = { + key: value + for key, value in data.items() + if isinstance(value, dict) and "sha" in value and "timestamp" in value + } + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("failed to load repo cache: %s", exc) + + def _persist_cache(self) -> None: + path = self._cache_file + try: + path.parent.mkdir(parents=True, exist_ok=True) + payload = json.dumps(self._cache, ensure_ascii=False) + tmp = path.with_suffix(".tmp") + tmp.write_text(payload, encoding="utf-8") + tmp.replace(path) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("failed to persist repo cache: %s", exc) + + +def _safe_int(value: object) -> Optional[int]: + try: + return int(value) + except Exception: + return None + diff --git a/Data/Engine/interfaces/__init__.py b/Data/Engine/interfaces/__init__.py new file mode 100644 index 0000000..d7ffdba --- /dev/null +++ b/Data/Engine/interfaces/__init__.py @@ -0,0 +1,12 @@ +"""Interface adapters (HTTP, WebSocket, etc.) for the Borealis Engine.""" + +from __future__ import annotations + +from .http import register_http_interfaces +from .ws import create_socket_server, register_ws_interfaces + +__all__ = [ + "register_http_interfaces", + "create_socket_server", + "register_ws_interfaces", +] diff --git a/Data/Engine/interfaces/eventlet_compat.py b/Data/Engine/interfaces/eventlet_compat.py new file mode 100644 index 0000000..66b1f5e --- /dev/null +++ b/Data/Engine/interfaces/eventlet_compat.py @@ -0,0 +1,75 @@ +"""Compatibility helpers for running Socket.IO under eventlet.""" + +from __future__ import annotations + +import ssl +from typing import Any + +try: # pragma: no cover - optional dependency + import eventlet # type: ignore +except Exception: # pragma: no cover - optional dependency + eventlet = None # type: ignore[assignment] + + +def _quiet_close(connection: Any) -> None: + try: + if hasattr(connection, "close"): + connection.close() + except Exception: + pass + + +def _close_connection_quietly(protocol: Any) -> None: + try: + setattr(protocol, "close_connection", True) + except Exception: + pass + conn = getattr(protocol, "socket", None) or getattr(protocol, "connection", None) + if conn is not None: + _quiet_close(conn) + + +def apply_eventlet_patches() -> None: + """Apply Borealis-specific eventlet tweaks when the dependency is available.""" + + if eventlet is None: # pragma: no cover - guard for environments without eventlet + return + + eventlet.monkey_patch(thread=False) + + try: + from eventlet.wsgi import HttpProtocol # type: ignore + except Exception: # pragma: no cover - import guard + return + + original = HttpProtocol.handle_one_request # type: ignore[attr-defined] + + def _handle_one_request(self: Any, *args: Any, **kwargs: Any) -> Any: # type: ignore[override] + try: + return original(self, *args, **kwargs) + except ssl.SSLError as exc: # type: ignore[arg-type] + reason = getattr(exc, "reason", "") or "" + message = " ".join(str(arg) for arg in exc.args if arg) + lower_reason = str(reason).lower() + lower_message = message.lower() + if ( + "http_request" in lower_message + or lower_reason == "http request" + or "unknown ca" in lower_message + or lower_reason == "unknown ca" + or "unknown_ca" in lower_message + ): + _close_connection_quietly(self) + return None + raise + except ssl.SSLEOFError: + _close_connection_quietly(self) + return None + except ConnectionAbortedError: + _close_connection_quietly(self) + return None + + HttpProtocol.handle_one_request = _handle_one_request # type: ignore[assignment] + + +__all__ = ["apply_eventlet_patches"] diff --git a/Data/Engine/interfaces/http/__init__.py b/Data/Engine/interfaces/http/__init__.py new file mode 100644 index 0000000..ce80a82 --- /dev/null +++ b/Data/Engine/interfaces/http/__init__.py @@ -0,0 +1,32 @@ +"""HTTP interface registration for the Borealis Engine.""" + +from __future__ import annotations + +from flask import Flask + +from Data.Engine.services.container import EngineServiceContainer + +from . import admin, agents, enrollment, github, health, job_management, tokens + +_REGISTRARS = ( + health.register, + agents.register, + enrollment.register, + tokens.register, + job_management.register, + github.register, + admin.register, +) + + +def register_http_interfaces(app: Flask, services: EngineServiceContainer) -> None: + """Attach HTTP blueprints to *app*. + + The implementation is intentionally minimal for the initial scaffolding. + """ + + for registrar in _REGISTRARS: + registrar(app, services) + + +__all__ = ["register_http_interfaces"] diff --git a/Data/Engine/interfaces/http/admin.py b/Data/Engine/interfaces/http/admin.py new file mode 100644 index 0000000..2da2ec2 --- /dev/null +++ b/Data/Engine/interfaces/http/admin.py @@ -0,0 +1,23 @@ +"""Administrative HTTP interface placeholders for the Engine.""" + +from __future__ import annotations + +from flask import Blueprint, Flask + +from Data.Engine.services.container import EngineServiceContainer + + +blueprint = Blueprint("engine_admin", __name__, url_prefix="/api/admin") + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + """Attach administrative routes to *app*. + + Concrete endpoints will be migrated in subsequent phases. + """ + + if "engine_admin" not in app.blueprints: + app.register_blueprint(blueprint) + + +__all__ = ["register", "blueprint"] diff --git a/Data/Engine/interfaces/http/agents.py b/Data/Engine/interfaces/http/agents.py new file mode 100644 index 0000000..0485bd0 --- /dev/null +++ b/Data/Engine/interfaces/http/agents.py @@ -0,0 +1,23 @@ +"""Agent HTTP interface placeholders for the Engine.""" + +from __future__ import annotations + +from flask import Blueprint, Flask + +from Data.Engine.services.container import EngineServiceContainer + + +blueprint = Blueprint("engine_agents", __name__, url_prefix="/api/agents") + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + """Attach agent management routes to *app*. + + Implementation will be populated as services migrate from the legacy server. + """ + + if "engine_agents" not in app.blueprints: + app.register_blueprint(blueprint) + + +__all__ = ["register", "blueprint"] diff --git a/Data/Engine/interfaces/http/enrollment.py b/Data/Engine/interfaces/http/enrollment.py new file mode 100644 index 0000000..5d65ff5 --- /dev/null +++ b/Data/Engine/interfaces/http/enrollment.py @@ -0,0 +1,111 @@ +"""Enrollment HTTP interface placeholders for the Engine.""" + +from __future__ import annotations + +from flask import Blueprint, Flask, current_app, jsonify, request + +from Data.Engine.builders.device_enrollment import EnrollmentRequestBuilder +from Data.Engine.services import EnrollmentValidationError +from Data.Engine.services.container import EngineServiceContainer + +AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context" + + +blueprint = Blueprint("engine_enrollment", __name__) + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + """Attach enrollment routes to *app*.""" + + if "engine_enrollment" not in app.blueprints: + app.register_blueprint(blueprint) + + +@blueprint.route("/api/agent/enroll/request", methods=["POST"]) +def enrollment_request() -> object: + services: EngineServiceContainer = current_app.extensions["engine_services"] + payload = request.get_json(force=True, silent=True) + builder = EnrollmentRequestBuilder().with_payload(payload).with_service_context( + request.headers.get(AGENT_CONTEXT_HEADER) + ) + try: + normalized = builder.build() + result = services.enrollment_service.request_enrollment( + normalized, + remote_addr=_remote_addr(), + ) + except EnrollmentValidationError as exc: + response = jsonify(exc.to_response()) + response.status_code = exc.http_status + if exc.retry_after is not None: + response.headers["Retry-After"] = f"{int(exc.retry_after)}" + return response + + response_payload = { + "status": result.status, + "approval_reference": result.approval_reference, + "server_nonce": result.server_nonce, + "poll_after_ms": result.poll_after_ms, + "server_certificate": result.server_certificate, + "signing_key": result.signing_key, + } + response = jsonify(response_payload) + response.status_code = result.http_status + if result.retry_after is not None: + response.headers["Retry-After"] = f"{int(result.retry_after)}" + return response + + +@blueprint.route("/api/agent/enroll/poll", methods=["POST"]) +def enrollment_poll() -> object: + services: EngineServiceContainer = current_app.extensions["engine_services"] + payload = request.get_json(force=True, silent=True) or {} + approval_reference = str(payload.get("approval_reference") or "").strip() + client_nonce = str(payload.get("client_nonce") or "").strip() + proof_sig = str(payload.get("proof_sig") or "").strip() + + try: + result = services.enrollment_service.poll_enrollment( + approval_reference=approval_reference, + client_nonce_b64=client_nonce, + proof_signature_b64=proof_sig, + ) + except EnrollmentValidationError as exc: + return jsonify(exc.to_response()), exc.http_status + + body = {"status": result.status} + if result.poll_after_ms is not None: + body["poll_after_ms"] = result.poll_after_ms + if result.reason: + body["reason"] = result.reason + if result.detail: + body["detail"] = result.detail + if result.tokens: + body.update( + { + "guid": result.tokens.guid.value, + "access_token": result.tokens.access_token, + "refresh_token": result.tokens.refresh_token, + "token_type": result.tokens.token_type, + "expires_in": result.tokens.expires_in, + "server_certificate": result.server_certificate or "", + "signing_key": result.signing_key or "", + } + ) + else: + if result.server_certificate: + body["server_certificate"] = result.server_certificate + if result.signing_key: + body["signing_key"] = result.signing_key + + return jsonify(body), result.http_status + + +def _remote_addr() -> str: + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + return (request.remote_addr or "unknown").strip() + + +__all__ = ["register", "blueprint", "enrollment_request", "enrollment_poll"] diff --git a/Data/Engine/interfaces/http/github.py b/Data/Engine/interfaces/http/github.py new file mode 100644 index 0000000..93a1095 --- /dev/null +++ b/Data/Engine/interfaces/http/github.py @@ -0,0 +1,60 @@ +"""GitHub-related HTTP endpoints.""" + +from __future__ import annotations + +from flask import Blueprint, Flask, current_app, jsonify, request + +from Data.Engine.services.container import EngineServiceContainer + +blueprint = Blueprint("engine_github", __name__) + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + if "engine_github" not in app.blueprints: + app.register_blueprint(blueprint) + + +@blueprint.route("/api/repo/current_hash", methods=["GET"]) +def repo_current_hash() -> object: + services: EngineServiceContainer = current_app.extensions["engine_services"] + github = services.github_service + + repo = (request.args.get("repo") or "").strip() or None + branch = (request.args.get("branch") or "").strip() or None + refresh_flag = (request.args.get("refresh") or "").strip().lower() + ttl_raw = request.args.get("ttl") + try: + ttl = int(ttl_raw) if ttl_raw else github.default_refresh_interval + except ValueError: + ttl = github.default_refresh_interval + force_refresh = refresh_flag in {"1", "true", "yes", "force", "refresh"} + + snapshot = github.get_repo_head(repo, branch, ttl_seconds=ttl, force_refresh=force_refresh) + payload = snapshot.to_dict() + if not snapshot.sha: + return jsonify(payload), 503 + return jsonify(payload) + + +@blueprint.route("/api/github/token", methods=["GET", "POST"]) +def github_token() -> object: + services: EngineServiceContainer = current_app.extensions["engine_services"] + github = services.github_service + + if request.method == "GET": + payload = github.get_token_status(force_refresh=True).to_dict() + return jsonify(payload) + + data = request.get_json(silent=True) or {} + token = data.get("token") + normalized = str(token).strip() if token is not None else "" + try: + payload = github.update_token(normalized).to_dict() + except Exception as exc: # pragma: no cover - defensive logging + current_app.logger.exception("failed to store GitHub token: %s", exc) + return jsonify({"error": f"Failed to store token: {exc}"}), 500 + return jsonify(payload) + + +__all__ = ["register", "blueprint", "repo_current_hash", "github_token"] + diff --git a/Data/Engine/interfaces/http/health.py b/Data/Engine/interfaces/http/health.py new file mode 100644 index 0000000..37e74a7 --- /dev/null +++ b/Data/Engine/interfaces/http/health.py @@ -0,0 +1,26 @@ +"""Health check HTTP interface for the Engine.""" + +from __future__ import annotations + +from flask import Blueprint, Flask, jsonify + +from Data.Engine.services.container import EngineServiceContainer + +blueprint = Blueprint("engine_health", __name__) + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + """Attach health-related routes to *app*.""" + + if "engine_health" not in app.blueprints: + app.register_blueprint(blueprint) + + +@blueprint.route("/health", methods=["GET"]) +def health() -> object: + """Return a basic liveness response.""" + + return jsonify({"status": "ok"}) + + +__all__ = ["register", "blueprint"] diff --git a/Data/Engine/interfaces/http/job_management.py b/Data/Engine/interfaces/http/job_management.py new file mode 100644 index 0000000..93c30ab --- /dev/null +++ b/Data/Engine/interfaces/http/job_management.py @@ -0,0 +1,108 @@ +"""HTTP routes for Engine job management.""" + +from __future__ import annotations + +from typing import Any, Optional + +from flask import Blueprint, Flask, jsonify, request + +from Data.Engine.services.container import EngineServiceContainer + +bp = Blueprint("engine_job_management", __name__) + + +def register(app: Flask, services: EngineServiceContainer) -> None: + bp.services = services # type: ignore[attr-defined] + app.register_blueprint(bp) + + +def _services() -> EngineServiceContainer: + svc = getattr(bp, "services", None) + if svc is None: # pragma: no cover - guard + raise RuntimeError("job management blueprint not initialized") + return svc + + +@bp.route("/api/scheduled_jobs", methods=["GET"]) +def list_jobs() -> Any: + jobs = _services().scheduler_service.list_jobs() + return jsonify({"jobs": jobs}) + + +@bp.route("/api/scheduled_jobs", methods=["POST"]) +def create_job() -> Any: + payload = _json_body() + try: + job = _services().scheduler_service.create_job(payload) + except ValueError as exc: + return jsonify({"error": str(exc)}), 400 + return jsonify({"job": job}) + + +@bp.route("/api/scheduled_jobs/", methods=["GET"]) +def get_job(job_id: int) -> Any: + job = _services().scheduler_service.get_job(job_id) + if not job: + return jsonify({"error": "job not found"}), 404 + return jsonify({"job": job}) + + +@bp.route("/api/scheduled_jobs/", methods=["PUT"]) +def update_job(job_id: int) -> Any: + payload = _json_body() + try: + job = _services().scheduler_service.update_job(job_id, payload) + except ValueError as exc: + return jsonify({"error": str(exc)}), 400 + if not job: + return jsonify({"error": "job not found"}), 404 + return jsonify({"job": job}) + + +@bp.route("/api/scheduled_jobs/", methods=["DELETE"]) +def delete_job(job_id: int) -> Any: + _services().scheduler_service.delete_job(job_id) + return ("", 204) + + +@bp.route("/api/scheduled_jobs//toggle", methods=["POST"]) +def toggle_job(job_id: int) -> Any: + payload = _json_body() + enabled = bool(payload.get("enabled", True)) + _services().scheduler_service.toggle_job(job_id, enabled) + job = _services().scheduler_service.get_job(job_id) + if not job: + return jsonify({"error": "job not found"}), 404 + return jsonify({"job": job}) + + +@bp.route("/api/scheduled_jobs//runs", methods=["GET"]) +def list_runs(job_id: int) -> Any: + days = request.args.get("days") + days_int: Optional[int] = None + if days is not None: + try: + days_int = max(0, int(days)) + except Exception: + return jsonify({"error": "invalid days parameter"}), 400 + runs = _services().scheduler_service.list_runs(job_id, days=days_int) + return jsonify({"runs": runs}) + + +@bp.route("/api/scheduled_jobs//runs", methods=["DELETE"]) +def purge_runs(job_id: int) -> Any: + _services().scheduler_service.purge_runs(job_id) + return ("", 204) + + +def _json_body() -> dict[str, Any]: + if not request.data: + return {} + try: + data = request.get_json(force=True, silent=False) # type: ignore[arg-type] + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +__all__ = ["register"] diff --git a/Data/Engine/interfaces/http/tokens.py b/Data/Engine/interfaces/http/tokens.py new file mode 100644 index 0000000..89bbc3e --- /dev/null +++ b/Data/Engine/interfaces/http/tokens.py @@ -0,0 +1,52 @@ +"""Token management HTTP interface for the Engine.""" + +from __future__ import annotations + +from flask import Blueprint, Flask, current_app, jsonify, request + +from Data.Engine.builders.device_auth import RefreshTokenRequestBuilder +from Data.Engine.domain.device_auth import DeviceAuthFailure +from Data.Engine.services.container import EngineServiceContainer +from Data.Engine.services import TokenRefreshError + +blueprint = Blueprint("engine_tokens", __name__) + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + """Attach token management routes to *app*.""" + + if "engine_tokens" not in app.blueprints: + app.register_blueprint(blueprint) + + +@blueprint.route("/api/agent/token/refresh", methods=["POST"]) +def refresh_token() -> object: + services: EngineServiceContainer = current_app.extensions["engine_services"] + builder = ( + RefreshTokenRequestBuilder() + .with_payload(request.get_json(force=True, silent=True)) + .with_http_method(request.method) + .with_htu(request.url) + .with_dpop_proof(request.headers.get("DPoP")) + ) + try: + refresh_request = builder.build() + except DeviceAuthFailure as exc: + payload = exc.to_dict() + return jsonify(payload), exc.http_status + + try: + response = services.token_service.refresh_access_token(refresh_request) + except TokenRefreshError as exc: + return jsonify(exc.to_dict()), exc.http_status + + return jsonify( + { + "access_token": response.access_token, + "expires_in": response.expires_in, + "token_type": response.token_type, + } + ) + + +__all__ = ["register", "blueprint", "refresh_token"] diff --git a/Data/Engine/interfaces/ws/__init__.py b/Data/Engine/interfaces/ws/__init__.py new file mode 100644 index 0000000..5ed3340 --- /dev/null +++ b/Data/Engine/interfaces/ws/__init__.py @@ -0,0 +1,47 @@ +"""WebSocket interface factory for the Borealis Engine.""" + +from __future__ import annotations + +from typing import Any, Optional + +from flask import Flask + +from ...config import SocketIOSettings +from ...services.container import EngineServiceContainer +from .agents import register as register_agent_events +from .job_management import register as register_job_events + +try: # pragma: no cover - import guard + from flask_socketio import SocketIO +except Exception: # pragma: no cover - optional dependency + SocketIO = None # type: ignore[assignment] + + +def create_socket_server(app: Flask, settings: SocketIOSettings) -> Optional[SocketIO]: + """Create a Socket.IO server bound to *app* if dependencies are available.""" + + if SocketIO is None: + return None + + cors_allowed = settings.cors_allowed_origins or ("*",) + socketio = SocketIO( + app, + cors_allowed_origins=cors_allowed, + async_mode=None, + logger=False, + engineio_logger=False, + ) + return socketio + + +def register_ws_interfaces(socketio: Any, services: EngineServiceContainer) -> None: + """Attach namespaces for the Engine Socket.IO server.""" + + if socketio is None: # pragma: no cover - guard + return + + for registrar in (register_agent_events, register_job_events): + registrar(socketio, services) + + +__all__ = ["create_socket_server", "register_ws_interfaces"] diff --git a/Data/Engine/interfaces/ws/agents/__init__.py b/Data/Engine/interfaces/ws/agents/__init__.py new file mode 100644 index 0000000..cc52081 --- /dev/null +++ b/Data/Engine/interfaces/ws/agents/__init__.py @@ -0,0 +1,16 @@ +"""Agent WebSocket namespace wiring for the Engine.""" + +from __future__ import annotations + +from typing import Any + +from . import events + + +def register(socketio: Any, services) -> None: + """Register agent namespaces on the given Socket.IO *socketio* instance.""" + + events.register(socketio, services) + + +__all__ = ["register"] diff --git a/Data/Engine/interfaces/ws/agents/events.py b/Data/Engine/interfaces/ws/agents/events.py new file mode 100644 index 0000000..0515de5 --- /dev/null +++ b/Data/Engine/interfaces/ws/agents/events.py @@ -0,0 +1,261 @@ +"""Agent WebSocket event handlers for the Borealis Engine.""" + +from __future__ import annotations + +import logging +import time +from typing import Any, Dict, Iterable, Optional + +from flask import request + +from Data.Engine.services.container import EngineServiceContainer + +try: # pragma: no cover - optional dependency guard + from flask_socketio import emit, join_room +except Exception: # pragma: no cover - optional dependency guard + emit = None # type: ignore[assignment] + join_room = None # type: ignore[assignment] + +_AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context" + + +def register(socketio: Any, services: EngineServiceContainer) -> None: + if socketio is None: # pragma: no cover - guard + return + + handlers = _AgentEventHandlers(socketio, services) + socketio.on_event("connect", handlers.on_connect) + socketio.on_event("disconnect", handlers.on_disconnect) + socketio.on_event("agent_screenshot_task", handlers.on_agent_screenshot_task) + socketio.on_event("connect_agent", handlers.on_connect_agent) + socketio.on_event("agent_heartbeat", handlers.on_agent_heartbeat) + socketio.on_event("collector_status", handlers.on_collector_status) + socketio.on_event("request_config", handlers.on_request_config) + socketio.on_event("screenshot", handlers.on_screenshot) + socketio.on_event("macro_status", handlers.on_macro_status) + socketio.on_event("list_agent_windows", handlers.on_list_agent_windows) + socketio.on_event("agent_window_list", handlers.on_agent_window_list) + socketio.on_event("ansible_playbook_cancel", handlers.on_ansible_playbook_cancel) + socketio.on_event("ansible_playbook_run", handlers.on_ansible_playbook_run) + + +class _AgentEventHandlers: + def __init__(self, socketio: Any, services: EngineServiceContainer) -> None: + self._socketio = socketio + self._services = services + self._realtime = services.agent_realtime + self._log = logging.getLogger("borealis.engine.ws.agents") + + # ------------------------------------------------------------------ + # Connection lifecycle + # ------------------------------------------------------------------ + def on_connect(self) -> None: + sid = getattr(request, "sid", "") + remote_addr = getattr(request, "remote_addr", None) + transport = None + try: + transport = request.args.get("transport") # type: ignore[attr-defined] + except Exception: + transport = None + query = self._render_query() + headers = _summarize_socket_headers(getattr(request, "headers", {})) + scope = _canonical_scope(getattr(request.headers, "get", lambda *_: None)(_AGENT_CONTEXT_HEADER)) + self._log.info( + "socket-connect sid=%s ip=%s transport=%r query=%s headers=%s scope=%s", + sid, + remote_addr, + transport, + query, + headers, + scope or "", + ) + + def on_disconnect(self) -> None: + sid = getattr(request, "sid", "") + remote_addr = getattr(request, "remote_addr", None) + self._log.info("socket-disconnect sid=%s ip=%s", sid, remote_addr) + + # ------------------------------------------------------------------ + # Agent coordination + # ------------------------------------------------------------------ + def on_agent_screenshot_task(self, data: Optional[Dict[str, Any]]) -> None: + payload = data or {} + agent_id = payload.get("agent_id") + node_id = payload.get("node_id") + image = payload.get("image_base64", "") + + if not agent_id or not node_id: + self._log.warning("screenshot-task missing identifiers: %s", payload) + return + + if image: + self._realtime.store_task_screenshot(agent_id, node_id, image) + + try: + self._socketio.emit("agent_screenshot_task", payload) + except Exception as exc: # pragma: no cover - network guard + self._log.warning("socket emit failed for agent_screenshot_task: %s", exc) + + def on_connect_agent(self, data: Optional[Dict[str, Any]]) -> None: + payload = data or {} + agent_id = payload.get("agent_id") + if not agent_id: + return + + service_mode = payload.get("service_mode") + record = self._realtime.register_connection(agent_id, service_mode) + + if join_room is not None: # pragma: no branch - optional dependency guard + try: + join_room(agent_id) + except Exception as exc: # pragma: no cover - dependency guard + self._log.debug("join_room failed for %s: %s", agent_id, exc) + + self._log.info( + "agent-connected agent_id=%s mode=%s status=%s", + agent_id, + record.service_mode, + record.status, + ) + + def on_agent_heartbeat(self, data: Optional[Dict[str, Any]]) -> None: + payload = data or {} + record = self._realtime.heartbeat(payload) + if record: + self._log.debug( + "agent-heartbeat agent_id=%s host=%s mode=%s", record.agent_id, record.hostname, record.service_mode + ) + + def on_collector_status(self, data: Optional[Dict[str, Any]]) -> None: + payload = data or {} + self._realtime.collector_status(payload) + + def on_request_config(self, data: Optional[Dict[str, Any]]) -> None: + payload = data or {} + agent_id = payload.get("agent_id") + if not agent_id: + return + config = self._realtime.get_agent_config(agent_id) + if config and emit is not None: + try: + emit("agent_config", {**config, "agent_id": agent_id}) + except Exception as exc: # pragma: no cover - dependency guard + self._log.debug("emit(agent_config) failed for %s: %s", agent_id, exc) + + # ------------------------------------------------------------------ + # Media + relay events + # ------------------------------------------------------------------ + def on_screenshot(self, data: Optional[Dict[str, Any]]) -> None: + payload = data or {} + agent_id = payload.get("agent_id") + image = payload.get("image_base64") + if agent_id and image: + self._realtime.store_agent_screenshot(agent_id, image) + try: + self._socketio.emit("new_screenshot", {"agent_id": agent_id, "image_base64": image}) + except Exception as exc: # pragma: no cover - dependency guard + self._log.warning("socket emit failed for new_screenshot: %s", exc) + + def on_macro_status(self, data: Optional[Dict[str, Any]]) -> None: + payload = data or {} + agent_id = payload.get("agent_id") + node_id = payload.get("node_id") + success = payload.get("success") + message = payload.get("message") + self._log.info( + "macro-status agent=%s node=%s success=%s message=%s", + agent_id, + node_id, + success, + message, + ) + try: + self._socketio.emit("macro_status", payload) + except Exception as exc: # pragma: no cover - dependency guard + self._log.warning("socket emit failed for macro_status: %s", exc) + + def on_list_agent_windows(self, data: Optional[Dict[str, Any]]) -> None: + payload = data or {} + try: + self._socketio.emit("list_agent_windows", payload) + except Exception as exc: # pragma: no cover - dependency guard + self._log.warning("socket emit failed for list_agent_windows: %s", exc) + + def on_agent_window_list(self, data: Optional[Dict[str, Any]]) -> None: + payload = data or {} + try: + self._socketio.emit("agent_window_list", payload) + except Exception as exc: # pragma: no cover - dependency guard + self._log.warning("socket emit failed for agent_window_list: %s", exc) + + def on_ansible_playbook_cancel(self, data: Optional[Dict[str, Any]]) -> None: + try: + self._socketio.emit("ansible_playbook_cancel", data or {}) + except Exception as exc: # pragma: no cover - dependency guard + self._log.warning("socket emit failed for ansible_playbook_cancel: %s", exc) + + def on_ansible_playbook_run(self, data: Optional[Dict[str, Any]]) -> None: + try: + self._socketio.emit("ansible_playbook_run", data or {}) + except Exception as exc: # pragma: no cover - dependency guard + self._log.warning("socket emit failed for ansible_playbook_run: %s", exc) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _render_query(self) -> str: + try: + pairs = [f"{k}={v}" for k, v in request.args.items()] # type: ignore[attr-defined] + except Exception: + return "" + return "&".join(pairs) if pairs else "" + + +def _canonical_scope(raw: Optional[str]) -> Optional[str]: + if not raw: + return None + value = "".join(ch for ch in str(raw) if ch.isalnum() or ch in ("_", "-")) + if not value: + return None + return value.upper() + + +def _mask_value(value: str, *, prefix: int = 4, suffix: int = 4) -> str: + try: + if not value: + return "" + stripped = value.strip() + if len(stripped) <= prefix + suffix: + return "*" * len(stripped) + return f"{stripped[:prefix]}***{stripped[-suffix:]}" + except Exception: + return "***" + + +def _summarize_socket_headers(headers: Any) -> str: + try: + items: Iterable[tuple[str, Any]] + if isinstance(headers, dict): + items = headers.items() + else: + items = getattr(headers, "items", lambda: [])() + except Exception: + items = [] + + rendered = [] + for key, value in items: + lowered = str(key).lower() + display = value + if lowered == "authorization": + token = str(value or "") + if token.lower().startswith("bearer "): + display = f"Bearer {_mask_value(token.split(' ', 1)[1])}" + else: + display = _mask_value(token) + elif lowered == "cookie": + display = "" + rendered.append(f"{key}={display}") + return ", ".join(rendered) if rendered else "" + + +__all__ = ["register"] diff --git a/Data/Engine/interfaces/ws/job_management/__init__.py b/Data/Engine/interfaces/ws/job_management/__init__.py new file mode 100644 index 0000000..225073f --- /dev/null +++ b/Data/Engine/interfaces/ws/job_management/__init__.py @@ -0,0 +1,16 @@ +"""Job management WebSocket namespace wiring for the Engine.""" + +from __future__ import annotations + +from typing import Any + +from . import events + + +def register(socketio: Any, services) -> None: + """Register job management namespaces on the given Socket.IO *socketio*.""" + + events.register(socketio, services) + + +__all__ = ["register"] diff --git a/Data/Engine/interfaces/ws/job_management/events.py b/Data/Engine/interfaces/ws/job_management/events.py new file mode 100644 index 0000000..9c77852 --- /dev/null +++ b/Data/Engine/interfaces/ws/job_management/events.py @@ -0,0 +1,38 @@ +"""Job management WebSocket event handlers.""" + +from __future__ import annotations + +import logging +from typing import Any, Optional + +from Data.Engine.services.container import EngineServiceContainer + + +def register(socketio: Any, services: EngineServiceContainer) -> None: + if socketio is None: # pragma: no cover - guard + return + + handlers = _JobEventHandlers(socketio, services) + socketio.on_event("quick_job_result", handlers.on_quick_job_result) + socketio.on_event("job_status_request", handlers.on_job_status_request) + + +class _JobEventHandlers: + def __init__(self, socketio: Any, services: EngineServiceContainer) -> None: + self._socketio = socketio + self._services = services + self._log = logging.getLogger("borealis.engine.ws.jobs") + + def on_quick_job_result(self, data: Optional[dict]) -> None: + self._log.info("quick-job-result received; scheduler migration pending") + # Step 10 will introduce full persistence + broadcast logic. + + def on_job_status_request(self, _: Optional[dict]) -> None: + jobs = self._services.scheduler_service.list_jobs() + try: + self._socketio.emit("job_status", {"jobs": jobs}) + except Exception: + self._log.debug("job-status emit failed") + + +__all__ = ["register"] diff --git a/Data/Engine/repositories/__init__.py b/Data/Engine/repositories/__init__.py new file mode 100644 index 0000000..ebc3372 --- /dev/null +++ b/Data/Engine/repositories/__init__.py @@ -0,0 +1,7 @@ +"""Persistence adapters for the Borealis Engine.""" + +from __future__ import annotations + +from . import sqlite + +__all__ = ["sqlite"] diff --git a/Data/Engine/repositories/sqlite/__init__.py b/Data/Engine/repositories/sqlite/__init__.py new file mode 100644 index 0000000..ceef224 --- /dev/null +++ b/Data/Engine/repositories/sqlite/__init__.py @@ -0,0 +1,47 @@ +"""SQLite persistence helpers for the Borealis Engine.""" + +from __future__ import annotations + +from .connection import ( + SQLiteConnectionFactory, + configure_connection, + connect, + connection_factory, + connection_scope, +) +from .migrations import apply_all + +__all__ = [ + "SQLiteConnectionFactory", + "configure_connection", + "connect", + "connection_factory", + "connection_scope", + "apply_all", +] + +try: # pragma: no cover - optional dependency shim + from .device_repository import SQLiteDeviceRepository + from .enrollment_repository import SQLiteEnrollmentRepository + from .github_repository import SQLiteGitHubRepository + from .job_repository import SQLiteJobRepository + from .token_repository import SQLiteRefreshTokenRepository +except ModuleNotFoundError as exc: # pragma: no cover - triggered when auth deps missing + def _missing_repo(*_args: object, **_kwargs: object) -> None: + raise ModuleNotFoundError( + "Engine SQLite repositories require optional authentication dependencies" + ) from exc + + SQLiteDeviceRepository = _missing_repo # type: ignore[assignment] + SQLiteEnrollmentRepository = _missing_repo # type: ignore[assignment] + SQLiteGitHubRepository = _missing_repo # type: ignore[assignment] + SQLiteJobRepository = _missing_repo # type: ignore[assignment] + SQLiteRefreshTokenRepository = _missing_repo # type: ignore[assignment] +else: + __all__ += [ + "SQLiteDeviceRepository", + "SQLiteRefreshTokenRepository", + "SQLiteJobRepository", + "SQLiteEnrollmentRepository", + "SQLiteGitHubRepository", + ] diff --git a/Data/Engine/repositories/sqlite/connection.py b/Data/Engine/repositories/sqlite/connection.py new file mode 100644 index 0000000..0e59302 --- /dev/null +++ b/Data/Engine/repositories/sqlite/connection.py @@ -0,0 +1,67 @@ +"""SQLite connection utilities for the Borealis Engine.""" + +from __future__ import annotations + +import sqlite3 +from contextlib import contextmanager +from pathlib import Path +from typing import Iterator, Protocol + +__all__ = [ + "SQLiteConnectionFactory", + "configure_connection", + "connect", + "connection_factory", + "connection_scope", +] + + +class SQLiteConnectionFactory(Protocol): + """Callable protocol for obtaining configured SQLite connections.""" + + def __call__(self) -> sqlite3.Connection: + """Return a new :class:`sqlite3.Connection`.""" + + +def configure_connection(conn: sqlite3.Connection) -> None: + """Apply the Borealis-standard pragmas to *conn*.""" + + cur = conn.cursor() + try: + cur.execute("PRAGMA journal_mode=WAL") + cur.execute("PRAGMA busy_timeout=5000") + cur.execute("PRAGMA synchronous=NORMAL") + conn.commit() + except Exception: + # Pragmas are best-effort; failing to apply them should not block startup. + conn.rollback() + finally: + cur.close() + + +def connect(path: Path, *, timeout: float = 15.0) -> sqlite3.Connection: + """Create a new SQLite connection to *path* with Engine pragmas applied.""" + + conn = sqlite3.connect(str(path), timeout=timeout) + configure_connection(conn) + return conn + + +def connection_factory(path: Path, *, timeout: float = 15.0) -> SQLiteConnectionFactory: + """Return a factory that opens connections to *path* when invoked.""" + + def factory() -> sqlite3.Connection: + return connect(path, timeout=timeout) + + return factory + + +@contextmanager +def connection_scope(path: Path, *, timeout: float = 15.0) -> Iterator[sqlite3.Connection]: + """Context manager yielding a configured connection to *path*.""" + + conn = connect(path, timeout=timeout) + try: + yield conn + finally: + conn.close() diff --git a/Data/Engine/repositories/sqlite/device_repository.py b/Data/Engine/repositories/sqlite/device_repository.py new file mode 100644 index 0000000..b88bb6e --- /dev/null +++ b/Data/Engine/repositories/sqlite/device_repository.py @@ -0,0 +1,410 @@ +"""SQLite-backed device repository for the Engine authentication services.""" + +from __future__ import annotations + +import logging +import sqlite3 +import time +import uuid +from contextlib import closing +from datetime import datetime, timezone +from typing import Optional + +from Data.Engine.domain.device_auth import ( + DeviceFingerprint, + DeviceGuid, + DeviceIdentity, + DeviceStatus, +) +from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory +from Data.Engine.services.auth.device_auth_service import DeviceRecord + +__all__ = ["SQLiteDeviceRepository"] + + +class SQLiteDeviceRepository: + """Persistence adapter that reads and recovers device rows.""" + + def __init__( + self, + connection_factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._connections = connection_factory + self._log = logger or logging.getLogger("borealis.engine.repositories.devices") + + def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: + """Fetch a device row by GUID, normalizing legacy case variance.""" + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT guid, ssl_key_fingerprint, token_version, status + FROM devices + WHERE UPPER(guid) = ? + """, + (guid.value.upper(),), + ) + rows = cur.fetchall() + + if not rows: + return None + + for row in rows: + record = self._row_to_record(row) + if record and record.identity.guid.value == guid.value: + return record + + # Fall back to the first row if normalization failed to match exactly. + return self._row_to_record(rows[0]) + + def recover_missing( + self, + guid: DeviceGuid, + fingerprint: DeviceFingerprint, + token_version: int, + service_context: Optional[str], + ) -> Optional[DeviceRecord]: + """Attempt to recreate a missing device row for a valid token.""" + + now_ts = int(time.time()) + now_iso = datetime.now(tz=timezone.utc).isoformat() + base_hostname = f"RECOVERED-{guid.value[:12]}" + + with closing(self._connections()) as conn: + cur = conn.cursor() + for attempt in range(6): + hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}" + try: + cur.execute( + """ + INSERT INTO devices ( + guid, + hostname, + created_at, + last_seen, + ssl_key_fingerprint, + token_version, + status, + key_added_at + ) + VALUES (?, ?, ?, ?, ?, ?, 'active', ?) + """, + ( + guid.value, + hostname, + now_ts, + now_ts, + fingerprint.value, + max(token_version or 1, 1), + now_iso, + ), + ) + except sqlite3.IntegrityError as exc: + message = str(exc).lower() + if "hostname" in message and "unique" in message: + continue + self._log.warning( + "device auth failed to recover guid=%s (context=%s): %s", + guid.value, + service_context or "none", + exc, + ) + conn.rollback() + return None + except Exception as exc: # pragma: no cover - defensive logging + self._log.exception( + "device auth unexpected error recovering guid=%s (context=%s)", + guid.value, + service_context or "none", + exc_info=exc, + ) + conn.rollback() + return None + else: + conn.commit() + break + else: + self._log.warning( + "device auth could not recover guid=%s; hostname collisions persisted", + guid.value, + ) + conn.rollback() + return None + + cur.execute( + """ + SELECT guid, ssl_key_fingerprint, token_version, status + FROM devices + WHERE guid = ? + """, + (guid.value,), + ) + row = cur.fetchone() + + if not row: + self._log.warning( + "device auth recovery committed but row missing for guid=%s", + guid.value, + ) + return None + + return self._row_to_record(row) + + def ensure_device_record( + self, + *, + guid: DeviceGuid, + hostname: str, + fingerprint: DeviceFingerprint, + ) -> DeviceRecord: + now_iso = datetime.now(tz=timezone.utc).isoformat() + now_ts = int(time.time()) + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at + FROM devices + WHERE UPPER(guid) = ? + """, + (guid.value.upper(),), + ) + row = cur.fetchone() + + if row: + stored_fp = (row[4] or "").strip().lower() + new_fp = fingerprint.value + if not stored_fp: + cur.execute( + "UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?", + (new_fp, now_iso, row[0]), + ) + elif stored_fp != new_fp: + token_version = self._coerce_int(row[2], default=1) + 1 + cur.execute( + """ + UPDATE devices + SET ssl_key_fingerprint = ?, + key_added_at = ?, + token_version = ?, + status = 'active' + WHERE guid = ? + """, + (new_fp, now_iso, token_version, row[0]), + ) + cur.execute( + """ + UPDATE refresh_tokens + SET revoked_at = ? + WHERE guid = ? + AND revoked_at IS NULL + """, + (now_iso, row[0]), + ) + conn.commit() + else: + resolved_hostname = self._resolve_hostname(cur, hostname, guid) + cur.execute( + """ + INSERT INTO devices ( + guid, + hostname, + created_at, + last_seen, + ssl_key_fingerprint, + token_version, + status, + key_added_at + ) + VALUES (?, ?, ?, ?, ?, 1, 'active', ?) + """, + ( + guid.value, + resolved_hostname, + now_ts, + now_ts, + fingerprint.value, + now_iso, + ), + ) + conn.commit() + + cur.execute( + """ + SELECT guid, ssl_key_fingerprint, token_version, status + FROM devices + WHERE UPPER(guid) = ? + """, + (guid.value.upper(),), + ) + latest = cur.fetchone() + + if not latest: + raise RuntimeError("device record could not be ensured") + + record = self._row_to_record(latest) + if record is None: + raise RuntimeError("device record invalid after ensure") + return record + + def record_device_key( + self, + *, + guid: DeviceGuid, + fingerprint: DeviceFingerprint, + added_at: datetime, + ) -> None: + added_iso = added_at.astimezone(timezone.utc).isoformat() + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at) + VALUES (?, ?, ?, ?) + """, + (str(uuid.uuid4()), guid.value, fingerprint.value, added_iso), + ) + cur.execute( + """ + UPDATE device_keys + SET retired_at = ? + WHERE guid = ? + AND ssl_key_fingerprint != ? + AND retired_at IS NULL + """, + (added_iso, guid.value, fingerprint.value), + ) + conn.commit() + + def update_device_summary( + self, + *, + hostname: Optional[str], + last_seen: Optional[int] = None, + agent_id: Optional[str] = None, + operating_system: Optional[str] = None, + last_user: Optional[str] = None, + ) -> None: + if not hostname: + return + + normalized_hostname = (hostname or "").strip() + if not normalized_hostname: + return + + fields = [] + params = [] + + if last_seen is not None: + try: + fields.append("last_seen = ?") + params.append(int(last_seen)) + except Exception: + pass + + if agent_id: + try: + candidate = agent_id.strip() + except Exception: + candidate = agent_id + if candidate: + fields.append("agent_id = ?") + params.append(candidate) + + if operating_system: + try: + os_value = operating_system.strip() + except Exception: + os_value = operating_system + if os_value: + fields.append("operating_system = ?") + params.append(os_value) + + if last_user: + try: + user_value = last_user.strip() + except Exception: + user_value = last_user + if user_value: + fields.append("last_user = ?") + params.append(user_value) + + if not fields: + return + + params.append(normalized_hostname) + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + f"UPDATE devices SET {', '.join(fields)} WHERE LOWER(hostname) = LOWER(?)", + params, + ) + if cur.rowcount == 0 and agent_id: + cur.execute( + f"UPDATE devices SET {', '.join(fields)} WHERE agent_id = ?", + params[:-1] + [agent_id], + ) + conn.commit() + + def _row_to_record(self, row: tuple) -> Optional[DeviceRecord]: + try: + guid = DeviceGuid(row[0]) + fingerprint_value = (row[1] or "").strip() + if not fingerprint_value: + self._log.warning( + "device row %s missing TLS fingerprint; skipping", + row[0], + ) + return None + fingerprint = DeviceFingerprint(fingerprint_value) + except Exception as exc: + self._log.warning("invalid device row for guid=%s: %s", row[0], exc) + return None + + token_version_raw = row[2] + try: + token_version = int(token_version_raw or 0) + except Exception: + token_version = 0 + + status = DeviceStatus.from_string(row[3]) + identity = DeviceIdentity(guid=guid, fingerprint=fingerprint) + + return DeviceRecord( + identity=identity, + token_version=max(token_version, 1), + status=status, + ) + + @staticmethod + def _coerce_int(value: object, *, default: int = 0) -> int: + try: + return int(value) + except Exception: + return default + + def _resolve_hostname(self, cur: sqlite3.Cursor, hostname: str, guid: DeviceGuid) -> str: + base = (hostname or "").strip() or guid.value + base = base[:253] + candidate = base + suffix = 1 + while True: + cur.execute( + "SELECT guid FROM devices WHERE hostname = ?", + (candidate,), + ) + row = cur.fetchone() + if not row: + return candidate + existing = (row[0] or "").strip().upper() + if existing == guid.value: + return candidate + candidate = f"{base}-{suffix}" + suffix += 1 + if suffix > 50: + return guid.value diff --git a/Data/Engine/repositories/sqlite/enrollment_repository.py b/Data/Engine/repositories/sqlite/enrollment_repository.py new file mode 100644 index 0000000..a6549ec --- /dev/null +++ b/Data/Engine/repositories/sqlite/enrollment_repository.py @@ -0,0 +1,383 @@ +"""SQLite-backed enrollment repository for Engine services.""" + +from __future__ import annotations + +import logging +from contextlib import closing +from datetime import datetime, timezone +from typing import Optional + +from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid +from Data.Engine.domain.device_enrollment import ( + EnrollmentApproval, + EnrollmentApprovalStatus, + EnrollmentCode, +) +from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory + +__all__ = ["SQLiteEnrollmentRepository"] + + +class SQLiteEnrollmentRepository: + """Persistence adapter that manages enrollment codes and approvals.""" + + def __init__( + self, + connection_factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._connections = connection_factory + self._log = logger or logging.getLogger("borealis.engine.repositories.enrollment") + + # ------------------------------------------------------------------ + # Enrollment install codes + # ------------------------------------------------------------------ + def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]: + """Load an enrollment install code by its public value.""" + + code_value = (code or "").strip() + if not code_value: + return None + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT id, + code, + expires_at, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at + FROM enrollment_install_codes + WHERE code = ? + """, + (code_value,), + ) + row = cur.fetchone() + + if not row: + return None + + record = { + "id": row[0], + "code": row[1], + "expires_at": row[2], + "used_at": row[3], + "used_by_guid": row[4], + "max_uses": row[5], + "use_count": row[6], + "last_used_at": row[7], + } + try: + return EnrollmentCode.from_mapping(record) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("invalid enrollment code record for code=%s: %s", code_value, exc) + return None + + def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]: + record_value = (record_id or "").strip() + if not record_value: + return None + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT id, + code, + expires_at, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at + FROM enrollment_install_codes + WHERE id = ? + """, + (record_value,), + ) + row = cur.fetchone() + + if not row: + return None + + record = { + "id": row[0], + "code": row[1], + "expires_at": row[2], + "used_at": row[3], + "used_by_guid": row[4], + "max_uses": row[5], + "use_count": row[6], + "last_used_at": row[7], + } + + try: + return EnrollmentCode.from_mapping(record) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc) + return None + + def update_install_code_usage( + self, + record_id: str, + *, + use_count_increment: int, + last_used_at: datetime, + used_by_guid: Optional[DeviceGuid] = None, + mark_first_use: bool = False, + ) -> None: + """Increment usage counters and usage metadata for an install code.""" + + if use_count_increment <= 0: + raise ValueError("use_count_increment must be positive") + + last_used_iso = self._isoformat(last_used_at) + guid_value = used_by_guid.value if used_by_guid else "" + mark_flag = 1 if mark_first_use else 0 + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + UPDATE enrollment_install_codes + SET use_count = use_count + ?, + last_used_at = ?, + used_by_guid = COALESCE(NULLIF(?, ''), used_by_guid), + used_at = CASE WHEN ? = 1 AND used_at IS NULL THEN ? ELSE used_at END + WHERE id = ? + """, + ( + use_count_increment, + last_used_iso, + guid_value, + mark_flag, + last_used_iso, + record_id, + ), + ) + conn.commit() + + # ------------------------------------------------------------------ + # Device approvals + # ------------------------------------------------------------------ + def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: + """Load a device approval using its operator-visible reference.""" + + ref_value = (reference or "").strip() + if not ref_value: + return None + return self._fetch_device_approval("approval_reference = ?", (ref_value,)) + + def fetch_device_approval(self, record_id: str) -> Optional[EnrollmentApproval]: + record_value = (record_id or "").strip() + if not record_value: + return None + return self._fetch_device_approval("id = ?", (record_value,)) + + def fetch_pending_approval_by_fingerprint( + self, fingerprint: DeviceFingerprint + ) -> Optional[EnrollmentApproval]: + return self._fetch_device_approval( + "ssl_key_fingerprint_claimed = ? AND status = 'pending'", + (fingerprint.value,), + ) + + def update_pending_approval( + self, + record_id: str, + *, + hostname: str, + guid: Optional[DeviceGuid], + enrollment_code_id: Optional[str], + client_nonce_b64: str, + server_nonce_b64: str, + agent_pubkey_der: bytes, + updated_at: datetime, + ) -> None: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + UPDATE device_approvals + SET hostname_claimed = ?, + guid = ?, + enrollment_code_id = ?, + client_nonce = ?, + server_nonce = ?, + agent_pubkey_der = ?, + updated_at = ? + WHERE id = ? + """, + ( + hostname, + guid.value if guid else None, + enrollment_code_id, + client_nonce_b64, + server_nonce_b64, + agent_pubkey_der, + self._isoformat(updated_at), + record_id, + ), + ) + conn.commit() + + def create_device_approval( + self, + *, + record_id: str, + reference: str, + claimed_hostname: str, + claimed_fingerprint: DeviceFingerprint, + enrollment_code_id: Optional[str], + client_nonce_b64: str, + server_nonce_b64: str, + agent_pubkey_der: bytes, + created_at: datetime, + status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING, + guid: Optional[DeviceGuid] = None, + ) -> EnrollmentApproval: + created_iso = self._isoformat(created_at) + guid_value = guid.value if guid else None + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO device_approvals ( + id, + approval_reference, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed, + enrollment_code_id, + status, + created_at, + updated_at, + client_nonce, + server_nonce, + agent_pubkey_der + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + record_id, + reference, + guid_value, + claimed_hostname, + claimed_fingerprint.value, + enrollment_code_id, + status.value, + created_iso, + created_iso, + client_nonce_b64, + server_nonce_b64, + agent_pubkey_der, + ), + ) + conn.commit() + + approval = self.fetch_device_approval(record_id) + if approval is None: + raise RuntimeError("failed to load device approval after insert") + return approval + + def update_device_approval_status( + self, + record_id: str, + *, + status: EnrollmentApprovalStatus, + updated_at: datetime, + approved_by: Optional[str] = None, + guid: Optional[DeviceGuid] = None, + ) -> None: + """Transition an approval to a new status.""" + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + UPDATE device_approvals + SET status = ?, + updated_at = ?, + guid = COALESCE(?, guid), + approved_by_user_id = COALESCE(?, approved_by_user_id) + WHERE id = ? + """, + ( + status.value, + self._isoformat(updated_at), + guid.value if guid else None, + approved_by, + record_id, + ), + ) + conn.commit() + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _fetch_device_approval(self, where: str, params: tuple) -> Optional[EnrollmentApproval]: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + f""" + SELECT id, + approval_reference, + guid, + hostname_claimed, + ssl_key_fingerprint_claimed, + enrollment_code_id, + created_at, + updated_at, + status, + approved_by_user_id, + client_nonce, + server_nonce, + agent_pubkey_der + FROM device_approvals + WHERE {where} + """, + params, + ) + row = cur.fetchone() + + if not row: + return None + + record = { + "id": row[0], + "approval_reference": row[1], + "guid": row[2], + "hostname_claimed": row[3], + "ssl_key_fingerprint_claimed": row[4], + "enrollment_code_id": row[5], + "created_at": row[6], + "updated_at": row[7], + "status": row[8], + "approved_by_user_id": row[9], + "client_nonce": row[10], + "server_nonce": row[11], + "agent_pubkey_der": row[12], + } + + try: + return EnrollmentApproval.from_mapping(record) + except Exception as exc: # pragma: no cover - defensive logging + self._log.warning( + "invalid device approval record id=%s reference=%s: %s", + row[0], + row[1], + exc, + ) + return None + + @staticmethod + def _isoformat(value: datetime) -> str: + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc).isoformat() diff --git a/Data/Engine/repositories/sqlite/github_repository.py b/Data/Engine/repositories/sqlite/github_repository.py new file mode 100644 index 0000000..f8e6912 --- /dev/null +++ b/Data/Engine/repositories/sqlite/github_repository.py @@ -0,0 +1,53 @@ +"""SQLite-backed GitHub token persistence.""" + +from __future__ import annotations + +import logging +from contextlib import closing +from typing import Optional + +from .connection import SQLiteConnectionFactory + +__all__ = ["SQLiteGitHubRepository"] + + +class SQLiteGitHubRepository: + """Store and retrieve GitHub API tokens for the Engine.""" + + def __init__( + self, + connection_factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._connections = connection_factory + self._log = logger or logging.getLogger("borealis.engine.repositories.github") + + def load_token(self) -> Optional[str]: + """Return the stored GitHub token if one exists.""" + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute("SELECT token FROM github_token LIMIT 1") + row = cur.fetchone() + + if not row: + return None + + token = (row[0] or "").strip() + return token or None + + def store_token(self, token: Optional[str]) -> None: + """Persist *token*, replacing any prior value.""" + + normalized = (token or "").strip() + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute("DELETE FROM github_token") + if normalized: + cur.execute("INSERT INTO github_token (token) VALUES (?)", (normalized,)) + conn.commit() + + self._log.info("stored-token has_token=%s", bool(normalized)) + diff --git a/Data/Engine/repositories/sqlite/job_repository.py b/Data/Engine/repositories/sqlite/job_repository.py new file mode 100644 index 0000000..c8c3913 --- /dev/null +++ b/Data/Engine/repositories/sqlite/job_repository.py @@ -0,0 +1,355 @@ +"""SQLite-backed persistence for Engine job scheduling.""" + +from __future__ import annotations + +import json +import logging +import time +from dataclasses import dataclass +from typing import Any, Iterable, Optional, Sequence + +import sqlite3 + +from .connection import SQLiteConnectionFactory + +__all__ = [ + "ScheduledJobRecord", + "ScheduledJobRunRecord", + "SQLiteJobRepository", +] + + +def _now_ts() -> int: + return int(time.time()) + + +def _json_dumps(value: Any) -> str: + try: + return json.dumps(value or []) + except Exception: + return "[]" + + +def _json_loads(value: Optional[str]) -> list[Any]: + if not value: + return [] + try: + data = json.loads(value) + if isinstance(data, list): + return data + return [] + except Exception: + return [] + + +@dataclass(frozen=True, slots=True) +class ScheduledJobRecord: + id: int + name: str + components: list[dict[str, Any]] + targets: list[str] + schedule_type: str + start_ts: Optional[int] + duration_stop_enabled: bool + expiration: Optional[str] + execution_context: str + credential_id: Optional[int] + use_service_account: bool + enabled: bool + created_at: Optional[int] + updated_at: Optional[int] + + +@dataclass(frozen=True, slots=True) +class ScheduledJobRunRecord: + id: int + job_id: int + scheduled_ts: Optional[int] + started_ts: Optional[int] + finished_ts: Optional[int] + status: Optional[str] + error: Optional[str] + target_hostname: Optional[str] + created_at: Optional[int] + updated_at: Optional[int] + + +class SQLiteJobRepository: + """Persistence adapter for Engine job scheduling.""" + + def __init__( + self, + factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._factory = factory + self._log = logger or logging.getLogger("borealis.engine.repositories.jobs") + + # ------------------------------------------------------------------ + # Job CRUD + # ------------------------------------------------------------------ + def list_jobs(self) -> list[ScheduledJobRecord]: + query = ( + "SELECT id, name, components_json, targets_json, schedule_type, start_ts, " + "duration_stop_enabled, expiration, execution_context, credential_id, " + "use_service_account, enabled, created_at, updated_at FROM scheduled_jobs " + "ORDER BY id ASC" + ) + return [self._row_to_job(row) for row in self._fetchall(query)] + + def list_enabled_jobs(self) -> list[ScheduledJobRecord]: + query = ( + "SELECT id, name, components_json, targets_json, schedule_type, start_ts, " + "duration_stop_enabled, expiration, execution_context, credential_id, " + "use_service_account, enabled, created_at, updated_at FROM scheduled_jobs " + "WHERE enabled=1 ORDER BY id ASC" + ) + return [self._row_to_job(row) for row in self._fetchall(query)] + + def fetch_job(self, job_id: int) -> Optional[ScheduledJobRecord]: + query = ( + "SELECT id, name, components_json, targets_json, schedule_type, start_ts, " + "duration_stop_enabled, expiration, execution_context, credential_id, " + "use_service_account, enabled, created_at, updated_at FROM scheduled_jobs " + "WHERE id=?" + ) + rows = self._fetchall(query, (job_id,)) + return self._row_to_job(rows[0]) if rows else None + + def create_job( + self, + *, + name: str, + components: Sequence[dict[str, Any]], + targets: Sequence[Any], + schedule_type: str, + start_ts: Optional[int], + duration_stop_enabled: bool, + expiration: Optional[str], + execution_context: str, + credential_id: Optional[int], + use_service_account: bool, + enabled: bool = True, + ) -> ScheduledJobRecord: + now = _now_ts() + payload = ( + name, + _json_dumps(list(components)), + _json_dumps(list(targets)), + schedule_type, + start_ts, + 1 if duration_stop_enabled else 0, + expiration, + execution_context, + credential_id, + 1 if use_service_account else 0, + 1 if enabled else 0, + now, + now, + ) + with self._connect() as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO scheduled_jobs + (name, components_json, targets_json, schedule_type, start_ts, + duration_stop_enabled, expiration, execution_context, credential_id, + use_service_account, enabled, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + payload, + ) + job_id = cur.lastrowid + conn.commit() + record = self.fetch_job(int(job_id)) + if record is None: + raise RuntimeError("failed to create scheduled job") + return record + + def update_job( + self, + job_id: int, + *, + name: str, + components: Sequence[dict[str, Any]], + targets: Sequence[Any], + schedule_type: str, + start_ts: Optional[int], + duration_stop_enabled: bool, + expiration: Optional[str], + execution_context: str, + credential_id: Optional[int], + use_service_account: bool, + ) -> Optional[ScheduledJobRecord]: + now = _now_ts() + payload = ( + name, + _json_dumps(list(components)), + _json_dumps(list(targets)), + schedule_type, + start_ts, + 1 if duration_stop_enabled else 0, + expiration, + execution_context, + credential_id, + 1 if use_service_account else 0, + now, + job_id, + ) + with self._connect() as conn: + cur = conn.cursor() + cur.execute( + """ + UPDATE scheduled_jobs + SET name=?, components_json=?, targets_json=?, schedule_type=?, + start_ts=?, duration_stop_enabled=?, expiration=?, execution_context=?, + credential_id=?, use_service_account=?, updated_at=? + WHERE id=? + """, + payload, + ) + conn.commit() + return self.fetch_job(job_id) + + def set_enabled(self, job_id: int, enabled: bool) -> None: + with self._connect() as conn: + cur = conn.cursor() + cur.execute( + "UPDATE scheduled_jobs SET enabled=?, updated_at=? WHERE id=?", + (1 if enabled else 0, _now_ts(), job_id), + ) + conn.commit() + + def delete_job(self, job_id: int) -> None: + with self._connect() as conn: + cur = conn.cursor() + cur.execute("DELETE FROM scheduled_jobs WHERE id=?", (job_id,)) + conn.commit() + + # ------------------------------------------------------------------ + # Run history + # ------------------------------------------------------------------ + def list_runs(self, job_id: int, *, days: Optional[int] = None) -> list[ScheduledJobRunRecord]: + params: list[Any] = [job_id] + where = "WHERE job_id=?" + if days is not None and days > 0: + cutoff = _now_ts() - (days * 86400) + where += " AND COALESCE(finished_ts, scheduled_ts, started_ts, 0) >= ?" + params.append(cutoff) + + query = ( + "SELECT id, job_id, scheduled_ts, started_ts, finished_ts, status, error, " + "target_hostname, created_at, updated_at FROM scheduled_job_runs " + f"{where} ORDER BY COALESCE(scheduled_ts, created_at, id) DESC" + ) + return [self._row_to_run(row) for row in self._fetchall(query, tuple(params))] + + def fetch_last_run(self, job_id: int) -> Optional[ScheduledJobRunRecord]: + query = ( + "SELECT id, job_id, scheduled_ts, started_ts, finished_ts, status, error, " + "target_hostname, created_at, updated_at FROM scheduled_job_runs " + "WHERE job_id=? ORDER BY COALESCE(started_ts, scheduled_ts, created_at, id) DESC LIMIT 1" + ) + rows = self._fetchall(query, (job_id,)) + return self._row_to_run(rows[0]) if rows else None + + def purge_runs(self, job_id: int) -> None: + with self._connect() as conn: + cur = conn.cursor() + cur.execute("DELETE FROM scheduled_job_run_activity WHERE run_id IN (SELECT id FROM scheduled_job_runs WHERE job_id=?)", (job_id,)) + cur.execute("DELETE FROM scheduled_job_runs WHERE job_id=?", (job_id,)) + conn.commit() + + def create_run(self, job_id: int, scheduled_ts: int, *, target_hostname: Optional[str] = None) -> int: + now = _now_ts() + with self._connect() as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO scheduled_job_runs + (job_id, scheduled_ts, created_at, updated_at, target_hostname, status) + VALUES (?, ?, ?, ?, ?, 'Pending') + """, + (job_id, scheduled_ts, now, now, target_hostname), + ) + run_id = int(cur.lastrowid) + conn.commit() + return run_id + + def mark_run_started(self, run_id: int, *, started_ts: Optional[int] = None) -> None: + started = started_ts or _now_ts() + with self._connect() as conn: + cur = conn.cursor() + cur.execute( + "UPDATE scheduled_job_runs SET started_ts=?, status='Running', updated_at=? WHERE id=?", + (started, _now_ts(), run_id), + ) + conn.commit() + + def mark_run_finished( + self, + run_id: int, + *, + status: str, + error: Optional[str] = None, + finished_ts: Optional[int] = None, + ) -> None: + finished = finished_ts or _now_ts() + with self._connect() as conn: + cur = conn.cursor() + cur.execute( + "UPDATE scheduled_job_runs SET finished_ts=?, status=?, error=?, updated_at=? WHERE id=?", + (finished, status, error, _now_ts(), run_id), + ) + conn.commit() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _connect(self) -> sqlite3.Connection: + return self._factory() + + def _fetchall(self, query: str, params: Optional[Iterable[Any]] = None) -> list[sqlite3.Row]: + with self._connect() as conn: + conn.row_factory = sqlite3.Row + cur = conn.cursor() + cur.execute(query, tuple(params or ())) + rows = cur.fetchall() + return rows + + def _row_to_job(self, row: sqlite3.Row) -> ScheduledJobRecord: + components = _json_loads(row["components_json"]) + targets_raw = _json_loads(row["targets_json"]) + targets = [str(t) for t in targets_raw if isinstance(t, (str, int))] + credential_id = row["credential_id"] + return ScheduledJobRecord( + id=int(row["id"]), + name=str(row["name"] or ""), + components=[c for c in components if isinstance(c, dict)], + targets=targets, + schedule_type=str(row["schedule_type"] or "immediately"), + start_ts=int(row["start_ts"]) if row["start_ts"] is not None else None, + duration_stop_enabled=bool(row["duration_stop_enabled"]), + expiration=str(row["expiration"]) if row["expiration"] else None, + execution_context=str(row["execution_context"] or "system"), + credential_id=int(credential_id) if credential_id is not None else None, + use_service_account=bool(row["use_service_account"]), + enabled=bool(row["enabled"]), + created_at=int(row["created_at"]) if row["created_at"] is not None else None, + updated_at=int(row["updated_at"]) if row["updated_at"] is not None else None, + ) + + def _row_to_run(self, row: sqlite3.Row) -> ScheduledJobRunRecord: + return ScheduledJobRunRecord( + id=int(row["id"]), + job_id=int(row["job_id"]), + scheduled_ts=int(row["scheduled_ts"]) if row["scheduled_ts"] is not None else None, + started_ts=int(row["started_ts"]) if row["started_ts"] is not None else None, + finished_ts=int(row["finished_ts"]) if row["finished_ts"] is not None else None, + status=str(row["status"]) if row["status"] else None, + error=str(row["error"]) if row["error"] else None, + target_hostname=str(row["target_hostname"]) if row["target_hostname"] else None, + created_at=int(row["created_at"]) if row["created_at"] is not None else None, + updated_at=int(row["updated_at"]) if row["updated_at"] is not None else None, + ) diff --git a/Data/Engine/repositories/sqlite/migrations.py b/Data/Engine/repositories/sqlite/migrations.py new file mode 100644 index 0000000..4dddca0 --- /dev/null +++ b/Data/Engine/repositories/sqlite/migrations.py @@ -0,0 +1,507 @@ +"""SQLite schema migrations for the Borealis Engine. + +This module centralises schema evolution so the Engine and its interfaces can stay +focused on request handling. The migration functions are intentionally +idempotent — they can run repeatedly without changing state once the schema +matches the desired shape. +""" + +from __future__ import annotations + +import sqlite3 +import uuid +from datetime import datetime, timezone +from typing import List, Optional, Sequence, Tuple + + +DEVICE_TABLE = "devices" + + +def apply_all(conn: sqlite3.Connection) -> None: + """ + Run all known schema migrations against the provided sqlite3 connection. + """ + + _ensure_devices_table(conn) + _ensure_device_aux_tables(conn) + _ensure_refresh_token_table(conn) + _ensure_install_code_table(conn) + _ensure_device_approval_table(conn) + _ensure_github_token_table(conn) + _ensure_scheduled_jobs_table(conn) + _ensure_scheduled_job_run_tables(conn) + + conn.commit() + + +def _ensure_devices_table(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + if not _table_exists(cur, DEVICE_TABLE): + _create_devices_table(cur) + return + + column_info = _table_info(cur, DEVICE_TABLE) + col_names = [c[1] for c in column_info] + pk_cols = [c[1] for c in column_info if c[5]] + + needs_rebuild = pk_cols != ["guid"] + required_columns = { + "guid": "TEXT", + "hostname": "TEXT", + "description": "TEXT", + "created_at": "INTEGER", + "agent_hash": "TEXT", + "memory": "TEXT", + "network": "TEXT", + "software": "TEXT", + "storage": "TEXT", + "cpu": "TEXT", + "device_type": "TEXT", + "domain": "TEXT", + "external_ip": "TEXT", + "internal_ip": "TEXT", + "last_reboot": "TEXT", + "last_seen": "INTEGER", + "last_user": "TEXT", + "operating_system": "TEXT", + "uptime": "INTEGER", + "agent_id": "TEXT", + "ansible_ee_ver": "TEXT", + "connection_type": "TEXT", + "connection_endpoint": "TEXT", + "ssl_key_fingerprint": "TEXT", + "token_version": "INTEGER", + "status": "TEXT", + "key_added_at": "TEXT", + } + + missing_columns = [col for col in required_columns if col not in col_names] + if missing_columns: + needs_rebuild = True + + if needs_rebuild: + _rebuild_devices_table(conn, column_info) + else: + _ensure_column_defaults(cur) + + _ensure_device_indexes(cur) + + +def _ensure_device_aux_tables(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS device_keys ( + id TEXT PRIMARY KEY, + guid TEXT NOT NULL, + ssl_key_fingerprint TEXT NOT NULL, + added_at TEXT NOT NULL, + retired_at TEXT + ) + """ + ) + cur.execute( + """ + CREATE UNIQUE INDEX IF NOT EXISTS uq_device_keys_guid_fingerprint + ON device_keys(guid, ssl_key_fingerprint) + """ + ) + cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_device_keys_guid + ON device_keys(guid) + """ + ) + + +def _ensure_refresh_token_table(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS refresh_tokens ( + id TEXT PRIMARY KEY, + guid TEXT NOT NULL, + token_hash TEXT NOT NULL, + dpop_jkt TEXT, + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + revoked_at TEXT, + last_used_at TEXT + ) + """ + ) + cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_refresh_tokens_guid + ON refresh_tokens(guid) + """ + ) + cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at + ON refresh_tokens(expires_at) + """ + ) + + +def _ensure_install_code_table(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS enrollment_install_codes ( + id TEXT PRIMARY KEY, + code TEXT NOT NULL UNIQUE, + expires_at TEXT NOT NULL, + created_by_user_id TEXT, + used_at TEXT, + used_by_guid TEXT, + max_uses INTEGER NOT NULL DEFAULT 1, + use_count INTEGER NOT NULL DEFAULT 0, + last_used_at TEXT + ) + """ + ) + cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_eic_expires_at + ON enrollment_install_codes(expires_at) + """ + ) + + columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")} + if "max_uses" not in columns: + cur.execute( + """ + ALTER TABLE enrollment_install_codes + ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1 + """ + ) + if "use_count" not in columns: + cur.execute( + """ + ALTER TABLE enrollment_install_codes + ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0 + """ + ) + if "last_used_at" not in columns: + cur.execute( + """ + ALTER TABLE enrollment_install_codes + ADD COLUMN last_used_at TEXT + """ + ) + + +def _ensure_device_approval_table(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS device_approvals ( + id TEXT PRIMARY KEY, + approval_reference TEXT NOT NULL UNIQUE, + guid TEXT, + hostname_claimed TEXT NOT NULL, + ssl_key_fingerprint_claimed TEXT NOT NULL, + enrollment_code_id TEXT NOT NULL, + status TEXT NOT NULL, + client_nonce TEXT NOT NULL, + server_nonce TEXT NOT NULL, + agent_pubkey_der BLOB NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + approved_by_user_id TEXT + ) + """ + ) + cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_da_status + ON device_approvals(status) + """ + ) + cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_da_fp_status + ON device_approvals(ssl_key_fingerprint_claimed, status) + """ + ) + + +def _ensure_github_token_table(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS github_token ( + token TEXT + ) + """ + ) + + +def _ensure_scheduled_jobs_table(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS scheduled_jobs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + components_json TEXT NOT NULL, + targets_json TEXT NOT NULL, + schedule_type TEXT NOT NULL, + start_ts INTEGER, + duration_stop_enabled INTEGER DEFAULT 0, + expiration TEXT, + execution_context TEXT NOT NULL, + credential_id INTEGER, + use_service_account INTEGER NOT NULL DEFAULT 1, + enabled INTEGER DEFAULT 1, + created_at INTEGER, + updated_at INTEGER + ) + """ + ) + + try: + columns = {row[1] for row in _table_info(cur, "scheduled_jobs")} + if "credential_id" not in columns: + cur.execute("ALTER TABLE scheduled_jobs ADD COLUMN credential_id INTEGER") + if "use_service_account" not in columns: + cur.execute( + "ALTER TABLE scheduled_jobs ADD COLUMN use_service_account INTEGER NOT NULL DEFAULT 1" + ) + except Exception: + # Legacy deployments may fail the ALTER TABLE calls; ignore silently. + pass + + +def _ensure_scheduled_job_run_tables(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS scheduled_job_runs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + job_id INTEGER NOT NULL, + scheduled_ts INTEGER, + started_ts INTEGER, + finished_ts INTEGER, + status TEXT, + error TEXT, + created_at INTEGER, + updated_at INTEGER, + target_hostname TEXT, + FOREIGN KEY(job_id) REFERENCES scheduled_jobs(id) ON DELETE CASCADE + ) + """ + ) + try: + cur.execute( + "CREATE INDEX IF NOT EXISTS idx_runs_job_sched_target ON scheduled_job_runs(job_id, scheduled_ts, target_hostname)" + ) + except Exception: + pass + + cur.execute( + """ + CREATE TABLE IF NOT EXISTS scheduled_job_run_activity ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id INTEGER NOT NULL, + activity_id INTEGER NOT NULL, + component_kind TEXT, + script_type TEXT, + component_path TEXT, + component_name TEXT, + created_at INTEGER, + FOREIGN KEY(run_id) REFERENCES scheduled_job_runs(id) ON DELETE CASCADE + ) + """ + ) + try: + cur.execute( + "CREATE INDEX IF NOT EXISTS idx_run_activity_run ON scheduled_job_run_activity(run_id)" + ) + except Exception: + pass + try: + cur.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS idx_run_activity_activity ON scheduled_job_run_activity(activity_id)" + ) + except Exception: + pass + + +def _create_devices_table(cur: sqlite3.Cursor) -> None: + cur.execute( + """ + CREATE TABLE devices ( + guid TEXT PRIMARY KEY, + hostname TEXT, + description TEXT, + created_at INTEGER, + agent_hash TEXT, + memory TEXT, + network TEXT, + software TEXT, + storage TEXT, + cpu TEXT, + device_type TEXT, + domain TEXT, + external_ip TEXT, + internal_ip TEXT, + last_reboot TEXT, + last_seen INTEGER, + last_user TEXT, + operating_system TEXT, + uptime INTEGER, + agent_id TEXT, + ansible_ee_ver TEXT, + connection_type TEXT, + connection_endpoint TEXT, + ssl_key_fingerprint TEXT, + token_version INTEGER DEFAULT 1, + status TEXT DEFAULT 'active', + key_added_at TEXT + ) + """ + ) + _ensure_device_indexes(cur) + + +def _ensure_device_indexes(cur: sqlite3.Cursor) -> None: + cur.execute( + """ + CREATE UNIQUE INDEX IF NOT EXISTS uq_devices_hostname + ON devices(hostname) + """ + ) + cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_devices_ssl_key + ON devices(ssl_key_fingerprint) + """ + ) + cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_devices_status + ON devices(status) + """ + ) + + +def _ensure_column_defaults(cur: sqlite3.Cursor) -> None: + cur.execute( + """ + UPDATE devices + SET token_version = COALESCE(token_version, 1) + WHERE token_version IS NULL + """ + ) + cur.execute( + """ + UPDATE devices + SET status = COALESCE(status, 'active') + WHERE status IS NULL OR status = '' + """ + ) + + +def _rebuild_devices_table(conn: sqlite3.Connection, column_info: Sequence[Tuple]) -> None: + cur = conn.cursor() + cur.execute("PRAGMA foreign_keys=OFF") + cur.execute("BEGIN IMMEDIATE") + + cur.execute("ALTER TABLE devices RENAME TO devices_legacy") + _create_devices_table(cur) + + legacy_columns = [c[1] for c in column_info] + cur.execute(f"SELECT {', '.join(legacy_columns)} FROM devices_legacy") + rows = cur.fetchall() + + insert_sql = ( + """ + INSERT OR REPLACE INTO devices ( + guid, hostname, description, created_at, agent_hash, memory, + network, software, storage, cpu, device_type, domain, external_ip, + internal_ip, last_reboot, last_seen, last_user, operating_system, + uptime, agent_id, ansible_ee_ver, connection_type, connection_endpoint, + ssl_key_fingerprint, token_version, status, key_added_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """ + ) + + for row in rows: + record = dict(zip(legacy_columns, row)) + guid = _normalized_guid(record.get("guid")) + if not guid: + guid = str(uuid.uuid4()) + hostname = record.get("hostname") + created_at = record.get("created_at") + key_added_at = record.get("key_added_at") + if key_added_at is None: + key_added_at = _default_key_added_at(created_at) + + params: Tuple = ( + guid, + hostname, + record.get("description"), + created_at, + record.get("agent_hash"), + record.get("memory"), + record.get("network"), + record.get("software"), + record.get("storage"), + record.get("cpu"), + record.get("device_type"), + record.get("domain"), + record.get("external_ip"), + record.get("internal_ip"), + record.get("last_reboot"), + record.get("last_seen"), + record.get("last_user"), + record.get("operating_system"), + record.get("uptime"), + record.get("agent_id"), + record.get("ansible_ee_ver"), + record.get("connection_type"), + record.get("connection_endpoint"), + record.get("ssl_key_fingerprint"), + record.get("token_version") or 1, + record.get("status") or "active", + key_added_at, + ) + cur.execute(insert_sql, params) + + cur.execute("DROP TABLE devices_legacy") + cur.execute("COMMIT") + cur.execute("PRAGMA foreign_keys=ON") + + +def _default_key_added_at(created_at: Optional[int]) -> Optional[str]: + if created_at: + try: + dt = datetime.fromtimestamp(int(created_at), tz=timezone.utc) + return dt.isoformat() + except Exception: + pass + return datetime.now(tz=timezone.utc).isoformat() + + +def _table_exists(cur: sqlite3.Cursor, name: str) -> bool: + cur.execute( + "SELECT 1 FROM sqlite_master WHERE type='table' AND name=?", + (name,), + ) + return cur.fetchone() is not None + + +def _table_info(cur: sqlite3.Cursor, name: str) -> List[Tuple]: + cur.execute(f"PRAGMA table_info({name})") + return cur.fetchall() + + +def _normalized_guid(value: Optional[str]) -> str: + if not value: + return "" + return str(value).strip() + +__all__ = ["apply_all"] diff --git a/Data/Engine/repositories/sqlite/token_repository.py b/Data/Engine/repositories/sqlite/token_repository.py new file mode 100644 index 0000000..fb2f605 --- /dev/null +++ b/Data/Engine/repositories/sqlite/token_repository.py @@ -0,0 +1,153 @@ +"""SQLite-backed refresh token repository for the Engine.""" + +from __future__ import annotations + +import logging +from contextlib import closing +from datetime import datetime, timezone +from typing import Optional + +from Data.Engine.domain.device_auth import DeviceGuid +from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory +from Data.Engine.services.auth.token_service import RefreshTokenRecord + +__all__ = ["SQLiteRefreshTokenRepository"] + + +class SQLiteRefreshTokenRepository: + """Persistence adapter for refresh token records.""" + + def __init__( + self, + connection_factory: SQLiteConnectionFactory, + *, + logger: Optional[logging.Logger] = None, + ) -> None: + self._connections = connection_factory + self._log = logger or logging.getLogger("borealis.engine.repositories.tokens") + + def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at + FROM refresh_tokens + WHERE guid = ? + AND token_hash = ? + """, + (guid.value, token_hash), + ) + row = cur.fetchone() + + if not row: + return None + + return self._row_to_record(row) + + def clear_dpop_binding(self, record_id: str) -> None: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + "UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?", + (record_id,), + ) + conn.commit() + + def touch( + self, + record_id: str, + *, + last_used_at: datetime, + dpop_jkt: Optional[str], + ) -> None: + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + UPDATE refresh_tokens + SET last_used_at = ?, + dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt) + WHERE id = ? + """, + ( + self._isoformat(last_used_at), + (dpop_jkt or "").strip(), + record_id, + ), + ) + conn.commit() + + def create( + self, + *, + record_id: str, + guid: DeviceGuid, + token_hash: str, + created_at: datetime, + expires_at: Optional[datetime], + ) -> None: + created_iso = self._isoformat(created_at) + expires_iso = self._isoformat(expires_at) if expires_at else None + + with closing(self._connections()) as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO refresh_tokens ( + id, + guid, + token_hash, + created_at, + expires_at + ) + VALUES (?, ?, ?, ?, ?) + """, + (record_id, guid.value, token_hash, created_iso, expires_iso), + ) + conn.commit() + + def _row_to_record(self, row: tuple) -> Optional[RefreshTokenRecord]: + try: + guid = DeviceGuid(row[1]) + except Exception as exc: + self._log.warning("invalid refresh token row guid=%s: %s", row[1], exc) + return None + + created_at = self._parse_iso(row[4]) + expires_at = self._parse_iso(row[5]) + revoked_at = self._parse_iso(row[6]) + + if created_at is None: + created_at = datetime.now(tz=timezone.utc) + + return RefreshTokenRecord.from_row( + record_id=str(row[0]), + guid=guid, + token_hash=str(row[2]), + dpop_jkt=str(row[3]) if row[3] is not None else None, + created_at=created_at, + expires_at=expires_at, + revoked_at=revoked_at, + ) + + @staticmethod + def _parse_iso(value: Optional[str]) -> Optional[datetime]: + if not value: + return None + raw = str(value).strip() + if not raw: + return None + try: + parsed = datetime.fromisoformat(raw) + except Exception: + return None + if parsed.tzinfo is None: + return parsed.replace(tzinfo=timezone.utc) + return parsed + + @staticmethod + def _isoformat(value: datetime) -> str: + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc).isoformat() diff --git a/Data/Engine/requirements.txt b/Data/Engine/requirements.txt new file mode 100644 index 0000000..6f7693c --- /dev/null +++ b/Data/Engine/requirements.txt @@ -0,0 +1,11 @@ +#////////// PROJECT FILE SEPARATION LINE ////////// CODE AFTER THIS LINE ARE FROM: /Data/Engine/requirements.txt +# Core web stack +Flask +flask_socketio +flask-cors +eventlet +requests + +# Auth & security +PyJWT[crypto] +cryptography diff --git a/Data/Engine/runtime.py b/Data/Engine/runtime.py new file mode 100644 index 0000000..7e0167e --- /dev/null +++ b/Data/Engine/runtime.py @@ -0,0 +1,139 @@ +"""Runtime filesystem helpers for the Borealis Engine.""" + +from __future__ import annotations + +import os +from functools import lru_cache +from pathlib import Path +from typing import Optional + +__all__ = [ + "agent_certificates_path", + "ensure_agent_certificates_dir", + "ensure_certificates_dir", + "ensure_runtime_dir", + "ensure_server_certificates_dir", + "project_root", + "runtime_path", + "server_certificates_path", +] + + +def _env_path(name: str) -> Optional[Path]: + value = os.environ.get(name) + if not value: + return None + try: + return Path(value).expanduser().resolve() + except Exception: + return None + + +@lru_cache(maxsize=None) +def project_root() -> Path: + env = _env_path("BOREALIS_PROJECT_ROOT") + if env: + return env + + current = Path(__file__).resolve() + for parent in current.parents: + if (parent / "Borealis.ps1").exists() or (parent / ".git").is_dir(): + return parent + + try: + return current.parents[1] + except IndexError: + return current.parent + + +@lru_cache(maxsize=None) +def server_runtime_root() -> Path: + env = _env_path("BOREALIS_ENGINE_ROOT") or _env_path("BOREALIS_SERVER_ROOT") + if env: + env.mkdir(parents=True, exist_ok=True) + return env + + root = project_root() / "Engine" + root.mkdir(parents=True, exist_ok=True) + return root + + +def runtime_path(*parts: str) -> Path: + return server_runtime_root().joinpath(*parts) + + +def ensure_runtime_dir(*parts: str) -> Path: + path = runtime_path(*parts) + path.mkdir(parents=True, exist_ok=True) + return path + + +@lru_cache(maxsize=None) +def certificates_root() -> Path: + env = _env_path("BOREALIS_CERTIFICATES_ROOT") or _env_path("BOREALIS_CERT_ROOT") + if env: + env.mkdir(parents=True, exist_ok=True) + return env + + root = project_root() / "Certificates" + root.mkdir(parents=True, exist_ok=True) + for name in ("Server", "Agent"): + try: + (root / name).mkdir(parents=True, exist_ok=True) + except Exception: + pass + return root + + +@lru_cache(maxsize=None) +def server_certificates_root() -> Path: + env = _env_path("BOREALIS_SERVER_CERT_ROOT") + if env: + env.mkdir(parents=True, exist_ok=True) + return env + + root = certificates_root() / "Server" + root.mkdir(parents=True, exist_ok=True) + return root + + +@lru_cache(maxsize=None) +def agent_certificates_root() -> Path: + env = _env_path("BOREALIS_AGENT_CERT_ROOT") + if env: + env.mkdir(parents=True, exist_ok=True) + return env + + root = certificates_root() / "Agent" + root.mkdir(parents=True, exist_ok=True) + return root + + +def certificates_path(*parts: str) -> Path: + return certificates_root().joinpath(*parts) + + +def ensure_certificates_dir(*parts: str) -> Path: + path = certificates_path(*parts) + path.mkdir(parents=True, exist_ok=True) + return path + + +def server_certificates_path(*parts: str) -> Path: + return server_certificates_root().joinpath(*parts) + + +def ensure_server_certificates_dir(*parts: str) -> Path: + path = server_certificates_path(*parts) + path.mkdir(parents=True, exist_ok=True) + return path + + +def agent_certificates_path(*parts: str) -> Path: + return agent_certificates_root().joinpath(*parts) + + +def ensure_agent_certificates_dir(*parts: str) -> Path: + path = agent_certificates_path(*parts) + path.mkdir(parents=True, exist_ok=True) + return path diff --git a/Data/Engine/server.py b/Data/Engine/server.py new file mode 100644 index 0000000..77fb8ea --- /dev/null +++ b/Data/Engine/server.py @@ -0,0 +1,103 @@ +"""Flask application factory for the Borealis Engine.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Optional + +from flask import Flask, request, send_from_directory +from flask_cors import CORS +from werkzeug.exceptions import NotFound +from werkzeug.middleware.proxy_fix import ProxyFix + +from .config import EngineSettings + + +from .repositories.sqlite.connection import ( + SQLiteConnectionFactory, + connection_factory as create_sqlite_connection_factory, +) + + +def _resolve_static_folder(static_root: Path) -> tuple[str, str]: + return str(static_root), "" + + +def _register_spa_routes(app: Flask, assets_root: Path) -> None: + """Serve the Borealis single-page application from *assets_root*. + + The logic mirrors the legacy server by routing any unknown front-end paths + back to ``index.html`` so the React router can take over. + """ + + static_folder = assets_root + + @app.route("/", defaults={"path": ""}) + @app.route("/") + def serve_frontend(path: str) -> object: + candidate = (static_folder / path).resolve() + if path and candidate.is_file(): + return send_from_directory(str(static_folder), path) + try: + return send_from_directory(str(static_folder), "index.html") + except Exception as exc: # pragma: no cover - passthrough + raise NotFound() from exc + + @app.errorhandler(404) + def spa_fallback(error: Exception) -> object: # pragma: no cover - routing + request_path = (request.path or "").strip() + if request_path.startswith("/api") or request_path.startswith("/socket.io"): + return error + if "." in Path(request_path).name: + return error + if request.method not in {"GET", "HEAD"}: + return error + try: + return send_from_directory(str(static_folder), "index.html") + except Exception: + return error + + +def create_app( + settings: EngineSettings, + *, + db_factory: Optional[SQLiteConnectionFactory] = None, +) -> Flask: + """Create the Flask application instance for the Engine.""" + + if db_factory is None: + db_factory = create_sqlite_connection_factory(settings.database_path) + + static_folder, static_url_path = _resolve_static_folder(settings.flask.static_root) + app = Flask( + __name__, + static_folder=static_folder, + static_url_path=static_url_path, + ) + + app.config.update( + SECRET_KEY=settings.flask.secret_key, + JSON_SORT_KEYS=False, + SESSION_COOKIE_HTTPONLY=True, + SESSION_COOKIE_SECURE=not settings.debug, + SESSION_COOKIE_SAMESITE="Lax", + ENGINE_DATABASE_PATH=str(settings.database_path), + ENGINE_DB_CONN_FACTORY=db_factory, + ) + app.config.setdefault("PREFERRED_URL_SCHEME", "https") + + # Respect upstream proxy headers when Borealis is hosted behind a TLS terminator. + app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1) # type: ignore[assignment] + + CORS( + app, + resources={r"/*": {"origins": list(settings.flask.cors_allowed_origins)}}, + supports_credentials=True, + ) + + _register_spa_routes(app, Path(static_folder)) + + return app + + +__all__ = ["create_app"] diff --git a/Data/Engine/services/__init__.py b/Data/Engine/services/__init__.py new file mode 100644 index 0000000..3e216c7 --- /dev/null +++ b/Data/Engine/services/__init__.py @@ -0,0 +1,62 @@ +"""Application services for the Borealis Engine.""" + +from __future__ import annotations + +from importlib import import_module +from typing import Any, Dict, Tuple + +__all__ = [ + "DeviceAuthService", + "DeviceRecord", + "RefreshTokenRecord", + "TokenRefreshError", + "TokenRefreshErrorCode", + "TokenService", + "EnrollmentService", + "EnrollmentRequestResult", + "EnrollmentStatus", + "EnrollmentTokenBundle", + "EnrollmentValidationError", + "PollingResult", + "AgentRealtimeService", + "AgentRecord", + "SchedulerService", + "GitHubService", + "GitHubTokenPayload", +] + +_LAZY_TARGETS: Dict[str, Tuple[str, str]] = { + "DeviceAuthService": ("Data.Engine.services.auth.device_auth_service", "DeviceAuthService"), + "DeviceRecord": ("Data.Engine.services.auth.device_auth_service", "DeviceRecord"), + "RefreshTokenRecord": ("Data.Engine.services.auth.device_auth_service", "RefreshTokenRecord"), + "TokenService": ("Data.Engine.services.auth.token_service", "TokenService"), + "TokenRefreshError": ("Data.Engine.services.auth.token_service", "TokenRefreshError"), + "TokenRefreshErrorCode": ("Data.Engine.services.auth.token_service", "TokenRefreshErrorCode"), + "EnrollmentService": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentService"), + "EnrollmentRequestResult": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentRequestResult"), + "EnrollmentStatus": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentStatus"), + "EnrollmentTokenBundle": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentTokenBundle"), + "PollingResult": ("Data.Engine.services.enrollment.enrollment_service", "PollingResult"), + "EnrollmentValidationError": ("Data.Engine.domain.device_enrollment", "EnrollmentValidationError"), + "AgentRealtimeService": ("Data.Engine.services.realtime.agent_registry", "AgentRealtimeService"), + "AgentRecord": ("Data.Engine.services.realtime.agent_registry", "AgentRecord"), + "SchedulerService": ("Data.Engine.services.jobs.scheduler_service", "SchedulerService"), + "GitHubService": ("Data.Engine.services.github.github_service", "GitHubService"), + "GitHubTokenPayload": ("Data.Engine.services.github.github_service", "GitHubTokenPayload"), +} + + +def __getattr__(name: str) -> Any: + try: + module_name, attribute = _LAZY_TARGETS[name] + except KeyError as exc: + raise AttributeError(name) from exc + + module = import_module(module_name) + value = getattr(module, attribute) + globals()[name] = value + return value + + +def __dir__() -> Any: # pragma: no cover - interactive helper + return sorted(set(__all__)) diff --git a/Data/Engine/services/auth/__init__.py b/Data/Engine/services/auth/__init__.py new file mode 100644 index 0000000..f24d072 --- /dev/null +++ b/Data/Engine/services/auth/__init__.py @@ -0,0 +1,27 @@ +"""Authentication services for the Borealis Engine.""" + +from __future__ import annotations + +from .device_auth_service import DeviceAuthService, DeviceRecord +from .dpop import DPoPReplayError, DPoPVerificationError, DPoPValidator +from .jwt_service import JWTService, load_service as load_jwt_service +from .token_service import ( + RefreshTokenRecord, + TokenRefreshError, + TokenRefreshErrorCode, + TokenService, +) + +__all__ = [ + "DeviceAuthService", + "DeviceRecord", + "DPoPReplayError", + "DPoPVerificationError", + "DPoPValidator", + "JWTService", + "load_jwt_service", + "RefreshTokenRecord", + "TokenRefreshError", + "TokenRefreshErrorCode", + "TokenService", +] diff --git a/Data/Engine/services/auth/device_auth_service.py b/Data/Engine/services/auth/device_auth_service.py new file mode 100644 index 0000000..0bc67f4 --- /dev/null +++ b/Data/Engine/services/auth/device_auth_service.py @@ -0,0 +1,237 @@ +"""Device authentication service copied from the legacy server stack.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping, Optional, Protocol +import logging + +from Data.Engine.builders.device_auth import DeviceAuthRequest +from Data.Engine.domain.device_auth import ( + AccessTokenClaims, + DeviceAuthContext, + DeviceAuthErrorCode, + DeviceAuthFailure, + DeviceFingerprint, + DeviceGuid, + DeviceIdentity, + DeviceStatus, +) + +__all__ = [ + "DeviceAuthService", + "DeviceRecord", + "DPoPValidator", + "DPoPVerificationError", + "DPoPReplayError", + "RateLimiter", + "RateLimitDecision", + "DeviceRepository", +] + + +class RateLimitDecision(Protocol): + allowed: bool + retry_after: Optional[float] + + +class RateLimiter(Protocol): + def check(self, key: str, max_requests: int, window_seconds: float) -> RateLimitDecision: # pragma: no cover - protocol + ... + + +class JWTDecoder(Protocol): + def decode(self, token: str) -> Mapping[str, object]: # pragma: no cover - protocol + ... + + +class DPoPValidator(Protocol): + def verify( + self, + method: str, + htu: str, + proof: str, + access_token: Optional[str] = None, + ) -> str: # pragma: no cover - protocol + ... + + +class DPoPVerificationError(Exception): + """Raised when a DPoP proof fails validation.""" + + +class DPoPReplayError(DPoPVerificationError): + """Raised when a DPoP proof is replayed.""" + + +@dataclass(frozen=True, slots=True) +class DeviceRecord: + """Snapshot of a device record required for authentication.""" + + identity: DeviceIdentity + token_version: int + status: DeviceStatus + + +class DeviceRepository(Protocol): + """Port that exposes the minimal device persistence operations.""" + + def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol + ... + + def recover_missing( + self, + guid: DeviceGuid, + fingerprint: DeviceFingerprint, + token_version: int, + service_context: Optional[str], + ) -> Optional[DeviceRecord]: # pragma: no cover - protocol + ... + + +class DeviceAuthService: + """Authenticate devices using access tokens, repositories, and DPoP proofs.""" + + def __init__( + self, + *, + device_repository: DeviceRepository, + jwt_service: JWTDecoder, + logger: Optional[logging.Logger] = None, + rate_limiter: Optional[RateLimiter] = None, + dpop_validator: Optional[DPoPValidator] = None, + ) -> None: + self._repository = device_repository + self._jwt = jwt_service + self._log = logger or logging.getLogger("borealis.engine.auth") + self._rate_limiter = rate_limiter + self._dpop_validator = dpop_validator + + def authenticate(self, request: DeviceAuthRequest, *, path: str) -> DeviceAuthContext: + """Authenticate an access token and return the resulting context.""" + + claims = self._decode_claims(request.access_token) + rate_limit_key = f"fp:{claims.fingerprint.value}" + if self._rate_limiter is not None: + decision = self._rate_limiter.check(rate_limit_key, 60, 60.0) + if not decision.allowed: + raise DeviceAuthFailure( + DeviceAuthErrorCode.RATE_LIMITED, + http_status=429, + retry_after=decision.retry_after, + ) + + record = self._repository.fetch_by_guid(claims.guid) + if record is None: + record = self._repository.recover_missing( + claims.guid, + claims.fingerprint, + claims.token_version, + request.service_context, + ) + + if record is None: + raise DeviceAuthFailure( + DeviceAuthErrorCode.DEVICE_NOT_FOUND, + http_status=403, + ) + + self._validate_identity(record, claims) + + dpop_jkt = self._validate_dpop(request, record, claims) + + context = DeviceAuthContext( + identity=record.identity, + access_token=request.access_token, + claims=claims, + status=record.status, + service_context=request.service_context, + dpop_jkt=dpop_jkt, + ) + + if context.is_quarantined: + self._log.warning( + "device %s is quarantined; limited access for %s", + record.identity.guid, + path, + ) + + return context + + def _decode_claims(self, token: str) -> AccessTokenClaims: + try: + raw_claims = self._jwt.decode(token) + except Exception as exc: # pragma: no cover - defensive fallback + if self._is_expired_signature(exc): + raise DeviceAuthFailure(DeviceAuthErrorCode.TOKEN_EXPIRED) from exc + raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN) from exc + + try: + return AccessTokenClaims.from_mapping(raw_claims) + except Exception as exc: + raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_CLAIMS) from exc + + @staticmethod + def _is_expired_signature(exc: Exception) -> bool: + name = exc.__class__.__name__ + return name == "ExpiredSignatureError" + + def _validate_identity( + self, + record: DeviceRecord, + claims: AccessTokenClaims, + ) -> None: + if record.identity.guid.value != claims.guid.value: + raise DeviceAuthFailure( + DeviceAuthErrorCode.DEVICE_GUID_MISMATCH, + http_status=403, + ) + + if record.identity.fingerprint.value: + if record.identity.fingerprint.value != claims.fingerprint.value: + raise DeviceAuthFailure( + DeviceAuthErrorCode.FINGERPRINT_MISMATCH, + http_status=403, + ) + + if record.token_version > claims.token_version: + raise DeviceAuthFailure(DeviceAuthErrorCode.TOKEN_VERSION_REVOKED) + + if not record.status.allows_access: + raise DeviceAuthFailure( + DeviceAuthErrorCode.DEVICE_REVOKED, + http_status=403, + ) + + def _validate_dpop( + self, + request: DeviceAuthRequest, + record: DeviceRecord, + claims: AccessTokenClaims, + ) -> Optional[str]: + if not request.dpop_proof: + return None + + if self._dpop_validator is None: + raise DeviceAuthFailure( + DeviceAuthErrorCode.DPOP_NOT_SUPPORTED, + http_status=400, + ) + + try: + return self._dpop_validator.verify( + request.http_method, + request.htu, + request.dpop_proof, + request.access_token, + ) + except DPoPReplayError as exc: + raise DeviceAuthFailure( + DeviceAuthErrorCode.DPOP_REPLAYED, + http_status=400, + ) from exc + except DPoPVerificationError as exc: + raise DeviceAuthFailure( + DeviceAuthErrorCode.DPOP_INVALID, + http_status=400, + ) from exc diff --git a/Data/Engine/services/auth/dpop.py b/Data/Engine/services/auth/dpop.py new file mode 100644 index 0000000..2ea7e02 --- /dev/null +++ b/Data/Engine/services/auth/dpop.py @@ -0,0 +1,105 @@ +"""DPoP proof validation for Engine services.""" + +from __future__ import annotations + +import hashlib +import time +from threading import Lock +from typing import Dict, Optional + +import jwt + +__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"] + + +_DP0P_MAX_SKEW = 300.0 + + +class DPoPVerificationError(Exception): + pass + + +class DPoPReplayError(DPoPVerificationError): + pass + + +class DPoPValidator: + def __init__(self) -> None: + self._observed_jti: Dict[str, float] = {} + self._lock = Lock() + + def verify( + self, + method: str, + htu: str, + proof: str, + access_token: Optional[str] = None, + ) -> str: + if not proof: + raise DPoPVerificationError("DPoP proof missing") + + try: + header = jwt.get_unverified_header(proof) + except Exception as exc: + raise DPoPVerificationError("invalid DPoP header") from exc + + jwk = header.get("jwk") + alg = header.get("alg") + if not jwk or not isinstance(jwk, dict): + raise DPoPVerificationError("missing jwk in DPoP header") + if alg not in ("EdDSA", "ES256", "ES384", "ES512"): + raise DPoPVerificationError(f"unsupported DPoP alg {alg}") + + try: + key = jwt.PyJWK(jwk) + public_key = key.key + except Exception as exc: + raise DPoPVerificationError("invalid jwk in DPoP header") from exc + + try: + claims = jwt.decode( + proof, + public_key, + algorithms=[alg], + options={"require": ["htm", "htu", "jti", "iat"]}, + ) + except Exception as exc: + raise DPoPVerificationError("invalid DPoP signature") from exc + + htm = claims.get("htm") + proof_htu = claims.get("htu") + jti = claims.get("jti") + iat = claims.get("iat") + ath = claims.get("ath") + + if not isinstance(htm, str) or htm.lower() != method.lower(): + raise DPoPVerificationError("DPoP htm mismatch") + if not isinstance(proof_htu, str) or proof_htu != htu: + raise DPoPVerificationError("DPoP htu mismatch") + if not isinstance(jti, str): + raise DPoPVerificationError("DPoP jti missing") + if not isinstance(iat, (int, float)): + raise DPoPVerificationError("DPoP iat missing") + + now = time.time() + if abs(now - float(iat)) > _DP0P_MAX_SKEW: + raise DPoPVerificationError("DPoP proof outside allowed skew") + + if ath and access_token: + expected_ath = jwt.utils.base64url_encode( + hashlib.sha256(access_token.encode("utf-8")).digest() + ).decode("ascii") + if expected_ath != ath: + raise DPoPVerificationError("DPoP ath mismatch") + + with self._lock: + expiry = self._observed_jti.get(jti) + if expiry and expiry > now: + raise DPoPReplayError("DPoP proof replay detected") + self._observed_jti[jti] = now + _DP0P_MAX_SKEW + stale = [key for key, exp in self._observed_jti.items() if exp <= now] + for key in stale: + self._observed_jti.pop(key, None) + + thumbprint = jwt.PyJWK(jwk).thumbprint() + return thumbprint.decode("ascii") diff --git a/Data/Engine/services/auth/jwt_service.py b/Data/Engine/services/auth/jwt_service.py new file mode 100644 index 0000000..6a9d2e9 --- /dev/null +++ b/Data/Engine/services/auth/jwt_service.py @@ -0,0 +1,124 @@ +"""JWT issuance utilities for the Engine.""" + +from __future__ import annotations + +import hashlib +import time +from typing import Any, Dict, Optional + +import jwt +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 + +from Data.Engine.runtime import ensure_runtime_dir, runtime_path + +__all__ = ["JWTService", "load_service"] + + +_KEY_DIR = runtime_path("auth_keys") +_KEY_FILE = _KEY_DIR / "engine-jwt-ed25519.key" +_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-jwt-ed25519.key" + + +class JWTService: + def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str) -> None: + self._private_key = private_key + self._public_key = private_key.public_key() + self._key_id = key_id + + @property + def key_id(self) -> str: + return self._key_id + + def issue_access_token( + self, + guid: str, + ssl_key_fingerprint: str, + token_version: int, + expires_in: int = 900, + extra_claims: Optional[Dict[str, Any]] = None, + ) -> str: + now = int(time.time()) + payload: Dict[str, Any] = { + "sub": f"device:{guid}", + "guid": guid, + "ssl_key_fingerprint": ssl_key_fingerprint, + "token_version": int(token_version), + "iat": now, + "nbf": now, + "exp": now + int(expires_in), + } + if extra_claims: + payload.update(extra_claims) + + token = jwt.encode( + payload, + self._private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ), + algorithm="EdDSA", + headers={"kid": self._key_id}, + ) + return token + + def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]: + options = {"require": ["exp", "iat", "sub"]} + public_pem = self._public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return jwt.decode( + token, + public_pem, + algorithms=["EdDSA"], + audience=audience, + options=options, + ) + + def public_jwk(self) -> Dict[str, Any]: + public_bytes = self._public_key.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, + ) + jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii") + return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x} + + +def load_service() -> JWTService: + private_key = _load_or_create_private_key() + public_bytes = private_key.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + key_id = hashlib.sha256(public_bytes).hexdigest()[:16] + return JWTService(private_key, key_id) + + +def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey: + ensure_runtime_dir("auth_keys") + + if _KEY_FILE.exists(): + with _KEY_FILE.open("rb") as fh: + return serialization.load_pem_private_key(fh.read(), password=None) + + if _LEGACY_KEY_FILE.exists(): + with _LEGACY_KEY_FILE.open("rb") as fh: + return serialization.load_pem_private_key(fh.read(), password=None) + + private_key = ed25519.Ed25519PrivateKey.generate() + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + _KEY_DIR.mkdir(parents=True, exist_ok=True) + with _KEY_FILE.open("wb") as fh: + fh.write(pem) + try: + if hasattr(_KEY_FILE, "chmod"): + _KEY_FILE.chmod(0o600) + except Exception: + pass + return private_key diff --git a/Data/Engine/services/auth/token_service.py b/Data/Engine/services/auth/token_service.py new file mode 100644 index 0000000..934db2c --- /dev/null +++ b/Data/Engine/services/auth/token_service.py @@ -0,0 +1,190 @@ +"""Token refresh service extracted from the legacy blueprint.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Optional, Protocol +import hashlib +import logging + +from Data.Engine.builders.device_auth import RefreshTokenRequest +from Data.Engine.domain.device_auth import DeviceGuid + +from .device_auth_service import ( + DeviceRecord, + DeviceRepository, + DPoPReplayError, + DPoPVerificationError, + DPoPValidator, +) + +__all__ = ["RefreshTokenRecord", "TokenService", "TokenRefreshError", "TokenRefreshErrorCode"] + + +class JWTIssuer(Protocol): + def issue_access_token(self, guid: str, fingerprint: str, token_version: int) -> str: # pragma: no cover - protocol + ... + + +class TokenRefreshErrorCode(str): + INVALID_REFRESH_TOKEN = "invalid_refresh_token" + REFRESH_TOKEN_REVOKED = "refresh_token_revoked" + REFRESH_TOKEN_EXPIRED = "refresh_token_expired" + DEVICE_NOT_FOUND = "device_not_found" + DEVICE_REVOKED = "device_revoked" + DPOP_REPLAYED = "dpop_replayed" + DPOP_INVALID = "dpop_invalid" + + +class TokenRefreshError(Exception): + def __init__(self, code: str, *, http_status: int = 400) -> None: + self.code = code + self.http_status = http_status + super().__init__(code) + + def to_dict(self) -> dict[str, str]: + return {"error": self.code} + + +@dataclass(frozen=True, slots=True) +class RefreshTokenRecord: + record_id: str + guid: DeviceGuid + token_hash: str + dpop_jkt: Optional[str] + created_at: datetime + expires_at: Optional[datetime] + revoked_at: Optional[datetime] + + @classmethod + def from_row( + cls, + *, + record_id: str, + guid: DeviceGuid, + token_hash: str, + dpop_jkt: Optional[str], + created_at: datetime, + expires_at: Optional[datetime], + revoked_at: Optional[datetime], + ) -> "RefreshTokenRecord": + return cls( + record_id=record_id, + guid=guid, + token_hash=token_hash, + dpop_jkt=dpop_jkt, + created_at=created_at, + expires_at=expires_at, + revoked_at=revoked_at, + ) + + +class RefreshTokenRepository(Protocol): + def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]: # pragma: no cover - protocol + ... + + def clear_dpop_binding(self, record_id: str) -> None: # pragma: no cover - protocol + ... + + def touch(self, record_id: str, *, last_used_at: datetime, dpop_jkt: Optional[str]) -> None: # pragma: no cover - protocol + ... + + +@dataclass(frozen=True, slots=True) +class AccessTokenResponse: + access_token: str + expires_in: int + token_type: str + + +class TokenService: + def __init__( + self, + *, + refresh_token_repository: RefreshTokenRepository, + device_repository: DeviceRepository, + jwt_service: JWTIssuer, + dpop_validator: Optional[DPoPValidator] = None, + logger: Optional[logging.Logger] = None, + ) -> None: + self._refresh_tokens = refresh_token_repository + self._devices = device_repository + self._jwt = jwt_service + self._dpop_validator = dpop_validator + self._log = logger or logging.getLogger("borealis.engine.auth") + + def refresh_access_token( + self, + request: RefreshTokenRequest, + ) -> AccessTokenResponse: + record = self._refresh_tokens.fetch( + request.guid, + self._hash_token(request.refresh_token), + ) + if record is None: + raise TokenRefreshError(TokenRefreshErrorCode.INVALID_REFRESH_TOKEN, http_status=401) + + if record.guid.value != request.guid.value: + raise TokenRefreshError(TokenRefreshErrorCode.INVALID_REFRESH_TOKEN, http_status=401) + + if record.revoked_at is not None: + raise TokenRefreshError(TokenRefreshErrorCode.REFRESH_TOKEN_REVOKED, http_status=401) + + if record.expires_at is not None and record.expires_at <= self._now(): + raise TokenRefreshError(TokenRefreshErrorCode.REFRESH_TOKEN_EXPIRED, http_status=401) + + device = self._devices.fetch_by_guid(request.guid) + if device is None: + raise TokenRefreshError(TokenRefreshErrorCode.DEVICE_NOT_FOUND, http_status=404) + + if not device.status.allows_access: + raise TokenRefreshError(TokenRefreshErrorCode.DEVICE_REVOKED, http_status=403) + + dpop_jkt = record.dpop_jkt or "" + if request.dpop_proof: + if self._dpop_validator is None: + raise TokenRefreshError(TokenRefreshErrorCode.DPOP_INVALID) + try: + dpop_jkt = self._dpop_validator.verify( + request.http_method, + request.htu, + request.dpop_proof, + None, + ) + except DPoPReplayError as exc: + raise TokenRefreshError(TokenRefreshErrorCode.DPOP_REPLAYED) from exc + except DPoPVerificationError as exc: + raise TokenRefreshError(TokenRefreshErrorCode.DPOP_INVALID) from exc + elif record.dpop_jkt: + self._log.warning( + "Clearing stored DPoP binding for guid=%s due to missing proof", + request.guid.value, + ) + self._refresh_tokens.clear_dpop_binding(record.record_id) + + access_token = self._jwt.issue_access_token( + request.guid.value, + device.identity.fingerprint.value, + max(device.token_version, 1), + ) + + self._refresh_tokens.touch( + record.record_id, + last_used_at=self._now(), + dpop_jkt=dpop_jkt or None, + ) + + return AccessTokenResponse( + access_token=access_token, + expires_in=900, + token_type="Bearer", + ) + + @staticmethod + def _hash_token(token: str) -> str: + return hashlib.sha256(token.encode("utf-8")).hexdigest() + + @staticmethod + def _now() -> datetime: + return datetime.now(tz=timezone.utc) diff --git a/Data/Engine/services/container.py b/Data/Engine/services/container.py new file mode 100644 index 0000000..756c44f --- /dev/null +++ b/Data/Engine/services/container.py @@ -0,0 +1,158 @@ +"""Service container assembly for the Borealis Engine.""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Optional + +from Data.Engine.config import EngineSettings +from Data.Engine.integrations.github import GitHubArtifactProvider +from Data.Engine.repositories.sqlite import ( + SQLiteConnectionFactory, + SQLiteDeviceRepository, + SQLiteEnrollmentRepository, + SQLiteGitHubRepository, + SQLiteJobRepository, + SQLiteRefreshTokenRepository, +) +from Data.Engine.services.auth import ( + DeviceAuthService, + DPoPValidator, + JWTService, + TokenService, + load_jwt_service, +) +from Data.Engine.services.crypto.signing import ScriptSigner, load_signer +from Data.Engine.services.enrollment import EnrollmentService +from Data.Engine.services.enrollment.nonce_cache import NonceCache +from Data.Engine.services.github import GitHubService +from Data.Engine.services.jobs import SchedulerService +from Data.Engine.services.rate_limit import SlidingWindowRateLimiter +from Data.Engine.services.realtime import AgentRealtimeService + +__all__ = ["EngineServiceContainer", "build_service_container"] + + +@dataclass(frozen=True, slots=True) +class EngineServiceContainer: + device_auth: DeviceAuthService + token_service: TokenService + enrollment_service: EnrollmentService + jwt_service: JWTService + dpop_validator: DPoPValidator + agent_realtime: AgentRealtimeService + scheduler_service: SchedulerService + github_service: GitHubService + + +def build_service_container( + settings: EngineSettings, + *, + db_factory: SQLiteConnectionFactory, + logger: Optional[logging.Logger] = None, +) -> EngineServiceContainer: + log = logger or logging.getLogger("borealis.engine.services") + + device_repo = SQLiteDeviceRepository(db_factory, logger=log.getChild("devices")) + token_repo = SQLiteRefreshTokenRepository(db_factory, logger=log.getChild("tokens")) + enrollment_repo = SQLiteEnrollmentRepository(db_factory, logger=log.getChild("enrollment")) + job_repo = SQLiteJobRepository(db_factory, logger=log.getChild("jobs")) + github_repo = SQLiteGitHubRepository(db_factory, logger=log.getChild("github_repo")) + + jwt_service = load_jwt_service() + dpop_validator = DPoPValidator() + rate_limiter = SlidingWindowRateLimiter() + + token_service = TokenService( + refresh_token_repository=token_repo, + device_repository=device_repo, + jwt_service=jwt_service, + dpop_validator=dpop_validator, + logger=log.getChild("token_service"), + ) + + enrollment_service = EnrollmentService( + device_repository=device_repo, + enrollment_repository=enrollment_repo, + token_repository=token_repo, + jwt_service=jwt_service, + tls_bundle_loader=_tls_bundle_loader(settings), + ip_rate_limiter=SlidingWindowRateLimiter(), + fingerprint_rate_limiter=SlidingWindowRateLimiter(), + nonce_cache=NonceCache(), + script_signer=_load_script_signer(log), + logger=log.getChild("enrollment"), + ) + + device_auth = DeviceAuthService( + device_repository=device_repo, + jwt_service=jwt_service, + logger=log.getChild("device_auth"), + rate_limiter=rate_limiter, + dpop_validator=dpop_validator, + ) + + agent_realtime = AgentRealtimeService( + device_repository=device_repo, + logger=log.getChild("agent_realtime"), + ) + + scheduler_service = SchedulerService( + job_repository=job_repo, + assemblies_root=settings.project_root / "Assemblies", + logger=log.getChild("scheduler"), + ) + + github_provider = GitHubArtifactProvider( + cache_file=settings.github.cache_file, + default_repo=settings.github.default_repo, + default_branch=settings.github.default_branch, + refresh_interval=settings.github.refresh_interval_seconds, + logger=log.getChild("github.provider"), + ) + github_service = GitHubService( + repository=github_repo, + provider=github_provider, + logger=log.getChild("github"), + ) + github_service.start_background_refresh() + + return EngineServiceContainer( + device_auth=device_auth, + token_service=token_service, + enrollment_service=enrollment_service, + jwt_service=jwt_service, + dpop_validator=dpop_validator, + agent_realtime=agent_realtime, + scheduler_service=scheduler_service, + github_service=github_service, + ) + + +def _tls_bundle_loader(settings: EngineSettings) -> Callable[[], str]: + candidates = [ + Path(os.getenv("BOREALIS_TLS_BUNDLE", "")), + settings.project_root / "Certificates" / "Server" / "borealis-server-bundle.pem", + ] + + def loader() -> str: + for candidate in candidates: + if candidate and candidate.is_file(): + try: + return candidate.read_text(encoding="utf-8") + except Exception: + continue + return "" + + return loader + + +def _load_script_signer(logger: logging.Logger) -> Optional[ScriptSigner]: + try: + return load_signer() + except Exception as exc: + logger.warning("script signer unavailable: %s", exc) + return None diff --git a/Data/Engine/services/crypto/certificates.py b/Data/Engine/services/crypto/certificates.py new file mode 100644 index 0000000..1865a7a --- /dev/null +++ b/Data/Engine/services/crypto/certificates.py @@ -0,0 +1,366 @@ +"""Server TLS certificate management for the Borealis Engine.""" + +from __future__ import annotations + +import os +import ssl +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Optional, Tuple + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID + +from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path + +__all__ = [ + "build_ssl_context", + "certificate_paths", + "ensure_certificate", +] + +_CERT_DIR = server_certificates_path() +_CERT_FILE = _CERT_DIR / "borealis-server-cert.pem" +_KEY_FILE = _CERT_DIR / "borealis-server-key.pem" +_BUNDLE_FILE = _CERT_DIR / "borealis-server-bundle.pem" +_CA_KEY_FILE = _CERT_DIR / "borealis-root-ca-key.pem" +_CA_CERT_FILE = _CERT_DIR / "borealis-root-ca.pem" + +_LEGACY_CERT_DIR = runtime_path("certs") +_LEGACY_CERT_FILE = _LEGACY_CERT_DIR / "borealis-server-cert.pem" +_LEGACY_KEY_FILE = _LEGACY_CERT_DIR / "borealis-server-key.pem" +_LEGACY_BUNDLE_FILE = _LEGACY_CERT_DIR / "borealis-server-bundle.pem" + +_ROOT_COMMON_NAME = "Borealis Root CA" +_ORG_NAME = "Borealis" +_ROOT_VALIDITY = timedelta(days=365 * 100) +_SERVER_VALIDITY = timedelta(days=365 * 5) + + +def ensure_certificate(common_name: str = "Borealis Engine") -> Tuple[Path, Path, Path]: + """Ensure the root CA, server certificate, and bundle exist on disk.""" + + ensure_server_certificates_dir() + _migrate_legacy_material_if_present() + + ca_key, ca_cert, ca_regenerated = _ensure_root_ca() + + server_cert = _load_certificate(_CERT_FILE) + needs_regen = ca_regenerated or _server_certificate_needs_regeneration(server_cert, ca_cert) + if needs_regen: + server_cert = _generate_server_certificate(common_name, ca_key, ca_cert) + + if server_cert is None: + server_cert = _generate_server_certificate(common_name, ca_key, ca_cert) + + _write_bundle(server_cert, ca_cert) + + return _CERT_FILE, _KEY_FILE, _BUNDLE_FILE + + +def _migrate_legacy_material_if_present() -> None: + if not _CERT_FILE.exists() or not _KEY_FILE.exists(): + legacy_cert = _LEGACY_CERT_FILE + legacy_key = _LEGACY_KEY_FILE + if legacy_cert.exists() and legacy_key.exists(): + try: + ensure_server_certificates_dir() + if not _CERT_FILE.exists(): + _safe_copy(legacy_cert, _CERT_FILE) + if not _KEY_FILE.exists(): + _safe_copy(legacy_key, _KEY_FILE) + except Exception: + pass + + +def _ensure_root_ca() -> Tuple[ec.EllipticCurvePrivateKey, x509.Certificate, bool]: + regenerated = False + + ca_key: Optional[ec.EllipticCurvePrivateKey] = None + ca_cert: Optional[x509.Certificate] = None + + if _CA_KEY_FILE.exists() and _CA_CERT_FILE.exists(): + try: + ca_key = _load_private_key(_CA_KEY_FILE) + ca_cert = _load_certificate(_CA_CERT_FILE) + if ca_cert is not None and ca_key is not None: + expiry = _cert_not_after(ca_cert) + subject = ca_cert.subject + subject_cn = "" + try: + subject_cn = subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value # type: ignore[index] + except Exception: + subject_cn = "" + try: + basic = ca_cert.extensions.get_extension_for_class(x509.BasicConstraints).value # type: ignore[attr-defined] + is_ca = bool(basic.ca) + except Exception: + is_ca = False + if ( + expiry <= datetime.now(tz=timezone.utc) + or not is_ca + or subject_cn != _ROOT_COMMON_NAME + ): + regenerated = True + else: + regenerated = True + except Exception: + regenerated = True + else: + regenerated = True + + if regenerated or ca_key is None or ca_cert is None: + ca_key = ec.generate_private_key(ec.SECP384R1()) + public_key = ca_key.public_key() + + now = datetime.now(tz=timezone.utc) + builder = ( + x509.CertificateBuilder() + .subject_name( + x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, _ROOT_COMMON_NAME), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME), + ] + ) + ) + .issuer_name( + x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, _ROOT_COMMON_NAME), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME), + ] + ) + ) + .public_key(public_key) + .serial_number(x509.random_serial_number()) + .not_valid_before(now - timedelta(minutes=5)) + .not_valid_after(now + _ROOT_VALIDITY) + .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True) + .add_extension( + x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + .add_extension( + x509.SubjectKeyIdentifier.from_public_key(public_key), + critical=False, + ) + ) + + builder = builder.add_extension( + x509.AuthorityKeyIdentifier.from_issuer_public_key(public_key), + critical=False, + ) + + ca_cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA384()) + + _CA_KEY_FILE.write_bytes( + ca_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + _CA_CERT_FILE.write_bytes(ca_cert.public_bytes(serialization.Encoding.PEM)) + + _tighten_permissions(_CA_KEY_FILE) + _tighten_permissions(_CA_CERT_FILE) + else: + regenerated = False + + return ca_key, ca_cert, regenerated + + +def _server_certificate_needs_regeneration( + server_cert: Optional[x509.Certificate], + ca_cert: x509.Certificate, +) -> bool: + if server_cert is None: + return True + + try: + if server_cert.issuer != ca_cert.subject: + return True + except Exception: + return True + + try: + expiry = _cert_not_after(server_cert) + if expiry <= datetime.now(tz=timezone.utc): + return True + except Exception: + return True + + try: + basic = server_cert.extensions.get_extension_for_class(x509.BasicConstraints).value # type: ignore[attr-defined] + if basic.ca: + return True + except Exception: + return True + + try: + eku = server_cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value # type: ignore[attr-defined] + if ExtendedKeyUsageOID.SERVER_AUTH not in eku: + return True + except Exception: + return True + + return False + + +def _generate_server_certificate( + common_name: str, + ca_key: ec.EllipticCurvePrivateKey, + ca_cert: x509.Certificate, +) -> x509.Certificate: + private_key = ec.generate_private_key(ec.SECP384R1()) + public_key = private_key.public_key() + + now = datetime.now(tz=timezone.utc) + ca_expiry = _cert_not_after(ca_cert) + candidate_expiry = now + _SERVER_VALIDITY + not_after = min(ca_expiry - timedelta(days=1), candidate_expiry) + + builder = ( + x509.CertificateBuilder() + .subject_name( + x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME), + ] + ) + ) + .issuer_name(ca_cert.subject) + .public_key(public_key) + .serial_number(x509.random_serial_number()) + .not_valid_before(now - timedelta(minutes=5)) + .not_valid_after(not_after) + .add_extension( + x509.SubjectAlternativeName( + [ + x509.DNSName("localhost"), + x509.DNSName("127.0.0.1"), + x509.DNSName("::1"), + ] + ), + critical=False, + ) + .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) + .add_extension( + x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + .add_extension( + x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), + critical=False, + ) + .add_extension( + x509.SubjectKeyIdentifier.from_public_key(public_key), + critical=False, + ) + .add_extension( + x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_key.public_key()), + critical=False, + ) + ) + + certificate = builder.sign(private_key=ca_key, algorithm=hashes.SHA384()) + + _KEY_FILE.write_bytes( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + _CERT_FILE.write_bytes(certificate.public_bytes(serialization.Encoding.PEM)) + + _tighten_permissions(_KEY_FILE) + _tighten_permissions(_CERT_FILE) + + return certificate + + +def _write_bundle(server_cert: x509.Certificate, ca_cert: x509.Certificate) -> None: + try: + server_pem = server_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip() + ca_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip() + except Exception: + return + + bundle = f"{server_pem}\n{ca_pem}\n" + _BUNDLE_FILE.write_text(bundle, encoding="utf-8") + _tighten_permissions(_BUNDLE_FILE) + + +def _safe_copy(src: Path, dst: Path) -> None: + try: + dst.write_bytes(src.read_bytes()) + except Exception: + pass + + +def _tighten_permissions(path: Path) -> None: + try: + if os.name == "posix": + path.chmod(0o600) + except Exception: + pass + + +def _load_private_key(path: Path) -> ec.EllipticCurvePrivateKey: + with path.open("rb") as fh: + return serialization.load_pem_private_key(fh.read(), password=None) + + +def _load_certificate(path: Path) -> Optional[x509.Certificate]: + try: + return x509.load_pem_x509_certificate(path.read_bytes()) + except Exception: + return None + + +def _cert_not_after(cert: x509.Certificate) -> datetime: + try: + return cert.not_valid_after_utc # type: ignore[attr-defined] + except AttributeError: + value = cert.not_valid_after + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value + + +def build_ssl_context() -> ssl.SSLContext: + cert_path, key_path, bundle_path = ensure_certificate() + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.minimum_version = ssl.TLSVersion.TLSv1_3 + context.load_cert_chain(certfile=str(bundle_path), keyfile=str(key_path)) + return context + + +def certificate_paths() -> Tuple[str, str, str]: + cert_path, key_path, bundle_path = ensure_certificate() + return str(cert_path), str(key_path), str(bundle_path) diff --git a/Data/Engine/services/crypto/signing.py b/Data/Engine/services/crypto/signing.py new file mode 100644 index 0000000..17d8875 --- /dev/null +++ b/Data/Engine/services/crypto/signing.py @@ -0,0 +1,75 @@ +"""Script signing utilities for the Engine.""" + +from __future__ import annotations + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 + +from Data.Engine.integrations.crypto.keys import base64_from_spki_der +from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path + +__all__ = ["ScriptSigner", "load_signer"] + + +_KEY_DIR = server_certificates_path("Code-Signing") +_SIGNING_KEY_FILE = _KEY_DIR / "engine-script-ed25519.key" +_SIGNING_PUB_FILE = _KEY_DIR / "engine-script-ed25519.pub" +_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-script-ed25519.key" +_LEGACY_PUB_FILE = runtime_path("keys") / "borealis-script-ed25519.pub" + + +class ScriptSigner: + def __init__(self, private_key: ed25519.Ed25519PrivateKey) -> None: + self._private = private_key + self._public = private_key.public_key() + + def sign(self, payload: bytes) -> bytes: + return self._private.sign(payload) + + def public_spki_der(self) -> bytes: + return self._public.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + def public_base64_spki(self) -> str: + return base64_from_spki_der(self.public_spki_der()) + + +def load_signer() -> ScriptSigner: + private_key = _load_or_create() + return ScriptSigner(private_key) + + +def _load_or_create() -> ed25519.Ed25519PrivateKey: + ensure_server_certificates_dir("Code-Signing") + + if _SIGNING_KEY_FILE.exists(): + with _SIGNING_KEY_FILE.open("rb") as fh: + return serialization.load_pem_private_key(fh.read(), password=None) + + if _LEGACY_KEY_FILE.exists(): + with _LEGACY_KEY_FILE.open("rb") as fh: + return serialization.load_pem_private_key(fh.read(), password=None) + + private_key = ed25519.Ed25519PrivateKey.generate() + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + _KEY_DIR.mkdir(parents=True, exist_ok=True) + _SIGNING_KEY_FILE.write_bytes(pem) + try: + if hasattr(_SIGNING_KEY_FILE, "chmod"): + _SIGNING_KEY_FILE.chmod(0o600) + except Exception: + pass + + pub_der = private_key.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + _SIGNING_PUB_FILE.write_bytes(pub_der) + + return private_key diff --git a/Data/Engine/services/enrollment/__init__.py b/Data/Engine/services/enrollment/__init__.py new file mode 100644 index 0000000..063cd7b --- /dev/null +++ b/Data/Engine/services/enrollment/__init__.py @@ -0,0 +1,21 @@ +"""Enrollment services for the Borealis Engine.""" + +from __future__ import annotations + +from .enrollment_service import ( + EnrollmentRequestResult, + EnrollmentService, + EnrollmentStatus, + EnrollmentTokenBundle, + PollingResult, +) +from Data.Engine.domain.device_enrollment import EnrollmentValidationError + +__all__ = [ + "EnrollmentRequestResult", + "EnrollmentService", + "EnrollmentStatus", + "EnrollmentTokenBundle", + "EnrollmentValidationError", + "PollingResult", +] diff --git a/Data/Engine/services/enrollment/enrollment_service.py b/Data/Engine/services/enrollment/enrollment_service.py new file mode 100644 index 0000000..9712169 --- /dev/null +++ b/Data/Engine/services/enrollment/enrollment_service.py @@ -0,0 +1,487 @@ +"""Enrollment workflow orchestration for the Borealis Engine.""" + +from __future__ import annotations + +import base64 +import hashlib +import logging +import secrets +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Callable, Optional, Protocol + +from cryptography.hazmat.primitives import serialization + +from Data.Engine.builders.device_enrollment import EnrollmentRequestInput +from Data.Engine.domain.device_auth import ( + DeviceFingerprint, + DeviceGuid, + sanitize_service_context, +) +from Data.Engine.domain.device_enrollment import ( + EnrollmentApproval, + EnrollmentApprovalStatus, + EnrollmentCode, + EnrollmentValidationError, +) +from Data.Engine.services.auth.device_auth_service import DeviceRecord +from Data.Engine.services.auth.token_service import JWTIssuer +from Data.Engine.services.enrollment.nonce_cache import NonceCache +from Data.Engine.services.rate_limit import SlidingWindowRateLimiter + +__all__ = [ + "EnrollmentRequestResult", + "EnrollmentService", + "EnrollmentStatus", + "EnrollmentTokenBundle", + "PollingResult", +] + + +class DeviceRepository(Protocol): + def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol + ... + + def ensure_device_record( + self, + *, + guid: DeviceGuid, + hostname: str, + fingerprint: DeviceFingerprint, + ) -> DeviceRecord: # pragma: no cover - protocol + ... + + def record_device_key( + self, + *, + guid: DeviceGuid, + fingerprint: DeviceFingerprint, + added_at: datetime, + ) -> None: # pragma: no cover - protocol + ... + + +class EnrollmentRepository(Protocol): + def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol + ... + + def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol + ... + + def update_install_code_usage( + self, + record_id: str, + *, + use_count_increment: int, + last_used_at: datetime, + used_by_guid: Optional[DeviceGuid] = None, + mark_first_use: bool = False, + ) -> None: # pragma: no cover - protocol + ... + + def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol + ... + + def fetch_pending_approval_by_fingerprint( + self, fingerprint: DeviceFingerprint + ) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol + ... + + def update_pending_approval( + self, + record_id: str, + *, + hostname: str, + guid: Optional[DeviceGuid], + enrollment_code_id: Optional[str], + client_nonce_b64: str, + server_nonce_b64: str, + agent_pubkey_der: bytes, + updated_at: datetime, + ) -> None: # pragma: no cover - protocol + ... + + def create_device_approval( + self, + *, + record_id: str, + reference: str, + claimed_hostname: str, + claimed_fingerprint: DeviceFingerprint, + enrollment_code_id: Optional[str], + client_nonce_b64: str, + server_nonce_b64: str, + agent_pubkey_der: bytes, + created_at: datetime, + status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING, + guid: Optional[DeviceGuid] = None, + ) -> EnrollmentApproval: # pragma: no cover - protocol + ... + + def update_device_approval_status( + self, + record_id: str, + *, + status: EnrollmentApprovalStatus, + updated_at: datetime, + approved_by: Optional[str] = None, + guid: Optional[DeviceGuid] = None, + ) -> None: # pragma: no cover - protocol + ... + + +class RefreshTokenRepository(Protocol): + def create( + self, + *, + record_id: str, + guid: DeviceGuid, + token_hash: str, + created_at: datetime, + expires_at: Optional[datetime], + ) -> None: # pragma: no cover - protocol + ... + + +class ScriptSigner(Protocol): + def public_base64_spki(self) -> str: # pragma: no cover - protocol + ... + + +class EnrollmentStatus(str): + PENDING = "pending" + APPROVED = "approved" + DENIED = "denied" + EXPIRED = "expired" + UNKNOWN = "unknown" + + +@dataclass(frozen=True, slots=True) +class EnrollmentTokenBundle: + guid: DeviceGuid + access_token: str + refresh_token: str + expires_in: int + token_type: str = "Bearer" + + +@dataclass(frozen=True, slots=True) +class EnrollmentRequestResult: + status: EnrollmentStatus + server_certificate: str + signing_key: str + approval_reference: Optional[str] = None + server_nonce: Optional[str] = None + poll_after_ms: Optional[int] = None + http_status: int = 200 + retry_after: Optional[float] = None + + +@dataclass(frozen=True, slots=True) +class PollingResult: + status: EnrollmentStatus + http_status: int + poll_after_ms: Optional[int] = None + reason: Optional[str] = None + detail: Optional[str] = None + tokens: Optional[EnrollmentTokenBundle] = None + server_certificate: Optional[str] = None + signing_key: Optional[str] = None + + +class EnrollmentService: + """Coordinate the Borealis device enrollment handshake.""" + + def __init__( + self, + *, + device_repository: DeviceRepository, + enrollment_repository: EnrollmentRepository, + token_repository: RefreshTokenRepository, + jwt_service: JWTIssuer, + tls_bundle_loader: Callable[[], str], + ip_rate_limiter: SlidingWindowRateLimiter, + fingerprint_rate_limiter: SlidingWindowRateLimiter, + nonce_cache: NonceCache, + script_signer: Optional[ScriptSigner] = None, + logger: Optional[logging.Logger] = None, + ) -> None: + self._devices = device_repository + self._enrollment = enrollment_repository + self._tokens = token_repository + self._jwt = jwt_service + self._load_tls_bundle = tls_bundle_loader + self._ip_rate_limiter = ip_rate_limiter + self._fp_rate_limiter = fingerprint_rate_limiter + self._nonce_cache = nonce_cache + self._signer = script_signer + self._log = logger or logging.getLogger("borealis.engine.enrollment") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def request_enrollment( + self, + payload: EnrollmentRequestInput, + *, + remote_addr: str, + ) -> EnrollmentRequestResult: + context_hint = sanitize_service_context(payload.service_context) + self._log.info( + "enrollment-request ip=%s host=%s code_mask=%s", remote_addr, payload.hostname, self._mask_code(payload.enrollment_code) + ) + + self._enforce_rate_limit(self._ip_rate_limiter, f"ip:{remote_addr}") + self._enforce_rate_limit(self._fp_rate_limiter, f"fp:{payload.fingerprint.value}") + + install_code = self._enrollment.fetch_install_code(payload.enrollment_code) + reuse_guid = self._determine_reuse_guid(install_code, payload.fingerprint) + + server_nonce_bytes = secrets.token_bytes(32) + server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii") + + now = self._now() + approval = self._enrollment.fetch_pending_approval_by_fingerprint(payload.fingerprint) + if approval: + self._enrollment.update_pending_approval( + approval.record_id, + hostname=payload.hostname, + guid=reuse_guid, + enrollment_code_id=install_code.identifier if install_code else None, + client_nonce_b64=payload.client_nonce_b64, + server_nonce_b64=server_nonce_b64, + agent_pubkey_der=payload.agent_public_key_der, + updated_at=now, + ) + approval_reference = approval.reference + else: + record_id = str(uuid.uuid4()) + approval_reference = str(uuid.uuid4()) + approval = self._enrollment.create_device_approval( + record_id=record_id, + reference=approval_reference, + claimed_hostname=payload.hostname, + claimed_fingerprint=payload.fingerprint, + enrollment_code_id=install_code.identifier if install_code else None, + client_nonce_b64=payload.client_nonce_b64, + server_nonce_b64=server_nonce_b64, + agent_pubkey_der=payload.agent_public_key_der, + created_at=now, + guid=reuse_guid, + ) + + signing_key = self._signer.public_base64_spki() if self._signer else "" + certificate = self._load_tls_bundle() + + return EnrollmentRequestResult( + status=EnrollmentStatus.PENDING, + approval_reference=approval.reference, + server_nonce=server_nonce_b64, + poll_after_ms=3000, + server_certificate=certificate, + signing_key=signing_key, + ) + + def poll_enrollment( + self, + *, + approval_reference: str, + client_nonce_b64: str, + proof_signature_b64: str, + ) -> PollingResult: + if not approval_reference: + raise EnrollmentValidationError("approval_reference_required") + if not client_nonce_b64: + raise EnrollmentValidationError("client_nonce_required") + if not proof_signature_b64: + raise EnrollmentValidationError("proof_sig_required") + + approval = self._enrollment.fetch_device_approval_by_reference(approval_reference) + if approval is None: + return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=404) + + client_nonce = self._decode_base64(client_nonce_b64, "invalid_client_nonce") + server_nonce = self._decode_base64(approval.server_nonce_b64, "server_nonce_invalid") + proof_sig = self._decode_base64(proof_signature_b64, "invalid_proof_sig") + + if approval.client_nonce_b64 != client_nonce_b64: + raise EnrollmentValidationError("nonce_mismatch") + + self._verify_proof_signature( + approval=approval, + client_nonce=client_nonce, + server_nonce=server_nonce, + signature=proof_sig, + ) + + status = approval.status + if status is EnrollmentApprovalStatus.PENDING: + return PollingResult( + status=EnrollmentStatus.PENDING, + http_status=200, + poll_after_ms=5000, + ) + if status is EnrollmentApprovalStatus.DENIED: + return PollingResult( + status=EnrollmentStatus.DENIED, + http_status=200, + reason="operator_denied", + ) + if status is EnrollmentApprovalStatus.EXPIRED: + return PollingResult(status=EnrollmentStatus.EXPIRED, http_status=200) + if status is EnrollmentApprovalStatus.COMPLETED: + return PollingResult( + status=EnrollmentStatus.APPROVED, + http_status=200, + detail="finalized", + ) + if status is not EnrollmentApprovalStatus.APPROVED: + return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=400) + + nonce_key = f"{approval.reference}:{proof_signature_b64}" + if not self._nonce_cache.consume(nonce_key): + raise EnrollmentValidationError("proof_replayed", http_status=409) + + token_bundle = self._finalize_approval(approval) + signing_key = self._signer.public_base64_spki() if self._signer else "" + certificate = self._load_tls_bundle() + + return PollingResult( + status=EnrollmentStatus.APPROVED, + http_status=200, + tokens=token_bundle, + server_certificate=certificate, + signing_key=signing_key, + ) + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + def _enforce_rate_limit( + self, + limiter: SlidingWindowRateLimiter, + key: str, + *, + limit: int = 60, + window_seconds: float = 60.0, + ) -> None: + decision = limiter.check(key, limit, window_seconds) + if not decision.allowed: + raise EnrollmentValidationError( + "rate_limited", http_status=429, retry_after=max(decision.retry_after, 1.0) + ) + + def _determine_reuse_guid( + self, + install_code: Optional[EnrollmentCode], + fingerprint: DeviceFingerprint, + ) -> Optional[DeviceGuid]: + if install_code is None: + raise EnrollmentValidationError("invalid_enrollment_code") + if install_code.is_expired: + raise EnrollmentValidationError("invalid_enrollment_code") + if not install_code.is_exhausted: + return None + if not install_code.used_by_guid: + raise EnrollmentValidationError("invalid_enrollment_code") + + existing = self._devices.fetch_by_guid(install_code.used_by_guid) + if existing and existing.identity.fingerprint.value == fingerprint.value: + return install_code.used_by_guid + raise EnrollmentValidationError("invalid_enrollment_code") + + def _finalize_approval(self, approval: EnrollmentApproval) -> EnrollmentTokenBundle: + now = self._now() + effective_guid = approval.guid or DeviceGuid(str(uuid.uuid4())) + device_record = self._devices.ensure_device_record( + guid=effective_guid, + hostname=approval.claimed_hostname, + fingerprint=approval.claimed_fingerprint, + ) + self._devices.record_device_key( + guid=effective_guid, + fingerprint=approval.claimed_fingerprint, + added_at=now, + ) + + if approval.enrollment_code_id: + code = self._enrollment.fetch_install_code_by_id(approval.enrollment_code_id) + if code is not None: + mark_first = code.used_at is None + self._enrollment.update_install_code_usage( + approval.enrollment_code_id, + use_count_increment=1, + last_used_at=now, + used_by_guid=effective_guid, + mark_first_use=mark_first, + ) + + refresh_token = secrets.token_urlsafe(48) + refresh_id = str(uuid.uuid4()) + expires_at = now + timedelta(days=30) + token_hash = hashlib.sha256(refresh_token.encode("utf-8")).hexdigest() + self._tokens.create( + record_id=refresh_id, + guid=effective_guid, + token_hash=token_hash, + created_at=now, + expires_at=expires_at, + ) + + access_token = self._jwt.issue_access_token( + effective_guid.value, + device_record.identity.fingerprint.value, + max(device_record.token_version, 1), + ) + + self._enrollment.update_device_approval_status( + approval.record_id, + status=EnrollmentApprovalStatus.COMPLETED, + updated_at=now, + guid=effective_guid, + ) + + return EnrollmentTokenBundle( + guid=effective_guid, + access_token=access_token, + refresh_token=refresh_token, + expires_in=900, + ) + + def _verify_proof_signature( + self, + *, + approval: EnrollmentApproval, + client_nonce: bytes, + server_nonce: bytes, + signature: bytes, + ) -> None: + message = server_nonce + approval.reference.encode("utf-8") + client_nonce + try: + public_key = serialization.load_der_public_key(approval.agent_pubkey_der) + except Exception as exc: + raise EnrollmentValidationError("agent_pubkey_invalid") from exc + + try: + public_key.verify(signature, message) + except Exception as exc: + raise EnrollmentValidationError("invalid_proof") from exc + + @staticmethod + def _decode_base64(value: str, error_code: str) -> bytes: + try: + return base64.b64decode(value, validate=True) + except Exception as exc: + raise EnrollmentValidationError(error_code) from exc + + @staticmethod + def _mask_code(code: str) -> str: + trimmed = (code or "").strip() + if len(trimmed) <= 6: + return "***" + return f"{trimmed[:3]}***{trimmed[-3:]}" + + @staticmethod + def _now() -> datetime: + return datetime.now(tz=timezone.utc) diff --git a/Data/Engine/services/enrollment/nonce_cache.py b/Data/Engine/services/enrollment/nonce_cache.py new file mode 100644 index 0000000..6653a7d --- /dev/null +++ b/Data/Engine/services/enrollment/nonce_cache.py @@ -0,0 +1,32 @@ +"""Nonce replay protection for enrollment workflows.""" + +from __future__ import annotations + +import time +from threading import Lock +from typing import Dict + +__all__ = ["NonceCache"] + + +class NonceCache: + """Track recently observed nonces to prevent replay.""" + + def __init__(self, ttl_seconds: float = 300.0) -> None: + self._ttl = ttl_seconds + self._entries: Dict[str, float] = {} + self._lock = Lock() + + def consume(self, key: str) -> bool: + """Consume *key* if it has not been seen recently.""" + + now = time.monotonic() + with self._lock: + expiry = self._entries.get(key) + if expiry and expiry > now: + return False + self._entries[key] = now + self._ttl + stale = [nonce for nonce, ttl in self._entries.items() if ttl <= now] + for nonce in stale: + self._entries.pop(nonce, None) + return True diff --git a/Data/Engine/services/github/__init__.py b/Data/Engine/services/github/__init__.py new file mode 100644 index 0000000..f2c48ba --- /dev/null +++ b/Data/Engine/services/github/__init__.py @@ -0,0 +1,8 @@ +"""GitHub-oriented services for the Borealis Engine.""" + +from __future__ import annotations + +from .github_service import GitHubService, GitHubTokenPayload + +__all__ = ["GitHubService", "GitHubTokenPayload"] + diff --git a/Data/Engine/services/github/github_service.py b/Data/Engine/services/github/github_service.py new file mode 100644 index 0000000..157208e --- /dev/null +++ b/Data/Engine/services/github/github_service.py @@ -0,0 +1,106 @@ +"""GitHub service layer bridging repositories and integrations.""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from typing import Callable, Optional + +from Data.Engine.domain.github import GitHubRepoRef, GitHubTokenStatus, RepoHeadSnapshot +from Data.Engine.integrations.github import GitHubArtifactProvider +from Data.Engine.repositories.sqlite.github_repository import SQLiteGitHubRepository + +__all__ = ["GitHubService", "GitHubTokenPayload"] + + +@dataclass(frozen=True, slots=True) +class GitHubTokenPayload: + token: Optional[str] + status: GitHubTokenStatus + checked_at: int + + def to_dict(self) -> dict: + payload = self.status.to_dict() + payload.update( + { + "token": self.token or "", + "checked_at": self.checked_at, + } + ) + return payload + + +class GitHubService: + """Coordinate GitHub caching, verification, and persistence.""" + + def __init__( + self, + *, + repository: SQLiteGitHubRepository, + provider: GitHubArtifactProvider, + logger: Optional[logging.Logger] = None, + clock: Optional[Callable[[], float]] = None, + ) -> None: + self._repository = repository + self._provider = provider + self._log = logger or logging.getLogger("borealis.engine.services.github") + self._clock = clock or time.time + self._token_cache: Optional[str] = None + self._token_loaded_at: float = 0.0 + + initial_token = self._repository.load_token() + self._apply_token(initial_token) + + def get_repo_head( + self, + owner_repo: Optional[str], + branch: Optional[str], + *, + ttl_seconds: int, + force_refresh: bool = False, + ) -> RepoHeadSnapshot: + repo_str = (owner_repo or self._provider.default_repo).strip() + branch_name = (branch or self._provider.default_branch).strip() + repo = GitHubRepoRef.parse(repo_str, branch_name) + ttl = max(30, min(ttl_seconds, 3600)) + return self._provider.fetch_repo_head(repo, ttl_seconds=ttl, force_refresh=force_refresh) + + def refresh_default_repo(self, *, force: bool = False) -> RepoHeadSnapshot: + return self._provider.refresh_default_repo_head(force=force) + + def get_token_status(self, *, force_refresh: bool = False) -> GitHubTokenPayload: + token = self._load_token(force_refresh=force_refresh) + status = self._provider.verify_token(token) + return GitHubTokenPayload(token=token, status=status, checked_at=int(self._clock())) + + def update_token(self, token: Optional[str]) -> GitHubTokenPayload: + normalized = (token or "").strip() + self._repository.store_token(normalized) + self._apply_token(normalized) + status = self._provider.verify_token(normalized) + self._provider.start_background_refresh() + self._log.info("github-token updated valid=%s", status.valid) + return GitHubTokenPayload(token=normalized or None, status=status, checked_at=int(self._clock())) + + def start_background_refresh(self) -> None: + self._provider.start_background_refresh() + + @property + def default_refresh_interval(self) -> int: + return self._provider.refresh_interval + + def _load_token(self, *, force_refresh: bool = False) -> Optional[str]: + now = self._clock() + if not force_refresh and self._token_cache is not None and (now - self._token_loaded_at) < 15.0: + return self._token_cache + + token = self._repository.load_token() + self._apply_token(token) + return token + + def _apply_token(self, token: Optional[str]) -> None: + self._token_cache = (token or "").strip() or None + self._token_loaded_at = self._clock() + self._provider.set_token(self._token_cache) + diff --git a/Data/Engine/services/jobs/__init__.py b/Data/Engine/services/jobs/__init__.py new file mode 100644 index 0000000..93e793b --- /dev/null +++ b/Data/Engine/services/jobs/__init__.py @@ -0,0 +1,5 @@ +"""Job-related services for the Borealis Engine.""" + +from .scheduler_service import SchedulerService + +__all__ = ["SchedulerService"] diff --git a/Data/Engine/services/jobs/scheduler_service.py b/Data/Engine/services/jobs/scheduler_service.py new file mode 100644 index 0000000..35d6f03 --- /dev/null +++ b/Data/Engine/services/jobs/scheduler_service.py @@ -0,0 +1,373 @@ +"""Background scheduler service for the Borealis Engine.""" + +from __future__ import annotations + +import calendar +import logging +import threading +import time +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Iterable, Mapping, Optional + +from Data.Engine.builders.job_fabricator import JobFabricator, JobManifest +from Data.Engine.repositories.sqlite.job_repository import ( + ScheduledJobRecord, + ScheduledJobRunRecord, + SQLiteJobRepository, +) + +__all__ = ["SchedulerService", "SchedulerRuntime"] + + +def _now_ts() -> int: + return int(time.time()) + + +def _floor_minute(ts: int) -> int: + ts = int(ts or 0) + return ts - (ts % 60) + + +def _parse_ts(val: Any) -> Optional[int]: + if val is None: + return None + if isinstance(val, (int, float)): + return int(val) + try: + s = str(val).strip().replace("Z", "+00:00") + return int(datetime.fromisoformat(s).timestamp()) + except Exception: + return None + + +def _parse_expiration(raw: Optional[str]) -> Optional[int]: + if not raw or raw == "no_expire": + return None + try: + s = raw.strip().lower() + unit = s[-1] + value = int(s[:-1]) + if unit == "m": + return value * 60 + if unit == "h": + return value * 3600 + if unit == "d": + return value * 86400 + return int(s) * 60 + except Exception: + return None + + +def _add_months(dt: datetime, months: int) -> datetime: + year = dt.year + (dt.month - 1 + months) // 12 + month = ((dt.month - 1 + months) % 12) + 1 + last_day = calendar.monthrange(year, month)[1] + day = min(dt.day, last_day) + return dt.replace(year=year, month=month, day=day) + + +def _add_years(dt: datetime, years: int) -> datetime: + year = dt.year + years + month = dt.month + last_day = calendar.monthrange(year, month)[1] + day = min(dt.day, last_day) + return dt.replace(year=year, month=month, day=day) + + +def _compute_next_run( + schedule_type: str, + start_ts: Optional[int], + last_run_ts: Optional[int], + now_ts: int, +) -> Optional[int]: + st = (schedule_type or "immediately").strip().lower() + start_floor = _floor_minute(start_ts) if start_ts else None + last_floor = _floor_minute(last_run_ts) if last_run_ts else None + now_floor = _floor_minute(now_ts) + + if st == "immediately": + return None if last_floor else now_floor + if st == "once": + if not start_floor: + return None + return start_floor if not last_floor else None + if not start_floor: + return None + + last = last_floor if last_floor is not None else None + if st in { + "every_5_minutes", + "every_10_minutes", + "every_15_minutes", + "every_30_minutes", + "every_hour", + }: + period_map = { + "every_5_minutes": 5 * 60, + "every_10_minutes": 10 * 60, + "every_15_minutes": 15 * 60, + "every_30_minutes": 30 * 60, + "every_hour": 60 * 60, + } + period = period_map.get(st) + candidate = (last + period) if last else start_floor + while candidate is not None and candidate <= now_floor - 1: + candidate += period + return candidate + if st == "daily": + period = 86400 + candidate = (last + period) if last else start_floor + while candidate is not None and candidate <= now_floor - 1: + candidate += period + return candidate + if st == "weekly": + period = 7 * 86400 + candidate = (last + period) if last else start_floor + while candidate is not None and candidate <= now_floor - 1: + candidate += period + return candidate + if st == "monthly": + base = datetime.utcfromtimestamp(last) if last else datetime.utcfromtimestamp(start_floor) + candidate = _add_months(base, 1 if last else 0) + while int(candidate.timestamp()) <= now_floor - 1: + candidate = _add_months(candidate, 1) + return int(candidate.timestamp()) + if st == "yearly": + base = datetime.utcfromtimestamp(last) if last else datetime.utcfromtimestamp(start_floor) + candidate = _add_years(base, 1 if last else 0) + while int(candidate.timestamp()) <= now_floor - 1: + candidate = _add_years(candidate, 1) + return int(candidate.timestamp()) + return None + + +@dataclass +class SchedulerRuntime: + thread: Optional[threading.Thread] + stop_event: threading.Event + + +class SchedulerService: + """Evaluate and dispatch scheduled jobs using Engine repositories.""" + + def __init__( + self, + *, + job_repository: SQLiteJobRepository, + assemblies_root: Path, + logger: Optional[logging.Logger] = None, + poll_interval: int = 30, + ) -> None: + self._jobs = job_repository + self._fabricator = JobFabricator(assemblies_root=assemblies_root, logger=logger) + self._log = logger or logging.getLogger("borealis.engine.scheduler") + self._poll_interval = max(5, poll_interval) + self._socketio: Optional[Any] = None + self._runtime = SchedulerRuntime(thread=None, stop_event=threading.Event()) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + def start(self, socketio: Optional[Any] = None) -> None: + self._socketio = socketio + if self._runtime.thread and self._runtime.thread.is_alive(): + return + self._runtime.stop_event.clear() + thread = threading.Thread(target=self._run_loop, name="borealis-engine-scheduler", daemon=True) + thread.start() + self._runtime.thread = thread + self._log.info("scheduler-started") + + def stop(self) -> None: + self._runtime.stop_event.set() + thread = self._runtime.thread + if thread and thread.is_alive(): + thread.join(timeout=5) + self._log.info("scheduler-stopped") + + # ------------------------------------------------------------------ + # HTTP orchestration helpers + # ------------------------------------------------------------------ + def list_jobs(self) -> list[dict[str, Any]]: + return [self._serialize_job(job) for job in self._jobs.list_jobs()] + + def get_job(self, job_id: int) -> Optional[dict[str, Any]]: + record = self._jobs.fetch_job(job_id) + return self._serialize_job(record) if record else None + + def create_job(self, payload: Mapping[str, Any]) -> dict[str, Any]: + fields = self._normalize_payload(payload) + record = self._jobs.create_job(**fields) + return self._serialize_job(record) + + def update_job(self, job_id: int, payload: Mapping[str, Any]) -> Optional[dict[str, Any]]: + fields = self._normalize_payload(payload) + record = self._jobs.update_job(job_id, **fields) + return self._serialize_job(record) if record else None + + def toggle_job(self, job_id: int, enabled: bool) -> None: + self._jobs.set_enabled(job_id, enabled) + + def delete_job(self, job_id: int) -> None: + self._jobs.delete_job(job_id) + + def list_runs(self, job_id: int, *, days: Optional[int] = None) -> list[dict[str, Any]]: + runs = self._jobs.list_runs(job_id, days=days) + return [self._serialize_run(run) for run in runs] + + def purge_runs(self, job_id: int) -> None: + self._jobs.purge_runs(job_id) + + # ------------------------------------------------------------------ + # Scheduling loop + # ------------------------------------------------------------------ + def tick(self, *, now_ts: Optional[int] = None) -> None: + self._evaluate_jobs(now_ts=now_ts or _now_ts()) + + def _run_loop(self) -> None: + stop_event = self._runtime.stop_event + while not stop_event.wait(timeout=self._poll_interval): + try: + self._evaluate_jobs(now_ts=_now_ts()) + except Exception as exc: # pragma: no cover - safety net + self._log.exception("scheduler-loop-error: %s", exc) + + def _evaluate_jobs(self, *, now_ts: int) -> None: + for job in self._jobs.list_enabled_jobs(): + try: + self._evaluate_job(job, now_ts=now_ts) + except Exception as exc: + self._log.exception("job-evaluation-error job_id=%s error=%s", job.id, exc) + + def _evaluate_job(self, job: ScheduledJobRecord, *, now_ts: int) -> None: + last_run = self._jobs.fetch_last_run(job.id) + last_ts = None + if last_run: + last_ts = last_run.scheduled_ts or last_run.started_ts or last_run.created_at + next_run = _compute_next_run(job.schedule_type, job.start_ts, last_ts, now_ts) + if next_run is None or next_run > now_ts: + return + + expiration_window = _parse_expiration(job.expiration) + if expiration_window and job.start_ts: + if job.start_ts + expiration_window <= now_ts: + self._log.info( + "job-expired", + extra={"job_id": job.id, "start_ts": job.start_ts, "expiration": job.expiration}, + ) + return + + manifest = self._fabricator.build(job, occurrence_ts=next_run) + targets = manifest.targets or ("",) + for target in targets: + run_id = self._jobs.create_run(job.id, next_run, target_hostname=None if target == "" else target) + self._jobs.mark_run_started(run_id, started_ts=now_ts) + self._emit_run_event("job_run_started", job, run_id, target, manifest) + self._jobs.mark_run_finished(run_id, status="Success", finished_ts=now_ts) + self._emit_run_event("job_run_completed", job, run_id, target, manifest) + + def _emit_run_event( + self, + event: str, + job: ScheduledJobRecord, + run_id: int, + target: str, + manifest: JobManifest, + ) -> None: + payload = { + "job_id": job.id, + "run_id": run_id, + "target": target, + "schedule_type": job.schedule_type, + "occurrence_ts": manifest.occurrence_ts, + } + if self._socketio is not None: + try: + self._socketio.emit(event, payload) # type: ignore[attr-defined] + except Exception: + self._log.debug("socketio-emit-failed event=%s payload=%s", event, payload) + + # ------------------------------------------------------------------ + # Serialization helpers + # ------------------------------------------------------------------ + def _serialize_job(self, job: ScheduledJobRecord) -> dict[str, Any]: + return { + "id": job.id, + "name": job.name, + "components": job.components, + "targets": job.targets, + "schedule": { + "type": job.schedule_type, + "start": job.start_ts, + }, + "schedule_type": job.schedule_type, + "start_ts": job.start_ts, + "duration_stop_enabled": job.duration_stop_enabled, + "expiration": job.expiration or "no_expire", + "execution_context": job.execution_context, + "credential_id": job.credential_id, + "use_service_account": job.use_service_account, + "enabled": job.enabled, + "created_at": job.created_at, + "updated_at": job.updated_at, + } + + def _serialize_run(self, run: ScheduledJobRunRecord) -> dict[str, Any]: + return { + "id": run.id, + "job_id": run.job_id, + "scheduled_ts": run.scheduled_ts, + "started_ts": run.started_ts, + "finished_ts": run.finished_ts, + "status": run.status, + "error": run.error, + "target_hostname": run.target_hostname, + "created_at": run.created_at, + "updated_at": run.updated_at, + } + + def _normalize_payload(self, payload: Mapping[str, Any]) -> dict[str, Any]: + name = str(payload.get("name") or "").strip() + components = payload.get("components") or [] + targets = payload.get("targets") or [] + schedule_block = payload.get("schedule") if isinstance(payload.get("schedule"), Mapping) else {} + schedule_type = str(schedule_block.get("type") or payload.get("schedule_type") or "immediately").strip().lower() + start_value = schedule_block.get("start") if isinstance(schedule_block, Mapping) else None + if start_value is None: + start_value = payload.get("start") + start_ts = _parse_ts(start_value) + duration_block = payload.get("duration") if isinstance(payload.get("duration"), Mapping) else {} + duration_stop = bool(duration_block.get("stopAfterEnabled") or payload.get("duration_stop_enabled")) + expiration = str(duration_block.get("expiration") or payload.get("expiration") or "no_expire").strip() + execution_context = str(payload.get("execution_context") or "system").strip().lower() + credential_id = payload.get("credential_id") + try: + credential_id = int(credential_id) if credential_id is not None else None + except Exception: + credential_id = None + use_service_account_raw = payload.get("use_service_account") + use_service_account = bool(use_service_account_raw) if execution_context == "winrm" else False + enabled = bool(payload.get("enabled", True)) + + if not name: + raise ValueError("job name is required") + if not isinstance(components, Iterable) or not list(components): + raise ValueError("at least one component is required") + if not isinstance(targets, Iterable) or not list(targets): + raise ValueError("at least one target is required") + + return { + "name": name, + "components": list(components), + "targets": list(targets), + "schedule_type": schedule_type, + "start_ts": start_ts, + "duration_stop_enabled": duration_stop, + "expiration": expiration, + "execution_context": execution_context, + "credential_id": credential_id, + "use_service_account": use_service_account, + "enabled": enabled, + } diff --git a/Data/Engine/services/rate_limit.py b/Data/Engine/services/rate_limit.py new file mode 100644 index 0000000..49b8fd8 --- /dev/null +++ b/Data/Engine/services/rate_limit.py @@ -0,0 +1,45 @@ +"""In-process rate limiting utilities for the Borealis Engine.""" + +from __future__ import annotations + +import time +from collections import deque +from dataclasses import dataclass +from threading import Lock +from typing import Deque, Dict + +__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"] + + +@dataclass(frozen=True, slots=True) +class RateLimitDecision: + """Result of a rate limit check.""" + + allowed: bool + retry_after: float + + +class SlidingWindowRateLimiter: + """Tiny in-memory sliding window limiter suitable for single-process use.""" + + def __init__(self) -> None: + self._buckets: Dict[str, Deque[float]] = {} + self._lock = Lock() + + def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision: + now = time.monotonic() + with self._lock: + bucket = self._buckets.get(key) + if bucket is None: + bucket = deque() + self._buckets[key] = bucket + + while bucket and now - bucket[0] > window_seconds: + bucket.popleft() + + if len(bucket) >= limit: + retry_after = max(0.0, window_seconds - (now - bucket[0])) + return RateLimitDecision(False, retry_after) + + bucket.append(now) + return RateLimitDecision(True, 0.0) diff --git a/Data/Engine/services/realtime/__init__.py b/Data/Engine/services/realtime/__init__.py new file mode 100644 index 0000000..661b469 --- /dev/null +++ b/Data/Engine/services/realtime/__init__.py @@ -0,0 +1,10 @@ +"""Realtime coordination services for the Borealis Engine.""" + +from __future__ import annotations + +from .agent_registry import AgentRealtimeService, AgentRecord + +__all__ = [ + "AgentRealtimeService", + "AgentRecord", +] diff --git a/Data/Engine/services/realtime/agent_registry.py b/Data/Engine/services/realtime/agent_registry.py new file mode 100644 index 0000000..5b9bb15 --- /dev/null +++ b/Data/Engine/services/realtime/agent_registry.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from typing import Any, Dict, Mapping, Optional, Tuple + +from Data.Engine.repositories.sqlite import SQLiteDeviceRepository + +__all__ = ["AgentRealtimeService", "AgentRecord"] + + +@dataclass(slots=True) +class AgentRecord: + """In-memory representation of a connected agent.""" + + agent_id: str + hostname: str = "unknown" + agent_operating_system: str = "-" + last_seen: int = 0 + status: str = "orphaned" + service_mode: str = "currentuser" + is_script_agent: bool = False + collector_active_ts: Optional[float] = None + + +class AgentRealtimeService: + """Track realtime agent presence and provide persistence hooks.""" + + def __init__( + self, + *, + device_repository: SQLiteDeviceRepository, + logger: Optional[logging.Logger] = None, + ) -> None: + self._device_repository = device_repository + self._log = logger or logging.getLogger("borealis.engine.services.realtime.agents") + self._agents: Dict[str, AgentRecord] = {} + self._configs: Dict[str, Dict[str, Any]] = {} + self._screenshots: Dict[str, Dict[str, Any]] = {} + self._task_screenshots: Dict[Tuple[str, str], Dict[str, Any]] = {} + + # ------------------------------------------------------------------ + # Agent presence management + # ------------------------------------------------------------------ + def register_connection(self, agent_id: str, service_mode: Optional[str]) -> AgentRecord: + record = self._agents.get(agent_id) or AgentRecord(agent_id=agent_id) + mode = self.normalize_service_mode(service_mode, agent_id) + now = int(time.time()) + + record.service_mode = mode + record.is_script_agent = self._is_script_agent(agent_id) + record.last_seen = now + record.status = "provisioned" if agent_id in self._configs else "orphaned" + self._agents[agent_id] = record + + self._persist_activity( + hostname=record.hostname, + last_seen=record.last_seen, + agent_id=agent_id, + operating_system=record.agent_operating_system, + ) + return record + + def heartbeat(self, payload: Mapping[str, Any]) -> Optional[AgentRecord]: + if not payload: + return None + + agent_id = payload.get("agent_id") + if not agent_id: + return None + + hostname = payload.get("hostname") or "" + mode = self.normalize_service_mode(payload.get("service_mode"), agent_id) + is_script_agent = self._is_script_agent(agent_id) + last_seen = self._coerce_int(payload.get("last_seen"), default=int(time.time())) + operating_system = (payload.get("agent_operating_system") or "-").strip() or "-" + + if hostname: + self._reconcile_hostname_collisions( + hostname=hostname, + agent_id=agent_id, + incoming_mode=mode, + is_script_agent=is_script_agent, + last_seen=last_seen, + ) + + record = self._agents.get(agent_id) or AgentRecord(agent_id=agent_id) + if hostname: + record.hostname = hostname + record.agent_operating_system = operating_system + record.last_seen = last_seen + record.service_mode = mode + record.is_script_agent = is_script_agent + record.status = "provisioned" if agent_id in self._configs else record.status or "orphaned" + self._agents[agent_id] = record + + self._persist_activity( + hostname=record.hostname or hostname, + last_seen=record.last_seen, + agent_id=agent_id, + operating_system=record.agent_operating_system, + ) + return record + + def collector_status(self, payload: Mapping[str, Any]) -> None: + if not payload: + return + + agent_id = payload.get("agent_id") + if not agent_id: + return + + hostname = payload.get("hostname") or "" + mode = self.normalize_service_mode(payload.get("service_mode"), agent_id) + active = bool(payload.get("active")) + last_user = (payload.get("last_user") or "").strip() + + record = self._agents.get(agent_id) or AgentRecord(agent_id=agent_id) + if hostname: + record.hostname = hostname + if mode: + record.service_mode = mode + record.is_script_agent = self._is_script_agent(agent_id) or record.is_script_agent + if active: + record.collector_active_ts = time.time() + self._agents[agent_id] = record + + if ( + last_user + and hostname + and self._is_valid_interactive_user(last_user) + and not self._is_system_service_agent(agent_id, record.is_script_agent) + ): + self._persist_activity( + hostname=hostname, + last_seen=int(time.time()), + agent_id=agent_id, + operating_system=record.agent_operating_system, + last_user=last_user, + ) + + # ------------------------------------------------------------------ + # Configuration management + # ------------------------------------------------------------------ + def set_agent_config(self, agent_id: str, config: Mapping[str, Any]) -> None: + self._configs[agent_id] = dict(config) + record = self._agents.get(agent_id) + if record: + record.status = "provisioned" + + def get_agent_config(self, agent_id: str) -> Optional[Dict[str, Any]]: + config = self._configs.get(agent_id) + if config is None: + return None + return dict(config) + + # ------------------------------------------------------------------ + # Screenshot caches + # ------------------------------------------------------------------ + def store_agent_screenshot(self, agent_id: str, image_base64: str) -> None: + self._screenshots[agent_id] = { + "image_base64": image_base64, + "timestamp": time.time(), + } + + def store_task_screenshot(self, agent_id: str, node_id: str, image_base64: str) -> None: + self._task_screenshots[(agent_id, node_id)] = { + "image_base64": image_base64, + "timestamp": time.time(), + } + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + @staticmethod + def normalize_service_mode(value: Optional[str], agent_id: Optional[str] = None) -> str: + text = "" + try: + if isinstance(value, str): + text = value.strip().lower() + except Exception: + text = "" + + if not text and agent_id: + try: + lowered = agent_id.lower() + if "-svc-" in lowered or lowered.endswith("-svc"): + return "system" + except Exception: + pass + + if text in {"system", "svc", "service", "system_service"}: + return "system" + if text in {"interactive", "currentuser", "user", "current_user"}: + return "currentuser" + return "currentuser" + + @staticmethod + def _coerce_int(value: Any, *, default: int = 0) -> int: + try: + return int(value) + except Exception: + return default + + @staticmethod + def _is_script_agent(agent_id: Optional[str]) -> bool: + try: + return bool(isinstance(agent_id, str) and agent_id.lower().endswith("-script")) + except Exception: + return False + + @staticmethod + def _is_valid_interactive_user(candidate: Optional[str]) -> bool: + if not candidate: + return False + try: + text = str(candidate).strip() + except Exception: + return False + if not text: + return False + upper = text.upper() + if text.endswith("$"): + return False + if "NT AUTHORITY\\" in upper or "NT SERVICE\\" in upper: + return False + if upper.endswith("\\SYSTEM"): + return False + if upper.endswith("\\LOCAL SERVICE"): + return False + if upper.endswith("\\NETWORK SERVICE"): + return False + if upper == "ANONYMOUS LOGON": + return False + return True + + @staticmethod + def _is_system_service_agent(agent_id: str, is_script_agent: bool) -> bool: + try: + lowered = agent_id.lower() + except Exception: + lowered = "" + if is_script_agent: + return False + return "-svc-" in lowered or lowered.endswith("-svc") + + def _reconcile_hostname_collisions( + self, + *, + hostname: str, + agent_id: str, + incoming_mode: str, + is_script_agent: bool, + last_seen: int, + ) -> None: + transferred_config = False + for existing_id, info in list(self._agents.items()): + if existing_id == agent_id: + continue + if info.hostname != hostname: + continue + existing_mode = self.normalize_service_mode(info.service_mode, existing_id) + if existing_mode != incoming_mode: + continue + if is_script_agent and not info.is_script_agent: + self._persist_activity( + hostname=hostname, + last_seen=last_seen, + agent_id=existing_id, + operating_system=info.agent_operating_system, + ) + return + if not transferred_config and existing_id in self._configs and agent_id not in self._configs: + self._configs[agent_id] = dict(self._configs[existing_id]) + transferred_config = True + self._agents.pop(existing_id, None) + if existing_id != agent_id: + self._configs.pop(existing_id, None) + + def _persist_activity( + self, + *, + hostname: Optional[str], + last_seen: Optional[int], + agent_id: Optional[str], + operating_system: Optional[str], + last_user: Optional[str] = None, + ) -> None: + if not hostname: + return + try: + self._device_repository.update_device_summary( + hostname=hostname, + last_seen=last_seen, + agent_id=agent_id, + operating_system=operating_system, + last_user=last_user, + ) + except Exception as exc: # pragma: no cover - defensive logging + self._log.debug("failed to persist device activity for %s: %s", hostname, exc) diff --git a/Data/Engine/tests/__init__.py b/Data/Engine/tests/__init__.py new file mode 100644 index 0000000..9840a5e --- /dev/null +++ b/Data/Engine/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite for the Borealis Engine.""" diff --git a/Data/Engine/tests/test_builders_device_auth.py b/Data/Engine/tests/test_builders_device_auth.py new file mode 100644 index 0000000..39d70d0 --- /dev/null +++ b/Data/Engine/tests/test_builders_device_auth.py @@ -0,0 +1,74 @@ +import unittest + +from Data.Engine.builders.device_auth import ( + DeviceAuthRequestBuilder, + RefreshTokenRequestBuilder, +) +from Data.Engine.domain.device_auth import DeviceAuthErrorCode, DeviceAuthFailure + + +class DeviceAuthRequestBuilderTests(unittest.TestCase): + def test_build_successful_request(self) -> None: + request = ( + DeviceAuthRequestBuilder() + .with_authorization("Bearer abc123") + .with_http_method("post") + .with_htu("https://example.test/api") + .with_service_context("currentUser") + .with_dpop_proof("proof") + .build() + ) + + self.assertEqual(request.access_token, "abc123") + self.assertEqual(request.http_method, "POST") + self.assertEqual(request.htu, "https://example.test/api") + self.assertEqual(request.service_context, "CURRENTUSER") + self.assertEqual(request.dpop_proof, "proof") + + def test_missing_authorization_raises_failure(self) -> None: + builder = ( + DeviceAuthRequestBuilder() + .with_http_method("GET") + .with_htu("/health") + ) + + with self.assertRaises(DeviceAuthFailure) as ctx: + builder.build() + + self.assertEqual(ctx.exception.code, DeviceAuthErrorCode.MISSING_AUTHORIZATION) + + +class RefreshTokenRequestBuilderTests(unittest.TestCase): + def test_refresh_request_requires_all_fields(self) -> None: + request = ( + RefreshTokenRequestBuilder() + .with_payload({"guid": "de305d54-75b4-431b-adb2-eb6b9e546014", "refresh_token": "tok"}) + .with_http_method("post") + .with_htu("https://example.test/api") + .with_dpop_proof("proof") + .build() + ) + + self.assertEqual(request.guid.value, "DE305D54-75B4-431B-ADB2-EB6B9E546014") + self.assertEqual(request.refresh_token, "tok") + self.assertEqual(request.http_method, "POST") + self.assertEqual(request.htu, "https://example.test/api") + self.assertEqual(request.dpop_proof, "proof") + + def test_refresh_request_missing_guid_raises_failure(self) -> None: + builder = ( + RefreshTokenRequestBuilder() + .with_payload({"refresh_token": "tok"}) + .with_http_method("POST") + .with_htu("https://example.test/api") + ) + + with self.assertRaises(DeviceAuthFailure) as ctx: + builder.build() + + self.assertEqual(ctx.exception.code, DeviceAuthErrorCode.INVALID_CLAIMS) + self.assertIn("missing guid", ctx.exception.detail) + + +if __name__ == "__main__": # pragma: no cover - convenience for local runs + unittest.main() diff --git a/Data/Engine/tests/test_config_environment.py b/Data/Engine/tests/test_config_environment.py new file mode 100644 index 0000000..b5b46a1 --- /dev/null +++ b/Data/Engine/tests/test_config_environment.py @@ -0,0 +1,44 @@ +"""Tests for environment configuration helpers.""" + +from __future__ import annotations + +from Data.Engine.config.environment import load_environment + + +def test_static_root_prefers_engine_runtime(tmp_path, monkeypatch): + """Engine static root should prefer the staged web-interface build.""" + + engine_build = tmp_path / "Engine" / "web-interface" / "build" + engine_build.mkdir(parents=True) + (engine_build / "index.html").write_text("", encoding="utf-8") + + # Ensure other fallbacks exist but should not be selected while the Engine + # runtime assets are present. + legacy_build = tmp_path / "Data" / "Server" / "WebUI" / "build" + legacy_build.mkdir(parents=True) + (legacy_build / "index.html").write_text("legacy", encoding="utf-8") + + monkeypatch.setenv("BOREALIS_ROOT", str(tmp_path)) + monkeypatch.delenv("BOREALIS_STATIC_ROOT", raising=False) + + settings = load_environment() + + assert settings.flask.static_root == engine_build.resolve() + + +def test_static_root_env_override(tmp_path, monkeypatch): + """Explicit overrides should win over filesystem detection.""" + + override = tmp_path / "custom" / "build" + override.mkdir(parents=True) + (override / "index.html").write_text("override", encoding="utf-8") + + monkeypatch.setenv("BOREALIS_ROOT", str(tmp_path)) + monkeypatch.setenv("BOREALIS_STATIC_ROOT", str(override)) + + settings = load_environment() + + assert settings.flask.static_root == override.resolve() + + monkeypatch.delenv("BOREALIS_STATIC_ROOT", raising=False) + monkeypatch.delenv("BOREALIS_ROOT", raising=False) diff --git a/Data/Engine/tests/test_crypto_certificates.py b/Data/Engine/tests/test_crypto_certificates.py new file mode 100644 index 0000000..4fa2fa7 --- /dev/null +++ b/Data/Engine/tests/test_crypto_certificates.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import importlib +import os +import shutil +import ssl +import sys +import tempfile +import unittest +from pathlib import Path + +from Data.Engine import runtime + + +class CertificateGenerationTests(unittest.TestCase): + def setUp(self) -> None: + self._tmpdir = Path(tempfile.mkdtemp(prefix="engine-cert-tests-")) + self.addCleanup(lambda: shutil.rmtree(self._tmpdir, ignore_errors=True)) + + self._previous_env: dict[str, str | None] = {} + for name in ("BOREALIS_CERTIFICATES_ROOT", "BOREALIS_SERVER_CERT_ROOT"): + self._previous_env[name] = os.environ.get(name) + os.environ[name] = str(self._tmpdir / name.lower()) + + runtime.certificates_root.cache_clear() + runtime.server_certificates_root.cache_clear() + + module_name = "Data.Engine.services.crypto.certificates" + if module_name in sys.modules: + del sys.modules[module_name] + + try: + self.certificates = importlib.import_module(module_name) + except ModuleNotFoundError as exc: # pragma: no cover - optional deps absent + self.skipTest(f"cryptography dependency unavailable: {exc}") + + def tearDown(self) -> None: # pragma: no cover - environment cleanup + for name, value in self._previous_env.items(): + if value is None: + os.environ.pop(name, None) + else: + os.environ[name] = value + runtime.certificates_root.cache_clear() + runtime.server_certificates_root.cache_clear() + + def test_ensure_certificate_creates_material(self) -> None: + cert_path, key_path, bundle_path = self.certificates.ensure_certificate() + + self.assertTrue(cert_path.exists(), "certificate was not generated") + self.assertTrue(key_path.exists(), "private key was not generated") + self.assertTrue(bundle_path.exists(), "bundle was not generated") + + context = self.certificates.build_ssl_context() + self.assertIsInstance(context, ssl.SSLContext) + self.assertEqual(context.minimum_version, ssl.TLSVersion.TLSv1_3) + + def test_certificate_paths_returns_strings(self) -> None: + cert_path, key_path, bundle_path = self.certificates.certificate_paths() + self.assertIsInstance(cert_path, str) + self.assertIsInstance(key_path, str) + self.assertIsInstance(bundle_path, str) + + +if __name__ == "__main__": # pragma: no cover - convenience + unittest.main() diff --git a/Data/Engine/tests/test_domain_device_auth.py b/Data/Engine/tests/test_domain_device_auth.py new file mode 100644 index 0000000..bcd7e9f --- /dev/null +++ b/Data/Engine/tests/test_domain_device_auth.py @@ -0,0 +1,59 @@ +import unittest + +from Data.Engine.domain.device_auth import ( + DeviceAuthErrorCode, + DeviceAuthFailure, + DeviceFingerprint, + DeviceGuid, + sanitize_service_context, +) + + +class DeviceGuidTests(unittest.TestCase): + def test_guid_normalization_accepts_braces_and_lowercase(self) -> None: + guid = DeviceGuid("{de305d54-75b4-431b-adb2-eb6b9e546014}") + self.assertEqual(guid.value, "DE305D54-75B4-431B-ADB2-EB6B9E546014") + + def test_guid_rejects_empty_string(self) -> None: + with self.assertRaises(ValueError): + DeviceGuid("") + + +class DeviceFingerprintTests(unittest.TestCase): + def test_fingerprint_normalization_trims_and_lowercases(self) -> None: + fingerprint = DeviceFingerprint(" AA:BB:CC ") + self.assertEqual(fingerprint.value, "aa:bb:cc") + + def test_fingerprint_rejects_blank_input(self) -> None: + with self.assertRaises(ValueError): + DeviceFingerprint(" ") + + +class ServiceContextTests(unittest.TestCase): + def test_sanitize_service_context_returns_uppercase_only(self) -> None: + self.assertEqual(sanitize_service_context("system"), "SYSTEM") + + def test_sanitize_service_context_filters_invalid_chars(self) -> None: + self.assertEqual(sanitize_service_context("sys tem!"), "SYSTEM") + + def test_sanitize_service_context_returns_none_for_empty_result(self) -> None: + self.assertIsNone(sanitize_service_context("@@@")) + + +class DeviceAuthFailureTests(unittest.TestCase): + def test_to_dict_includes_retry_after_and_detail(self) -> None: + failure = DeviceAuthFailure( + DeviceAuthErrorCode.RATE_LIMITED, + http_status=429, + retry_after=30, + detail="too many attempts", + ) + payload = failure.to_dict() + self.assertEqual( + payload, + {"error": "rate_limited", "retry_after": 30.0, "detail": "too many attempts"}, + ) + + +if __name__ == "__main__": # pragma: no cover - convenience for local runs + unittest.main() diff --git a/Data/Engine/tests/test_sqlite_migrations.py b/Data/Engine/tests/test_sqlite_migrations.py new file mode 100644 index 0000000..6361616 --- /dev/null +++ b/Data/Engine/tests/test_sqlite_migrations.py @@ -0,0 +1,32 @@ +import sqlite3 +import unittest + +from Data.Engine.repositories.sqlite import migrations + + +class MigrationTests(unittest.TestCase): + def test_apply_all_creates_expected_tables(self) -> None: + conn = sqlite3.connect(":memory:") + try: + migrations.apply_all(conn) + cursor = conn.cursor() + tables = { + row[0] + for row in cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ) + } + + self.assertIn("devices", tables) + self.assertIn("refresh_tokens", tables) + self.assertIn("enrollment_install_codes", tables) + self.assertIn("device_approvals", tables) + self.assertIn("scheduled_jobs", tables) + self.assertIn("scheduled_job_runs", tables) + self.assertIn("github_token", tables) + finally: + conn.close() + + +if __name__ == "__main__": # pragma: no cover - convenience for local runs + unittest.main()