mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 13:21:57 -06:00
Removed Experimental Engine
This commit is contained in:
@@ -1,67 +0,0 @@
|
||||
# Borealis Engine Migration Progress
|
||||
|
||||
[COMPLETED] 1. Stabilize Engine foundation (baseline already in repo)
|
||||
- 1.1 Confirm `Data/Engine/bootstrapper.py` launches the placeholder Flask app without side effects.
|
||||
- 1.2 Document environment variables/settings expected by the Engine to keep parity with legacy defaults.
|
||||
- 1.3 Verify Engine logging produces `Logs/Server/engine.log` entries alongside the legacy server.
|
||||
|
||||
[COMPLETED] 2. Introduce configuration & dependency wiring
|
||||
- 2.1 Create `config/environment.py` loaders mirroring legacy defaults (TLS paths, feature flags).
|
||||
- 2.2 Add settings dataclasses for Flask, Socket.IO, and DB paths; inject them via `server.py`.
|
||||
- 2.3 Commit once the Engine can start with equivalent config but no real routes.
|
||||
|
||||
[COMPLETED] 3. Copy Flask application scaffolding
|
||||
- 3.1 Port proxy/CORS/static setup from `Data/Server/server.py` into Engine `server.py` using dependency injection.
|
||||
- 3.2 Stub out blueprint/Socket.IO registration hooks that mirror names from legacy code (no logic yet).
|
||||
- 3.3 Smoke-test app startup via `python Data/Engine/bootstrapper.py` (or Flask CLI) to ensure no regressions.
|
||||
|
||||
[COMPLETED] 4. Establish SQLite infrastructure
|
||||
- 4.1 Copy `_db_conn` logic into `repositories/sqlite/connection.py`, parameterized by database path (`<root>/database.db`).
|
||||
- 4.2 Port migration helpers into `repositories/sqlite/migrations.py`; expose an `apply_all()` callable.
|
||||
- 4.3 Wire migrations to run during Engine bootstrap (behind a flag) and confirm tables initialize in a sandbox DB.
|
||||
- 4.4 Commit once DB connection + migrations succeed independently of legacy server.
|
||||
|
||||
[COMPLETED] 5. Extract authentication/enrollment domain surface
|
||||
- 5.1 Define immutable dataclasses in `domain/device_auth.py`, `domain/device_enrollment.py` for tokens, GUIDs, approvals.
|
||||
- 5.2 Map legacy error codes/enums into domain exceptions or enums in the same modules.
|
||||
- 5.3 Commit after unit tests (or doctests) validate dataclass invariants.
|
||||
|
||||
[COMPLETED] 6. Port authentication services
|
||||
- 6.1 Copy `DeviceAuthManager` logic into `services/auth/device_auth_service.py`, refactoring to use new repositories and domain types.
|
||||
- 6.2 Create `builders/device_auth.py` to assemble `DeviceAuthContext` from headers/DPoP proof.
|
||||
- 6.3 Mirror refresh token issuance into `services/auth/token_service.py`; use `builders/device_enrollment.py` for payload assembly.
|
||||
- 6.4 Commit once services pass targeted unit tests and integrate with placeholder repositories.
|
||||
|
||||
[COMPLETED] 7. Implement SQLite repositories
|
||||
- 7.1 Introduce `repositories/sqlite/device_repository.py`, `token_repository.py`, `enrollment_repository.py` using copied SQL.
|
||||
- 7.2 Write integration tests exercising CRUD against a temporary SQLite file.
|
||||
- 7.3 Commit when repositories provide the required ports used by services.
|
||||
|
||||
[COMPLETED] 8. Recreate HTTP interfaces
|
||||
- 8.1 Port health/enrollment/token blueprints into `interfaces/http/<feature>/routes.py`, calling Engine services only.
|
||||
- 8.2 Ensure request validation occurs via builders; response schemas stay aligned with legacy JSON.
|
||||
- 8.3 Register blueprints through Engine `server.py`; confirm endpoints respond via manual or automated tests.
|
||||
- 8.4 Commit after each major blueprint migration for clear milestones.
|
||||
|
||||
[COMPLETED] 9. Rebuild WebSocket interfaces
|
||||
- 9.1 Establish feature-scoped modules (e.g., `interfaces/ws/agents/events.py`) and copy event handlers.
|
||||
- 9.2 Replace global state with repository/service calls where feasible; otherwise encapsulate in Engine-managed caches.
|
||||
- 9.3 Validate namespace registration with Socket.IO test clients before committing.
|
||||
|
||||
[COMPLETED] 10. Scheduler & job management
|
||||
- 10.1 Port scheduler core into `services/jobs/scheduler_service.py`; wrap job state persistence via new repositories.
|
||||
- 10.2 Implement `builders/job_fabricator.py` for manifest assembly; ensure immutability and validation.
|
||||
- 10.3 Expose HTTP orchestration via `interfaces/http/job_management.py` and WS notifications via dedicated modules.
|
||||
- 10.4 Commit after scheduler can run a no-op job loop independently.
|
||||
|
||||
[COMPLETED] 11. GitHub integration
|
||||
- 11.1 Copy GitHub helper logic into `integrations/github/artifact_provider.py` with proper configuration injection.
|
||||
- 11.2 Provide repository/service hooks for fetching artifacts or repo heads; add resilience logging.
|
||||
- 11.3 Commit after integration tests (or mocked unit tests) confirm API workflows.
|
||||
|
||||
[COMPLETE] 12. Final parity verification
|
||||
- 12.1 Follow the staging playbook in `Data/Engine/STAGING_GUIDE.md` to stand up the Engine end-to-end and exercise enrollment, token refresh, agent connections, GitHub integration, and scheduler flows.
|
||||
- 12.2 Record any divergences in the staging guide’s table and address them with follow-up commits before cut-over.
|
||||
- 12.3 Once parity is confirmed, coordinate entrypoint switching (point deployment at `Data/Engine/bootstrapper.py`) and plan the legacy server deprecation.
|
||||
- Supporting documentation and unit tests live in `Data/Engine/README.md`, `Data/Engine/STAGING_GUIDE.md`, and `Data/Engine/tests/` to guide the remaining staging work.
|
||||
- Engine deployment now installs dependencies via `Data/Engine/requirements.txt` so parity runs include Flask, Socket.IO, and security packages.
|
||||
@@ -1,210 +0,0 @@
|
||||
# Borealis Engine Overview
|
||||
|
||||
The Engine is an additive server stack that will ultimately replace the legacy Flask app under `Data/Server`. It is safe to run the Engine entrypoint (`Data/Engine/bootstrapper.py`) side-by-side with the legacy server while we migrate functionality feature-by-feature.
|
||||
|
||||
## Architectural roles
|
||||
|
||||
The Engine is organized around explicit dependency layers so each concern stays
|
||||
testable and replaceable:
|
||||
|
||||
- **Configuration (`Data/Engine/config/`)** parses environment variables into
|
||||
immutable settings objects that the bootstrapper hands to factories and
|
||||
integrations.
|
||||
- **Builders (`Data/Engine/builders/`)** transform external inputs (HTTP
|
||||
headers, JSON payloads, scheduled job definitions) into validated immutable
|
||||
records that services can trust.
|
||||
- **Domain models (`Data/Engine/domain/`)** house pure value objects, enums, and
|
||||
error types with no I/O so services can express intent without depending on
|
||||
Flask or SQLite.
|
||||
- **Repositories (`Data/Engine/repositories/`)** encapsulate all SQLite access
|
||||
and expose protocol methods that return domain models. They are injected into
|
||||
services through the container so persistence can be swapped or mocked.
|
||||
- **Services (`Data/Engine/services/`)** host business logic such as device
|
||||
authentication, enrollment, job scheduling, GitHub artifact lookups, and
|
||||
real-time agent coordination. Services depend only on repositories,
|
||||
integrations, and builders.
|
||||
- **Integrations (`Data/Engine/integrations/`)** wrap external systems (GitHub
|
||||
today) and keep HTTP/token handling outside the services that consume them.
|
||||
- **Interfaces (`Data/Engine/interfaces/`)** provide thin HTTP/Socket.IO
|
||||
adapters that translate requests to builder/service calls and serialize
|
||||
responses. They contain no business rules of their own.
|
||||
|
||||
The runtime factory (`Data/Engine/runtime.py`) wires these layers together and
|
||||
attaches the resulting container to the Flask app created in
|
||||
`Data/Engine/server.py`.
|
||||
|
||||
## Environment configuration
|
||||
|
||||
The Engine mirrors the legacy defaults so it can boot without additional configuration. These environment variables are read by `Data/Engine/config/environment.py`:
|
||||
|
||||
| Variable | Purpose | Default |
|
||||
| --- | --- | --- |
|
||||
| `BOREALIS_ROOT` | Overrides automatic project root detection. Useful when running from a packaged location. | Directory two levels above `Data/Engine/` |
|
||||
| `BOREALIS_DATABASE_PATH` | Path to the SQLite database. | `<project_root>/database.db` |
|
||||
| `BOREALIS_ENGINE_AUTO_MIGRATE` | Run Engine-managed schema migrations during bootstrap (`true`/`false`). | `true` |
|
||||
| `BOREALIS_STATIC_ROOT` | Directory that serves static assets for the SPA. | First existing path among `Engine/web-interface/build`, `Engine/web-interface/dist`, `Data/Engine/web-interface/build`, `Data/Server/WebUI/build`, `Data/Server/web-interface/build`, `Data/WebUI/build` |
|
||||
| `BOREALIS_CORS_ALLOWED_ORIGINS` | Comma-delimited list of origins granted CORS access. Use `*` for all origins. | `*` |
|
||||
| `BOREALIS_FLASK_SECRET_KEY` | Secret key for Flask session signing. | `change-me` |
|
||||
| `BOREALIS_DEBUG` | Enables debug logging, disables secure-cookie requirements, and allows Werkzeug debug mode. | `false` |
|
||||
| `BOREALIS_HOST` | Bind address for the HTTP/Socket.IO server. | `127.0.0.1` |
|
||||
| `BOREALIS_PORT` | Bind port for the HTTP/Socket.IO server. | `5000` |
|
||||
| `BOREALIS_REPO` | Default GitHub repository (`owner/name`) for artifact lookups. | `bunny-lab-io/Borealis` |
|
||||
| `BOREALIS_REPO_BRANCH` | Default branch tracked by the Engine GitHub integration. | `main` |
|
||||
| `BOREALIS_REPO_HASH_REFRESH` | Seconds between default repository head refresh attempts (clamped 30-3600). | `60` |
|
||||
| `BOREALIS_CACHE_DIR` | Directory used to persist Engine cache files (GitHub repo head cache). | `<project_root>/Data/Engine/cache` |
|
||||
| `BOREALIS_CERTIFICATES_ROOT` | Overrides where TLS certificates (root CA + leaf) are stored. | `<project_root>/Certificates` |
|
||||
| `BOREALIS_SERVER_CERT_ROOT` | Directly points to the Engine server certificate directory if certificates are staged elsewhere. | `<project_root>/Certificates/Server` |
|
||||
|
||||
The launch scripts (`Borealis.ps1` / `Borealis.sh`) automatically synchronize
|
||||
`Data/Server/WebUI` into `Data/Engine/web-interface` when the Engine’s copy is
|
||||
missing. The repository keeps that directory mostly empty (except for
|
||||
documentation) so Git history does not duplicate the large SPA payload, but the
|
||||
runtime staging still ensures Vite reads from the Engine tree.
|
||||
|
||||
## TLS and transport stack
|
||||
|
||||
`Data/Engine/services/crypto/certificates.py` mirrors the legacy certificate
|
||||
generator so the Engine always serves HTTPS with a self-managed root CA and
|
||||
leaf certificate. During bootstrap the Engine:
|
||||
|
||||
1. Runs the certificate helper to ensure the root CA, server key, and bundle
|
||||
exist under `Certificates/Server/` (or the configured override path).
|
||||
2. Exposes the resulting bundle via `BOREALIS_TLS_BUNDLE` so enrollment flows
|
||||
can deliver the pinned certificate to agents.
|
||||
3. Launches Socket.IO/Eventlet with the generated cert/key pair. A fallback to
|
||||
Werkzeug’s TLS support keeps HTTPS available even if Socket.IO is disabled.
|
||||
|
||||
`Data/Engine/interfaces/eventlet_compat.py` applies the same Eventlet monkey
|
||||
patch as the legacy server so TLS handshakes presented to the HTTP listener are
|
||||
handled quietly instead of surfacing `400 Bad Request` noise when non-TLS
|
||||
clients connect.
|
||||
|
||||
## Logging expectations
|
||||
|
||||
`Data/Engine/config/logging.py` configures a timed rotating file handler that writes to `Logs/Server/engine.log`. Each entry follows the `<timestamp>-engine-<message>` format required by the project logging policy. The handler is attached to both the Engine logger (`borealis.engine`) and the root logger so that third-party frameworks share the same log destination.
|
||||
|
||||
## Bootstrapping flow
|
||||
|
||||
1. `Data/Engine/bootstrapper.py` loads the environment, configures logging, prepares the SQLite connection factory, optionally applies schema migrations, and builds the Flask application via `Data/Engine/server.py`.
|
||||
2. A service container is assembled (`Data/Engine/services/container.py`) that wires repositories, JWT/DPoP helpers, and Engine services (device auth, token refresh, enrollment). The container is stored on the Flask app for interface modules to consume.
|
||||
3. HTTP and Socket.IO interfaces register against the new service container. The resulting runtime object exposes the Flask app, resolved settings, optional Socket.IO server, and the configured database connection factory. `bootstrapper.main()` runs the appropriate server based on whether Socket.IO is present.
|
||||
|
||||
As migration continues, services, repositories, interfaces, and integrations will live under their respective subpackages while maintaining isolation from the legacy server.
|
||||
|
||||
## Python dependencies
|
||||
|
||||
`Data/Engine/requirements.txt` mirrors the minimal runtime stack (Flask, Flask-SocketIO, CORS, requests, PyJWT, and cryptography) needed by the Engine entrypoint. The PowerShell launcher consumes this file when preparing the `Engine/` virtual environment so parity tests always run against an environment with the expected web and security packages preinstalled.
|
||||
|
||||
## HTTP interfaces
|
||||
|
||||
The Engine now exposes working HTTP routes alongside the remaining scaffolding:
|
||||
|
||||
- `Data/Engine/interfaces/http/health.py` implements `GET /health` for liveness probes.
|
||||
- `Data/Engine/interfaces/http/tokens.py` ports the refresh-token endpoint (`POST /api/agent/token/refresh`) using the Engine `TokenService` and request builders.
|
||||
- `Data/Engine/interfaces/http/enrollment.py` handles the enrollment handshake (`/api/agent/enroll/request` and `/api/agent/enroll/poll`) with rate limiting, nonce protection, and repository-backed approvals.
|
||||
- The admin and agent blueprints remain placeholders until their services migrate.
|
||||
|
||||
## WebSocket interfaces
|
||||
|
||||
Step 9 introduces real-time handlers backed by the new service container:
|
||||
|
||||
- `Data/Engine/services/realtime/agent_registry.py` manages connected-agent state, last-seen persistence, collector updates, and screenshot caches without sharing globals with the legacy server.
|
||||
- `Data/Engine/interfaces/ws/agents/events.py` ports the agent namespace, handling connect/disconnect logging, heartbeat reconciliation, screenshot relays, macro status broadcasts, and provisioning lookups through the realtime service.
|
||||
- `Data/Engine/interfaces/ws/job_management/events.py` now forwards scheduler updates and responds to job status requests, keeping WebSocket clients informed as new runs are simulated.
|
||||
|
||||
The WebSocket factory (`Data/Engine/interfaces/ws/__init__.py`) now accepts the Engine service container so namespaces can resolve dependencies just like their HTTP counterparts.
|
||||
|
||||
## Authentication services
|
||||
|
||||
Step 6 introduces the first real Engine services:
|
||||
|
||||
- `Data/Engine/builders/device_auth.py` normalizes headers for access-token authentication and token refresh payloads.
|
||||
- `Data/Engine/builders/device_enrollment.py` prepares enrollment payloads and nonce proof challenges for future migration steps.
|
||||
- `Data/Engine/services/auth/device_auth_service.py` ports the legacy `DeviceAuthManager` into a repository-driven service that emits `DeviceAuthContext` instances from the new domain layer.
|
||||
- `Data/Engine/services/auth/token_service.py` issues refreshed access tokens while enforcing DPoP bindings and repository lookups.
|
||||
|
||||
Interfaces now consume these services via the shared container, keeping business logic inside the Engine service layer while HTTP modules remain thin request/response translators.
|
||||
|
||||
## SQLite repositories
|
||||
|
||||
Step 7 ports the first persistence adapters into the Engine:
|
||||
|
||||
- `Data/Engine/repositories/sqlite/device_repository.py` exposes `SQLiteDeviceRepository`, mirroring the legacy device lookups and automatic record recovery used during authentication.
|
||||
- `Data/Engine/repositories/sqlite/token_repository.py` provides `SQLiteRefreshTokenRepository` for refresh-token validation, DPoP binding management, and usage timestamps.
|
||||
- `Data/Engine/repositories/sqlite/enrollment_repository.py` surfaces enrollment install-code counters and device approval records so future services can operate without touching raw SQL.
|
||||
|
||||
Each repository accepts the shared `SQLiteConnectionFactory`, keeping all SQL execution confined to the Engine layer while services depend only on protocol interfaces.
|
||||
|
||||
## Job scheduling services
|
||||
|
||||
Step 10 migrates the foundational job scheduler into the Engine:
|
||||
|
||||
- `Data/Engine/builders/job_fabricator.py` transforms stored job definitions into immutable manifests, decoding scripts, resolving environment variables, and preparing execution metadata.
|
||||
- `Data/Engine/repositories/sqlite/job_repository.py` encapsulates scheduled job persistence, run history, and status tracking in SQLite.
|
||||
- `Data/Engine/services/jobs/scheduler_service.py` runs the background evaluation loop, emits Socket.IO lifecycle events, and exposes CRUD helpers for the HTTP and WebSocket interfaces.
|
||||
- `Data/Engine/interfaces/http/job_management.py` mirrors the legacy REST surface for creating, updating, toggling, and inspecting scheduled jobs and their run history.
|
||||
|
||||
The scheduler service starts automatically from `Data/Engine/bootstrapper.py` once the Engine runtime builds the service container, ensuring a no-op scheduling loop executes independently of the legacy server.
|
||||
|
||||
## GitHub integration
|
||||
|
||||
Step 11 migrates the GitHub artifact provider into the Engine:
|
||||
|
||||
- `Data/Engine/integrations/github/artifact_provider.py` caches branch head lookups, verifies API tokens, and optionally refreshes the default repository in the background.
|
||||
- `Data/Engine/repositories/sqlite/github_repository.py` persists the GitHub API token so HTTP handlers do not speak to SQLite directly.
|
||||
- `Data/Engine/services/github/github_service.py` coordinates token caching, verification, and repo head lookups for both HTTP and background refresh flows.
|
||||
- `Data/Engine/interfaces/http/github.py` exposes `/api/repo/current_hash` and `/api/github/token` through the Engine stack while keeping business logic in the service layer.
|
||||
|
||||
The service container now wires `github_service`, giving other interfaces and background jobs a clean entry point for GitHub functionality.
|
||||
|
||||
## Final parity checklist
|
||||
|
||||
Step 12 tracks the final integration work required before switching over to the
|
||||
Engine entrypoint. Use the detailed playbook in
|
||||
[`Data/Engine/STAGING_GUIDE.md`](./STAGING_GUIDE.md) to coordinate each
|
||||
staging run:
|
||||
|
||||
1. Stand up the Engine in a staging environment and exercise enrollment, token
|
||||
refresh, scheduler operations, and the agent real-time channel side-by-side
|
||||
with the legacy server.
|
||||
2. Capture any behavioural differences uncovered during staging using the
|
||||
divergence table in the staging guide and file them for follow-up fixes
|
||||
before the cut-over.
|
||||
3. When satisfied with parity, coordinate the entrypoint swap (point production
|
||||
tooling at `Data/Engine/bootstrapper.py`) and plan the deprecation of
|
||||
`Data/Server`.
|
||||
|
||||
## Performing unit tests
|
||||
|
||||
Targeted unit tests cover the most important domain, builder, repository, and
|
||||
migration behaviours without requiring Flask or external services. Run them
|
||||
with the standard library test runner:
|
||||
|
||||
```bash
|
||||
python -m unittest discover Data/Engine/tests
|
||||
```
|
||||
|
||||
The suite currently validates:
|
||||
|
||||
- Domain normalization helpers for GUIDs, fingerprints, and authentication
|
||||
failures.
|
||||
- Device authentication and refresh-token builders, including error handling for
|
||||
malformed requests.
|
||||
- SQLite schema migrations to ensure the Engine can provision required tables in
|
||||
a fresh database.
|
||||
- TLS certificate provisioning helpers to guarantee HTTPS material exists before
|
||||
the Engine starts serving requests.
|
||||
|
||||
Successful execution prints a summary similar to:
|
||||
|
||||
```
|
||||
.............
|
||||
----------------------------------------------------------------------
|
||||
Ran 13 tests in <N>.<M>s
|
||||
|
||||
OK
|
||||
```
|
||||
|
||||
Additional tests should follow the same pattern and live under
|
||||
`Data/Engine/tests/` so this command remains the single entry point for Engine
|
||||
unit verification.
|
||||
@@ -1,116 +0,0 @@
|
||||
# Engine Staging & Parity Guide
|
||||
|
||||
This guide supports Step 12 of the migration plan by walking operators through
|
||||
standing up the Engine alongside the legacy server, validating core workflows,
|
||||
and documenting any behavioural gaps before switching the production entrypoint
|
||||
to `Data/Engine/bootstrapper.py`.
|
||||
|
||||
## 1. Prerequisites
|
||||
|
||||
- Python 3.11 or later available on the host.
|
||||
- A clone of the Borealis repository with the Engine tree checked out.
|
||||
- Access to the legacy runtime assets (certificates, TLS bundle, etc.).
|
||||
- Optional: a staging agent install for end-to-end WebSocket validation.
|
||||
|
||||
Ensure the SQLite database lives at `<project_root>/database.db` and that the
|
||||
Engine migrations have already run (they execute automatically when the
|
||||
`BOREALIS_ENGINE_AUTO_MIGRATE` environment variable is left at its default
|
||||
`true`).
|
||||
|
||||
## 2. Launching the Engine in staging mode
|
||||
|
||||
1. Open a terminal at the project root.
|
||||
2. Set any environment overrides required for the test scenario (for example,
|
||||
`BOREALIS_DEBUG=true` to surface verbose logging, or
|
||||
`BOREALIS_CORS_ALLOWED_ORIGINS=https://localhost:3000` when pairing with the
|
||||
React UI).
|
||||
3. Run the Engine entrypoint:
|
||||
|
||||
```bash
|
||||
python Data/Engine/bootstrapper.py
|
||||
```
|
||||
|
||||
4. Verify `Logs/Server/engine.log` is created and that the startup entries are
|
||||
timestamped `<timestamp>-engine-<message>`.
|
||||
|
||||
Keep the legacy server running in a separate process if comparative testing is
|
||||
required; they do not share global state.
|
||||
|
||||
## 3. Feature validation checklist
|
||||
|
||||
Work through the following areas and tick each box once verified. Capture any
|
||||
issues in the log table in §4.
|
||||
|
||||
### Authentication and tokens
|
||||
|
||||
- [ ] `POST /api/agent/token/refresh` returns a new access token when supplied a
|
||||
valid refresh token + DPoP proof.
|
||||
- [ ] Invalid DPoP proofs or revoked refresh tokens yield the expected HTTP 401
|
||||
responses and structured error payloads.
|
||||
- [ ] Device last-seen metadata updates inside the database after a successful
|
||||
refresh.
|
||||
|
||||
### Enrollment
|
||||
|
||||
- [ ] `POST /api/agent/enroll/request` produces an enrollment ticket with the
|
||||
correct expiration and retry counters.
|
||||
- [ ] `POST /api/agent/enroll/poll` transitions an approved device into an
|
||||
authenticated state and returns the TLS bundle.
|
||||
- [ ] Audit logging for approvals lands in `Logs/Server/engine.log`.
|
||||
|
||||
### Job management
|
||||
|
||||
- [ ] `POST /api/jobs` (or UI equivalent) creates a scheduled job and returns a
|
||||
manifest identifier.
|
||||
- [ ] `GET /api/jobs/<id>` surfaces the stored manifest with normalized
|
||||
schedules and environment variables.
|
||||
- [ ] Job lifecycle events arrive over the `job_management` Socket.IO namespace
|
||||
when a job transitions between `pending`, `running`, and `completed`.
|
||||
|
||||
### Real-time agents
|
||||
|
||||
- [ ] Agents connecting to the `agents` namespace appear in the realtime roster
|
||||
with accurate hostname, username, and fingerprint details.
|
||||
- [ ] Screenshot broadcasts relay from agents to the UI without residual cache
|
||||
bleed-through after disconnects.
|
||||
- [ ] Macro execution responses round-trip through Socket.IO and reach the
|
||||
initiating client.
|
||||
|
||||
### GitHub integration
|
||||
|
||||
- [ ] `GET /api/repo/current_hash` reflects the latest branch head and caches
|
||||
repeated calls.
|
||||
- [ ] `POST /api/github/token` persists a new token and survives Engine restarts
|
||||
(confirm via database inspection).
|
||||
- [ ] The background refresher logs rate-limit warnings instead of raising
|
||||
uncaught exceptions when the GitHub API throttles requests.
|
||||
|
||||
## 4. Recording divergences
|
||||
|
||||
Use the table below to document behavioural differences or bugs uncovered during
|
||||
staging. This artifact should accompany the staging run summary so follow-up
|
||||
fixes can be triaged quickly.
|
||||
|
||||
| Area | Legacy Behaviour | Engine Behaviour | Notes / Links |
|
||||
| --- | --- | --- | --- |
|
||||
| Authentication | | | |
|
||||
| Enrollment | | | |
|
||||
| Scheduler | | | |
|
||||
| Realtime | | | |
|
||||
| GitHub | | | |
|
||||
| Other | | | |
|
||||
|
||||
## 5. Cut-over readiness
|
||||
|
||||
Once every checklist item passes and no critical divergences remain:
|
||||
|
||||
1. Update `Data/Engine/CURRENT_STAGE.md` with the completion date for Step 12.
|
||||
2. Coordinate with the operator to switch deployment scripts to
|
||||
`Data/Engine/bootstrapper.py`.
|
||||
3. Plan a rollback strategy (typically re-launching the legacy server) should
|
||||
issues appear immediately after the cut-over.
|
||||
4. Archive the filled divergence table alongside Engine logs for historical
|
||||
traceability.
|
||||
|
||||
Document the results in project tracking tools before moving on to deprecating
|
||||
`Data/Server`.
|
||||
@@ -1,11 +0,0 @@
|
||||
"""Borealis Engine package.
|
||||
|
||||
This namespace contains the next-generation server implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"bootstrapper",
|
||||
"server",
|
||||
]
|
||||
@@ -1,118 +0,0 @@
|
||||
"""Entrypoint for the Borealis Engine server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from .config import EngineSettings, configure_logging, load_environment
|
||||
from .interfaces import (
|
||||
create_socket_server,
|
||||
register_http_interfaces,
|
||||
register_ws_interfaces,
|
||||
)
|
||||
from .interfaces.eventlet_compat import apply_eventlet_patches
|
||||
from .repositories.sqlite import connection as sqlite_connection
|
||||
from .repositories.sqlite import migrations as sqlite_migrations
|
||||
from .server import create_app
|
||||
from .services.container import build_service_container
|
||||
from .services.crypto.certificates import ensure_certificate
|
||||
|
||||
|
||||
apply_eventlet_patches()
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EngineRuntime:
|
||||
"""Aggregated runtime context produced by :func:`bootstrap`."""
|
||||
|
||||
app: Flask
|
||||
settings: EngineSettings
|
||||
socketio: Optional[object]
|
||||
db_factory: sqlite_connection.SQLiteConnectionFactory
|
||||
tls_certificate: Path
|
||||
tls_key: Path
|
||||
tls_bundle: Path
|
||||
|
||||
|
||||
def bootstrap() -> EngineRuntime:
|
||||
"""Construct the Flask application and supporting infrastructure."""
|
||||
|
||||
settings = load_environment()
|
||||
logger = configure_logging(settings)
|
||||
logger.info("bootstrap-started")
|
||||
|
||||
cert_path, key_path, bundle_path = ensure_certificate()
|
||||
os.environ.setdefault("BOREALIS_TLS_BUNDLE", str(bundle_path))
|
||||
logger.info(
|
||||
"tls-material-ready",
|
||||
extra={
|
||||
"cert_path": str(cert_path),
|
||||
"key_path": str(key_path),
|
||||
"bundle_path": str(bundle_path),
|
||||
},
|
||||
)
|
||||
|
||||
db_factory = sqlite_connection.connection_factory(settings.database_path)
|
||||
if settings.apply_migrations:
|
||||
logger.info("migrations-start")
|
||||
with sqlite_connection.connection_scope(settings.database_path) as conn:
|
||||
sqlite_migrations.apply_all(conn)
|
||||
logger.info("migrations-complete")
|
||||
else:
|
||||
logger.info("migrations-skipped")
|
||||
|
||||
with sqlite_connection.connection_scope(settings.database_path) as conn:
|
||||
sqlite_migrations.ensure_default_admin(conn)
|
||||
logger.info("default-admin-ensured")
|
||||
|
||||
app = create_app(settings, db_factory=db_factory)
|
||||
services = build_service_container(settings, db_factory=db_factory, logger=logger.getChild("services"))
|
||||
app.extensions["engine_services"] = services
|
||||
register_http_interfaces(app, services)
|
||||
|
||||
socketio = create_socket_server(app, settings.socketio)
|
||||
register_ws_interfaces(socketio, services)
|
||||
services.scheduler_service.start(socketio)
|
||||
logger.info("bootstrap-complete")
|
||||
return EngineRuntime(
|
||||
app=app,
|
||||
settings=settings,
|
||||
socketio=socketio,
|
||||
db_factory=db_factory,
|
||||
tls_certificate=cert_path,
|
||||
tls_key=key_path,
|
||||
tls_bundle=bundle_path,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
runtime = bootstrap()
|
||||
socketio = runtime.socketio
|
||||
certfile = str(runtime.tls_bundle)
|
||||
keyfile = str(runtime.tls_key)
|
||||
|
||||
if socketio is not None:
|
||||
socketio.run( # type: ignore[call-arg]
|
||||
runtime.app,
|
||||
host=runtime.settings.server.host,
|
||||
port=runtime.settings.server.port,
|
||||
debug=runtime.settings.debug,
|
||||
certfile=certfile,
|
||||
keyfile=keyfile,
|
||||
)
|
||||
else:
|
||||
runtime.app.run(
|
||||
host=runtime.settings.server.host,
|
||||
port=runtime.settings.server.port,
|
||||
debug=runtime.settings.debug,
|
||||
ssl_context=(certfile, keyfile),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover - manual execution
|
||||
main()
|
||||
@@ -1,45 +0,0 @@
|
||||
"""Builder utilities for constructing immutable Engine aggregates."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .device_auth import (
|
||||
DeviceAuthRequest,
|
||||
DeviceAuthRequestBuilder,
|
||||
RefreshTokenRequest,
|
||||
RefreshTokenRequestBuilder,
|
||||
)
|
||||
from .operator_auth import (
|
||||
OperatorLoginRequest,
|
||||
OperatorMFAVerificationRequest,
|
||||
build_login_request,
|
||||
build_mfa_request,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthRequest",
|
||||
"DeviceAuthRequestBuilder",
|
||||
"RefreshTokenRequest",
|
||||
"RefreshTokenRequestBuilder",
|
||||
"OperatorLoginRequest",
|
||||
"OperatorMFAVerificationRequest",
|
||||
"build_login_request",
|
||||
"build_mfa_request",
|
||||
]
|
||||
|
||||
try: # pragma: no cover - optional dependency shim
|
||||
from .device_enrollment import (
|
||||
EnrollmentRequestBuilder,
|
||||
ProofChallengeBuilder,
|
||||
)
|
||||
except ModuleNotFoundError as exc: # pragma: no cover - executed when crypto deps missing
|
||||
_missing_reason = str(exc)
|
||||
|
||||
def _missing_builder(*_args: object, **_kwargs: object) -> None:
|
||||
raise ModuleNotFoundError(
|
||||
"device enrollment builders require optional cryptography dependencies"
|
||||
) from exc
|
||||
|
||||
EnrollmentRequestBuilder = _missing_builder # type: ignore[assignment]
|
||||
ProofChallengeBuilder = _missing_builder # type: ignore[assignment]
|
||||
else:
|
||||
__all__ += ["EnrollmentRequestBuilder", "ProofChallengeBuilder"]
|
||||
@@ -1,165 +0,0 @@
|
||||
"""Builders for device authentication and token refresh inputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from Data.Engine.domain.device_auth import (
|
||||
DeviceAuthErrorCode,
|
||||
DeviceAuthFailure,
|
||||
DeviceGuid,
|
||||
sanitize_service_context,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthRequest",
|
||||
"DeviceAuthRequestBuilder",
|
||||
"RefreshTokenRequest",
|
||||
"RefreshTokenRequestBuilder",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceAuthRequest:
|
||||
"""Normalized authentication inputs derived from an HTTP request."""
|
||||
|
||||
access_token: str
|
||||
http_method: str
|
||||
htu: str
|
||||
service_context: Optional[str]
|
||||
dpop_proof: Optional[str]
|
||||
|
||||
|
||||
class DeviceAuthRequestBuilder:
|
||||
"""Validate and normalize HTTP headers for device authentication."""
|
||||
|
||||
_authorization: Optional[str]
|
||||
_http_method: Optional[str]
|
||||
_htu: Optional[str]
|
||||
_service_context: Optional[str]
|
||||
_dpop_proof: Optional[str]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._authorization = None
|
||||
self._http_method = None
|
||||
self._htu = None
|
||||
self._service_context = None
|
||||
self._dpop_proof = None
|
||||
|
||||
def with_authorization(self, header_value: Optional[str]) -> "DeviceAuthRequestBuilder":
|
||||
if header_value is None:
|
||||
self._authorization = None
|
||||
else:
|
||||
self._authorization = header_value.strip()
|
||||
return self
|
||||
|
||||
def with_http_method(self, method: Optional[str]) -> "DeviceAuthRequestBuilder":
|
||||
self._http_method = (method or "").strip().upper()
|
||||
return self
|
||||
|
||||
def with_htu(self, url: Optional[str]) -> "DeviceAuthRequestBuilder":
|
||||
self._htu = (url or "").strip()
|
||||
return self
|
||||
|
||||
def with_service_context(self, header_value: Optional[str]) -> "DeviceAuthRequestBuilder":
|
||||
self._service_context = sanitize_service_context(header_value)
|
||||
return self
|
||||
|
||||
def with_dpop_proof(self, proof: Optional[str]) -> "DeviceAuthRequestBuilder":
|
||||
self._dpop_proof = (proof or "").strip() or None
|
||||
return self
|
||||
|
||||
def build(self) -> DeviceAuthRequest:
|
||||
token = self._parse_authorization(self._authorization)
|
||||
method = (self._http_method or "").strip().upper()
|
||||
if not method:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN, detail="missing HTTP method")
|
||||
url = (self._htu or "").strip()
|
||||
if not url:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN, detail="missing request URL")
|
||||
return DeviceAuthRequest(
|
||||
access_token=token,
|
||||
http_method=method,
|
||||
htu=url,
|
||||
service_context=self._service_context,
|
||||
dpop_proof=self._dpop_proof,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_authorization(header_value: Optional[str]) -> str:
|
||||
header = (header_value or "").strip()
|
||||
if not header:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.MISSING_AUTHORIZATION)
|
||||
prefix = "Bearer "
|
||||
if not header.startswith(prefix):
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.MISSING_AUTHORIZATION)
|
||||
token = header[len(prefix) :].strip()
|
||||
if not token:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.MISSING_AUTHORIZATION)
|
||||
return token
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RefreshTokenRequest:
|
||||
"""Validated refresh token payload supplied by an agent."""
|
||||
|
||||
guid: DeviceGuid
|
||||
refresh_token: str
|
||||
http_method: str
|
||||
htu: str
|
||||
dpop_proof: Optional[str]
|
||||
|
||||
|
||||
class RefreshTokenRequestBuilder:
|
||||
"""Helper to normalize refresh token JSON payloads."""
|
||||
|
||||
_guid: Optional[str]
|
||||
_refresh_token: Optional[str]
|
||||
_http_method: Optional[str]
|
||||
_htu: Optional[str]
|
||||
_dpop_proof: Optional[str]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._guid = None
|
||||
self._refresh_token = None
|
||||
self._http_method = None
|
||||
self._htu = None
|
||||
self._dpop_proof = None
|
||||
|
||||
def with_payload(self, payload: Optional[dict[str, object]]) -> "RefreshTokenRequestBuilder":
|
||||
payload = payload or {}
|
||||
self._guid = str(payload.get("guid") or "").strip()
|
||||
self._refresh_token = str(payload.get("refresh_token") or "").strip()
|
||||
return self
|
||||
|
||||
def with_http_method(self, method: Optional[str]) -> "RefreshTokenRequestBuilder":
|
||||
self._http_method = (method or "").strip().upper()
|
||||
return self
|
||||
|
||||
def with_htu(self, url: Optional[str]) -> "RefreshTokenRequestBuilder":
|
||||
self._htu = (url or "").strip()
|
||||
return self
|
||||
|
||||
def with_dpop_proof(self, proof: Optional[str]) -> "RefreshTokenRequestBuilder":
|
||||
self._dpop_proof = (proof or "").strip() or None
|
||||
return self
|
||||
|
||||
def build(self) -> RefreshTokenRequest:
|
||||
if not self._guid:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_CLAIMS, detail="missing guid")
|
||||
if not self._refresh_token:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_CLAIMS, detail="missing refresh token")
|
||||
method = (self._http_method or "").strip().upper()
|
||||
if not method:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN, detail="missing HTTP method")
|
||||
url = (self._htu or "").strip()
|
||||
if not url:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN, detail="missing request URL")
|
||||
return RefreshTokenRequest(
|
||||
guid=DeviceGuid(self._guid),
|
||||
refresh_token=self._refresh_token,
|
||||
http_method=method,
|
||||
htu=url,
|
||||
dpop_proof=self._dpop_proof,
|
||||
)
|
||||
@@ -1,131 +0,0 @@
|
||||
"""Builder utilities for device enrollment payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from Data.Engine.domain.device_auth import DeviceFingerprint, sanitize_service_context
|
||||
from Data.Engine.domain.device_enrollment import (
|
||||
EnrollmentValidationError,
|
||||
ProofChallenge,
|
||||
)
|
||||
from Data.Engine.integrations.crypto import keys as crypto_keys
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentRequestBuilder",
|
||||
"EnrollmentRequestInput",
|
||||
"ProofChallengeBuilder",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentRequestInput:
|
||||
"""Structured enrollment request payload ready for the service layer."""
|
||||
|
||||
hostname: str
|
||||
enrollment_code: str
|
||||
fingerprint: DeviceFingerprint
|
||||
client_nonce: bytes
|
||||
client_nonce_b64: str
|
||||
agent_public_key_der: bytes
|
||||
service_context: Optional[str]
|
||||
|
||||
|
||||
class EnrollmentRequestBuilder:
|
||||
"""Normalize agent enrollment JSON payloads into domain objects."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._hostname: Optional[str] = None
|
||||
self._enrollment_code: Optional[str] = None
|
||||
self._agent_pubkey_b64: Optional[str] = None
|
||||
self._client_nonce_b64: Optional[str] = None
|
||||
self._service_context: Optional[str] = None
|
||||
|
||||
def with_payload(self, payload: Optional[dict[str, object]]) -> "EnrollmentRequestBuilder":
|
||||
payload = payload or {}
|
||||
self._hostname = str(payload.get("hostname") or "").strip()
|
||||
self._enrollment_code = str(payload.get("enrollment_code") or "").strip()
|
||||
agent_pubkey = payload.get("agent_pubkey")
|
||||
self._agent_pubkey_b64 = agent_pubkey if isinstance(agent_pubkey, str) else None
|
||||
client_nonce = payload.get("client_nonce")
|
||||
self._client_nonce_b64 = client_nonce if isinstance(client_nonce, str) else None
|
||||
return self
|
||||
|
||||
def with_service_context(self, value: Optional[str]) -> "EnrollmentRequestBuilder":
|
||||
self._service_context = value
|
||||
return self
|
||||
|
||||
def build(self) -> EnrollmentRequestInput:
|
||||
if not self._hostname:
|
||||
raise EnrollmentValidationError("hostname_required")
|
||||
if not self._enrollment_code:
|
||||
raise EnrollmentValidationError("enrollment_code_required")
|
||||
if not self._agent_pubkey_b64:
|
||||
raise EnrollmentValidationError("agent_pubkey_required")
|
||||
if not self._client_nonce_b64:
|
||||
raise EnrollmentValidationError("client_nonce_required")
|
||||
|
||||
try:
|
||||
agent_pubkey_der = crypto_keys.spki_der_from_base64(self._agent_pubkey_b64)
|
||||
except Exception as exc: # pragma: no cover - invalid input path
|
||||
raise EnrollmentValidationError("invalid_agent_pubkey") from exc
|
||||
|
||||
if len(agent_pubkey_der) < 10:
|
||||
raise EnrollmentValidationError("invalid_agent_pubkey")
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(self._client_nonce_b64, validate=True)
|
||||
except Exception as exc: # pragma: no cover - invalid input path
|
||||
raise EnrollmentValidationError("invalid_client_nonce") from exc
|
||||
|
||||
if len(client_nonce_bytes) < 16:
|
||||
raise EnrollmentValidationError("invalid_client_nonce")
|
||||
|
||||
fingerprint_value = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
||||
|
||||
return EnrollmentRequestInput(
|
||||
hostname=self._hostname,
|
||||
enrollment_code=self._enrollment_code,
|
||||
fingerprint=DeviceFingerprint(fingerprint_value),
|
||||
client_nonce=client_nonce_bytes,
|
||||
client_nonce_b64=self._client_nonce_b64,
|
||||
agent_public_key_der=agent_pubkey_der,
|
||||
service_context=sanitize_service_context(self._service_context),
|
||||
)
|
||||
|
||||
|
||||
class ProofChallengeBuilder:
|
||||
"""Construct proof challenges during enrollment approval."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._server_nonce: Optional[bytes] = None
|
||||
self._client_nonce: Optional[bytes] = None
|
||||
self._fingerprint: Optional[DeviceFingerprint] = None
|
||||
|
||||
def with_server_nonce(self, nonce: Optional[bytes]) -> "ProofChallengeBuilder":
|
||||
self._server_nonce = bytes(nonce or b"")
|
||||
return self
|
||||
|
||||
def with_client_nonce(self, nonce: Optional[bytes]) -> "ProofChallengeBuilder":
|
||||
self._client_nonce = bytes(nonce or b"")
|
||||
return self
|
||||
|
||||
def with_fingerprint(self, fingerprint: Optional[str]) -> "ProofChallengeBuilder":
|
||||
if fingerprint:
|
||||
self._fingerprint = DeviceFingerprint(fingerprint)
|
||||
else:
|
||||
self._fingerprint = None
|
||||
return self
|
||||
|
||||
def build(self) -> ProofChallenge:
|
||||
if self._server_nonce is None or self._client_nonce is None:
|
||||
raise ValueError("both server and client nonces are required")
|
||||
if not self._fingerprint:
|
||||
raise ValueError("fingerprint is required")
|
||||
return ProofChallenge(
|
||||
client_nonce=self._client_nonce,
|
||||
server_nonce=self._server_nonce,
|
||||
fingerprint=self._fingerprint,
|
||||
)
|
||||
@@ -1,382 +0,0 @@
|
||||
"""Builders for Engine job manifests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from Data.Engine.repositories.sqlite.job_repository import ScheduledJobRecord
|
||||
|
||||
__all__ = [
|
||||
"JobComponentManifest",
|
||||
"JobManifest",
|
||||
"JobFabricator",
|
||||
]
|
||||
|
||||
|
||||
_ENV_VAR_PATTERN = re.compile(r"(?i)\$env:(\{)?([A-Za-z0-9_\-]+)(?(1)\})")
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class JobComponentManifest:
|
||||
"""Materialized job component ready for execution."""
|
||||
|
||||
name: str
|
||||
path: str
|
||||
script_type: str
|
||||
script_content: str
|
||||
encoded_content: str
|
||||
environment: Dict[str, str]
|
||||
literal_environment: Dict[str, str]
|
||||
timeout_seconds: int
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class JobManifest:
|
||||
job_id: int
|
||||
name: str
|
||||
occurrence_ts: int
|
||||
execution_context: str
|
||||
targets: Tuple[str, ...]
|
||||
components: Tuple[JobComponentManifest, ...]
|
||||
|
||||
|
||||
class JobFabricator:
|
||||
"""Convert stored job records into immutable manifests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
assemblies_root: Path,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._assemblies_root = assemblies_root
|
||||
self._log = logger or logging.getLogger("borealis.engine.builders.jobs")
|
||||
|
||||
def build(
|
||||
self,
|
||||
job: ScheduledJobRecord,
|
||||
*,
|
||||
occurrence_ts: int,
|
||||
) -> JobManifest:
|
||||
components = tuple(self._materialize_component(job, component) for component in job.components)
|
||||
targets = tuple(str(t) for t in job.targets)
|
||||
return JobManifest(
|
||||
job_id=job.id,
|
||||
name=job.name,
|
||||
occurrence_ts=occurrence_ts,
|
||||
execution_context=job.execution_context,
|
||||
targets=targets,
|
||||
components=tuple(c for c in components if c is not None),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Component handling
|
||||
# ------------------------------------------------------------------
|
||||
def _materialize_component(
|
||||
self,
|
||||
job: ScheduledJobRecord,
|
||||
component: Mapping[str, Any],
|
||||
) -> Optional[JobComponentManifest]:
|
||||
if not isinstance(component, Mapping):
|
||||
return None
|
||||
|
||||
component_type = str(component.get("type") or "").strip().lower()
|
||||
if component_type not in {"script", "ansible"}:
|
||||
return None
|
||||
|
||||
path = str(component.get("path") or component.get("script_path") or "").strip()
|
||||
if not path:
|
||||
return None
|
||||
|
||||
try:
|
||||
abs_path = self._resolve_script_path(path)
|
||||
except FileNotFoundError:
|
||||
self._log.warning(
|
||||
"job component path invalid", extra={"job_id": job.id, "path": path}
|
||||
)
|
||||
return None
|
||||
script_type = self._detect_script_type(abs_path, component_type)
|
||||
script_content = self._load_script_content(abs_path, component)
|
||||
|
||||
doc_variables: List[Dict[str, Any]] = []
|
||||
if isinstance(component.get("variables"), list):
|
||||
doc_variables = [v for v in component["variables"] if isinstance(v, dict)]
|
||||
overrides = self._collect_overrides(component)
|
||||
env_map, _, literal_lookup = _prepare_variable_context(doc_variables, overrides)
|
||||
|
||||
rewritten = _rewrite_powershell_script(script_content, literal_lookup)
|
||||
encoded = _encode_script_content(rewritten)
|
||||
|
||||
timeout_seconds = _coerce_int(component.get("timeout_seconds"))
|
||||
if not timeout_seconds:
|
||||
timeout_seconds = _coerce_int(component.get("timeout"))
|
||||
|
||||
return JobComponentManifest(
|
||||
name=self._component_name(abs_path, component),
|
||||
path=path,
|
||||
script_type=script_type,
|
||||
script_content=rewritten,
|
||||
encoded_content=encoded,
|
||||
environment=env_map,
|
||||
literal_environment=literal_lookup,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
def _component_name(self, abs_path: Path, component: Mapping[str, Any]) -> str:
|
||||
if isinstance(component.get("name"), str) and component["name"].strip():
|
||||
return component["name"].strip()
|
||||
return abs_path.stem
|
||||
|
||||
def _resolve_script_path(self, rel_path: str) -> Path:
|
||||
candidate = Path(rel_path.replace("\\", "/").lstrip("/"))
|
||||
if candidate.parts and candidate.parts[0] != "Scripts":
|
||||
candidate = Path("Scripts") / candidate
|
||||
abs_path = (self._assemblies_root / candidate).resolve()
|
||||
try:
|
||||
abs_path.relative_to(self._assemblies_root)
|
||||
except ValueError:
|
||||
raise FileNotFoundError(rel_path)
|
||||
if not abs_path.is_file():
|
||||
raise FileNotFoundError(rel_path)
|
||||
return abs_path
|
||||
|
||||
def _load_script_content(self, abs_path: Path, component: Mapping[str, Any]) -> str:
|
||||
if isinstance(component.get("script"), str) and component["script"].strip():
|
||||
return _decode_script_content(component["script"], component.get("encoding") or "")
|
||||
try:
|
||||
return abs_path.read_text(encoding="utf-8")
|
||||
except Exception as exc:
|
||||
self._log.warning("unable to read script for job component: path=%s error=%s", abs_path, exc)
|
||||
return ""
|
||||
|
||||
def _detect_script_type(self, abs_path: Path, declared: str) -> str:
|
||||
lower = declared.lower()
|
||||
if lower in {"script", "powershell"}:
|
||||
return "powershell"
|
||||
suffix = abs_path.suffix.lower()
|
||||
if suffix == ".ps1":
|
||||
return "powershell"
|
||||
if suffix == ".yml":
|
||||
return "ansible"
|
||||
if suffix == ".json":
|
||||
try:
|
||||
data = json.loads(abs_path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, dict):
|
||||
t = str(data.get("type") or data.get("script_type") or "").strip().lower()
|
||||
if t:
|
||||
return t
|
||||
except Exception:
|
||||
pass
|
||||
return lower or "powershell"
|
||||
|
||||
def _collect_overrides(self, component: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
overrides: Dict[str, Any] = {}
|
||||
values = component.get("variable_values")
|
||||
if isinstance(values, Mapping):
|
||||
for key, value in values.items():
|
||||
name = str(key or "").strip()
|
||||
if name:
|
||||
overrides[name] = value
|
||||
vars_inline = component.get("variables")
|
||||
if isinstance(vars_inline, Iterable):
|
||||
for var in vars_inline:
|
||||
if not isinstance(var, Mapping):
|
||||
continue
|
||||
name = str(var.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
if "value" in var:
|
||||
overrides[name] = var.get("value")
|
||||
return overrides
|
||||
|
||||
|
||||
def _coerce_int(value: Any) -> int:
|
||||
try:
|
||||
return int(value or 0)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _env_string(value: Any) -> str:
|
||||
if isinstance(value, bool):
|
||||
return "True" if value else "False"
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value)
|
||||
|
||||
|
||||
def _decode_base64_text(value: Any) -> Optional[str]:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return ""
|
||||
try:
|
||||
cleaned = re.sub(r"\s+", "", stripped)
|
||||
except Exception:
|
||||
cleaned = stripped
|
||||
try:
|
||||
decoded = base64.b64decode(cleaned, validate=True)
|
||||
except Exception:
|
||||
return None
|
||||
try:
|
||||
return decoded.decode("utf-8")
|
||||
except Exception:
|
||||
return decoded.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _decode_script_content(value: Any, encoding_hint: Any = "") -> str:
|
||||
encoding = str(encoding_hint or "").strip().lower()
|
||||
if isinstance(value, str):
|
||||
if encoding in {"base64", "b64", "base-64"}:
|
||||
decoded = _decode_base64_text(value)
|
||||
if decoded is not None:
|
||||
return decoded.replace("\r\n", "\n")
|
||||
decoded = _decode_base64_text(value)
|
||||
if decoded is not None:
|
||||
return decoded.replace("\r\n", "\n")
|
||||
return value.replace("\r\n", "\n")
|
||||
return ""
|
||||
|
||||
|
||||
def _encode_script_content(script_text: Any) -> str:
|
||||
if not isinstance(script_text, str):
|
||||
if script_text is None:
|
||||
script_text = ""
|
||||
else:
|
||||
script_text = str(script_text)
|
||||
normalized = script_text.replace("\r\n", "\n")
|
||||
if not normalized:
|
||||
return ""
|
||||
encoded = base64.b64encode(normalized.encode("utf-8"))
|
||||
return encoded.decode("ascii")
|
||||
|
||||
|
||||
def _canonical_env_key(name: Any) -> str:
|
||||
try:
|
||||
return re.sub(r"[^A-Za-z0-9_]", "_", str(name or "").strip()).upper()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _expand_env_aliases(env_map: Dict[str, str], variables: Iterable[Mapping[str, Any]]) -> Dict[str, str]:
|
||||
expanded = dict(env_map or {})
|
||||
for var in variables:
|
||||
if not isinstance(var, Mapping):
|
||||
continue
|
||||
name = str(var.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
canonical = _canonical_env_key(name)
|
||||
if not canonical or canonical not in expanded:
|
||||
continue
|
||||
value = expanded[canonical]
|
||||
alias = re.sub(r"[^A-Za-z0-9_]", "_", name)
|
||||
if alias and alias not in expanded:
|
||||
expanded[alias] = value
|
||||
if alias != name and re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", name) and name not in expanded:
|
||||
expanded[name] = value
|
||||
return expanded
|
||||
|
||||
|
||||
def _powershell_literal(value: Any, var_type: str) -> str:
|
||||
typ = str(var_type or "string").lower()
|
||||
if typ == "boolean":
|
||||
if isinstance(value, bool):
|
||||
truthy = value
|
||||
elif value is None:
|
||||
truthy = False
|
||||
elif isinstance(value, (int, float)):
|
||||
truthy = value != 0
|
||||
else:
|
||||
s = str(value).strip().lower()
|
||||
if s in {"true", "1", "yes", "y", "on"}:
|
||||
truthy = True
|
||||
elif s in {"false", "0", "no", "n", "off", ""}:
|
||||
truthy = False
|
||||
else:
|
||||
truthy = bool(s)
|
||||
return "$true" if truthy else "$false"
|
||||
if typ == "number":
|
||||
if value is None or value == "":
|
||||
return "0"
|
||||
return str(value)
|
||||
s = "" if value is None else str(value)
|
||||
return "'" + s.replace("'", "''") + "'"
|
||||
|
||||
|
||||
def _extract_variable_default(var: Mapping[str, Any]) -> Any:
|
||||
for key in ("value", "default", "defaultValue", "default_value"):
|
||||
if key in var:
|
||||
val = var.get(key)
|
||||
return "" if val is None else val
|
||||
return ""
|
||||
|
||||
|
||||
def _prepare_variable_context(
|
||||
doc_variables: Iterable[Mapping[str, Any]],
|
||||
overrides: Mapping[str, Any],
|
||||
) -> Tuple[Dict[str, str], List[Dict[str, Any]], Dict[str, str]]:
|
||||
env_map: Dict[str, str] = {}
|
||||
variables: List[Dict[str, Any]] = []
|
||||
literal_lookup: Dict[str, str] = {}
|
||||
doc_names: Dict[str, bool] = {}
|
||||
|
||||
overrides = dict(overrides or {})
|
||||
|
||||
for var in doc_variables:
|
||||
if not isinstance(var, Mapping):
|
||||
continue
|
||||
name = str(var.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
doc_names[name] = True
|
||||
canonical = _canonical_env_key(name)
|
||||
var_type = str(var.get("type") or "string").lower()
|
||||
default_val = _extract_variable_default(var)
|
||||
final_val = overrides[name] if name in overrides else default_val
|
||||
if canonical:
|
||||
env_map[canonical] = _env_string(final_val)
|
||||
literal_lookup[canonical] = _powershell_literal(final_val, var_type)
|
||||
if name in overrides:
|
||||
new_var = dict(var)
|
||||
new_var["value"] = overrides[name]
|
||||
variables.append(new_var)
|
||||
else:
|
||||
variables.append(dict(var))
|
||||
|
||||
for name, val in overrides.items():
|
||||
if name in doc_names:
|
||||
continue
|
||||
canonical = _canonical_env_key(name)
|
||||
if canonical:
|
||||
env_map[canonical] = _env_string(val)
|
||||
literal_lookup[canonical] = _powershell_literal(val, "string")
|
||||
variables.append({"name": name, "value": val, "type": "string"})
|
||||
|
||||
env_map = _expand_env_aliases(env_map, variables)
|
||||
return env_map, variables, literal_lookup
|
||||
|
||||
|
||||
def _rewrite_powershell_script(content: str, literal_lookup: Mapping[str, str]) -> str:
|
||||
if not content or not literal_lookup:
|
||||
return content
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
name = match.group(2)
|
||||
canonical = _canonical_env_key(name)
|
||||
if not canonical:
|
||||
return match.group(0)
|
||||
literal = literal_lookup.get(canonical)
|
||||
if literal is None:
|
||||
return match.group(0)
|
||||
return literal
|
||||
|
||||
return _ENV_VAR_PATTERN.sub(_replace, content)
|
||||
@@ -1,72 +0,0 @@
|
||||
"""Builders for operator authentication payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OperatorLoginRequest:
|
||||
"""Normalized operator login credentials."""
|
||||
|
||||
username: str
|
||||
password_sha512: str
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OperatorMFAVerificationRequest:
|
||||
"""Normalized MFA verification payload."""
|
||||
|
||||
pending_token: str
|
||||
code: str
|
||||
|
||||
|
||||
def _sha512_hex(raw: str) -> str:
|
||||
digest = hashlib.sha512()
|
||||
digest.update(raw.encode("utf-8"))
|
||||
return digest.hexdigest()
|
||||
|
||||
|
||||
def build_login_request(payload: Mapping[str, object]) -> OperatorLoginRequest:
|
||||
"""Validate and normalize the login *payload*."""
|
||||
|
||||
username = str(payload.get("username") or "").strip()
|
||||
password_sha512 = str(payload.get("password_sha512") or "").strip().lower()
|
||||
password = payload.get("password")
|
||||
|
||||
if not username:
|
||||
raise ValueError("username is required")
|
||||
|
||||
if password_sha512:
|
||||
normalized_hash = password_sha512
|
||||
else:
|
||||
if not isinstance(password, str) or not password:
|
||||
raise ValueError("password is required")
|
||||
normalized_hash = _sha512_hex(password)
|
||||
|
||||
return OperatorLoginRequest(username=username, password_sha512=normalized_hash)
|
||||
|
||||
|
||||
def build_mfa_request(payload: Mapping[str, object]) -> OperatorMFAVerificationRequest:
|
||||
"""Validate and normalize the MFA verification *payload*."""
|
||||
|
||||
pending_token = str(payload.get("pending_token") or "").strip()
|
||||
raw_code = str(payload.get("code") or "").strip()
|
||||
digits = "".join(ch for ch in raw_code if ch.isdigit())
|
||||
|
||||
if not pending_token:
|
||||
raise ValueError("pending_token is required")
|
||||
if len(digits) < 6:
|
||||
raise ValueError("code must contain 6 digits")
|
||||
|
||||
return OperatorMFAVerificationRequest(pending_token=pending_token, code=digits)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OperatorLoginRequest",
|
||||
"OperatorMFAVerificationRequest",
|
||||
"build_login_request",
|
||||
"build_mfa_request",
|
||||
]
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Configuration primitives for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .environment import (
|
||||
DatabaseSettings,
|
||||
EngineSettings,
|
||||
FlaskSettings,
|
||||
GitHubSettings,
|
||||
ServerSettings,
|
||||
SocketIOSettings,
|
||||
load_environment,
|
||||
)
|
||||
from .logging import configure_logging
|
||||
|
||||
__all__ = [
|
||||
"DatabaseSettings",
|
||||
"EngineSettings",
|
||||
"FlaskSettings",
|
||||
"GitHubSettings",
|
||||
"load_environment",
|
||||
"ServerSettings",
|
||||
"SocketIOSettings",
|
||||
"configure_logging",
|
||||
]
|
||||
@@ -1,218 +0,0 @@
|
||||
"""Environment detection for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DatabaseSettings:
|
||||
"""SQLite database configuration for the Engine."""
|
||||
|
||||
path: Path
|
||||
apply_migrations: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FlaskSettings:
|
||||
"""Parameters that influence Flask application behavior."""
|
||||
|
||||
secret_key: str
|
||||
static_root: Path
|
||||
cors_allowed_origins: Tuple[str, ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SocketIOSettings:
|
||||
"""Configuration for the optional Socket.IO server."""
|
||||
|
||||
cors_allowed_origins: Tuple[str, ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ServerSettings:
|
||||
"""HTTP server binding configuration."""
|
||||
|
||||
host: str
|
||||
port: int
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GitHubSettings:
|
||||
"""Configuration surface for GitHub repository interactions."""
|
||||
|
||||
default_repo: str
|
||||
default_branch: str
|
||||
refresh_interval_seconds: int
|
||||
cache_root: Path
|
||||
|
||||
@property
|
||||
def cache_file(self) -> Path:
|
||||
"""Location of the persisted repository-head cache."""
|
||||
|
||||
return self.cache_root / "repo_hash_cache.json"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EngineSettings:
|
||||
"""Immutable container describing the Engine runtime configuration."""
|
||||
|
||||
project_root: Path
|
||||
debug: bool
|
||||
database: DatabaseSettings
|
||||
flask: FlaskSettings
|
||||
socketio: SocketIOSettings
|
||||
server: ServerSettings
|
||||
github: GitHubSettings
|
||||
|
||||
@property
|
||||
def logs_root(self) -> Path:
|
||||
"""Return the directory where Engine-specific logs should live."""
|
||||
|
||||
return self.project_root / "Logs" / "Server"
|
||||
|
||||
@property
|
||||
def database_path(self) -> Path:
|
||||
"""Convenience accessor for the database file path."""
|
||||
|
||||
return self.database.path
|
||||
|
||||
@property
|
||||
def apply_migrations(self) -> bool:
|
||||
"""Return whether schema migrations should run at bootstrap."""
|
||||
|
||||
return self.database.apply_migrations
|
||||
|
||||
|
||||
def _resolve_project_root() -> Path:
|
||||
candidate = os.getenv("BOREALIS_ROOT")
|
||||
if candidate:
|
||||
return Path(candidate).expanduser().resolve()
|
||||
# ``environment.py`` lives under ``Data/Engine/config``. The project
|
||||
# root is three levels above this module (the repository checkout). The
|
||||
# previous implementation only walked up two levels which incorrectly
|
||||
# treated ``Data/`` as the root, breaking all filesystem discovery logic
|
||||
# that expects peers such as ``Data/Server`` to be available.
|
||||
return Path(__file__).resolve().parents[3]
|
||||
|
||||
|
||||
def _resolve_database_path(project_root: Path) -> Path:
|
||||
candidate = os.getenv("BOREALIS_DATABASE_PATH")
|
||||
if candidate:
|
||||
return Path(candidate).expanduser().resolve()
|
||||
return (project_root / "database.db").resolve()
|
||||
|
||||
|
||||
def _should_apply_migrations() -> bool:
|
||||
raw = os.getenv("BOREALIS_ENGINE_AUTO_MIGRATE", "true")
|
||||
return raw.lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _resolve_static_root(project_root: Path) -> Path:
|
||||
candidate = os.getenv("BOREALIS_STATIC_ROOT")
|
||||
if candidate:
|
||||
return Path(candidate).expanduser().resolve()
|
||||
|
||||
candidates = (
|
||||
project_root / "Engine" / "web-interface" / "build",
|
||||
project_root / "Engine" / "web-interface" / "dist",
|
||||
project_root / "Engine" / "web-interface",
|
||||
project_root / "Data" / "Engine" / "web-interface" / "build",
|
||||
project_root / "Data" / "Engine" / "web-interface",
|
||||
project_root / "Server" / "web-interface" / "build",
|
||||
project_root / "Server" / "web-interface",
|
||||
project_root / "Data" / "Server" / "WebUI" / "build",
|
||||
project_root / "Data" / "Server" / "WebUI",
|
||||
project_root / "Data" / "Server" / "web-interface" / "build",
|
||||
project_root / "Data" / "Server" / "web-interface",
|
||||
project_root / "Data" / "WebUI" / "build",
|
||||
project_root / "Data" / "WebUI",
|
||||
)
|
||||
for path in candidates:
|
||||
resolved = path.resolve()
|
||||
if resolved.is_dir():
|
||||
return resolved
|
||||
|
||||
# Fall back to the first candidate even if it does not yet exist so the
|
||||
# Flask factory still initialises; individual requests will surface 404s
|
||||
# until an asset build is available, matching the legacy behaviour.
|
||||
return candidates[0].resolve()
|
||||
|
||||
|
||||
def _resolve_github_cache_root(project_root: Path) -> Path:
|
||||
candidate = os.getenv("BOREALIS_CACHE_DIR")
|
||||
if candidate:
|
||||
return Path(candidate).expanduser().resolve()
|
||||
return (project_root / "Data" / "Engine" / "cache").resolve()
|
||||
|
||||
|
||||
def _parse_refresh_interval(raw: str | None) -> int:
|
||||
if not raw:
|
||||
return 60
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError:
|
||||
value = 60
|
||||
return max(30, min(value, 3600))
|
||||
|
||||
|
||||
def _parse_origins(raw: str | None) -> Tuple[str, ...]:
|
||||
if not raw:
|
||||
return ("*",)
|
||||
parts: Iterable[str] = (segment.strip() for segment in raw.split(","))
|
||||
filtered = tuple(part for part in parts if part)
|
||||
return filtered or ("*",)
|
||||
|
||||
|
||||
def load_environment() -> EngineSettings:
|
||||
"""Load Engine settings from environment variables and filesystem hints."""
|
||||
|
||||
project_root = _resolve_project_root()
|
||||
database = DatabaseSettings(
|
||||
path=_resolve_database_path(project_root),
|
||||
apply_migrations=_should_apply_migrations(),
|
||||
)
|
||||
cors_allowed_origins = _parse_origins(os.getenv("BOREALIS_CORS_ALLOWED_ORIGINS"))
|
||||
flask_settings = FlaskSettings(
|
||||
secret_key=os.getenv("BOREALIS_FLASK_SECRET_KEY", "change-me"),
|
||||
static_root=_resolve_static_root(project_root),
|
||||
cors_allowed_origins=cors_allowed_origins,
|
||||
)
|
||||
socket_settings = SocketIOSettings(cors_allowed_origins=cors_allowed_origins)
|
||||
debug = os.getenv("BOREALIS_DEBUG", "false").lower() in {"1", "true", "yes", "on"}
|
||||
host = os.getenv("BOREALIS_HOST", "127.0.0.1")
|
||||
try:
|
||||
port = int(os.getenv("BOREALIS_PORT", "5000"))
|
||||
except ValueError:
|
||||
port = 5000
|
||||
server_settings = ServerSettings(host=host, port=port)
|
||||
github_settings = GitHubSettings(
|
||||
default_repo=os.getenv("BOREALIS_REPO", "bunny-lab-io/Borealis"),
|
||||
default_branch=os.getenv("BOREALIS_REPO_BRANCH", "main"),
|
||||
refresh_interval_seconds=_parse_refresh_interval(os.getenv("BOREALIS_REPO_HASH_REFRESH")),
|
||||
cache_root=_resolve_github_cache_root(project_root),
|
||||
)
|
||||
|
||||
return EngineSettings(
|
||||
project_root=project_root,
|
||||
debug=debug,
|
||||
database=database,
|
||||
flask=flask_settings,
|
||||
socketio=socket_settings,
|
||||
server=server_settings,
|
||||
github=github_settings,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DatabaseSettings",
|
||||
"EngineSettings",
|
||||
"FlaskSettings",
|
||||
"GitHubSettings",
|
||||
"SocketIOSettings",
|
||||
"ServerSettings",
|
||||
"load_environment",
|
||||
]
|
||||
@@ -1,71 +0,0 @@
|
||||
"""Logging bootstrap helpers for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
from .environment import EngineSettings
|
||||
|
||||
|
||||
_ENGINE_LOGGER_NAME = "borealis.engine"
|
||||
_SERVICE_NAME = "engine"
|
||||
_DEFAULT_FORMAT = "%(asctime)s-" + _SERVICE_NAME + "-%(message)s"
|
||||
|
||||
|
||||
def _handler_already_attached(logger: logging.Logger, log_path: Path) -> bool:
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, TimedRotatingFileHandler):
|
||||
handler_path = Path(getattr(handler, "baseFilename", ""))
|
||||
if handler_path == log_path:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_handler(log_path: Path) -> TimedRotatingFileHandler:
|
||||
handler = TimedRotatingFileHandler(
|
||||
log_path,
|
||||
when="midnight",
|
||||
backupCount=30,
|
||||
encoding="utf-8",
|
||||
)
|
||||
handler.setLevel(logging.INFO)
|
||||
handler.setFormatter(logging.Formatter(_DEFAULT_FORMAT))
|
||||
return handler
|
||||
|
||||
|
||||
def configure_logging(settings: EngineSettings) -> logging.Logger:
|
||||
"""Configure a rotating log handler for the Engine."""
|
||||
|
||||
logs_root = settings.logs_root
|
||||
logs_root.mkdir(parents=True, exist_ok=True)
|
||||
log_path = logs_root / "engine.log"
|
||||
|
||||
logger = logging.getLogger(_ENGINE_LOGGER_NAME)
|
||||
logger.setLevel(logging.INFO if not settings.debug else logging.DEBUG)
|
||||
|
||||
if not _handler_already_attached(logger, log_path):
|
||||
handler = _build_handler(log_path)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
|
||||
# Also ensure the root logger follows suit so third-party modules inherit the handler.
|
||||
root_logger = logging.getLogger()
|
||||
if not _handler_already_attached(root_logger, log_path):
|
||||
handler = _build_handler(log_path)
|
||||
root_logger.addHandler(handler)
|
||||
if root_logger.level == logging.WARNING:
|
||||
# Default level is WARNING; lower it to INFO so our handler captures application messages.
|
||||
root_logger.setLevel(logging.INFO if not settings.debug else logging.DEBUG)
|
||||
|
||||
# Quieten overly chatty frameworks unless debugging is explicitly requested.
|
||||
if not settings.debug:
|
||||
logging.getLogger("werkzeug").setLevel(logging.WARNING)
|
||||
logging.getLogger("engineio").setLevel(logging.WARNING)
|
||||
logging.getLogger("socketio").setLevel(logging.WARNING)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
__all__ = ["configure_logging"]
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Pure value objects and enums for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .device_auth import ( # noqa: F401
|
||||
AccessTokenClaims,
|
||||
DeviceAuthContext,
|
||||
DeviceAuthErrorCode,
|
||||
DeviceAuthFailure,
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
DeviceIdentity,
|
||||
DeviceStatus,
|
||||
sanitize_service_context,
|
||||
)
|
||||
from .device_enrollment import ( # noqa: F401
|
||||
EnrollmentApproval,
|
||||
EnrollmentApprovalStatus,
|
||||
EnrollmentCode,
|
||||
EnrollmentRequest,
|
||||
ProofChallenge,
|
||||
)
|
||||
from .github import ( # noqa: F401
|
||||
GitHubRateLimit,
|
||||
GitHubRepoRef,
|
||||
GitHubTokenStatus,
|
||||
RepoHeadSnapshot,
|
||||
)
|
||||
from .operator import ( # noqa: F401
|
||||
OperatorAccount,
|
||||
OperatorLoginSuccess,
|
||||
OperatorMFAChallenge,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AccessTokenClaims",
|
||||
"DeviceAuthContext",
|
||||
"DeviceAuthErrorCode",
|
||||
"DeviceAuthFailure",
|
||||
"DeviceFingerprint",
|
||||
"DeviceGuid",
|
||||
"DeviceIdentity",
|
||||
"DeviceStatus",
|
||||
"EnrollmentApproval",
|
||||
"EnrollmentApprovalStatus",
|
||||
"EnrollmentCode",
|
||||
"EnrollmentRequest",
|
||||
"ProofChallenge",
|
||||
"GitHubRateLimit",
|
||||
"GitHubRepoRef",
|
||||
"GitHubTokenStatus",
|
||||
"RepoHeadSnapshot",
|
||||
"OperatorAccount",
|
||||
"OperatorLoginSuccess",
|
||||
"OperatorMFAChallenge",
|
||||
"sanitize_service_context",
|
||||
]
|
||||
@@ -1,249 +0,0 @@
|
||||
"""Domain primitives for device authentication and token validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Mapping, Optional
|
||||
import string
|
||||
import uuid
|
||||
|
||||
__all__ = [
|
||||
"DeviceGuid",
|
||||
"DeviceFingerprint",
|
||||
"DeviceIdentity",
|
||||
"DeviceStatus",
|
||||
"DeviceAuthErrorCode",
|
||||
"DeviceAuthFailure",
|
||||
"AccessTokenClaims",
|
||||
"DeviceAuthContext",
|
||||
"sanitize_service_context",
|
||||
"normalize_guid",
|
||||
]
|
||||
|
||||
|
||||
def _normalize_guid(value: Optional[str]) -> str:
|
||||
"""Return a canonical GUID string or an empty string."""
|
||||
candidate = (value or "").strip()
|
||||
if not candidate:
|
||||
return ""
|
||||
candidate = candidate.strip("{}")
|
||||
try:
|
||||
return str(uuid.UUID(candidate)).upper()
|
||||
except Exception:
|
||||
cleaned = "".join(
|
||||
ch for ch in candidate if ch in string.hexdigits or ch == "-"
|
||||
).strip("-")
|
||||
if cleaned:
|
||||
try:
|
||||
return str(uuid.UUID(cleaned)).upper()
|
||||
except Exception:
|
||||
pass
|
||||
return candidate.upper()
|
||||
|
||||
|
||||
def _normalize_fingerprint(value: Optional[str]) -> str:
|
||||
return (value or "").strip().lower()
|
||||
|
||||
|
||||
def sanitize_service_context(value: Optional[str]) -> Optional[str]:
|
||||
"""Normalize the optional agent service context header value."""
|
||||
if not value:
|
||||
return None
|
||||
cleaned = "".join(
|
||||
ch for ch in str(value) if ch.isalnum() or ch in ("_", "-")
|
||||
)
|
||||
if not cleaned:
|
||||
return None
|
||||
return cleaned.upper()
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceGuid:
|
||||
"""Canonical GUID wrapper that enforces Borealis normalization."""
|
||||
|
||||
value: str
|
||||
|
||||
def __post_init__(self) -> None: # pragma: no cover - simple data normalization
|
||||
normalized = _normalize_guid(self.value)
|
||||
if not normalized:
|
||||
raise ValueError("device GUID is required")
|
||||
object.__setattr__(self, "value", normalized)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
def normalize_guid(value: Optional[str]) -> str:
|
||||
"""Expose GUID normalization for administrative helpers."""
|
||||
|
||||
return _normalize_guid(value)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceFingerprint:
|
||||
"""Normalized TLS key fingerprint associated with a device."""
|
||||
|
||||
value: str
|
||||
|
||||
def __post_init__(self) -> None: # pragma: no cover - simple data normalization
|
||||
normalized = _normalize_fingerprint(self.value)
|
||||
if not normalized:
|
||||
raise ValueError("device fingerprint is required")
|
||||
object.__setattr__(self, "value", normalized)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceIdentity:
|
||||
"""Immutable pairing of device GUID and TLS key fingerprint."""
|
||||
|
||||
guid: DeviceGuid
|
||||
fingerprint: DeviceFingerprint
|
||||
|
||||
|
||||
class DeviceStatus(str, Enum):
|
||||
"""Lifecycle markers mirrored from the legacy devices table."""
|
||||
|
||||
ACTIVE = "active"
|
||||
QUARANTINED = "quarantined"
|
||||
REVOKED = "revoked"
|
||||
DECOMMISSIONED = "decommissioned"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: Optional[str]) -> "DeviceStatus":
|
||||
normalized = (value or "active").strip().lower()
|
||||
try:
|
||||
return cls(normalized)
|
||||
except ValueError:
|
||||
return cls.ACTIVE
|
||||
|
||||
@property
|
||||
def allows_access(self) -> bool:
|
||||
return self in {self.ACTIVE, self.QUARANTINED}
|
||||
|
||||
|
||||
class DeviceAuthErrorCode(str, Enum):
|
||||
"""Well-known authentication failure categories."""
|
||||
|
||||
MISSING_AUTHORIZATION = "missing_authorization"
|
||||
TOKEN_EXPIRED = "token_expired"
|
||||
INVALID_TOKEN = "invalid_token"
|
||||
INVALID_CLAIMS = "invalid_claims"
|
||||
RATE_LIMITED = "rate_limited"
|
||||
DEVICE_NOT_FOUND = "device_not_found"
|
||||
DEVICE_GUID_MISMATCH = "device_guid_mismatch"
|
||||
FINGERPRINT_MISMATCH = "fingerprint_mismatch"
|
||||
TOKEN_VERSION_REVOKED = "token_version_revoked"
|
||||
DEVICE_REVOKED = "device_revoked"
|
||||
DPOP_NOT_SUPPORTED = "dpop_not_supported"
|
||||
DPOP_REPLAYED = "dpop_replayed"
|
||||
DPOP_INVALID = "dpop_invalid"
|
||||
|
||||
|
||||
class DeviceAuthFailure(Exception):
|
||||
"""Domain-level authentication error with HTTP metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code: DeviceAuthErrorCode,
|
||||
*,
|
||||
http_status: int = 401,
|
||||
retry_after: Optional[float] = None,
|
||||
detail: Optional[str] = None,
|
||||
) -> None:
|
||||
self.code = code
|
||||
self.http_status = int(http_status)
|
||||
self.retry_after = retry_after
|
||||
self.detail = detail or code.value
|
||||
super().__init__(self.detail)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"error": self.code.value}
|
||||
if self.retry_after is not None:
|
||||
payload["retry_after"] = float(self.retry_after)
|
||||
if self.detail and self.detail != self.code.value:
|
||||
payload["detail"] = self.detail
|
||||
return payload
|
||||
|
||||
|
||||
def _coerce_int(value: Any, *, minimum: Optional[int] = None) -> int:
|
||||
try:
|
||||
result = int(value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError("expected integer value") from None
|
||||
if minimum is not None and result < minimum:
|
||||
raise ValueError("integer below minimum")
|
||||
return result
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AccessTokenClaims:
|
||||
"""Validated subset of JWT claims issued to a device."""
|
||||
|
||||
subject: str
|
||||
guid: DeviceGuid
|
||||
fingerprint: DeviceFingerprint
|
||||
token_version: int
|
||||
issued_at: int
|
||||
not_before: int
|
||||
expires_at: int
|
||||
raw: Mapping[str, Any]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.token_version <= 0:
|
||||
raise ValueError("token_version must be positive")
|
||||
if self.issued_at <= 0 or self.not_before <= 0 or self.expires_at <= 0:
|
||||
raise ValueError("temporal claims must be positive integers")
|
||||
if self.expires_at <= self.not_before:
|
||||
raise ValueError("token expiration must be after not-before")
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, claims: Mapping[str, Any]) -> "AccessTokenClaims":
|
||||
subject = str(claims.get("sub") or "").strip()
|
||||
if not subject:
|
||||
raise ValueError("missing token subject")
|
||||
guid = DeviceGuid(str(claims.get("guid") or ""))
|
||||
fingerprint = DeviceFingerprint(claims.get("ssl_key_fingerprint"))
|
||||
token_version = _coerce_int(claims.get("token_version"), minimum=1)
|
||||
issued_at = _coerce_int(claims.get("iat"), minimum=1)
|
||||
not_before = _coerce_int(claims.get("nbf"), minimum=1)
|
||||
expires_at = _coerce_int(claims.get("exp"), minimum=1)
|
||||
return cls(
|
||||
subject=subject,
|
||||
guid=guid,
|
||||
fingerprint=fingerprint,
|
||||
token_version=token_version,
|
||||
issued_at=issued_at,
|
||||
not_before=not_before,
|
||||
expires_at=expires_at,
|
||||
raw=dict(claims),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceAuthContext:
|
||||
"""Domain result emitted after successful authentication."""
|
||||
|
||||
identity: DeviceIdentity
|
||||
access_token: str
|
||||
claims: AccessTokenClaims
|
||||
status: DeviceStatus
|
||||
service_context: Optional[str]
|
||||
dpop_jkt: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.access_token:
|
||||
raise ValueError("access token is required")
|
||||
service = sanitize_service_context(self.service_context)
|
||||
object.__setattr__(self, "service_context", service)
|
||||
|
||||
@property
|
||||
def is_quarantined(self) -> bool:
|
||||
return self.status is DeviceStatus.QUARANTINED
|
||||
|
||||
@property
|
||||
def allows_access(self) -> bool:
|
||||
return self.status.allows_access
|
||||
@@ -1,261 +0,0 @@
|
||||
"""Domain types describing device enrollment flows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from .device_auth import DeviceFingerprint, DeviceGuid, sanitize_service_context
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentCode",
|
||||
"EnrollmentApprovalStatus",
|
||||
"EnrollmentApproval",
|
||||
"EnrollmentRequest",
|
||||
"EnrollmentValidationError",
|
||||
"ProofChallenge",
|
||||
]
|
||||
|
||||
|
||||
def _parse_iso8601(value: Optional[str]) -> Optional[datetime]:
|
||||
if not value:
|
||||
return None
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
dt = datetime.fromisoformat(raw)
|
||||
except Exception as exc: # pragma: no cover - error path
|
||||
raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
def _require(value: Optional[str], field: str) -> str:
|
||||
text = (value or "").strip()
|
||||
if not text:
|
||||
raise ValueError(f"{field} is required")
|
||||
return text
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentValidationError(Exception):
|
||||
"""Raised when enrollment input fails validation."""
|
||||
|
||||
code: str
|
||||
http_status: int = 400
|
||||
retry_after: Optional[float] = None
|
||||
|
||||
def to_response(self) -> dict[str, object]:
|
||||
payload: dict[str, object] = {"error": self.code}
|
||||
if self.retry_after is not None:
|
||||
payload["retry_after"] = self.retry_after
|
||||
return payload
|
||||
|
||||
def __str__(self) -> str: # pragma: no cover - debug helper
|
||||
return f"{self.code} (status={self.http_status})"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentCode:
|
||||
"""Installer code metadata loaded from the persistence layer."""
|
||||
|
||||
code: str
|
||||
expires_at: datetime
|
||||
max_uses: int
|
||||
use_count: int
|
||||
used_by_guid: Optional[DeviceGuid]
|
||||
last_used_at: Optional[datetime]
|
||||
used_at: Optional[datetime]
|
||||
record_id: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.code:
|
||||
raise ValueError("code is required")
|
||||
if self.max_uses < 1:
|
||||
raise ValueError("max_uses must be >= 1")
|
||||
if self.use_count < 0:
|
||||
raise ValueError("use_count cannot be negative")
|
||||
if self.use_count > self.max_uses:
|
||||
raise ValueError("use_count cannot exceed max_uses")
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, record: Mapping[str, Any]) -> "EnrollmentCode":
|
||||
used_by = record.get("used_by_guid")
|
||||
used_by_guid = DeviceGuid(used_by) if used_by else None
|
||||
return cls(
|
||||
code=_require(record.get("code"), "code"),
|
||||
expires_at=_parse_iso8601(record.get("expires_at")) or datetime.now(tz=timezone.utc),
|
||||
max_uses=int(record.get("max_uses") or 1),
|
||||
use_count=int(record.get("use_count") or 0),
|
||||
used_by_guid=used_by_guid,
|
||||
last_used_at=_parse_iso8601(record.get("last_used_at")),
|
||||
used_at=_parse_iso8601(record.get("used_at")),
|
||||
record_id=str(record.get("id") or "") or None,
|
||||
)
|
||||
|
||||
@property
|
||||
def remaining_uses(self) -> int:
|
||||
return max(self.max_uses - self.use_count, 0)
|
||||
|
||||
@property
|
||||
def is_exhausted(self) -> bool:
|
||||
return self.remaining_uses == 0
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return self.expires_at <= datetime.now(tz=timezone.utc)
|
||||
|
||||
@property
|
||||
def identifier(self) -> Optional[str]:
|
||||
return self.record_id
|
||||
|
||||
|
||||
class EnrollmentApprovalStatus(str, Enum):
|
||||
"""Possible states for a device approval entry."""
|
||||
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
COMPLETED = "completed"
|
||||
EXPIRED = "expired"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: Optional[str]) -> "EnrollmentApprovalStatus":
|
||||
normalized = (value or "pending").strip().lower()
|
||||
try:
|
||||
return cls(normalized)
|
||||
except ValueError:
|
||||
return cls.PENDING
|
||||
|
||||
@property
|
||||
def is_terminal(self) -> bool:
|
||||
return self in {self.APPROVED, self.DENIED, self.COMPLETED, self.EXPIRED}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProofChallenge:
|
||||
"""Client/server nonce pair distributed during enrollment."""
|
||||
|
||||
client_nonce: bytes
|
||||
server_nonce: bytes
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.client_nonce or not self.server_nonce:
|
||||
raise ValueError("nonce payloads must be non-empty")
|
||||
|
||||
@classmethod
|
||||
def from_base64(cls, *, client: bytes, server: bytes) -> "ProofChallenge":
|
||||
return cls(client_nonce=client, server_nonce=server)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentRequest:
|
||||
"""Validated payload submitted by an agent during enrollment."""
|
||||
|
||||
hostname: str
|
||||
enrollment_code: str
|
||||
fingerprint: DeviceFingerprint
|
||||
proof: ProofChallenge
|
||||
service_context: Optional[str]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.hostname:
|
||||
raise ValueError("hostname is required")
|
||||
if not self.enrollment_code:
|
||||
raise ValueError("enrollment code is required")
|
||||
object.__setattr__(
|
||||
self,
|
||||
"service_context",
|
||||
sanitize_service_context(self.service_context),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_payload(
|
||||
cls,
|
||||
*,
|
||||
hostname: str,
|
||||
enrollment_code: str,
|
||||
fingerprint: str,
|
||||
client_nonce: bytes,
|
||||
server_nonce: bytes,
|
||||
service_context: Optional[str] = None,
|
||||
) -> "EnrollmentRequest":
|
||||
proof = ProofChallenge(client_nonce=client_nonce, server_nonce=server_nonce)
|
||||
return cls(
|
||||
hostname=_require(hostname, "hostname"),
|
||||
enrollment_code=_require(enrollment_code, "enrollment_code"),
|
||||
fingerprint=DeviceFingerprint(fingerprint),
|
||||
proof=proof,
|
||||
service_context=service_context,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentApproval:
|
||||
"""Pending or resolved approval tracked by operators."""
|
||||
|
||||
record_id: str
|
||||
reference: str
|
||||
status: EnrollmentApprovalStatus
|
||||
claimed_hostname: str
|
||||
claimed_fingerprint: DeviceFingerprint
|
||||
enrollment_code_id: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
client_nonce_b64: str
|
||||
server_nonce_b64: str
|
||||
agent_pubkey_der: bytes
|
||||
guid: Optional[DeviceGuid] = None
|
||||
approved_by: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.record_id:
|
||||
raise ValueError("record identifier is required")
|
||||
if not self.reference:
|
||||
raise ValueError("approval reference is required")
|
||||
if not self.claimed_hostname:
|
||||
raise ValueError("claimed hostname is required")
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, record: Mapping[str, Any]) -> "EnrollmentApproval":
|
||||
guid_raw = record.get("guid")
|
||||
approved_raw = record.get("approved_by_user_id")
|
||||
return cls(
|
||||
record_id=_require(record.get("id"), "id"),
|
||||
reference=_require(record.get("approval_reference"), "approval_reference"),
|
||||
status=EnrollmentApprovalStatus.from_string(record.get("status")),
|
||||
claimed_hostname=_require(record.get("hostname_claimed"), "hostname_claimed"),
|
||||
claimed_fingerprint=DeviceFingerprint(record.get("ssl_key_fingerprint_claimed")),
|
||||
enrollment_code_id=record.get("enrollment_code_id"),
|
||||
created_at=_parse_iso8601(record.get("created_at")) or datetime.now(tz=timezone.utc),
|
||||
updated_at=_parse_iso8601(record.get("updated_at")) or datetime.now(tz=timezone.utc),
|
||||
guid=DeviceGuid(guid_raw) if guid_raw else None,
|
||||
approved_by=(approved_raw or None),
|
||||
client_nonce_b64=_require(record.get("client_nonce"), "client_nonce"),
|
||||
server_nonce_b64=_require(record.get("server_nonce"), "server_nonce"),
|
||||
agent_pubkey_der=bytes(record.get("agent_pubkey_der") or b""),
|
||||
)
|
||||
|
||||
@property
|
||||
def is_pending(self) -> bool:
|
||||
return self.status is EnrollmentApprovalStatus.PENDING
|
||||
|
||||
@property
|
||||
def is_completed(self) -> bool:
|
||||
return self.status in {
|
||||
EnrollmentApprovalStatus.APPROVED,
|
||||
EnrollmentApprovalStatus.COMPLETED,
|
||||
}
|
||||
|
||||
@property
|
||||
def client_nonce_bytes(self) -> bytes:
|
||||
return base64.b64decode(self.client_nonce_b64.encode("ascii"), validate=True)
|
||||
|
||||
@property
|
||||
def server_nonce_bytes(self) -> bytes:
|
||||
return base64.b64decode(self.server_nonce_b64.encode("ascii"), validate=True)
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Domain objects for saved device list views."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
__all__ = ["DeviceListView"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceListView:
|
||||
id: int
|
||||
name: str
|
||||
columns: List[str]
|
||||
filters: Dict[str, object]
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"columns": self.columns,
|
||||
"filters": self.filters,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
@@ -1,323 +0,0 @@
|
||||
"""Device domain helpers mirroring the legacy server payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
||||
|
||||
from Data.Engine.domain.device_auth import normalize_guid
|
||||
|
||||
__all__ = [
|
||||
"DEVICE_TABLE_COLUMNS",
|
||||
"DEVICE_TABLE",
|
||||
"DeviceSnapshot",
|
||||
"assemble_device_snapshot",
|
||||
"row_to_device_dict",
|
||||
"serialize_device_json",
|
||||
"clean_device_str",
|
||||
"coerce_int",
|
||||
"ts_to_iso",
|
||||
"device_column_sql",
|
||||
"ts_to_human",
|
||||
]
|
||||
|
||||
|
||||
DEVICE_TABLE = "devices"
|
||||
|
||||
DEVICE_JSON_LIST_FIELDS: Mapping[str, List[Any]] = {
|
||||
"memory": [],
|
||||
"network": [],
|
||||
"software": [],
|
||||
"storage": [],
|
||||
}
|
||||
|
||||
DEVICE_JSON_OBJECT_FIELDS: Mapping[str, Dict[str, Any]] = {
|
||||
"cpu": {},
|
||||
}
|
||||
|
||||
DEVICE_TABLE_COLUMNS: Sequence[str] = (
|
||||
"guid",
|
||||
"hostname",
|
||||
"description",
|
||||
"created_at",
|
||||
"agent_hash",
|
||||
"memory",
|
||||
"network",
|
||||
"software",
|
||||
"storage",
|
||||
"cpu",
|
||||
"device_type",
|
||||
"domain",
|
||||
"external_ip",
|
||||
"internal_ip",
|
||||
"last_reboot",
|
||||
"last_seen",
|
||||
"last_user",
|
||||
"operating_system",
|
||||
"uptime",
|
||||
"agent_id",
|
||||
"ansible_ee_ver",
|
||||
"connection_type",
|
||||
"connection_endpoint",
|
||||
"ssl_key_fingerprint",
|
||||
"token_version",
|
||||
"status",
|
||||
"key_added_at",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeviceSnapshot:
|
||||
hostname: str
|
||||
description: str
|
||||
created_at: int
|
||||
created_at_iso: str
|
||||
agent_hash: str
|
||||
agent_guid: str
|
||||
guid: str
|
||||
memory: List[Dict[str, Any]]
|
||||
network: List[Dict[str, Any]]
|
||||
software: List[Dict[str, Any]]
|
||||
storage: List[Dict[str, Any]]
|
||||
cpu: Dict[str, Any]
|
||||
device_type: str
|
||||
domain: str
|
||||
external_ip: str
|
||||
internal_ip: str
|
||||
last_reboot: str
|
||||
last_seen: int
|
||||
last_seen_iso: str
|
||||
last_user: str
|
||||
operating_system: str
|
||||
uptime: int
|
||||
agent_id: str
|
||||
ansible_ee_ver: str
|
||||
connection_type: str
|
||||
connection_endpoint: str
|
||||
ssl_key_fingerprint: str
|
||||
token_version: int
|
||||
status: str
|
||||
key_added_at: str
|
||||
details: Dict[str, Any]
|
||||
summary: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"hostname": self.hostname,
|
||||
"description": self.description,
|
||||
"created_at": self.created_at,
|
||||
"created_at_iso": self.created_at_iso,
|
||||
"agent_hash": self.agent_hash,
|
||||
"agent_guid": self.agent_guid,
|
||||
"guid": self.guid,
|
||||
"memory": self.memory,
|
||||
"network": self.network,
|
||||
"software": self.software,
|
||||
"storage": self.storage,
|
||||
"cpu": self.cpu,
|
||||
"device_type": self.device_type,
|
||||
"domain": self.domain,
|
||||
"external_ip": self.external_ip,
|
||||
"internal_ip": self.internal_ip,
|
||||
"last_reboot": self.last_reboot,
|
||||
"last_seen": self.last_seen,
|
||||
"last_seen_iso": self.last_seen_iso,
|
||||
"last_user": self.last_user,
|
||||
"operating_system": self.operating_system,
|
||||
"uptime": self.uptime,
|
||||
"agent_id": self.agent_id,
|
||||
"ansible_ee_ver": self.ansible_ee_ver,
|
||||
"connection_type": self.connection_type,
|
||||
"connection_endpoint": self.connection_endpoint,
|
||||
"ssl_key_fingerprint": self.ssl_key_fingerprint,
|
||||
"token_version": self.token_version,
|
||||
"status": self.status,
|
||||
"key_added_at": self.key_added_at,
|
||||
"details": self.details,
|
||||
"summary": self.summary,
|
||||
}
|
||||
|
||||
|
||||
def ts_to_iso(ts: Optional[int]) -> str:
|
||||
if not ts:
|
||||
return ""
|
||||
try:
|
||||
return datetime.fromtimestamp(int(ts), timezone.utc).isoformat()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _ts_to_human(ts: Optional[int]) -> str:
|
||||
if not ts:
|
||||
return ""
|
||||
try:
|
||||
return datetime.utcfromtimestamp(int(ts)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _parse_device_json(raw: Optional[str], default: Any) -> Any:
|
||||
if raw is None:
|
||||
return json.loads(json.dumps(default)) if isinstance(default, (list, dict)) else default
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except Exception:
|
||||
data = None
|
||||
if isinstance(default, list):
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
return []
|
||||
if isinstance(default, dict):
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
return {}
|
||||
return default
|
||||
|
||||
|
||||
def serialize_device_json(value: Any, default: Any) -> str:
|
||||
candidate = value
|
||||
if candidate is None:
|
||||
candidate = default
|
||||
if not isinstance(candidate, (list, dict)):
|
||||
candidate = default
|
||||
try:
|
||||
return json.dumps(candidate)
|
||||
except Exception:
|
||||
try:
|
||||
return json.dumps(default)
|
||||
except Exception:
|
||||
return "{}" if isinstance(default, dict) else "[]"
|
||||
|
||||
|
||||
def clean_device_str(value: Any) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (int, float)) and not isinstance(value, bool):
|
||||
text = str(value)
|
||||
elif isinstance(value, str):
|
||||
text = value
|
||||
else:
|
||||
try:
|
||||
text = str(value)
|
||||
except Exception:
|
||||
return None
|
||||
text = text.strip()
|
||||
return text or None
|
||||
|
||||
|
||||
def coerce_int(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
if isinstance(value, str) and value.strip() == "":
|
||||
return None
|
||||
return int(float(value))
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def row_to_device_dict(row: Sequence[Any], columns: Sequence[str]) -> Dict[str, Any]:
|
||||
return {columns[idx]: row[idx] for idx in range(min(len(row), len(columns)))}
|
||||
|
||||
|
||||
def assemble_device_snapshot(record: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
hostname = clean_device_str(record.get("hostname")) or ""
|
||||
description = clean_device_str(record.get("description")) or ""
|
||||
agent_hash = clean_device_str(record.get("agent_hash")) or ""
|
||||
raw_guid = clean_device_str(record.get("guid"))
|
||||
normalized_guid = normalize_guid(raw_guid)
|
||||
|
||||
created_ts = coerce_int(record.get("created_at")) or 0
|
||||
last_seen_ts = coerce_int(record.get("last_seen")) or 0
|
||||
uptime_val = coerce_int(record.get("uptime")) or 0
|
||||
token_version = coerce_int(record.get("token_version")) or 0
|
||||
|
||||
parsed_lists = {
|
||||
key: _parse_device_json(record.get(key), default)
|
||||
for key, default in DEVICE_JSON_LIST_FIELDS.items()
|
||||
}
|
||||
cpu_obj = _parse_device_json(record.get("cpu"), DEVICE_JSON_OBJECT_FIELDS["cpu"])
|
||||
|
||||
summary: Dict[str, Any] = {
|
||||
"hostname": hostname,
|
||||
"description": description,
|
||||
"agent_hash": agent_hash,
|
||||
"agent_guid": normalized_guid or "",
|
||||
"agent_id": clean_device_str(record.get("agent_id")) or "",
|
||||
"device_type": clean_device_str(record.get("device_type")) or "",
|
||||
"domain": clean_device_str(record.get("domain")) or "",
|
||||
"external_ip": clean_device_str(record.get("external_ip")) or "",
|
||||
"internal_ip": clean_device_str(record.get("internal_ip")) or "",
|
||||
"last_reboot": clean_device_str(record.get("last_reboot")) or "",
|
||||
"last_seen": last_seen_ts,
|
||||
"last_user": clean_device_str(record.get("last_user")) or "",
|
||||
"operating_system": clean_device_str(record.get("operating_system")) or "",
|
||||
"uptime": uptime_val,
|
||||
"uptime_sec": uptime_val,
|
||||
"ansible_ee_ver": clean_device_str(record.get("ansible_ee_ver")) or "",
|
||||
"connection_type": clean_device_str(record.get("connection_type")) or "",
|
||||
"connection_endpoint": clean_device_str(record.get("connection_endpoint")) or "",
|
||||
"ssl_key_fingerprint": clean_device_str(record.get("ssl_key_fingerprint")) or "",
|
||||
"status": clean_device_str(record.get("status")) or "",
|
||||
"token_version": token_version,
|
||||
"key_added_at": clean_device_str(record.get("key_added_at")) or "",
|
||||
"created_at": created_ts,
|
||||
"created": ts_to_human(created_ts),
|
||||
}
|
||||
|
||||
details = {
|
||||
"memory": parsed_lists["memory"],
|
||||
"network": parsed_lists["network"],
|
||||
"software": parsed_lists["software"],
|
||||
"storage": parsed_lists["storage"],
|
||||
"cpu": cpu_obj,
|
||||
"summary": dict(summary),
|
||||
}
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"hostname": hostname,
|
||||
"description": description,
|
||||
"created_at": created_ts,
|
||||
"created_at_iso": ts_to_iso(created_ts),
|
||||
"agent_hash": agent_hash,
|
||||
"agent_guid": summary.get("agent_guid", ""),
|
||||
"guid": summary.get("agent_guid", ""),
|
||||
"memory": parsed_lists["memory"],
|
||||
"network": parsed_lists["network"],
|
||||
"software": parsed_lists["software"],
|
||||
"storage": parsed_lists["storage"],
|
||||
"cpu": cpu_obj,
|
||||
"device_type": summary.get("device_type", ""),
|
||||
"domain": summary.get("domain", ""),
|
||||
"external_ip": summary.get("external_ip", ""),
|
||||
"internal_ip": summary.get("internal_ip", ""),
|
||||
"last_reboot": summary.get("last_reboot", ""),
|
||||
"last_seen": last_seen_ts,
|
||||
"last_seen_iso": ts_to_iso(last_seen_ts),
|
||||
"last_user": summary.get("last_user", ""),
|
||||
"operating_system": summary.get("operating_system", ""),
|
||||
"uptime": uptime_val,
|
||||
"agent_id": summary.get("agent_id", ""),
|
||||
"ansible_ee_ver": summary.get("ansible_ee_ver", ""),
|
||||
"connection_type": summary.get("connection_type", ""),
|
||||
"connection_endpoint": summary.get("connection_endpoint", ""),
|
||||
"ssl_key_fingerprint": summary.get("ssl_key_fingerprint", ""),
|
||||
"token_version": summary.get("token_version", 0),
|
||||
"status": summary.get("status", ""),
|
||||
"key_added_at": summary.get("key_added_at", ""),
|
||||
"details": details,
|
||||
"summary": summary,
|
||||
}
|
||||
return payload
|
||||
|
||||
|
||||
def device_column_sql(alias: Optional[str] = None) -> str:
|
||||
if alias:
|
||||
return ", ".join(f"{alias}.{col}" for col in DEVICE_TABLE_COLUMNS)
|
||||
return ", ".join(DEVICE_TABLE_COLUMNS)
|
||||
|
||||
|
||||
def ts_to_human(ts: Optional[int]) -> str:
|
||||
return _ts_to_human(ts)
|
||||
@@ -1,206 +0,0 @@
|
||||
"""Administrative enrollment domain models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from Data.Engine.domain.device_auth import DeviceGuid, normalize_guid
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentCodeRecord",
|
||||
"DeviceApprovalRecord",
|
||||
"HostnameConflict",
|
||||
]
|
||||
|
||||
|
||||
def _parse_iso8601(value: Optional[str]) -> Optional[datetime]:
|
||||
if not value:
|
||||
return None
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
dt = datetime.fromisoformat(raw)
|
||||
except Exception as exc: # pragma: no cover - defensive parsing
|
||||
raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def _isoformat(value: Optional[datetime]) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc).isoformat()
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentCodeRecord:
|
||||
"""Installer code metadata exposed to administrative clients."""
|
||||
|
||||
record_id: str
|
||||
code: str
|
||||
expires_at: datetime
|
||||
max_uses: int
|
||||
use_count: int
|
||||
created_by_user_id: Optional[str]
|
||||
used_at: Optional[datetime]
|
||||
used_by_guid: Optional[DeviceGuid]
|
||||
last_used_at: Optional[datetime]
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: Mapping[str, Any]) -> "EnrollmentCodeRecord":
|
||||
record_id = str(row.get("id") or "").strip()
|
||||
code = str(row.get("code") or "").strip()
|
||||
if not record_id or not code:
|
||||
raise ValueError("invalid enrollment install code record")
|
||||
|
||||
used_by = row.get("used_by_guid")
|
||||
used_by_guid = DeviceGuid(str(used_by)) if used_by else None
|
||||
|
||||
return cls(
|
||||
record_id=record_id,
|
||||
code=code,
|
||||
expires_at=_parse_iso8601(row.get("expires_at")) or datetime.now(tz=timezone.utc),
|
||||
max_uses=int(row.get("max_uses") or 1),
|
||||
use_count=int(row.get("use_count") or 0),
|
||||
created_by_user_id=str(row.get("created_by_user_id") or "").strip() or None,
|
||||
used_at=_parse_iso8601(row.get("used_at")),
|
||||
used_by_guid=used_by_guid,
|
||||
last_used_at=_parse_iso8601(row.get("last_used_at")),
|
||||
)
|
||||
|
||||
def status(self, *, now: Optional[datetime] = None) -> str:
|
||||
reference = now or datetime.now(tz=timezone.utc)
|
||||
if self.use_count >= self.max_uses:
|
||||
return "used"
|
||||
if self.expires_at <= reference:
|
||||
return "expired"
|
||||
return "active"
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.record_id,
|
||||
"code": self.code,
|
||||
"expires_at": _isoformat(self.expires_at),
|
||||
"max_uses": self.max_uses,
|
||||
"use_count": self.use_count,
|
||||
"created_by_user_id": self.created_by_user_id,
|
||||
"used_at": _isoformat(self.used_at),
|
||||
"used_by_guid": self.used_by_guid.value if self.used_by_guid else None,
|
||||
"last_used_at": _isoformat(self.last_used_at),
|
||||
"status": self.status(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class HostnameConflict:
|
||||
"""Existing device details colliding with a pending approval."""
|
||||
|
||||
guid: Optional[str]
|
||||
ssl_key_fingerprint: Optional[str]
|
||||
site_id: Optional[int]
|
||||
site_name: str
|
||||
fingerprint_match: bool
|
||||
requires_prompt: bool
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"guid": self.guid,
|
||||
"ssl_key_fingerprint": self.ssl_key_fingerprint,
|
||||
"site_id": self.site_id,
|
||||
"site_name": self.site_name,
|
||||
"fingerprint_match": self.fingerprint_match,
|
||||
"requires_prompt": self.requires_prompt,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceApprovalRecord:
|
||||
"""Administrative projection of a device approval entry."""
|
||||
|
||||
record_id: str
|
||||
reference: str
|
||||
status: str
|
||||
claimed_hostname: str
|
||||
claimed_fingerprint: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
enrollment_code_id: Optional[str]
|
||||
guid: Optional[str]
|
||||
approved_by_user_id: Optional[str]
|
||||
approved_by_username: Optional[str]
|
||||
client_nonce: str
|
||||
server_nonce: str
|
||||
hostname_conflict: Optional[HostnameConflict]
|
||||
alternate_hostname: Optional[str]
|
||||
conflict_requires_prompt: bool
|
||||
fingerprint_match: bool
|
||||
|
||||
@classmethod
|
||||
def from_row(
|
||||
cls,
|
||||
row: Mapping[str, Any],
|
||||
*,
|
||||
conflict: Optional[HostnameConflict] = None,
|
||||
alternate_hostname: Optional[str] = None,
|
||||
fingerprint_match: bool = False,
|
||||
requires_prompt: bool = False,
|
||||
) -> "DeviceApprovalRecord":
|
||||
record_id = str(row.get("id") or "").strip()
|
||||
reference = str(row.get("approval_reference") or "").strip()
|
||||
hostname = str(row.get("hostname_claimed") or "").strip()
|
||||
fingerprint = str(row.get("ssl_key_fingerprint_claimed") or "").strip().lower()
|
||||
if not record_id or not reference or not hostname or not fingerprint:
|
||||
raise ValueError("invalid device approval record")
|
||||
|
||||
guid_raw = normalize_guid(row.get("guid")) or None
|
||||
|
||||
return cls(
|
||||
record_id=record_id,
|
||||
reference=reference,
|
||||
status=str(row.get("status") or "pending").strip().lower(),
|
||||
claimed_hostname=hostname,
|
||||
claimed_fingerprint=fingerprint,
|
||||
created_at=_parse_iso8601(row.get("created_at")) or datetime.now(tz=timezone.utc),
|
||||
updated_at=_parse_iso8601(row.get("updated_at")) or datetime.now(tz=timezone.utc),
|
||||
enrollment_code_id=str(row.get("enrollment_code_id") or "").strip() or None,
|
||||
guid=guid_raw,
|
||||
approved_by_user_id=str(row.get("approved_by_user_id") or "").strip() or None,
|
||||
approved_by_username=str(row.get("approved_by_username") or "").strip() or None,
|
||||
client_nonce=str(row.get("client_nonce") or "").strip(),
|
||||
server_nonce=str(row.get("server_nonce") or "").strip(),
|
||||
hostname_conflict=conflict,
|
||||
alternate_hostname=alternate_hostname,
|
||||
conflict_requires_prompt=requires_prompt,
|
||||
fingerprint_match=fingerprint_match,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"id": self.record_id,
|
||||
"approval_reference": self.reference,
|
||||
"status": self.status,
|
||||
"hostname_claimed": self.claimed_hostname,
|
||||
"ssl_key_fingerprint_claimed": self.claimed_fingerprint,
|
||||
"created_at": _isoformat(self.created_at),
|
||||
"updated_at": _isoformat(self.updated_at),
|
||||
"enrollment_code_id": self.enrollment_code_id,
|
||||
"guid": self.guid,
|
||||
"approved_by_user_id": self.approved_by_user_id,
|
||||
"approved_by_username": self.approved_by_username,
|
||||
"client_nonce": self.client_nonce,
|
||||
"server_nonce": self.server_nonce,
|
||||
"conflict_requires_prompt": self.conflict_requires_prompt,
|
||||
"fingerprint_match": self.fingerprint_match,
|
||||
}
|
||||
if self.hostname_conflict is not None:
|
||||
payload["hostname_conflict"] = self.hostname_conflict.to_dict()
|
||||
if self.alternate_hostname:
|
||||
payload["alternate_hostname"] = self.alternate_hostname
|
||||
return payload
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Domain types for GitHub integrations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GitHubRepoRef:
|
||||
"""Identify a GitHub repository and branch."""
|
||||
|
||||
owner: str
|
||||
name: str
|
||||
branch: str
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
return f"{self.owner}/{self.name}".strip("/")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, owner_repo: str, branch: str) -> "GitHubRepoRef":
|
||||
owner_repo = (owner_repo or "").strip()
|
||||
if "/" not in owner_repo:
|
||||
raise ValueError("repo must be in the form owner/name")
|
||||
owner, repo = owner_repo.split("/", 1)
|
||||
return cls(owner=owner.strip(), name=repo.strip(), branch=(branch or "main").strip())
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RepoHeadSnapshot:
|
||||
"""Snapshot describing the current head of a repository branch."""
|
||||
|
||||
repository: GitHubRepoRef
|
||||
sha: Optional[str]
|
||||
cached: bool
|
||||
age_seconds: Optional[float]
|
||||
source: str
|
||||
error: Optional[str]
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
"repo": self.repository.full_name,
|
||||
"branch": self.repository.branch,
|
||||
"sha": self.sha,
|
||||
"cached": self.cached,
|
||||
"age_seconds": self.age_seconds,
|
||||
"source": self.source,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GitHubRateLimit:
|
||||
"""Subset of rate limit details returned by the GitHub API."""
|
||||
|
||||
limit: Optional[int]
|
||||
remaining: Optional[int]
|
||||
reset_epoch: Optional[int]
|
||||
used: Optional[int]
|
||||
|
||||
def to_dict(self) -> Dict[str, Optional[int]]:
|
||||
return {
|
||||
"limit": self.limit,
|
||||
"remaining": self.remaining,
|
||||
"reset": self.reset_epoch,
|
||||
"used": self.used,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GitHubTokenStatus:
|
||||
"""Describe the verification result for a GitHub access token."""
|
||||
|
||||
has_token: bool
|
||||
valid: bool
|
||||
status: str
|
||||
message: str
|
||||
rate_limit: Optional[GitHubRateLimit]
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
payload: Dict[str, object] = {
|
||||
"has_token": self.has_token,
|
||||
"valid": self.valid,
|
||||
"status": self.status,
|
||||
"message": self.message,
|
||||
"error": self.error,
|
||||
}
|
||||
if self.rate_limit is not None:
|
||||
payload["rate_limit"] = self.rate_limit.to_dict()
|
||||
else:
|
||||
payload["rate_limit"] = None
|
||||
return payload
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GitHubRateLimit",
|
||||
"GitHubRepoRef",
|
||||
"GitHubTokenStatus",
|
||||
"RepoHeadSnapshot",
|
||||
]
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
"""Domain models for operator authentication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OperatorAccount:
|
||||
"""Snapshot of an operator account stored in SQLite."""
|
||||
|
||||
username: str
|
||||
display_name: str
|
||||
password_sha512: str
|
||||
role: str
|
||||
last_login: int
|
||||
created_at: int
|
||||
updated_at: int
|
||||
mfa_enabled: bool
|
||||
mfa_secret: Optional[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OperatorLoginSuccess:
|
||||
"""Successful login payload for the caller."""
|
||||
|
||||
username: str
|
||||
role: str
|
||||
token: str
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OperatorMFAChallenge:
|
||||
"""Details describing an in-progress MFA challenge."""
|
||||
|
||||
username: str
|
||||
role: str
|
||||
stage: Literal["setup", "verify"]
|
||||
pending_token: str
|
||||
expires_at: int
|
||||
secret: Optional[str] = None
|
||||
otpauth_url: Optional[str] = None
|
||||
qr_image: Optional[str] = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OperatorAccount",
|
||||
"OperatorLoginSuccess",
|
||||
"OperatorMFAChallenge",
|
||||
]
|
||||
@@ -1,43 +0,0 @@
|
||||
"""Domain models for operator site management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
__all__ = ["SiteSummary", "SiteDeviceMapping"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SiteSummary:
|
||||
"""Representation of a site record including device counts."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
created_at: int
|
||||
device_count: int
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"created_at": self.created_at,
|
||||
"device_count": self.device_count,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SiteDeviceMapping:
|
||||
"""Mapping entry describing which site a device belongs to."""
|
||||
|
||||
hostname: str
|
||||
site_id: Optional[int]
|
||||
site_name: str
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
"site_id": self.site_id,
|
||||
"site_name": self.site_name,
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
"""External system adapters for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .github.artifact_provider import GitHubArtifactProvider
|
||||
|
||||
__all__ = ["GitHubArtifactProvider"]
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Crypto integration helpers for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .keys import (
|
||||
base64_from_spki_der,
|
||||
fingerprint_from_base64_spki,
|
||||
fingerprint_from_spki_der,
|
||||
generate_ed25519_keypair,
|
||||
normalize_base64,
|
||||
private_key_to_pem,
|
||||
public_key_to_pem,
|
||||
spki_der_from_base64,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"base64_from_spki_der",
|
||||
"fingerprint_from_base64_spki",
|
||||
"fingerprint_from_spki_der",
|
||||
"generate_ed25519_keypair",
|
||||
"normalize_base64",
|
||||
"private_key_to_pem",
|
||||
"public_key_to_pem",
|
||||
"spki_der_from_base64",
|
||||
]
|
||||
@@ -1,70 +0,0 @@
|
||||
"""Key utilities mirrored from the legacy crypto helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||
|
||||
__all__ = [
|
||||
"base64_from_spki_der",
|
||||
"fingerprint_from_base64_spki",
|
||||
"fingerprint_from_spki_der",
|
||||
"generate_ed25519_keypair",
|
||||
"normalize_base64",
|
||||
"private_key_to_pem",
|
||||
"public_key_to_pem",
|
||||
"spki_der_from_base64",
|
||||
]
|
||||
|
||||
|
||||
def generate_ed25519_keypair() -> Tuple[ed25519.Ed25519PrivateKey, bytes]:
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
public_key = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return private_key, public_key
|
||||
|
||||
|
||||
def normalize_base64(data: str) -> str:
|
||||
cleaned = re.sub(r"\s+", "", data or "")
|
||||
return cleaned.replace("-", "+").replace("_", "/")
|
||||
|
||||
|
||||
def spki_der_from_base64(spki_b64: str) -> bytes:
|
||||
return base64.b64decode(normalize_base64(spki_b64), validate=True)
|
||||
|
||||
|
||||
def base64_from_spki_der(spki_der: bytes) -> str:
|
||||
return base64.b64encode(spki_der).decode("ascii")
|
||||
|
||||
|
||||
def fingerprint_from_spki_der(spki_der: bytes) -> str:
|
||||
digest = hashlib.sha256(spki_der).hexdigest()
|
||||
return digest.lower()
|
||||
|
||||
|
||||
def fingerprint_from_base64_spki(spki_b64: str) -> str:
|
||||
return fingerprint_from_spki_der(spki_der_from_base64(spki_b64))
|
||||
|
||||
|
||||
def private_key_to_pem(private_key: ed25519.Ed25519PrivateKey) -> bytes:
|
||||
return private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
|
||||
def public_key_to_pem(public_spki_der: bytes) -> bytes:
|
||||
public_key = load_der_public_key(public_spki_der)
|
||||
return public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
@@ -1,8 +0,0 @@
|
||||
"""GitHub integration surface for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .artifact_provider import GitHubArtifactProvider
|
||||
|
||||
__all__ = ["GitHubArtifactProvider"]
|
||||
|
||||
@@ -1,275 +0,0 @@
|
||||
"""GitHub REST API integration with caching support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from Data.Engine.domain.github import GitHubRepoRef, GitHubTokenStatus, RepoHeadSnapshot, GitHubRateLimit
|
||||
|
||||
try: # pragma: no cover - optional dependency guard
|
||||
import requests
|
||||
from requests import Response
|
||||
except Exception: # pragma: no cover - fallback when requests is unavailable
|
||||
requests = None # type: ignore[assignment]
|
||||
Response = object # type: ignore[misc,assignment]
|
||||
|
||||
__all__ = ["GitHubArtifactProvider"]
|
||||
|
||||
|
||||
class GitHubArtifactProvider:
|
||||
"""Resolve repository heads and token metadata from the GitHub API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cache_file: Path,
|
||||
default_repo: str,
|
||||
default_branch: str,
|
||||
refresh_interval: int,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._cache_file = cache_file
|
||||
self._default_repo = default_repo
|
||||
self._default_branch = default_branch
|
||||
self._refresh_interval = max(30, min(refresh_interval, 3600))
|
||||
self._log = logger or logging.getLogger("borealis.engine.integrations.github")
|
||||
self._token: Optional[str] = None
|
||||
self._cache_lock = threading.Lock()
|
||||
self._cache: Dict[str, Dict[str, float | str]] = {}
|
||||
self._worker: Optional[threading.Thread] = None
|
||||
self._hydrate_cache_from_disk()
|
||||
|
||||
def set_token(self, token: Optional[str]) -> None:
|
||||
self._token = (token or "").strip() or None
|
||||
|
||||
@property
|
||||
def default_repo(self) -> str:
|
||||
return self._default_repo
|
||||
|
||||
@property
|
||||
def default_branch(self) -> str:
|
||||
return self._default_branch
|
||||
|
||||
@property
|
||||
def refresh_interval(self) -> int:
|
||||
return self._refresh_interval
|
||||
|
||||
def fetch_repo_head(
|
||||
self,
|
||||
repo: GitHubRepoRef,
|
||||
*,
|
||||
ttl_seconds: int,
|
||||
force_refresh: bool = False,
|
||||
) -> RepoHeadSnapshot:
|
||||
key = f"{repo.full_name}:{repo.branch}"
|
||||
now = time.time()
|
||||
|
||||
cached_entry = None
|
||||
with self._cache_lock:
|
||||
cached_entry = self._cache.get(key, {}).copy()
|
||||
|
||||
cached_sha = (cached_entry.get("sha") if cached_entry else None) # type: ignore[assignment]
|
||||
cached_ts = cached_entry.get("timestamp") if cached_entry else None # type: ignore[assignment]
|
||||
cached_age = None
|
||||
if isinstance(cached_ts, (int, float)):
|
||||
cached_age = max(0.0, now - float(cached_ts))
|
||||
|
||||
ttl = max(30, min(ttl_seconds, 3600))
|
||||
if cached_sha and not force_refresh and cached_age is not None and cached_age < ttl:
|
||||
return RepoHeadSnapshot(
|
||||
repository=repo,
|
||||
sha=str(cached_sha),
|
||||
cached=True,
|
||||
age_seconds=cached_age,
|
||||
source="cache",
|
||||
error=None,
|
||||
)
|
||||
|
||||
if requests is None:
|
||||
return RepoHeadSnapshot(
|
||||
repository=repo,
|
||||
sha=str(cached_sha) if cached_sha else None,
|
||||
cached=bool(cached_sha),
|
||||
age_seconds=cached_age,
|
||||
source="unavailable",
|
||||
error="requests library not available",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"User-Agent": "Borealis-Engine",
|
||||
}
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
|
||||
url = f"https://api.github.com/repos/{repo.full_name}/branches/{repo.branch}"
|
||||
error: Optional[str] = None
|
||||
sha: Optional[str] = None
|
||||
|
||||
try:
|
||||
response: Response = requests.get(url, headers=headers, timeout=20)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
sha = (data.get("commit", {}).get("sha") or "").strip() # type: ignore[assignment]
|
||||
else:
|
||||
error = f"GitHub REST API repo head lookup failed: HTTP {response.status_code} {response.text[:200]}"
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
error = f"GitHub REST API repo head lookup raised: {exc}"
|
||||
|
||||
if sha:
|
||||
payload = {"sha": sha, "timestamp": now}
|
||||
with self._cache_lock:
|
||||
self._cache[key] = payload
|
||||
self._persist_cache()
|
||||
return RepoHeadSnapshot(
|
||||
repository=repo,
|
||||
sha=sha,
|
||||
cached=False,
|
||||
age_seconds=0.0,
|
||||
source="github",
|
||||
error=None,
|
||||
)
|
||||
|
||||
if error:
|
||||
self._log.warning("repo-head-lookup failure repo=%s branch=%s error=%s", repo.full_name, repo.branch, error)
|
||||
|
||||
return RepoHeadSnapshot(
|
||||
repository=repo,
|
||||
sha=str(cached_sha) if cached_sha else None,
|
||||
cached=bool(cached_sha),
|
||||
age_seconds=cached_age,
|
||||
source="cache-stale" if cached_sha else "github",
|
||||
error=error or ("using cached value" if cached_sha else "unable to resolve repository head"),
|
||||
)
|
||||
|
||||
def refresh_default_repo_head(self, *, force: bool = False) -> RepoHeadSnapshot:
|
||||
repo = GitHubRepoRef.parse(self._default_repo, self._default_branch)
|
||||
return self.fetch_repo_head(repo, ttl_seconds=self._refresh_interval, force_refresh=force)
|
||||
|
||||
def verify_token(self, token: Optional[str]) -> GitHubTokenStatus:
|
||||
token = (token or "").strip()
|
||||
if not token:
|
||||
return GitHubTokenStatus(
|
||||
has_token=False,
|
||||
valid=False,
|
||||
status="missing",
|
||||
message="API Token Not Configured",
|
||||
rate_limit=None,
|
||||
error=None,
|
||||
)
|
||||
|
||||
if requests is None:
|
||||
return GitHubTokenStatus(
|
||||
has_token=True,
|
||||
valid=False,
|
||||
status="unknown",
|
||||
message="requests library not available",
|
||||
rate_limit=None,
|
||||
error="requests library not available",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": "Borealis-Engine",
|
||||
}
|
||||
try:
|
||||
response: Response = requests.get("https://api.github.com/rate_limit", headers=headers, timeout=10)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
message = f"GitHub token verification raised: {exc}"
|
||||
self._log.warning("github-token-verify error=%s", message)
|
||||
return GitHubTokenStatus(
|
||||
has_token=True,
|
||||
valid=False,
|
||||
status="error",
|
||||
message="API Token Invalid",
|
||||
rate_limit=None,
|
||||
error=message,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
message = f"GitHub API error (HTTP {response.status_code})"
|
||||
self._log.warning("github-token-verify http_status=%s", response.status_code)
|
||||
return GitHubTokenStatus(
|
||||
has_token=True,
|
||||
valid=False,
|
||||
status="error",
|
||||
message="API Token Invalid",
|
||||
rate_limit=None,
|
||||
error=message,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
core = (data.get("resources", {}).get("core", {}) if isinstance(data, dict) else {})
|
||||
rate_limit = GitHubRateLimit(
|
||||
limit=_safe_int(core.get("limit")),
|
||||
remaining=_safe_int(core.get("remaining")),
|
||||
reset_epoch=_safe_int(core.get("reset")),
|
||||
used=_safe_int(core.get("used")),
|
||||
)
|
||||
|
||||
message = "API Token Valid" if rate_limit.remaining is not None else "API Token Verified"
|
||||
return GitHubTokenStatus(
|
||||
has_token=True,
|
||||
valid=True,
|
||||
status="valid",
|
||||
message=message,
|
||||
rate_limit=rate_limit,
|
||||
error=None,
|
||||
)
|
||||
|
||||
def start_background_refresh(self) -> None:
|
||||
if self._worker and self._worker.is_alive(): # pragma: no cover - guard
|
||||
return
|
||||
|
||||
def _loop() -> None:
|
||||
interval = max(30, self._refresh_interval)
|
||||
while True:
|
||||
try:
|
||||
self.refresh_default_repo_head(force=True)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("default-repo-refresh failure: %s", exc)
|
||||
time.sleep(interval)
|
||||
|
||||
self._worker = threading.Thread(target=_loop, name="github-repo-refresh", daemon=True)
|
||||
self._worker.start()
|
||||
|
||||
def _hydrate_cache_from_disk(self) -> None:
|
||||
path = self._cache_file
|
||||
try:
|
||||
if not path.exists():
|
||||
return
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, dict):
|
||||
with self._cache_lock:
|
||||
self._cache = {
|
||||
key: value
|
||||
for key, value in data.items()
|
||||
if isinstance(value, dict) and "sha" in value and "timestamp" in value
|
||||
}
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("failed to load repo cache: %s", exc)
|
||||
|
||||
def _persist_cache(self) -> None:
|
||||
path = self._cache_file
|
||||
try:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = json.dumps(self._cache, ensure_ascii=False)
|
||||
tmp = path.with_suffix(".tmp")
|
||||
tmp.write_text(payload, encoding="utf-8")
|
||||
tmp.replace(path)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("failed to persist repo cache: %s", exc)
|
||||
|
||||
|
||||
def _safe_int(value: object) -> Optional[int]:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
"""Interface adapters (HTTP, WebSocket, etc.) for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .http import register_http_interfaces
|
||||
from .ws import create_socket_server, register_ws_interfaces
|
||||
|
||||
__all__ = [
|
||||
"register_http_interfaces",
|
||||
"create_socket_server",
|
||||
"register_ws_interfaces",
|
||||
]
|
||||
@@ -1,75 +0,0 @@
|
||||
"""Compatibility helpers for running Socket.IO under eventlet."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ssl
|
||||
from typing import Any
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import eventlet # type: ignore
|
||||
except Exception: # pragma: no cover - optional dependency
|
||||
eventlet = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def _quiet_close(connection: Any) -> None:
|
||||
try:
|
||||
if hasattr(connection, "close"):
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _close_connection_quietly(protocol: Any) -> None:
|
||||
try:
|
||||
setattr(protocol, "close_connection", True)
|
||||
except Exception:
|
||||
pass
|
||||
conn = getattr(protocol, "socket", None) or getattr(protocol, "connection", None)
|
||||
if conn is not None:
|
||||
_quiet_close(conn)
|
||||
|
||||
|
||||
def apply_eventlet_patches() -> None:
|
||||
"""Apply Borealis-specific eventlet tweaks when the dependency is available."""
|
||||
|
||||
if eventlet is None: # pragma: no cover - guard for environments without eventlet
|
||||
return
|
||||
|
||||
eventlet.monkey_patch(thread=False)
|
||||
|
||||
try:
|
||||
from eventlet.wsgi import HttpProtocol # type: ignore
|
||||
except Exception: # pragma: no cover - import guard
|
||||
return
|
||||
|
||||
original = HttpProtocol.handle_one_request # type: ignore[attr-defined]
|
||||
|
||||
def _handle_one_request(self: Any, *args: Any, **kwargs: Any) -> Any: # type: ignore[override]
|
||||
try:
|
||||
return original(self, *args, **kwargs)
|
||||
except ssl.SSLError as exc: # type: ignore[arg-type]
|
||||
reason = getattr(exc, "reason", "") or ""
|
||||
message = " ".join(str(arg) for arg in exc.args if arg)
|
||||
lower_reason = str(reason).lower()
|
||||
lower_message = message.lower()
|
||||
if (
|
||||
"http_request" in lower_message
|
||||
or lower_reason == "http request"
|
||||
or "unknown ca" in lower_message
|
||||
or lower_reason == "unknown ca"
|
||||
or "unknown_ca" in lower_message
|
||||
):
|
||||
_close_connection_quietly(self)
|
||||
return None
|
||||
raise
|
||||
except ssl.SSLEOFError:
|
||||
_close_connection_quietly(self)
|
||||
return None
|
||||
except ConnectionAbortedError:
|
||||
_close_connection_quietly(self)
|
||||
return None
|
||||
|
||||
HttpProtocol.handle_one_request = _handle_one_request # type: ignore[assignment]
|
||||
|
||||
|
||||
__all__ = ["apply_eventlet_patches"]
|
||||
@@ -1,56 +0,0 @@
|
||||
"""HTTP interface registration for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
from . import (
|
||||
admin,
|
||||
agent,
|
||||
agents,
|
||||
auth,
|
||||
enrollment,
|
||||
github,
|
||||
health,
|
||||
job_management,
|
||||
tokens,
|
||||
users,
|
||||
sites,
|
||||
devices,
|
||||
credentials,
|
||||
assemblies,
|
||||
server_info,
|
||||
)
|
||||
|
||||
_REGISTRARS = (
|
||||
health.register,
|
||||
agent.register,
|
||||
agents.register,
|
||||
enrollment.register,
|
||||
tokens.register,
|
||||
job_management.register,
|
||||
github.register,
|
||||
auth.register,
|
||||
admin.register,
|
||||
users.register,
|
||||
sites.register,
|
||||
devices.register,
|
||||
credentials.register,
|
||||
assemblies.register,
|
||||
server_info.register,
|
||||
)
|
||||
|
||||
|
||||
def register_http_interfaces(app: Flask, services: EngineServiceContainer) -> None:
|
||||
"""Attach HTTP blueprints to *app*.
|
||||
|
||||
The implementation is intentionally minimal for the initial scaffolding.
|
||||
"""
|
||||
|
||||
for registrar in _REGISTRARS:
|
||||
registrar(app, services)
|
||||
|
||||
|
||||
__all__ = ["register_http_interfaces"]
|
||||
@@ -1,173 +0,0 @@
|
||||
"""Administrative HTTP endpoints for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request, session
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
|
||||
blueprint = Blueprint("engine_admin", __name__, url_prefix="/api/admin")
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach administrative routes to *app*."""
|
||||
|
||||
if "engine_admin" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
def _services() -> EngineServiceContainer:
|
||||
services = current_app.extensions.get("engine_services")
|
||||
if services is None: # pragma: no cover - defensive
|
||||
raise RuntimeError("engine services not initialized")
|
||||
return services
|
||||
|
||||
|
||||
def _admin_service():
|
||||
return _services().enrollment_admin_service
|
||||
|
||||
|
||||
def _require_admin():
|
||||
username = session.get("username")
|
||||
role = (session.get("role") or "").strip().lower()
|
||||
if not isinstance(username, str) or not username:
|
||||
return jsonify({"error": "not_authenticated"}), 401
|
||||
if role != "admin":
|
||||
return jsonify({"error": "forbidden"}), 403
|
||||
return None
|
||||
|
||||
|
||||
@blueprint.route("/enrollment-codes", methods=["GET"])
|
||||
def list_enrollment_codes() -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
status = request.args.get("status")
|
||||
records = _admin_service().list_install_codes(status=status)
|
||||
return jsonify({"codes": [record.to_dict() for record in records]})
|
||||
|
||||
|
||||
@blueprint.route("/enrollment-codes", methods=["POST"])
|
||||
def create_enrollment_code() -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
|
||||
ttl_value = payload.get("ttl_hours")
|
||||
if ttl_value is None:
|
||||
ttl_value = payload.get("ttl") or 1
|
||||
try:
|
||||
ttl_hours = int(ttl_value)
|
||||
except (TypeError, ValueError):
|
||||
ttl_hours = 1
|
||||
|
||||
max_uses_value = payload.get("max_uses")
|
||||
if max_uses_value is None:
|
||||
max_uses_value = payload.get("allowed_uses", 2)
|
||||
try:
|
||||
max_uses = int(max_uses_value)
|
||||
except (TypeError, ValueError):
|
||||
max_uses = 2
|
||||
|
||||
creator = session.get("username") if isinstance(session.get("username"), str) else None
|
||||
|
||||
try:
|
||||
record = _admin_service().create_install_code(
|
||||
ttl_hours=ttl_hours,
|
||||
max_uses=max_uses,
|
||||
created_by=creator,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if str(exc) == "invalid_ttl":
|
||||
return jsonify({"error": "invalid_ttl"}), 400
|
||||
raise
|
||||
|
||||
response = jsonify(record.to_dict())
|
||||
response.status_code = 201
|
||||
return response
|
||||
|
||||
|
||||
@blueprint.route("/enrollment-codes/<code_id>", methods=["DELETE"])
|
||||
def delete_enrollment_code(code_id: str) -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
if not _admin_service().delete_install_code(code_id):
|
||||
return jsonify({"error": "not_found"}), 404
|
||||
return jsonify({"status": "deleted"})
|
||||
|
||||
|
||||
@blueprint.route("/device-approvals", methods=["GET"])
|
||||
def list_device_approvals() -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
status = request.args.get("status")
|
||||
records = _admin_service().list_device_approvals(status=status)
|
||||
return jsonify({"approvals": [record.to_dict() for record in records]})
|
||||
|
||||
|
||||
@blueprint.route("/device-approvals/<approval_id>/approve", methods=["POST"])
|
||||
def approve_device_approval(approval_id: str) -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
guid = payload.get("guid")
|
||||
resolution_raw = payload.get("conflict_resolution") or payload.get("resolution")
|
||||
resolution = resolution_raw.strip().lower() if isinstance(resolution_raw, str) else None
|
||||
|
||||
actor = session.get("username") if isinstance(session.get("username"), str) else None
|
||||
|
||||
try:
|
||||
result = _admin_service().approve_device_approval(
|
||||
approval_id,
|
||||
actor=actor,
|
||||
guid=guid,
|
||||
conflict_resolution=resolution,
|
||||
)
|
||||
except LookupError:
|
||||
return jsonify({"error": "not_found"}), 404
|
||||
except ValueError as exc:
|
||||
code = str(exc)
|
||||
if code == "approval_not_pending":
|
||||
return jsonify({"error": "approval_not_pending"}), 409
|
||||
if code == "conflict_resolution_required":
|
||||
return jsonify({"error": "conflict_resolution_required"}), 409
|
||||
if code == "invalid_guid":
|
||||
return jsonify({"error": "invalid_guid"}), 400
|
||||
raise
|
||||
|
||||
response = jsonify(result.to_dict())
|
||||
response.status_code = 200
|
||||
return response
|
||||
|
||||
|
||||
@blueprint.route("/device-approvals/<approval_id>/deny", methods=["POST"])
|
||||
def deny_device_approval(approval_id: str) -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
actor = session.get("username") if isinstance(session.get("username"), str) else None
|
||||
|
||||
try:
|
||||
result = _admin_service().deny_device_approval(approval_id, actor=actor)
|
||||
except LookupError:
|
||||
return jsonify({"error": "not_found"}), 404
|
||||
except ValueError as exc:
|
||||
if str(exc) == "approval_not_pending":
|
||||
return jsonify({"error": "approval_not_pending"}), 409
|
||||
raise
|
||||
|
||||
return jsonify(result.to_dict())
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
@@ -1,148 +0,0 @@
|
||||
"""Agent REST endpoints for device communication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, TypeVar, cast
|
||||
|
||||
from flask import Blueprint, Flask, current_app, g, jsonify, request
|
||||
|
||||
from Data.Engine.builders.device_auth import DeviceAuthRequestBuilder
|
||||
from Data.Engine.domain.device_auth import DeviceAuthContext, DeviceAuthFailure
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
from Data.Engine.services.devices.device_inventory_service import (
|
||||
DeviceDetailsError,
|
||||
DeviceHeartbeatError,
|
||||
)
|
||||
|
||||
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||
|
||||
blueprint = Blueprint("engine_agent", __name__)
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _services() -> EngineServiceContainer:
|
||||
return cast(EngineServiceContainer, current_app.extensions["engine_services"])
|
||||
|
||||
|
||||
def require_device_auth(func: F) -> F:
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any):
|
||||
services = _services()
|
||||
builder = (
|
||||
DeviceAuthRequestBuilder()
|
||||
.with_authorization(request.headers.get("Authorization"))
|
||||
.with_http_method(request.method)
|
||||
.with_htu(request.url)
|
||||
.with_service_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||
.with_dpop_proof(request.headers.get("DPoP"))
|
||||
)
|
||||
try:
|
||||
auth_request = builder.build()
|
||||
context = services.device_auth.authenticate(auth_request, path=request.path)
|
||||
except DeviceAuthFailure as exc:
|
||||
payload = exc.to_dict()
|
||||
response = jsonify(payload)
|
||||
if exc.retry_after is not None:
|
||||
response.headers["Retry-After"] = str(int(math.ceil(exc.retry_after)))
|
||||
return response, exc.http_status
|
||||
|
||||
g.device_auth = context
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
g.pop("device_auth", None)
|
||||
|
||||
return cast(F, wrapper)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
if "engine_agent" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
|
||||
@require_device_auth
|
||||
def heartbeat() -> Any:
|
||||
services = _services()
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
context = cast(DeviceAuthContext, g.device_auth)
|
||||
|
||||
try:
|
||||
services.device_inventory.record_heartbeat(context=context, payload=payload)
|
||||
except DeviceHeartbeatError as exc:
|
||||
error_payload = {"error": exc.code}
|
||||
if exc.code == "device_not_registered":
|
||||
return jsonify(error_payload), 404
|
||||
if exc.code == "storage_conflict":
|
||||
return jsonify(error_payload), 409
|
||||
current_app.logger.exception(
|
||||
"device-heartbeat-error guid=%s code=%s", context.identity.guid.value, exc.code
|
||||
)
|
||||
return jsonify(error_payload), 500
|
||||
|
||||
return jsonify({"status": "ok", "poll_after_ms": 15000})
|
||||
|
||||
|
||||
@blueprint.route("/api/agent/script/request", methods=["POST"])
|
||||
@require_device_auth
|
||||
def script_request() -> Any:
|
||||
services = _services()
|
||||
context = cast(DeviceAuthContext, g.device_auth)
|
||||
|
||||
signing_key: Optional[str] = None
|
||||
signer = services.script_signer
|
||||
if signer is not None:
|
||||
try:
|
||||
signing_key = signer.public_base64_spki()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
current_app.logger.warning("script-signer-unavailable: %s", exc)
|
||||
|
||||
status = "quarantined" if context.is_quarantined else "idle"
|
||||
poll_after = 60000 if context.is_quarantined else 30000
|
||||
|
||||
response = {
|
||||
"status": status,
|
||||
"poll_after_ms": poll_after,
|
||||
"sig_alg": "ed25519",
|
||||
}
|
||||
if signing_key:
|
||||
response["signing_key"] = signing_key
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
@blueprint.route("/api/agent/details", methods=["POST"])
|
||||
@require_device_auth
|
||||
def save_details() -> Any:
|
||||
services = _services()
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
context = cast(DeviceAuthContext, g.device_auth)
|
||||
|
||||
try:
|
||||
services.device_inventory.save_agent_details(context=context, payload=payload)
|
||||
except DeviceDetailsError as exc:
|
||||
error_payload = {"error": exc.code}
|
||||
if exc.code == "invalid_payload":
|
||||
return jsonify(error_payload), 400
|
||||
if exc.code in {"fingerprint_mismatch", "guid_mismatch"}:
|
||||
return jsonify(error_payload), 403
|
||||
if exc.code == "device_not_registered":
|
||||
return jsonify(error_payload), 404
|
||||
current_app.logger.exception(
|
||||
"device-details-error guid=%s code=%s", context.identity.guid.value, exc.code
|
||||
)
|
||||
return jsonify(error_payload), 500
|
||||
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
__all__ = [
|
||||
"register",
|
||||
"blueprint",
|
||||
"heartbeat",
|
||||
"script_request",
|
||||
"save_details",
|
||||
"require_device_auth",
|
||||
]
|
||||
@@ -1,23 +0,0 @@
|
||||
"""Agent HTTP interface placeholders for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
|
||||
blueprint = Blueprint("engine_agents", __name__, url_prefix="/api/agents")
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach agent management routes to *app*.
|
||||
|
||||
Implementation will be populated as services migrate from the legacy server.
|
||||
"""
|
||||
|
||||
if "engine_agents" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
@@ -1,182 +0,0 @@
|
||||
"""HTTP endpoints for assembly management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
blueprint = Blueprint("engine_assemblies", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
if "engine_assemblies" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
def _services() -> EngineServiceContainer:
|
||||
services = current_app.extensions.get("engine_services")
|
||||
if services is None: # pragma: no cover - defensive
|
||||
raise RuntimeError("engine services not initialized")
|
||||
return services
|
||||
|
||||
|
||||
def _assembly_service():
|
||||
return _services().assembly_service
|
||||
|
||||
|
||||
def _value_error_response(exc: ValueError):
|
||||
code = str(exc)
|
||||
if code == "invalid_island":
|
||||
return jsonify({"error": "invalid island"}), 400
|
||||
if code == "path_required":
|
||||
return jsonify({"error": "path required"}), 400
|
||||
if code == "invalid_kind":
|
||||
return jsonify({"error": "invalid kind"}), 400
|
||||
if code == "invalid_destination":
|
||||
return jsonify({"error": "invalid destination"}), 400
|
||||
if code == "invalid_path":
|
||||
return jsonify({"error": "invalid path"}), 400
|
||||
if code == "cannot_delete_root":
|
||||
return jsonify({"error": "cannot delete root"}), 400
|
||||
return jsonify({"error": code or "invalid request"}), 400
|
||||
|
||||
|
||||
def _not_found_response(exc: FileNotFoundError):
|
||||
code = str(exc)
|
||||
if code == "file_not_found":
|
||||
return jsonify({"error": "file not found"}), 404
|
||||
if code == "folder_not_found":
|
||||
return jsonify({"error": "folder not found"}), 404
|
||||
return jsonify({"error": "not found"}), 404
|
||||
|
||||
|
||||
@blueprint.route("/api/assembly/list", methods=["GET"])
|
||||
def list_assemblies() -> object:
|
||||
island = (request.args.get("island") or "").strip()
|
||||
try:
|
||||
listing = _assembly_service().list_items(island)
|
||||
except ValueError as exc:
|
||||
return _value_error_response(exc)
|
||||
return jsonify(listing.to_dict())
|
||||
|
||||
|
||||
@blueprint.route("/api/assembly/load", methods=["GET"])
|
||||
def load_assembly() -> object:
|
||||
island = (request.args.get("island") or "").strip()
|
||||
rel_path = (request.args.get("path") or "").strip()
|
||||
try:
|
||||
result = _assembly_service().load_item(island, rel_path)
|
||||
except ValueError as exc:
|
||||
return _value_error_response(exc)
|
||||
except FileNotFoundError as exc:
|
||||
return _not_found_response(exc)
|
||||
return jsonify(result.to_dict())
|
||||
|
||||
|
||||
@blueprint.route("/api/assembly/create", methods=["POST"])
|
||||
def create_assembly() -> object:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
island = (payload.get("island") or "").strip()
|
||||
kind = (payload.get("kind") or "").strip().lower()
|
||||
rel_path = (payload.get("path") or "").strip()
|
||||
content = payload.get("content")
|
||||
item_type = payload.get("type")
|
||||
try:
|
||||
result = _assembly_service().create_item(
|
||||
island,
|
||||
kind=kind,
|
||||
rel_path=rel_path,
|
||||
content=content,
|
||||
item_type=item_type if isinstance(item_type, str) else None,
|
||||
)
|
||||
except ValueError as exc:
|
||||
return _value_error_response(exc)
|
||||
return jsonify(result.to_dict())
|
||||
|
||||
|
||||
@blueprint.route("/api/assembly/edit", methods=["POST"])
|
||||
def edit_assembly() -> object:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
island = (payload.get("island") or "").strip()
|
||||
rel_path = (payload.get("path") or "").strip()
|
||||
content = payload.get("content")
|
||||
item_type = payload.get("type")
|
||||
try:
|
||||
result = _assembly_service().edit_item(
|
||||
island,
|
||||
rel_path=rel_path,
|
||||
content=content,
|
||||
item_type=item_type if isinstance(item_type, str) else None,
|
||||
)
|
||||
except ValueError as exc:
|
||||
return _value_error_response(exc)
|
||||
except FileNotFoundError as exc:
|
||||
return _not_found_response(exc)
|
||||
return jsonify(result.to_dict())
|
||||
|
||||
|
||||
@blueprint.route("/api/assembly/rename", methods=["POST"])
|
||||
def rename_assembly() -> object:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
island = (payload.get("island") or "").strip()
|
||||
kind = (payload.get("kind") or "").strip().lower()
|
||||
rel_path = (payload.get("path") or "").strip()
|
||||
new_name = (payload.get("new_name") or "").strip()
|
||||
item_type = payload.get("type")
|
||||
try:
|
||||
result = _assembly_service().rename_item(
|
||||
island,
|
||||
kind=kind,
|
||||
rel_path=rel_path,
|
||||
new_name=new_name,
|
||||
item_type=item_type if isinstance(item_type, str) else None,
|
||||
)
|
||||
except ValueError as exc:
|
||||
return _value_error_response(exc)
|
||||
except FileNotFoundError as exc:
|
||||
return _not_found_response(exc)
|
||||
return jsonify(result.to_dict())
|
||||
|
||||
|
||||
@blueprint.route("/api/assembly/move", methods=["POST"])
|
||||
def move_assembly() -> object:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
island = (payload.get("island") or "").strip()
|
||||
rel_path = (payload.get("path") or "").strip()
|
||||
new_path = (payload.get("new_path") or "").strip()
|
||||
kind = (payload.get("kind") or "").strip().lower()
|
||||
try:
|
||||
result = _assembly_service().move_item(
|
||||
island,
|
||||
rel_path=rel_path,
|
||||
new_path=new_path,
|
||||
kind=kind,
|
||||
)
|
||||
except ValueError as exc:
|
||||
return _value_error_response(exc)
|
||||
except FileNotFoundError as exc:
|
||||
return _not_found_response(exc)
|
||||
return jsonify(result.to_dict())
|
||||
|
||||
|
||||
@blueprint.route("/api/assembly/delete", methods=["POST"])
|
||||
def delete_assembly() -> object:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
island = (payload.get("island") or "").strip()
|
||||
rel_path = (payload.get("path") or "").strip()
|
||||
kind = (payload.get("kind") or "").strip().lower()
|
||||
try:
|
||||
result = _assembly_service().delete_item(
|
||||
island,
|
||||
rel_path=rel_path,
|
||||
kind=kind,
|
||||
)
|
||||
except ValueError as exc:
|
||||
return _value_error_response(exc)
|
||||
except FileNotFoundError as exc:
|
||||
return _not_found_response(exc)
|
||||
return jsonify(result.to_dict())
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
@@ -1,195 +0,0 @@
|
||||
"""Operator authentication HTTP endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request, session
|
||||
|
||||
from Data.Engine.builders import build_login_request, build_mfa_request
|
||||
from Data.Engine.domain import OperatorLoginSuccess, OperatorMFAChallenge
|
||||
from Data.Engine.services.auth import (
|
||||
InvalidCredentialsError,
|
||||
InvalidMFACodeError,
|
||||
MFAUnavailableError,
|
||||
MFASessionError,
|
||||
OperatorAuthService,
|
||||
)
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
|
||||
def _service(container: EngineServiceContainer) -> OperatorAuthService:
|
||||
return container.operator_auth_service
|
||||
|
||||
|
||||
def register(app: Flask, services: EngineServiceContainer) -> None:
|
||||
bp = Blueprint("auth", __name__)
|
||||
|
||||
@bp.route("/api/auth/login", methods=["POST"])
|
||||
def login() -> Any:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
try:
|
||||
login_request = build_login_request(payload)
|
||||
except ValueError as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
|
||||
service = _service(services)
|
||||
|
||||
try:
|
||||
result = service.authenticate(login_request)
|
||||
except InvalidCredentialsError:
|
||||
return jsonify({"error": "invalid username or password"}), 401
|
||||
except MFAUnavailableError as exc:
|
||||
current_app.logger.error("mfa unavailable: %s", exc)
|
||||
return jsonify({"error": str(exc)}), 500
|
||||
|
||||
session.pop("username", None)
|
||||
session.pop("role", None)
|
||||
|
||||
if isinstance(result, OperatorLoginSuccess):
|
||||
session.pop("mfa_pending", None)
|
||||
session["username"] = result.username
|
||||
session["role"] = result.role or "User"
|
||||
response = jsonify(
|
||||
{"status": "ok", "username": result.username, "role": result.role, "token": result.token}
|
||||
)
|
||||
_set_auth_cookie(response, result.token)
|
||||
return response
|
||||
|
||||
challenge = result
|
||||
session["mfa_pending"] = {
|
||||
"username": challenge.username,
|
||||
"role": challenge.role,
|
||||
"stage": challenge.stage,
|
||||
"token": challenge.pending_token,
|
||||
"expires": challenge.expires_at,
|
||||
"secret": challenge.secret,
|
||||
}
|
||||
session.modified = True
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"status": "mfa_required",
|
||||
"stage": challenge.stage,
|
||||
"pending_token": challenge.pending_token,
|
||||
"username": challenge.username,
|
||||
"role": challenge.role,
|
||||
}
|
||||
if challenge.stage == "setup":
|
||||
if challenge.secret:
|
||||
payload["secret"] = challenge.secret
|
||||
if challenge.otpauth_url:
|
||||
payload["otpauth_url"] = challenge.otpauth_url
|
||||
if challenge.qr_image:
|
||||
payload["qr_image"] = challenge.qr_image
|
||||
return jsonify(payload)
|
||||
|
||||
@bp.route("/api/auth/logout", methods=["POST"])
|
||||
def logout() -> Any:
|
||||
session.clear()
|
||||
response = jsonify({"status": "ok"})
|
||||
_set_auth_cookie(response, "", expires=0)
|
||||
return response
|
||||
|
||||
@bp.route("/api/auth/me", methods=["GET"])
|
||||
def me() -> Any:
|
||||
service = _service(services)
|
||||
|
||||
account = None
|
||||
username = session.get("username")
|
||||
if isinstance(username, str) and username:
|
||||
account = service.fetch_account(username)
|
||||
|
||||
if account is None:
|
||||
token = request.cookies.get("borealis_auth", "")
|
||||
if not token:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(None, 1)[1]
|
||||
account = service.resolve_token(token)
|
||||
if account is not None:
|
||||
session["username"] = account.username
|
||||
session["role"] = account.role or "User"
|
||||
|
||||
if account is None:
|
||||
return jsonify({"error": "not_authenticated"}), 401
|
||||
|
||||
payload = {
|
||||
"username": account.username,
|
||||
"display_name": account.display_name or account.username,
|
||||
"role": account.role,
|
||||
}
|
||||
return jsonify(payload)
|
||||
|
||||
@bp.route("/api/auth/mfa/verify", methods=["POST"])
|
||||
def verify_mfa() -> Any:
|
||||
pending = session.get("mfa_pending")
|
||||
if not isinstance(pending, dict):
|
||||
return jsonify({"error": "mfa_pending"}), 401
|
||||
|
||||
try:
|
||||
request_payload = build_mfa_request(request.get_json(silent=True) or {})
|
||||
except ValueError as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
|
||||
challenge = OperatorMFAChallenge(
|
||||
username=str(pending.get("username") or ""),
|
||||
role=str(pending.get("role") or "User"),
|
||||
stage=str(pending.get("stage") or "verify"),
|
||||
pending_token=str(pending.get("token") or ""),
|
||||
expires_at=int(pending.get("expires") or 0),
|
||||
secret=str(pending.get("secret") or "") or None,
|
||||
)
|
||||
|
||||
service = _service(services)
|
||||
|
||||
try:
|
||||
result = service.verify_mfa(challenge, request_payload)
|
||||
except MFASessionError as exc:
|
||||
error_key = str(exc)
|
||||
status = 401 if error_key != "mfa_not_configured" else 403
|
||||
if error_key not in {"expired", "invalid_session", "mfa_not_configured"}:
|
||||
error_key = "invalid_session"
|
||||
session.pop("mfa_pending", None)
|
||||
return jsonify({"error": error_key}), status
|
||||
except InvalidMFACodeError as exc:
|
||||
return jsonify({"error": str(exc) or "invalid_code"}), 401
|
||||
except MFAUnavailableError as exc:
|
||||
current_app.logger.error("mfa unavailable: %s", exc)
|
||||
return jsonify({"error": str(exc)}), 500
|
||||
except InvalidCredentialsError:
|
||||
session.pop("mfa_pending", None)
|
||||
return jsonify({"error": "invalid username or password"}), 401
|
||||
|
||||
session.pop("mfa_pending", None)
|
||||
session["username"] = result.username
|
||||
session["role"] = result.role or "User"
|
||||
payload = {
|
||||
"status": "ok",
|
||||
"username": result.username,
|
||||
"role": result.role,
|
||||
"token": result.token,
|
||||
}
|
||||
response = jsonify(payload)
|
||||
_set_auth_cookie(response, result.token)
|
||||
return response
|
||||
|
||||
app.register_blueprint(bp)
|
||||
|
||||
|
||||
def _set_auth_cookie(response, value: str, *, expires: int | None = None) -> None:
|
||||
same_site = current_app.config.get("SESSION_COOKIE_SAMESITE", "Lax")
|
||||
secure = bool(current_app.config.get("SESSION_COOKIE_SECURE", False))
|
||||
domain = current_app.config.get("SESSION_COOKIE_DOMAIN", None)
|
||||
response.set_cookie(
|
||||
"borealis_auth",
|
||||
value,
|
||||
httponly=False,
|
||||
samesite=same_site,
|
||||
secure=secure,
|
||||
domain=domain,
|
||||
path="/",
|
||||
expires=expires,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["register"]
|
||||
@@ -1,70 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request, session
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
blueprint = Blueprint("engine_credentials", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
if "engine_credentials" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
def _services() -> EngineServiceContainer:
|
||||
services = current_app.extensions.get("engine_services")
|
||||
if services is None: # pragma: no cover - defensive
|
||||
raise RuntimeError("engine services not initialized")
|
||||
return services
|
||||
|
||||
|
||||
def _credentials_service():
|
||||
return _services().credential_service
|
||||
|
||||
|
||||
def _require_admin():
|
||||
username = session.get("username")
|
||||
role = (session.get("role") or "").strip().lower()
|
||||
if not isinstance(username, str) or not username:
|
||||
return jsonify({"error": "not_authenticated"}), 401
|
||||
if role != "admin":
|
||||
return jsonify({"error": "forbidden"}), 403
|
||||
return None
|
||||
|
||||
|
||||
@blueprint.route("/api/credentials", methods=["GET"])
|
||||
def list_credentials() -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
site_id_param = request.args.get("site_id")
|
||||
connection_type = (request.args.get("connection_type") or "").strip() or None
|
||||
try:
|
||||
site_id = int(site_id_param) if site_id_param not in (None, "") else None
|
||||
except (TypeError, ValueError):
|
||||
site_id = None
|
||||
|
||||
records = _credentials_service().list_credentials(
|
||||
site_id=site_id,
|
||||
connection_type=connection_type,
|
||||
)
|
||||
return jsonify({"credentials": records})
|
||||
|
||||
|
||||
@blueprint.route("/api/credentials", methods=["POST"])
|
||||
def create_credential() -> object: # pragma: no cover - placeholder
|
||||
return jsonify({"error": "not implemented"}), 501
|
||||
|
||||
|
||||
@blueprint.route("/api/credentials/<int:credential_id>", methods=["GET", "PUT", "DELETE"])
|
||||
def credential_detail(credential_id: int) -> object: # pragma: no cover - placeholder
|
||||
if request.method == "GET":
|
||||
return jsonify({"error": "not implemented"}), 501
|
||||
if request.method == "DELETE":
|
||||
return jsonify({"error": "not implemented"}), 501
|
||||
return jsonify({"error": "not implemented"}), 501
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
@@ -1,325 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import ip_address
|
||||
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request, session
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
from Data.Engine.services.devices import DeviceDescriptionError, RemoteDeviceError
|
||||
|
||||
blueprint = Blueprint("engine_devices", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
if "engine_devices" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
def _services() -> EngineServiceContainer:
|
||||
services = current_app.extensions.get("engine_services")
|
||||
if services is None: # pragma: no cover - defensive
|
||||
raise RuntimeError("engine services not initialized")
|
||||
return services
|
||||
|
||||
|
||||
def _inventory():
|
||||
return _services().device_inventory
|
||||
|
||||
|
||||
def _views():
|
||||
return _services().device_view_service
|
||||
|
||||
|
||||
def _require_admin():
|
||||
username = session.get("username")
|
||||
role = (session.get("role") or "").strip().lower()
|
||||
if not isinstance(username, str) or not username:
|
||||
return jsonify({"error": "not_authenticated"}), 401
|
||||
if role != "admin":
|
||||
return jsonify({"error": "forbidden"}), 403
|
||||
return None
|
||||
|
||||
|
||||
def _is_internal_request(req: request) -> bool:
|
||||
remote = (req.remote_addr or "").strip()
|
||||
if not remote:
|
||||
return False
|
||||
try:
|
||||
return ip_address(remote).is_loopback
|
||||
except ValueError:
|
||||
return remote in {"localhost"}
|
||||
|
||||
|
||||
@blueprint.route("/api/devices", methods=["GET"])
|
||||
def list_devices() -> object:
|
||||
devices = _inventory().list_devices()
|
||||
return jsonify({"devices": devices})
|
||||
|
||||
|
||||
@blueprint.route("/api/devices/<guid>", methods=["GET"])
|
||||
def get_device_by_guid(guid: str) -> object:
|
||||
device = _inventory().get_device_by_guid(guid)
|
||||
if not device:
|
||||
return jsonify({"error": "not found"}), 404
|
||||
return jsonify(device)
|
||||
|
||||
|
||||
@blueprint.route("/api/device/details/<hostname>", methods=["GET"])
|
||||
def get_device_details(hostname: str) -> object:
|
||||
payload = _inventory().get_device_details(hostname)
|
||||
return jsonify(payload)
|
||||
|
||||
|
||||
@blueprint.route("/api/device/description/<hostname>", methods=["POST"])
|
||||
def set_device_description(hostname: str) -> object:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
description = payload.get("description")
|
||||
try:
|
||||
_inventory().update_device_description(hostname, description)
|
||||
except DeviceDescriptionError as exc:
|
||||
if exc.code == "invalid_hostname":
|
||||
return jsonify({"error": "invalid hostname"}), 400
|
||||
if exc.code == "not_found":
|
||||
return jsonify({"error": "not found"}), 404
|
||||
current_app.logger.exception(
|
||||
"device-description-error host=%s code=%s", hostname, exc.code
|
||||
)
|
||||
return jsonify({"error": "internal error"}), 500
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
@blueprint.route("/api/agent_devices", methods=["GET"])
|
||||
def list_agent_devices() -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
devices = _inventory().list_agent_devices()
|
||||
return jsonify({"devices": devices})
|
||||
|
||||
|
||||
@blueprint.route("/api/ssh_devices", methods=["GET", "POST"])
|
||||
def ssh_devices() -> object:
|
||||
return _remote_devices_endpoint("ssh")
|
||||
|
||||
|
||||
@blueprint.route("/api/winrm_devices", methods=["GET", "POST"])
|
||||
def winrm_devices() -> object:
|
||||
return _remote_devices_endpoint("winrm")
|
||||
|
||||
|
||||
@blueprint.route("/api/ssh_devices/<hostname>", methods=["PUT", "DELETE"])
|
||||
def ssh_device_detail(hostname: str) -> object:
|
||||
return _remote_device_detail("ssh", hostname)
|
||||
|
||||
|
||||
@blueprint.route("/api/winrm_devices/<hostname>", methods=["PUT", "DELETE"])
|
||||
def winrm_device_detail(hostname: str) -> object:
|
||||
return _remote_device_detail("winrm", hostname)
|
||||
|
||||
|
||||
@blueprint.route("/api/agent/hash_list", methods=["GET"])
|
||||
def agent_hash_list() -> object:
|
||||
if not _is_internal_request(request):
|
||||
remote_addr = (request.remote_addr or "unknown").strip() or "unknown"
|
||||
current_app.logger.warning(
|
||||
"/api/agent/hash_list denied non-local request from %s", remote_addr
|
||||
)
|
||||
return jsonify({"error": "forbidden"}), 403
|
||||
try:
|
||||
records = _inventory().collect_agent_hash_records()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
current_app.logger.exception("/api/agent/hash_list error: %s", exc)
|
||||
return jsonify({"error": "internal error"}), 500
|
||||
return jsonify({"agents": records})
|
||||
|
||||
|
||||
@blueprint.route("/api/device_list_views", methods=["GET"])
|
||||
def list_device_list_views() -> object:
|
||||
views = _views().list_views()
|
||||
return jsonify({"views": [view.to_dict() for view in views]})
|
||||
|
||||
|
||||
@blueprint.route("/api/device_list_views/<int:view_id>", methods=["GET"])
|
||||
def get_device_list_view(view_id: int) -> object:
|
||||
view = _views().get_view(view_id)
|
||||
if not view:
|
||||
return jsonify({"error": "not found"}), 404
|
||||
return jsonify(view.to_dict())
|
||||
|
||||
|
||||
@blueprint.route("/api/device_list_views", methods=["POST"])
|
||||
def create_device_list_view() -> object:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
name = (payload.get("name") or "").strip()
|
||||
columns = payload.get("columns") or []
|
||||
filters = payload.get("filters") or {}
|
||||
|
||||
if not name:
|
||||
return jsonify({"error": "name is required"}), 400
|
||||
if name.lower() == "default view":
|
||||
return jsonify({"error": "reserved name"}), 400
|
||||
if not isinstance(columns, list) or not all(isinstance(x, str) for x in columns):
|
||||
return jsonify({"error": "columns must be a list of strings"}), 400
|
||||
if not isinstance(filters, dict):
|
||||
return jsonify({"error": "filters must be an object"}), 400
|
||||
|
||||
try:
|
||||
view = _views().create_view(name, columns, filters)
|
||||
except ValueError as exc:
|
||||
if str(exc) == "duplicate":
|
||||
return jsonify({"error": "name already exists"}), 409
|
||||
raise
|
||||
response = jsonify(view.to_dict())
|
||||
response.status_code = 201
|
||||
return response
|
||||
|
||||
|
||||
@blueprint.route("/api/device_list_views/<int:view_id>", methods=["PUT"])
|
||||
def update_device_list_view(view_id: int) -> object:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
updates: dict = {}
|
||||
if "name" in payload:
|
||||
name_val = payload.get("name")
|
||||
if name_val is None:
|
||||
return jsonify({"error": "name cannot be empty"}), 400
|
||||
normalized = (str(name_val) or "").strip()
|
||||
if not normalized:
|
||||
return jsonify({"error": "name cannot be empty"}), 400
|
||||
if normalized.lower() == "default view":
|
||||
return jsonify({"error": "reserved name"}), 400
|
||||
updates["name"] = normalized
|
||||
if "columns" in payload:
|
||||
columns_val = payload.get("columns")
|
||||
if not isinstance(columns_val, list) or not all(isinstance(x, str) for x in columns_val):
|
||||
return jsonify({"error": "columns must be a list of strings"}), 400
|
||||
updates["columns"] = columns_val
|
||||
if "filters" in payload:
|
||||
filters_val = payload.get("filters")
|
||||
if filters_val is not None and not isinstance(filters_val, dict):
|
||||
return jsonify({"error": "filters must be an object"}), 400
|
||||
if filters_val is not None:
|
||||
updates["filters"] = filters_val
|
||||
if not updates:
|
||||
return jsonify({"error": "no fields to update"}), 400
|
||||
|
||||
try:
|
||||
view = _views().update_view(
|
||||
view_id,
|
||||
name=updates.get("name"),
|
||||
columns=updates.get("columns"),
|
||||
filters=updates.get("filters"),
|
||||
)
|
||||
except ValueError as exc:
|
||||
code = str(exc)
|
||||
if code == "duplicate":
|
||||
return jsonify({"error": "name already exists"}), 409
|
||||
if code == "missing_name":
|
||||
return jsonify({"error": "name cannot be empty"}), 400
|
||||
if code == "reserved":
|
||||
return jsonify({"error": "reserved name"}), 400
|
||||
return jsonify({"error": "invalid payload"}), 400
|
||||
except LookupError:
|
||||
return jsonify({"error": "not found"}), 404
|
||||
return jsonify(view.to_dict())
|
||||
|
||||
|
||||
@blueprint.route("/api/device_list_views/<int:view_id>", methods=["DELETE"])
|
||||
def delete_device_list_view(view_id: int) -> object:
|
||||
if not _views().delete_view(view_id):
|
||||
return jsonify({"error": "not found"}), 404
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
def _remote_devices_endpoint(connection_type: str) -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
if request.method == "GET":
|
||||
devices = _inventory().list_remote_devices(connection_type)
|
||||
return jsonify({"devices": devices})
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
hostname = (payload.get("hostname") or "").strip()
|
||||
address = (
|
||||
payload.get("address")
|
||||
or payload.get("connection_endpoint")
|
||||
or payload.get("endpoint")
|
||||
or payload.get("host")
|
||||
)
|
||||
description = payload.get("description")
|
||||
os_hint = payload.get("operating_system") or payload.get("os")
|
||||
|
||||
if not hostname:
|
||||
return jsonify({"error": "hostname is required"}), 400
|
||||
if not (address or "").strip():
|
||||
return jsonify({"error": "address is required"}), 400
|
||||
|
||||
try:
|
||||
device = _inventory().upsert_remote_device(
|
||||
connection_type,
|
||||
hostname,
|
||||
address,
|
||||
description,
|
||||
os_hint,
|
||||
ensure_existing_type=None,
|
||||
)
|
||||
except RemoteDeviceError as exc:
|
||||
status = 409 if exc.code in {"conflict", "address_required"} else 500
|
||||
if exc.code == "conflict":
|
||||
return jsonify({"error": str(exc)}), 409
|
||||
if exc.code == "address_required":
|
||||
return jsonify({"error": "address is required"}), 400
|
||||
return jsonify({"error": str(exc)}), status
|
||||
return jsonify({"device": device}), 201
|
||||
|
||||
|
||||
def _remote_device_detail(connection_type: str, hostname: str) -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
normalized_host = (hostname or "").strip()
|
||||
if not normalized_host:
|
||||
return jsonify({"error": "invalid hostname"}), 400
|
||||
|
||||
if request.method == "DELETE":
|
||||
try:
|
||||
_inventory().delete_remote_device(connection_type, normalized_host)
|
||||
except RemoteDeviceError as exc:
|
||||
if exc.code == "not_found":
|
||||
return jsonify({"error": "device not found"}), 404
|
||||
if exc.code == "invalid_hostname":
|
||||
return jsonify({"error": "invalid hostname"}), 400
|
||||
return jsonify({"error": str(exc)}), 500
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
address = (
|
||||
payload.get("address")
|
||||
or payload.get("connection_endpoint")
|
||||
or payload.get("endpoint")
|
||||
)
|
||||
description = payload.get("description")
|
||||
os_hint = payload.get("operating_system") or payload.get("os")
|
||||
|
||||
if address is None and description is None and os_hint is None:
|
||||
return jsonify({"error": "no fields to update"}), 400
|
||||
|
||||
try:
|
||||
device = _inventory().upsert_remote_device(
|
||||
connection_type,
|
||||
normalized_host,
|
||||
address if address is not None else "",
|
||||
description,
|
||||
os_hint,
|
||||
ensure_existing_type=connection_type,
|
||||
)
|
||||
except RemoteDeviceError as exc:
|
||||
if exc.code == "not_found":
|
||||
return jsonify({"error": "device not found"}), 404
|
||||
if exc.code == "address_required":
|
||||
return jsonify({"error": "address is required"}), 400
|
||||
return jsonify({"error": str(exc)}), 500
|
||||
return jsonify({"device": device})
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
@@ -1,111 +0,0 @@
|
||||
"""Enrollment HTTP interface placeholders for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request
|
||||
|
||||
from Data.Engine.builders.device_enrollment import EnrollmentRequestBuilder
|
||||
from Data.Engine.services import EnrollmentValidationError
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||
|
||||
|
||||
blueprint = Blueprint("engine_enrollment", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach enrollment routes to *app*."""
|
||||
|
||||
if "engine_enrollment" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
|
||||
def enrollment_request() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
payload = request.get_json(force=True, silent=True)
|
||||
builder = EnrollmentRequestBuilder().with_payload(payload).with_service_context(
|
||||
request.headers.get(AGENT_CONTEXT_HEADER)
|
||||
)
|
||||
try:
|
||||
normalized = builder.build()
|
||||
result = services.enrollment_service.request_enrollment(
|
||||
normalized,
|
||||
remote_addr=_remote_addr(),
|
||||
)
|
||||
except EnrollmentValidationError as exc:
|
||||
response = jsonify(exc.to_response())
|
||||
response.status_code = exc.http_status
|
||||
if exc.retry_after is not None:
|
||||
response.headers["Retry-After"] = f"{int(exc.retry_after)}"
|
||||
return response
|
||||
|
||||
response_payload = {
|
||||
"status": result.status,
|
||||
"approval_reference": result.approval_reference,
|
||||
"server_nonce": result.server_nonce,
|
||||
"poll_after_ms": result.poll_after_ms,
|
||||
"server_certificate": result.server_certificate,
|
||||
"signing_key": result.signing_key,
|
||||
}
|
||||
response = jsonify(response_payload)
|
||||
response.status_code = result.http_status
|
||||
if result.retry_after is not None:
|
||||
response.headers["Retry-After"] = f"{int(result.retry_after)}"
|
||||
return response
|
||||
|
||||
|
||||
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
|
||||
def enrollment_poll() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
approval_reference = str(payload.get("approval_reference") or "").strip()
|
||||
client_nonce = str(payload.get("client_nonce") or "").strip()
|
||||
proof_sig = str(payload.get("proof_sig") or "").strip()
|
||||
|
||||
try:
|
||||
result = services.enrollment_service.poll_enrollment(
|
||||
approval_reference=approval_reference,
|
||||
client_nonce_b64=client_nonce,
|
||||
proof_signature_b64=proof_sig,
|
||||
)
|
||||
except EnrollmentValidationError as exc:
|
||||
return jsonify(exc.to_response()), exc.http_status
|
||||
|
||||
body = {"status": result.status}
|
||||
if result.poll_after_ms is not None:
|
||||
body["poll_after_ms"] = result.poll_after_ms
|
||||
if result.reason:
|
||||
body["reason"] = result.reason
|
||||
if result.detail:
|
||||
body["detail"] = result.detail
|
||||
if result.tokens:
|
||||
body.update(
|
||||
{
|
||||
"guid": result.tokens.guid.value,
|
||||
"access_token": result.tokens.access_token,
|
||||
"refresh_token": result.tokens.refresh_token,
|
||||
"token_type": result.tokens.token_type,
|
||||
"expires_in": result.tokens.expires_in,
|
||||
"server_certificate": result.server_certificate or "",
|
||||
"signing_key": result.signing_key or "",
|
||||
}
|
||||
)
|
||||
else:
|
||||
if result.server_certificate:
|
||||
body["server_certificate"] = result.server_certificate
|
||||
if result.signing_key:
|
||||
body["signing_key"] = result.signing_key
|
||||
|
||||
return jsonify(body), result.http_status
|
||||
|
||||
|
||||
def _remote_addr() -> str:
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return (request.remote_addr or "unknown").strip()
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint", "enrollment_request", "enrollment_poll"]
|
||||
@@ -1,60 +0,0 @@
|
||||
"""GitHub-related HTTP endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
blueprint = Blueprint("engine_github", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
if "engine_github" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@blueprint.route("/api/repo/current_hash", methods=["GET"])
|
||||
def repo_current_hash() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
github = services.github_service
|
||||
|
||||
repo = (request.args.get("repo") or "").strip() or None
|
||||
branch = (request.args.get("branch") or "").strip() or None
|
||||
refresh_flag = (request.args.get("refresh") or "").strip().lower()
|
||||
ttl_raw = request.args.get("ttl")
|
||||
try:
|
||||
ttl = int(ttl_raw) if ttl_raw else github.default_refresh_interval
|
||||
except ValueError:
|
||||
ttl = github.default_refresh_interval
|
||||
force_refresh = refresh_flag in {"1", "true", "yes", "force", "refresh"}
|
||||
|
||||
snapshot = github.get_repo_head(repo, branch, ttl_seconds=ttl, force_refresh=force_refresh)
|
||||
payload = snapshot.to_dict()
|
||||
if not snapshot.sha:
|
||||
return jsonify(payload), 503
|
||||
return jsonify(payload)
|
||||
|
||||
|
||||
@blueprint.route("/api/github/token", methods=["GET", "POST"])
|
||||
def github_token() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
github = services.github_service
|
||||
|
||||
if request.method == "GET":
|
||||
payload = github.get_token_status(force_refresh=True).to_dict()
|
||||
return jsonify(payload)
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
token = data.get("token")
|
||||
normalized = str(token).strip() if token is not None else ""
|
||||
try:
|
||||
payload = github.update_token(normalized).to_dict()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
current_app.logger.exception("failed to store GitHub token: %s", exc)
|
||||
return jsonify({"error": f"Failed to store token: {exc}"}), 500
|
||||
return jsonify(payload)
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint", "repo_current_hash", "github_token"]
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Health check HTTP interface for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, jsonify
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
blueprint = Blueprint("engine_health", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach health-related routes to *app*."""
|
||||
|
||||
if "engine_health" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@blueprint.route("/health", methods=["GET"])
|
||||
def health() -> object:
|
||||
"""Return a basic liveness response."""
|
||||
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
@@ -1,108 +0,0 @@
|
||||
"""HTTP routes for Engine job management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Blueprint, Flask, jsonify, request
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
bp = Blueprint("engine_job_management", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, services: EngineServiceContainer) -> None:
|
||||
bp.services = services # type: ignore[attr-defined]
|
||||
app.register_blueprint(bp)
|
||||
|
||||
|
||||
def _services() -> EngineServiceContainer:
|
||||
svc = getattr(bp, "services", None)
|
||||
if svc is None: # pragma: no cover - guard
|
||||
raise RuntimeError("job management blueprint not initialized")
|
||||
return svc
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs", methods=["GET"])
|
||||
def list_jobs() -> Any:
|
||||
jobs = _services().scheduler_service.list_jobs()
|
||||
return jsonify({"jobs": jobs})
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs", methods=["POST"])
|
||||
def create_job() -> Any:
|
||||
payload = _json_body()
|
||||
try:
|
||||
job = _services().scheduler_service.create_job(payload)
|
||||
except ValueError as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
return jsonify({"job": job})
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs/<int:job_id>", methods=["GET"])
|
||||
def get_job(job_id: int) -> Any:
|
||||
job = _services().scheduler_service.get_job(job_id)
|
||||
if not job:
|
||||
return jsonify({"error": "job not found"}), 404
|
||||
return jsonify({"job": job})
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs/<int:job_id>", methods=["PUT"])
|
||||
def update_job(job_id: int) -> Any:
|
||||
payload = _json_body()
|
||||
try:
|
||||
job = _services().scheduler_service.update_job(job_id, payload)
|
||||
except ValueError as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
if not job:
|
||||
return jsonify({"error": "job not found"}), 404
|
||||
return jsonify({"job": job})
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs/<int:job_id>", methods=["DELETE"])
|
||||
def delete_job(job_id: int) -> Any:
|
||||
_services().scheduler_service.delete_job(job_id)
|
||||
return ("", 204)
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs/<int:job_id>/toggle", methods=["POST"])
|
||||
def toggle_job(job_id: int) -> Any:
|
||||
payload = _json_body()
|
||||
enabled = bool(payload.get("enabled", True))
|
||||
_services().scheduler_service.toggle_job(job_id, enabled)
|
||||
job = _services().scheduler_service.get_job(job_id)
|
||||
if not job:
|
||||
return jsonify({"error": "job not found"}), 404
|
||||
return jsonify({"job": job})
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs/<int:job_id>/runs", methods=["GET"])
|
||||
def list_runs(job_id: int) -> Any:
|
||||
days = request.args.get("days")
|
||||
days_int: Optional[int] = None
|
||||
if days is not None:
|
||||
try:
|
||||
days_int = max(0, int(days))
|
||||
except Exception:
|
||||
return jsonify({"error": "invalid days parameter"}), 400
|
||||
runs = _services().scheduler_service.list_runs(job_id, days=days_int)
|
||||
return jsonify({"runs": runs})
|
||||
|
||||
|
||||
@bp.route("/api/scheduled_jobs/<int:job_id>/runs", methods=["DELETE"])
|
||||
def purge_runs(job_id: int) -> Any:
|
||||
_services().scheduler_service.purge_runs(job_id)
|
||||
return ("", 204)
|
||||
|
||||
|
||||
def _json_body() -> dict[str, Any]:
|
||||
if not request.data:
|
||||
return {}
|
||||
try:
|
||||
data = request.get_json(force=True, silent=False) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
__all__ = ["register"]
|
||||
@@ -1,53 +0,0 @@
|
||||
"""Server metadata endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from flask import Blueprint, Flask, jsonify
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
blueprint = Blueprint("engine_server_info", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
if "engine_server_info" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@blueprint.route("/api/server/time", methods=["GET"])
|
||||
def server_time() -> object:
|
||||
now_local = datetime.now().astimezone()
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
tzinfo = now_local.tzinfo
|
||||
offset = tzinfo.utcoffset(now_local) if tzinfo else None
|
||||
|
||||
def _ordinal(n: int) -> str:
|
||||
if 11 <= (n % 100) <= 13:
|
||||
suffix = "th"
|
||||
else:
|
||||
suffix = {1: "st", 2: "nd", 3: "rd"}.get(n % 10, "th")
|
||||
return f"{n}{suffix}"
|
||||
|
||||
month = now_local.strftime("%B")
|
||||
day_disp = _ordinal(now_local.day)
|
||||
year = now_local.strftime("%Y")
|
||||
hour24 = now_local.hour
|
||||
hour12 = hour24 % 12 or 12
|
||||
minute = now_local.minute
|
||||
ampm = "AM" if hour24 < 12 else "PM"
|
||||
display = f"{month} {day_disp} {year} @ {hour12}:{minute:02d}{ampm}"
|
||||
|
||||
payload = {
|
||||
"epoch": int(now_local.timestamp()),
|
||||
"iso": now_local.isoformat(),
|
||||
"utc_iso": now_utc.isoformat().replace("+00:00", "Z"),
|
||||
"timezone": str(tzinfo) if tzinfo else "",
|
||||
"offset_seconds": int(offset.total_seconds()) if offset else 0,
|
||||
"display": display,
|
||||
}
|
||||
return jsonify(payload)
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
@@ -1,112 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
blueprint = Blueprint("engine_sites", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
if "engine_sites" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
def _services() -> EngineServiceContainer:
|
||||
services = current_app.extensions.get("engine_services")
|
||||
if services is None: # pragma: no cover - defensive
|
||||
raise RuntimeError("engine services not initialized")
|
||||
return services
|
||||
|
||||
|
||||
def _site_service():
|
||||
return _services().site_service
|
||||
|
||||
|
||||
@blueprint.route("/api/sites", methods=["GET"])
|
||||
def list_sites() -> object:
|
||||
records = _site_service().list_sites()
|
||||
return jsonify({"sites": [record.to_dict() for record in records]})
|
||||
|
||||
|
||||
@blueprint.route("/api/sites", methods=["POST"])
|
||||
def create_site() -> object:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
name = payload.get("name")
|
||||
description = payload.get("description")
|
||||
try:
|
||||
record = _site_service().create_site(name or "", description or "")
|
||||
except ValueError as exc:
|
||||
if str(exc) == "missing_name":
|
||||
return jsonify({"error": "name is required"}), 400
|
||||
if str(exc) == "duplicate":
|
||||
return jsonify({"error": "name already exists"}), 409
|
||||
raise
|
||||
response = jsonify(record.to_dict())
|
||||
response.status_code = 201
|
||||
return response
|
||||
|
||||
|
||||
@blueprint.route("/api/sites/delete", methods=["POST"])
|
||||
def delete_sites() -> object:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
ids = payload.get("ids") or []
|
||||
if not isinstance(ids, list):
|
||||
return jsonify({"error": "ids must be a list"}), 400
|
||||
deleted = _site_service().delete_sites(ids)
|
||||
return jsonify({"status": "ok", "deleted": deleted})
|
||||
|
||||
|
||||
@blueprint.route("/api/sites/device_map", methods=["GET"])
|
||||
def sites_device_map() -> object:
|
||||
host_param = (request.args.get("hostnames") or "").strip()
|
||||
filter_set = []
|
||||
if host_param:
|
||||
for part in host_param.split(","):
|
||||
normalized = part.strip()
|
||||
if normalized:
|
||||
filter_set.append(normalized)
|
||||
mapping = _site_service().map_devices(filter_set or None)
|
||||
return jsonify({"mapping": {hostname: entry.to_dict() for hostname, entry in mapping.items()}})
|
||||
|
||||
|
||||
@blueprint.route("/api/sites/assign", methods=["POST"])
|
||||
def assign_devices_to_site() -> object:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
site_id = payload.get("site_id")
|
||||
hostnames = payload.get("hostnames") or []
|
||||
if not isinstance(hostnames, list):
|
||||
return jsonify({"error": "hostnames must be a list of strings"}), 400
|
||||
try:
|
||||
_site_service().assign_devices(site_id, hostnames)
|
||||
except ValueError as exc:
|
||||
message = str(exc)
|
||||
if message == "invalid_site_id":
|
||||
return jsonify({"error": "invalid site_id"}), 400
|
||||
if message == "invalid_hostnames":
|
||||
return jsonify({"error": "hostnames must be a list of strings"}), 400
|
||||
raise
|
||||
except LookupError:
|
||||
return jsonify({"error": "site not found"}), 404
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
@blueprint.route("/api/sites/rename", methods=["POST"])
|
||||
def rename_site() -> object:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
site_id = payload.get("id")
|
||||
new_name = payload.get("new_name") or ""
|
||||
try:
|
||||
record = _site_service().rename_site(site_id, new_name)
|
||||
except ValueError as exc:
|
||||
if str(exc) == "missing_name":
|
||||
return jsonify({"error": "new_name is required"}), 400
|
||||
if str(exc) == "duplicate":
|
||||
return jsonify({"error": "name already exists"}), 409
|
||||
raise
|
||||
except LookupError:
|
||||
return jsonify({"error": "site not found"}), 404
|
||||
return jsonify(record.to_dict())
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Token management HTTP interface for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, current_app, jsonify, request
|
||||
|
||||
from Data.Engine.builders.device_auth import RefreshTokenRequestBuilder
|
||||
from Data.Engine.domain.device_auth import DeviceAuthFailure
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
from Data.Engine.services import TokenRefreshError
|
||||
|
||||
blueprint = Blueprint("engine_tokens", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, _services: EngineServiceContainer) -> None:
|
||||
"""Attach token management routes to *app*."""
|
||||
|
||||
if "engine_tokens" not in app.blueprints:
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@blueprint.route("/api/agent/token/refresh", methods=["POST"])
|
||||
def refresh_token() -> object:
|
||||
services: EngineServiceContainer = current_app.extensions["engine_services"]
|
||||
builder = (
|
||||
RefreshTokenRequestBuilder()
|
||||
.with_payload(request.get_json(force=True, silent=True))
|
||||
.with_http_method(request.method)
|
||||
.with_htu(request.url)
|
||||
.with_dpop_proof(request.headers.get("DPoP"))
|
||||
)
|
||||
try:
|
||||
refresh_request = builder.build()
|
||||
except DeviceAuthFailure as exc:
|
||||
payload = exc.to_dict()
|
||||
return jsonify(payload), exc.http_status
|
||||
|
||||
try:
|
||||
response = services.token_service.refresh_access_token(refresh_request)
|
||||
except TokenRefreshError as exc:
|
||||
return jsonify(exc.to_dict()), exc.http_status
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"access_token": response.access_token,
|
||||
"expires_in": response.expires_in,
|
||||
"token_type": response.token_type,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint", "refresh_token"]
|
||||
@@ -1,185 +0,0 @@
|
||||
"""HTTP endpoints for operator account management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, Flask, jsonify, request, session
|
||||
|
||||
from Data.Engine.services.auth import (
|
||||
AccountNotFoundError,
|
||||
CannotModifySelfError,
|
||||
InvalidPasswordHashError,
|
||||
InvalidRoleError,
|
||||
LastAdminError,
|
||||
LastUserError,
|
||||
OperatorAccountService,
|
||||
UsernameAlreadyExistsError,
|
||||
)
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
blueprint = Blueprint("engine_users", __name__)
|
||||
|
||||
|
||||
def register(app: Flask, services: EngineServiceContainer) -> None:
|
||||
blueprint.services = services # type: ignore[attr-defined]
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
def _services() -> EngineServiceContainer:
|
||||
svc = getattr(blueprint, "services", None)
|
||||
if svc is None: # pragma: no cover - defensive
|
||||
raise RuntimeError("user blueprint not initialized")
|
||||
return svc
|
||||
|
||||
|
||||
def _accounts() -> OperatorAccountService:
|
||||
return _services().operator_account_service
|
||||
|
||||
|
||||
def _require_admin():
|
||||
username = session.get("username")
|
||||
role = (session.get("role") or "").strip().lower()
|
||||
if not isinstance(username, str) or not username:
|
||||
return jsonify({"error": "not_authenticated"}), 401
|
||||
if role != "admin":
|
||||
return jsonify({"error": "forbidden"}), 403
|
||||
return None
|
||||
|
||||
|
||||
def _format_user(record) -> dict[str, object]:
|
||||
return {
|
||||
"username": record.username,
|
||||
"display_name": record.display_name,
|
||||
"role": record.role,
|
||||
"last_login": record.last_login,
|
||||
"created_at": record.created_at,
|
||||
"updated_at": record.updated_at,
|
||||
"mfa_enabled": 1 if record.mfa_enabled else 0,
|
||||
}
|
||||
|
||||
|
||||
@blueprint.route("/api/users", methods=["GET"])
|
||||
def list_users() -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
records = _accounts().list_accounts()
|
||||
return jsonify({"users": [_format_user(record) for record in records]})
|
||||
|
||||
|
||||
@blueprint.route("/api/users", methods=["POST"])
|
||||
def create_user() -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
username = str(payload.get("username") or "").strip()
|
||||
password_sha512 = str(payload.get("password_sha512") or "").strip()
|
||||
role = str(payload.get("role") or "User")
|
||||
display_name = str(payload.get("display_name") or username)
|
||||
|
||||
try:
|
||||
_accounts().create_account(
|
||||
username=username,
|
||||
password_sha512=password_sha512,
|
||||
role=role,
|
||||
display_name=display_name,
|
||||
)
|
||||
except UsernameAlreadyExistsError as exc:
|
||||
return jsonify({"error": str(exc)}), 409
|
||||
except (InvalidPasswordHashError, InvalidRoleError) as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
@blueprint.route("/api/users/<username>", methods=["DELETE"])
|
||||
def delete_user(username: str) -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
actor = session.get("username") if isinstance(session.get("username"), str) else None
|
||||
|
||||
try:
|
||||
_accounts().delete_account(username, actor=actor)
|
||||
except CannotModifySelfError as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
except LastUserError as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
except LastAdminError as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
except AccountNotFoundError as exc:
|
||||
return jsonify({"error": str(exc)}), 404
|
||||
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
@blueprint.route("/api/users/<username>/reset_password", methods=["POST"])
|
||||
def reset_password(username: str) -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
password_sha512 = str(payload.get("password_sha512") or "").strip()
|
||||
|
||||
try:
|
||||
_accounts().reset_password(username, password_sha512)
|
||||
except InvalidPasswordHashError as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
except AccountNotFoundError as exc:
|
||||
return jsonify({"error": str(exc)}), 404
|
||||
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
@blueprint.route("/api/users/<username>/role", methods=["POST"])
|
||||
def change_role(username: str) -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
role = str(payload.get("role") or "").strip()
|
||||
actor = session.get("username") if isinstance(session.get("username"), str) else None
|
||||
|
||||
try:
|
||||
record = _accounts().change_role(username, role, actor=actor)
|
||||
except InvalidRoleError as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
except LastAdminError as exc:
|
||||
return jsonify({"error": str(exc)}), 400
|
||||
except AccountNotFoundError as exc:
|
||||
return jsonify({"error": str(exc)}), 404
|
||||
|
||||
if actor and actor.strip().lower() == username.strip().lower():
|
||||
session["role"] = record.role
|
||||
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
@blueprint.route("/api/users/<username>/mfa", methods=["POST"])
|
||||
def update_mfa(username: str) -> object:
|
||||
guard = _require_admin()
|
||||
if guard:
|
||||
return guard
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
enabled = bool(payload.get("enabled", False))
|
||||
reset_secret = bool(payload.get("reset_secret", False))
|
||||
|
||||
try:
|
||||
_accounts().update_mfa(username, enabled=enabled, reset_secret=reset_secret)
|
||||
except AccountNotFoundError as exc:
|
||||
return jsonify({"error": str(exc)}), 404
|
||||
|
||||
actor = session.get("username") if isinstance(session.get("username"), str) else None
|
||||
if actor and actor.strip().lower() == username.strip().lower() and not enabled:
|
||||
session.pop("mfa_pending", None)
|
||||
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
__all__ = ["register", "blueprint"]
|
||||
@@ -1,47 +0,0 @@
|
||||
"""WebSocket interface factory for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from ...config import SocketIOSettings
|
||||
from ...services.container import EngineServiceContainer
|
||||
from .agents import register as register_agent_events
|
||||
from .job_management import register as register_job_events
|
||||
|
||||
try: # pragma: no cover - import guard
|
||||
from flask_socketio import SocketIO
|
||||
except Exception: # pragma: no cover - optional dependency
|
||||
SocketIO = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def create_socket_server(app: Flask, settings: SocketIOSettings) -> Optional[SocketIO]:
|
||||
"""Create a Socket.IO server bound to *app* if dependencies are available."""
|
||||
|
||||
if SocketIO is None:
|
||||
return None
|
||||
|
||||
cors_allowed = settings.cors_allowed_origins or ("*",)
|
||||
socketio = SocketIO(
|
||||
app,
|
||||
cors_allowed_origins=cors_allowed,
|
||||
async_mode=None,
|
||||
logger=False,
|
||||
engineio_logger=False,
|
||||
)
|
||||
return socketio
|
||||
|
||||
|
||||
def register_ws_interfaces(socketio: Any, services: EngineServiceContainer) -> None:
|
||||
"""Attach namespaces for the Engine Socket.IO server."""
|
||||
|
||||
if socketio is None: # pragma: no cover - guard
|
||||
return
|
||||
|
||||
for registrar in (register_agent_events, register_job_events):
|
||||
registrar(socketio, services)
|
||||
|
||||
|
||||
__all__ = ["create_socket_server", "register_ws_interfaces"]
|
||||
@@ -1,16 +0,0 @@
|
||||
"""Agent WebSocket namespace wiring for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from . import events
|
||||
|
||||
|
||||
def register(socketio: Any, services) -> None:
|
||||
"""Register agent namespaces on the given Socket.IO *socketio* instance."""
|
||||
|
||||
events.register(socketio, services)
|
||||
|
||||
|
||||
__all__ = ["register"]
|
||||
@@ -1,261 +0,0 @@
|
||||
"""Agent WebSocket event handlers for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
from flask import request
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
try: # pragma: no cover - optional dependency guard
|
||||
from flask_socketio import emit, join_room
|
||||
except Exception: # pragma: no cover - optional dependency guard
|
||||
emit = None # type: ignore[assignment]
|
||||
join_room = None # type: ignore[assignment]
|
||||
|
||||
_AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||
|
||||
|
||||
def register(socketio: Any, services: EngineServiceContainer) -> None:
|
||||
if socketio is None: # pragma: no cover - guard
|
||||
return
|
||||
|
||||
handlers = _AgentEventHandlers(socketio, services)
|
||||
socketio.on_event("connect", handlers.on_connect)
|
||||
socketio.on_event("disconnect", handlers.on_disconnect)
|
||||
socketio.on_event("agent_screenshot_task", handlers.on_agent_screenshot_task)
|
||||
socketio.on_event("connect_agent", handlers.on_connect_agent)
|
||||
socketio.on_event("agent_heartbeat", handlers.on_agent_heartbeat)
|
||||
socketio.on_event("collector_status", handlers.on_collector_status)
|
||||
socketio.on_event("request_config", handlers.on_request_config)
|
||||
socketio.on_event("screenshot", handlers.on_screenshot)
|
||||
socketio.on_event("macro_status", handlers.on_macro_status)
|
||||
socketio.on_event("list_agent_windows", handlers.on_list_agent_windows)
|
||||
socketio.on_event("agent_window_list", handlers.on_agent_window_list)
|
||||
socketio.on_event("ansible_playbook_cancel", handlers.on_ansible_playbook_cancel)
|
||||
socketio.on_event("ansible_playbook_run", handlers.on_ansible_playbook_run)
|
||||
|
||||
|
||||
class _AgentEventHandlers:
|
||||
def __init__(self, socketio: Any, services: EngineServiceContainer) -> None:
|
||||
self._socketio = socketio
|
||||
self._services = services
|
||||
self._realtime = services.agent_realtime
|
||||
self._log = logging.getLogger("borealis.engine.ws.agents")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
def on_connect(self) -> None:
|
||||
sid = getattr(request, "sid", "<unknown>")
|
||||
remote_addr = getattr(request, "remote_addr", None)
|
||||
transport = None
|
||||
try:
|
||||
transport = request.args.get("transport") # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
transport = None
|
||||
query = self._render_query()
|
||||
headers = _summarize_socket_headers(getattr(request, "headers", {}))
|
||||
scope = _canonical_scope(getattr(request.headers, "get", lambda *_: None)(_AGENT_CONTEXT_HEADER))
|
||||
self._log.info(
|
||||
"socket-connect sid=%s ip=%s transport=%r query=%s headers=%s scope=%s",
|
||||
sid,
|
||||
remote_addr,
|
||||
transport,
|
||||
query,
|
||||
headers,
|
||||
scope or "<none>",
|
||||
)
|
||||
|
||||
def on_disconnect(self) -> None:
|
||||
sid = getattr(request, "sid", "<unknown>")
|
||||
remote_addr = getattr(request, "remote_addr", None)
|
||||
self._log.info("socket-disconnect sid=%s ip=%s", sid, remote_addr)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent coordination
|
||||
# ------------------------------------------------------------------
|
||||
def on_agent_screenshot_task(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
agent_id = payload.get("agent_id")
|
||||
node_id = payload.get("node_id")
|
||||
image = payload.get("image_base64", "")
|
||||
|
||||
if not agent_id or not node_id:
|
||||
self._log.warning("screenshot-task missing identifiers: %s", payload)
|
||||
return
|
||||
|
||||
if image:
|
||||
self._realtime.store_task_screenshot(agent_id, node_id, image)
|
||||
|
||||
try:
|
||||
self._socketio.emit("agent_screenshot_task", payload)
|
||||
except Exception as exc: # pragma: no cover - network guard
|
||||
self._log.warning("socket emit failed for agent_screenshot_task: %s", exc)
|
||||
|
||||
def on_connect_agent(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
agent_id = payload.get("agent_id")
|
||||
if not agent_id:
|
||||
return
|
||||
|
||||
service_mode = payload.get("service_mode")
|
||||
record = self._realtime.register_connection(agent_id, service_mode)
|
||||
|
||||
if join_room is not None: # pragma: no branch - optional dependency guard
|
||||
try:
|
||||
join_room(agent_id)
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.debug("join_room failed for %s: %s", agent_id, exc)
|
||||
|
||||
self._log.info(
|
||||
"agent-connected agent_id=%s mode=%s status=%s",
|
||||
agent_id,
|
||||
record.service_mode,
|
||||
record.status,
|
||||
)
|
||||
|
||||
def on_agent_heartbeat(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
record = self._realtime.heartbeat(payload)
|
||||
if record:
|
||||
self._log.debug(
|
||||
"agent-heartbeat agent_id=%s host=%s mode=%s", record.agent_id, record.hostname, record.service_mode
|
||||
)
|
||||
|
||||
def on_collector_status(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
self._realtime.collector_status(payload)
|
||||
|
||||
def on_request_config(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
agent_id = payload.get("agent_id")
|
||||
if not agent_id:
|
||||
return
|
||||
config = self._realtime.get_agent_config(agent_id)
|
||||
if config and emit is not None:
|
||||
try:
|
||||
emit("agent_config", {**config, "agent_id": agent_id})
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.debug("emit(agent_config) failed for %s: %s", agent_id, exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Media + relay events
|
||||
# ------------------------------------------------------------------
|
||||
def on_screenshot(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
agent_id = payload.get("agent_id")
|
||||
image = payload.get("image_base64")
|
||||
if agent_id and image:
|
||||
self._realtime.store_agent_screenshot(agent_id, image)
|
||||
try:
|
||||
self._socketio.emit("new_screenshot", {"agent_id": agent_id, "image_base64": image})
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.warning("socket emit failed for new_screenshot: %s", exc)
|
||||
|
||||
def on_macro_status(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
agent_id = payload.get("agent_id")
|
||||
node_id = payload.get("node_id")
|
||||
success = payload.get("success")
|
||||
message = payload.get("message")
|
||||
self._log.info(
|
||||
"macro-status agent=%s node=%s success=%s message=%s",
|
||||
agent_id,
|
||||
node_id,
|
||||
success,
|
||||
message,
|
||||
)
|
||||
try:
|
||||
self._socketio.emit("macro_status", payload)
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.warning("socket emit failed for macro_status: %s", exc)
|
||||
|
||||
def on_list_agent_windows(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
try:
|
||||
self._socketio.emit("list_agent_windows", payload)
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.warning("socket emit failed for list_agent_windows: %s", exc)
|
||||
|
||||
def on_agent_window_list(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
payload = data or {}
|
||||
try:
|
||||
self._socketio.emit("agent_window_list", payload)
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.warning("socket emit failed for agent_window_list: %s", exc)
|
||||
|
||||
def on_ansible_playbook_cancel(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
try:
|
||||
self._socketio.emit("ansible_playbook_cancel", data or {})
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.warning("socket emit failed for ansible_playbook_cancel: %s", exc)
|
||||
|
||||
def on_ansible_playbook_run(self, data: Optional[Dict[str, Any]]) -> None:
|
||||
try:
|
||||
self._socketio.emit("ansible_playbook_run", data or {})
|
||||
except Exception as exc: # pragma: no cover - dependency guard
|
||||
self._log.warning("socket emit failed for ansible_playbook_run: %s", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _render_query(self) -> str:
|
||||
try:
|
||||
pairs = [f"{k}={v}" for k, v in request.args.items()] # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return "<unavailable>"
|
||||
return "&".join(pairs) if pairs else "<none>"
|
||||
|
||||
|
||||
def _canonical_scope(raw: Optional[str]) -> Optional[str]:
|
||||
if not raw:
|
||||
return None
|
||||
value = "".join(ch for ch in str(raw) if ch.isalnum() or ch in ("_", "-"))
|
||||
if not value:
|
||||
return None
|
||||
return value.upper()
|
||||
|
||||
|
||||
def _mask_value(value: str, *, prefix: int = 4, suffix: int = 4) -> str:
|
||||
try:
|
||||
if not value:
|
||||
return ""
|
||||
stripped = value.strip()
|
||||
if len(stripped) <= prefix + suffix:
|
||||
return "*" * len(stripped)
|
||||
return f"{stripped[:prefix]}***{stripped[-suffix:]}"
|
||||
except Exception:
|
||||
return "***"
|
||||
|
||||
|
||||
def _summarize_socket_headers(headers: Any) -> str:
|
||||
try:
|
||||
items: Iterable[tuple[str, Any]]
|
||||
if isinstance(headers, dict):
|
||||
items = headers.items()
|
||||
else:
|
||||
items = getattr(headers, "items", lambda: [])()
|
||||
except Exception:
|
||||
items = []
|
||||
|
||||
rendered = []
|
||||
for key, value in items:
|
||||
lowered = str(key).lower()
|
||||
display = value
|
||||
if lowered == "authorization":
|
||||
token = str(value or "")
|
||||
if token.lower().startswith("bearer "):
|
||||
display = f"Bearer {_mask_value(token.split(' ', 1)[1])}"
|
||||
else:
|
||||
display = _mask_value(token)
|
||||
elif lowered == "cookie":
|
||||
display = "<redacted>"
|
||||
rendered.append(f"{key}={display}")
|
||||
return ", ".join(rendered) if rendered else "<no-headers>"
|
||||
|
||||
|
||||
__all__ = ["register"]
|
||||
@@ -1,16 +0,0 @@
|
||||
"""Job management WebSocket namespace wiring for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from . import events
|
||||
|
||||
|
||||
def register(socketio: Any, services) -> None:
|
||||
"""Register job management namespaces on the given Socket.IO *socketio*."""
|
||||
|
||||
events.register(socketio, services)
|
||||
|
||||
|
||||
__all__ = ["register"]
|
||||
@@ -1,38 +0,0 @@
|
||||
"""Job management WebSocket event handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from Data.Engine.services.container import EngineServiceContainer
|
||||
|
||||
|
||||
def register(socketio: Any, services: EngineServiceContainer) -> None:
|
||||
if socketio is None: # pragma: no cover - guard
|
||||
return
|
||||
|
||||
handlers = _JobEventHandlers(socketio, services)
|
||||
socketio.on_event("quick_job_result", handlers.on_quick_job_result)
|
||||
socketio.on_event("job_status_request", handlers.on_job_status_request)
|
||||
|
||||
|
||||
class _JobEventHandlers:
|
||||
def __init__(self, socketio: Any, services: EngineServiceContainer) -> None:
|
||||
self._socketio = socketio
|
||||
self._services = services
|
||||
self._log = logging.getLogger("borealis.engine.ws.jobs")
|
||||
|
||||
def on_quick_job_result(self, data: Optional[dict]) -> None:
|
||||
self._log.info("quick-job-result received; scheduler migration pending")
|
||||
# Step 10 will introduce full persistence + broadcast logic.
|
||||
|
||||
def on_job_status_request(self, _: Optional[dict]) -> None:
|
||||
jobs = self._services.scheduler_service.list_jobs()
|
||||
try:
|
||||
self._socketio.emit("job_status", {"jobs": jobs})
|
||||
except Exception:
|
||||
self._log.debug("job-status emit failed")
|
||||
|
||||
|
||||
__all__ = ["register"]
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Persistence adapters for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import sqlite
|
||||
|
||||
__all__ = ["sqlite"]
|
||||
@@ -1,62 +0,0 @@
|
||||
"""SQLite persistence helpers for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .connection import (
|
||||
SQLiteConnectionFactory,
|
||||
configure_connection,
|
||||
connect,
|
||||
connection_factory,
|
||||
connection_scope,
|
||||
)
|
||||
from .migrations import apply_all, ensure_default_admin
|
||||
|
||||
__all__ = [
|
||||
"SQLiteConnectionFactory",
|
||||
"configure_connection",
|
||||
"connect",
|
||||
"connection_factory",
|
||||
"connection_scope",
|
||||
"apply_all",
|
||||
"ensure_default_admin",
|
||||
]
|
||||
|
||||
try: # pragma: no cover - optional dependency shim
|
||||
from .device_repository import SQLiteDeviceRepository
|
||||
from .enrollment_repository import SQLiteEnrollmentRepository
|
||||
from .device_inventory_repository import SQLiteDeviceInventoryRepository
|
||||
from .device_view_repository import SQLiteDeviceViewRepository
|
||||
from .credential_repository import SQLiteCredentialRepository
|
||||
from .github_repository import SQLiteGitHubRepository
|
||||
from .job_repository import SQLiteJobRepository
|
||||
from .site_repository import SQLiteSiteRepository
|
||||
from .token_repository import SQLiteRefreshTokenRepository
|
||||
from .user_repository import SQLiteUserRepository
|
||||
except ModuleNotFoundError as exc: # pragma: no cover - triggered when auth deps missing
|
||||
def _missing_repo(*_args: object, **_kwargs: object) -> None:
|
||||
raise ModuleNotFoundError(
|
||||
"Engine SQLite repositories require optional authentication dependencies"
|
||||
) from exc
|
||||
|
||||
SQLiteDeviceRepository = _missing_repo # type: ignore[assignment]
|
||||
SQLiteEnrollmentRepository = _missing_repo # type: ignore[assignment]
|
||||
SQLiteDeviceInventoryRepository = _missing_repo # type: ignore[assignment]
|
||||
SQLiteDeviceViewRepository = _missing_repo # type: ignore[assignment]
|
||||
SQLiteCredentialRepository = _missing_repo # type: ignore[assignment]
|
||||
SQLiteGitHubRepository = _missing_repo # type: ignore[assignment]
|
||||
SQLiteJobRepository = _missing_repo # type: ignore[assignment]
|
||||
SQLiteSiteRepository = _missing_repo # type: ignore[assignment]
|
||||
SQLiteRefreshTokenRepository = _missing_repo # type: ignore[assignment]
|
||||
else:
|
||||
__all__ += [
|
||||
"SQLiteDeviceRepository",
|
||||
"SQLiteRefreshTokenRepository",
|
||||
"SQLiteJobRepository",
|
||||
"SQLiteEnrollmentRepository",
|
||||
"SQLiteDeviceInventoryRepository",
|
||||
"SQLiteDeviceViewRepository",
|
||||
"SQLiteCredentialRepository",
|
||||
"SQLiteGitHubRepository",
|
||||
"SQLiteUserRepository",
|
||||
"SQLiteSiteRepository",
|
||||
]
|
||||
@@ -1,67 +0,0 @@
|
||||
"""SQLite connection utilities for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Protocol
|
||||
|
||||
__all__ = [
|
||||
"SQLiteConnectionFactory",
|
||||
"configure_connection",
|
||||
"connect",
|
||||
"connection_factory",
|
||||
"connection_scope",
|
||||
]
|
||||
|
||||
|
||||
class SQLiteConnectionFactory(Protocol):
|
||||
"""Callable protocol for obtaining configured SQLite connections."""
|
||||
|
||||
def __call__(self) -> sqlite3.Connection:
|
||||
"""Return a new :class:`sqlite3.Connection`."""
|
||||
|
||||
|
||||
def configure_connection(conn: sqlite3.Connection) -> None:
|
||||
"""Apply the Borealis-standard pragmas to *conn*."""
|
||||
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute("PRAGMA journal_mode=WAL")
|
||||
cur.execute("PRAGMA busy_timeout=5000")
|
||||
cur.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.commit()
|
||||
except Exception:
|
||||
# Pragmas are best-effort; failing to apply them should not block startup.
|
||||
conn.rollback()
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
|
||||
def connect(path: Path, *, timeout: float = 15.0) -> sqlite3.Connection:
|
||||
"""Create a new SQLite connection to *path* with Engine pragmas applied."""
|
||||
|
||||
conn = sqlite3.connect(str(path), timeout=timeout)
|
||||
configure_connection(conn)
|
||||
return conn
|
||||
|
||||
|
||||
def connection_factory(path: Path, *, timeout: float = 15.0) -> SQLiteConnectionFactory:
|
||||
"""Return a factory that opens connections to *path* when invoked."""
|
||||
|
||||
def factory() -> sqlite3.Connection:
|
||||
return connect(path, timeout=timeout)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@contextmanager
|
||||
def connection_scope(path: Path, *, timeout: float = 15.0) -> Iterator[sqlite3.Connection]:
|
||||
"""Context manager yielding a configured connection to *path*."""
|
||||
|
||||
conn = connect(path, timeout=timeout)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
@@ -1,103 +0,0 @@
|
||||
"""SQLite access for operator credential metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from contextlib import closing
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||
|
||||
__all__ = ["SQLiteCredentialRepository"]
|
||||
|
||||
|
||||
class SQLiteCredentialRepository:
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connections = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.credentials")
|
||||
|
||||
def list_credentials(
|
||||
self,
|
||||
*,
|
||||
site_id: Optional[int] = None,
|
||||
connection_type: Optional[str] = None,
|
||||
) -> List[Dict[str, object]]:
|
||||
sql = """
|
||||
SELECT c.id,
|
||||
c.name,
|
||||
c.description,
|
||||
c.credential_type,
|
||||
c.connection_type,
|
||||
c.username,
|
||||
c.site_id,
|
||||
s.name AS site_name,
|
||||
c.become_method,
|
||||
c.become_username,
|
||||
c.metadata_json,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.password_encrypted,
|
||||
c.private_key_encrypted,
|
||||
c.private_key_passphrase_encrypted,
|
||||
c.become_password_encrypted
|
||||
FROM credentials c
|
||||
LEFT JOIN sites s ON s.id = c.site_id
|
||||
"""
|
||||
clauses: List[str] = []
|
||||
params: List[object] = []
|
||||
if site_id is not None:
|
||||
clauses.append("c.site_id = ?")
|
||||
params.append(site_id)
|
||||
if connection_type:
|
||||
clauses.append("LOWER(c.connection_type) = LOWER(?)")
|
||||
params.append(connection_type)
|
||||
if clauses:
|
||||
sql += " WHERE " + " AND ".join(clauses)
|
||||
sql += " ORDER BY LOWER(c.name) ASC"
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
conn.row_factory = sqlite3.Row # type: ignore[attr-defined]
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql, params)
|
||||
rows = cur.fetchall()
|
||||
|
||||
results: List[Dict[str, object]] = []
|
||||
for row in rows:
|
||||
metadata_json = row["metadata_json"] if "metadata_json" in row.keys() else None
|
||||
metadata = {}
|
||||
if metadata_json:
|
||||
try:
|
||||
candidate = json.loads(metadata_json)
|
||||
if isinstance(candidate, dict):
|
||||
metadata = candidate
|
||||
except Exception:
|
||||
metadata = {}
|
||||
results.append(
|
||||
{
|
||||
"id": row["id"],
|
||||
"name": row["name"],
|
||||
"description": row["description"] or "",
|
||||
"credential_type": row["credential_type"] or "machine",
|
||||
"connection_type": row["connection_type"] or "ssh",
|
||||
"site_id": row["site_id"],
|
||||
"site_name": row["site_name"],
|
||||
"username": row["username"] or "",
|
||||
"become_method": row["become_method"] or "",
|
||||
"become_username": row["become_username"] or "",
|
||||
"metadata": metadata,
|
||||
"created_at": int(row["created_at"] or 0),
|
||||
"updated_at": int(row["updated_at"] or 0),
|
||||
"has_password": bool(row["password_encrypted"]),
|
||||
"has_private_key": bool(row["private_key_encrypted"]),
|
||||
"has_private_key_passphrase": bool(row["private_key_passphrase_encrypted"]),
|
||||
"has_become_password": bool(row["become_password_encrypted"]),
|
||||
}
|
||||
)
|
||||
return results
|
||||
@@ -1,338 +0,0 @@
|
||||
"""Device inventory operations backed by SQLite."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import closing
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from Data.Engine.domain.devices import (
|
||||
DEVICE_TABLE,
|
||||
DEVICE_TABLE_COLUMNS,
|
||||
assemble_device_snapshot,
|
||||
clean_device_str,
|
||||
coerce_int,
|
||||
device_column_sql,
|
||||
row_to_device_dict,
|
||||
serialize_device_json,
|
||||
)
|
||||
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||
|
||||
__all__ = ["SQLiteDeviceInventoryRepository"]
|
||||
|
||||
|
||||
class SQLiteDeviceInventoryRepository:
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connections = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.device_inventory")
|
||||
|
||||
def fetch_devices(
|
||||
self,
|
||||
*,
|
||||
connection_type: Optional[str] = None,
|
||||
hostname: Optional[str] = None,
|
||||
only_agents: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
sql = f"""
|
||||
SELECT {device_column_sql('d')}, s.id, s.name, s.description
|
||||
FROM {DEVICE_TABLE} d
|
||||
LEFT JOIN device_sites ds ON ds.device_hostname = d.hostname
|
||||
LEFT JOIN sites s ON s.id = ds.site_id
|
||||
"""
|
||||
clauses: List[str] = []
|
||||
params: List[Any] = []
|
||||
if connection_type:
|
||||
clauses.append("LOWER(d.connection_type) = LOWER(?)")
|
||||
params.append(connection_type)
|
||||
if hostname:
|
||||
clauses.append("LOWER(d.hostname) = LOWER(?)")
|
||||
params.append(hostname.lower())
|
||||
if only_agents:
|
||||
clauses.append("(d.connection_type IS NULL OR TRIM(d.connection_type) = '')")
|
||||
if clauses:
|
||||
sql += " WHERE " + " AND ".join(clauses)
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql, params)
|
||||
rows = cur.fetchall()
|
||||
|
||||
now = time.time()
|
||||
devices: List[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
core = row[: len(DEVICE_TABLE_COLUMNS)]
|
||||
site_id, site_name, site_description = row[len(DEVICE_TABLE_COLUMNS) :]
|
||||
record = row_to_device_dict(core, DEVICE_TABLE_COLUMNS)
|
||||
snapshot = assemble_device_snapshot(record)
|
||||
summary = snapshot.get("summary", {})
|
||||
last_seen = snapshot.get("last_seen") or 0
|
||||
status = "Offline"
|
||||
try:
|
||||
if last_seen and (now - float(last_seen)) <= 300:
|
||||
status = "Online"
|
||||
except Exception:
|
||||
pass
|
||||
devices.append(
|
||||
{
|
||||
**snapshot,
|
||||
"site_id": site_id,
|
||||
"site_name": site_name or "",
|
||||
"site_description": site_description or "",
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
return devices
|
||||
|
||||
def load_snapshot(self, *, hostname: Optional[str] = None, guid: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||
if not hostname and not guid:
|
||||
return None
|
||||
sql = None
|
||||
params: Tuple[Any, ...]
|
||||
if hostname:
|
||||
sql = f"SELECT {device_column_sql()} FROM {DEVICE_TABLE} WHERE hostname = ?"
|
||||
params = (hostname,)
|
||||
else:
|
||||
sql = f"SELECT {device_column_sql()} FROM {DEVICE_TABLE} WHERE LOWER(guid) = LOWER(?)"
|
||||
params = (guid,)
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql, params)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
record = row_to_device_dict(row, DEVICE_TABLE_COLUMNS)
|
||||
return assemble_device_snapshot(record)
|
||||
|
||||
def upsert_device(
|
||||
self,
|
||||
hostname: str,
|
||||
description: Optional[str],
|
||||
merged_details: Dict[str, Any],
|
||||
created_at: Optional[int],
|
||||
*,
|
||||
agent_hash: Optional[str] = None,
|
||||
guid: Optional[str] = None,
|
||||
) -> None:
|
||||
if not hostname:
|
||||
return
|
||||
|
||||
column_values = self._extract_device_columns(merged_details or {})
|
||||
normalized_description = description if description is not None else ""
|
||||
try:
|
||||
normalized_description = str(normalized_description)
|
||||
except Exception:
|
||||
normalized_description = ""
|
||||
|
||||
normalized_hash = clean_device_str(agent_hash) or None
|
||||
normalized_guid = clean_device_str(guid) or None
|
||||
created_ts = coerce_int(created_at) or int(time.time())
|
||||
|
||||
sql = f"""
|
||||
INSERT INTO {DEVICE_TABLE}(
|
||||
hostname,
|
||||
description,
|
||||
created_at,
|
||||
agent_hash,
|
||||
guid,
|
||||
memory,
|
||||
network,
|
||||
software,
|
||||
storage,
|
||||
cpu,
|
||||
device_type,
|
||||
domain,
|
||||
external_ip,
|
||||
internal_ip,
|
||||
last_reboot,
|
||||
last_seen,
|
||||
last_user,
|
||||
operating_system,
|
||||
uptime,
|
||||
agent_id,
|
||||
ansible_ee_ver,
|
||||
connection_type,
|
||||
connection_endpoint,
|
||||
ssl_key_fingerprint,
|
||||
token_version,
|
||||
status,
|
||||
key_added_at
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
ON CONFLICT(hostname) DO UPDATE SET
|
||||
description=excluded.description,
|
||||
created_at=COALESCE({DEVICE_TABLE}.created_at, excluded.created_at),
|
||||
agent_hash=COALESCE(NULLIF(excluded.agent_hash, ''), {DEVICE_TABLE}.agent_hash),
|
||||
guid=COALESCE(NULLIF(excluded.guid, ''), {DEVICE_TABLE}.guid),
|
||||
memory=excluded.memory,
|
||||
network=excluded.network,
|
||||
software=excluded.software,
|
||||
storage=excluded.storage,
|
||||
cpu=excluded.cpu,
|
||||
device_type=COALESCE(NULLIF(excluded.device_type, ''), {DEVICE_TABLE}.device_type),
|
||||
domain=COALESCE(NULLIF(excluded.domain, ''), {DEVICE_TABLE}.domain),
|
||||
external_ip=COALESCE(NULLIF(excluded.external_ip, ''), {DEVICE_TABLE}.external_ip),
|
||||
internal_ip=COALESCE(NULLIF(excluded.internal_ip, ''), {DEVICE_TABLE}.internal_ip),
|
||||
last_reboot=COALESCE(NULLIF(excluded.last_reboot, ''), {DEVICE_TABLE}.last_reboot),
|
||||
last_seen=COALESCE(NULLIF(excluded.last_seen, 0), {DEVICE_TABLE}.last_seen),
|
||||
last_user=COALESCE(NULLIF(excluded.last_user, ''), {DEVICE_TABLE}.last_user),
|
||||
operating_system=COALESCE(NULLIF(excluded.operating_system, ''), {DEVICE_TABLE}.operating_system),
|
||||
uptime=COALESCE(NULLIF(excluded.uptime, 0), {DEVICE_TABLE}.uptime),
|
||||
agent_id=COALESCE(NULLIF(excluded.agent_id, ''), {DEVICE_TABLE}.agent_id),
|
||||
ansible_ee_ver=COALESCE(NULLIF(excluded.ansible_ee_ver, ''), {DEVICE_TABLE}.ansible_ee_ver),
|
||||
connection_type=COALESCE(NULLIF(excluded.connection_type, ''), {DEVICE_TABLE}.connection_type),
|
||||
connection_endpoint=COALESCE(NULLIF(excluded.connection_endpoint, ''), {DEVICE_TABLE}.connection_endpoint),
|
||||
ssl_key_fingerprint=COALESCE(NULLIF(excluded.ssl_key_fingerprint, ''), {DEVICE_TABLE}.ssl_key_fingerprint),
|
||||
token_version=COALESCE(NULLIF(excluded.token_version, 0), {DEVICE_TABLE}.token_version),
|
||||
status=COALESCE(NULLIF(excluded.status, ''), {DEVICE_TABLE}.status),
|
||||
key_added_at=COALESCE(NULLIF(excluded.key_added_at, ''), {DEVICE_TABLE}.key_added_at)
|
||||
"""
|
||||
|
||||
params: List[Any] = [
|
||||
hostname,
|
||||
normalized_description,
|
||||
created_ts,
|
||||
normalized_hash,
|
||||
normalized_guid,
|
||||
column_values.get("memory"),
|
||||
column_values.get("network"),
|
||||
column_values.get("software"),
|
||||
column_values.get("storage"),
|
||||
column_values.get("cpu"),
|
||||
column_values.get("device_type"),
|
||||
column_values.get("domain"),
|
||||
column_values.get("external_ip"),
|
||||
column_values.get("internal_ip"),
|
||||
column_values.get("last_reboot"),
|
||||
column_values.get("last_seen"),
|
||||
column_values.get("last_user"),
|
||||
column_values.get("operating_system"),
|
||||
column_values.get("uptime"),
|
||||
column_values.get("agent_id"),
|
||||
column_values.get("ansible_ee_ver"),
|
||||
column_values.get("connection_type"),
|
||||
column_values.get("connection_endpoint"),
|
||||
column_values.get("ssl_key_fingerprint"),
|
||||
column_values.get("token_version"),
|
||||
column_values.get("status"),
|
||||
column_values.get("key_added_at"),
|
||||
]
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql, params)
|
||||
conn.commit()
|
||||
|
||||
def delete_device_by_hostname(self, hostname: str) -> None:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("DELETE FROM device_sites WHERE device_hostname = ?", (hostname,))
|
||||
cur.execute(f"DELETE FROM {DEVICE_TABLE} WHERE hostname = ?", (hostname,))
|
||||
conn.commit()
|
||||
|
||||
def record_device_fingerprint(self, guid: Optional[str], fingerprint: Optional[str], added_at: str) -> None:
|
||||
normalized_guid = clean_device_str(guid)
|
||||
normalized_fp = clean_device_str(fingerprint)
|
||||
if not normalized_guid or not normalized_fp:
|
||||
return
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(str(uuid.uuid4()), normalized_guid, normalized_fp.lower(), added_at),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_keys
|
||||
SET retired_at = ?
|
||||
WHERE guid = ?
|
||||
AND ssl_key_fingerprint != ?
|
||||
AND retired_at IS NULL
|
||||
""",
|
||||
(added_at, normalized_guid, normalized_fp.lower()),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET ssl_key_fingerprint = COALESCE(LOWER(?), ssl_key_fingerprint),
|
||||
key_added_at = COALESCE(key_added_at, ?)
|
||||
WHERE LOWER(guid) = LOWER(?)
|
||||
""",
|
||||
(normalized_fp, added_at, normalized_guid),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _extract_device_columns(self, details: Dict[str, Any]) -> Dict[str, Any]:
|
||||
summary = details.get("summary") or {}
|
||||
payload: Dict[str, Any] = {}
|
||||
for field in ("memory", "network", "software", "storage"):
|
||||
payload[field] = serialize_device_json(details.get(field), [])
|
||||
payload["cpu"] = serialize_device_json(summary.get("cpu") or details.get("cpu"), {})
|
||||
payload["device_type"] = clean_device_str(
|
||||
summary.get("device_type")
|
||||
or summary.get("type")
|
||||
or summary.get("device_class")
|
||||
)
|
||||
payload["domain"] = clean_device_str(
|
||||
summary.get("domain") or summary.get("domain_name")
|
||||
)
|
||||
payload["external_ip"] = clean_device_str(
|
||||
summary.get("external_ip") or summary.get("public_ip")
|
||||
)
|
||||
payload["internal_ip"] = clean_device_str(
|
||||
summary.get("internal_ip") or summary.get("private_ip")
|
||||
)
|
||||
payload["last_reboot"] = clean_device_str(
|
||||
summary.get("last_reboot") or summary.get("last_boot")
|
||||
)
|
||||
payload["last_seen"] = coerce_int(
|
||||
summary.get("last_seen") or summary.get("last_seen_epoch")
|
||||
)
|
||||
payload["last_user"] = clean_device_str(
|
||||
summary.get("last_user")
|
||||
or summary.get("last_user_name")
|
||||
or summary.get("logged_in_user")
|
||||
or summary.get("username")
|
||||
or summary.get("user")
|
||||
)
|
||||
payload["operating_system"] = clean_device_str(
|
||||
summary.get("operating_system")
|
||||
or summary.get("agent_operating_system")
|
||||
or summary.get("os")
|
||||
)
|
||||
uptime_value = (
|
||||
summary.get("uptime_sec")
|
||||
or summary.get("uptime_seconds")
|
||||
or summary.get("uptime")
|
||||
)
|
||||
payload["uptime"] = coerce_int(uptime_value)
|
||||
payload["agent_id"] = clean_device_str(summary.get("agent_id"))
|
||||
payload["ansible_ee_ver"] = clean_device_str(summary.get("ansible_ee_ver"))
|
||||
payload["connection_type"] = clean_device_str(
|
||||
summary.get("connection_type") or summary.get("remote_type")
|
||||
)
|
||||
payload["connection_endpoint"] = clean_device_str(
|
||||
summary.get("connection_endpoint")
|
||||
or summary.get("endpoint")
|
||||
or summary.get("connection_address")
|
||||
or summary.get("address")
|
||||
or summary.get("external_ip")
|
||||
or summary.get("internal_ip")
|
||||
)
|
||||
payload["ssl_key_fingerprint"] = clean_device_str(
|
||||
summary.get("ssl_key_fingerprint")
|
||||
)
|
||||
payload["token_version"] = coerce_int(summary.get("token_version")) or 0
|
||||
payload["status"] = clean_device_str(summary.get("status"))
|
||||
payload["key_added_at"] = clean_device_str(summary.get("key_added_at"))
|
||||
return payload
|
||||
@@ -1,410 +0,0 @@
|
||||
"""SQLite-backed device repository for the Engine authentication services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import closing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from Data.Engine.domain.device_auth import (
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
DeviceIdentity,
|
||||
DeviceStatus,
|
||||
)
|
||||
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||
from Data.Engine.services.auth.device_auth_service import DeviceRecord
|
||||
|
||||
__all__ = ["SQLiteDeviceRepository"]
|
||||
|
||||
|
||||
class SQLiteDeviceRepository:
|
||||
"""Persistence adapter that reads and recovers device rows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connections = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.devices")
|
||||
|
||||
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]:
|
||||
"""Fetch a device row by GUID, normalizing legacy case variance."""
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid.value.upper(),),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
for row in rows:
|
||||
record = self._row_to_record(row)
|
||||
if record and record.identity.guid.value == guid.value:
|
||||
return record
|
||||
|
||||
# Fall back to the first row if normalization failed to match exactly.
|
||||
return self._row_to_record(rows[0])
|
||||
|
||||
def recover_missing(
|
||||
self,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
token_version: int,
|
||||
service_context: Optional[str],
|
||||
) -> Optional[DeviceRecord]:
|
||||
"""Attempt to recreate a missing device row for a valid token."""
|
||||
|
||||
now_ts = int(time.time())
|
||||
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||
base_hostname = f"RECOVERED-{guid.value[:12]}"
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
for attempt in range(6):
|
||||
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
|
||||
try:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO devices (
|
||||
guid,
|
||||
hostname,
|
||||
created_at,
|
||||
last_seen,
|
||||
ssl_key_fingerprint,
|
||||
token_version,
|
||||
status,
|
||||
key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
|
||||
""",
|
||||
(
|
||||
guid.value,
|
||||
hostname,
|
||||
now_ts,
|
||||
now_ts,
|
||||
fingerprint.value,
|
||||
max(token_version or 1, 1),
|
||||
now_iso,
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
message = str(exc).lower()
|
||||
if "hostname" in message and "unique" in message:
|
||||
continue
|
||||
self._log.warning(
|
||||
"device auth failed to recover guid=%s (context=%s): %s",
|
||||
guid.value,
|
||||
service_context or "none",
|
||||
exc,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.exception(
|
||||
"device auth unexpected error recovering guid=%s (context=%s)",
|
||||
guid.value,
|
||||
service_context or "none",
|
||||
exc_info=exc,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
else:
|
||||
conn.commit()
|
||||
break
|
||||
else:
|
||||
self._log.warning(
|
||||
"device auth could not recover guid=%s; hostname collisions persisted",
|
||||
guid.value,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(guid.value,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
self._log.warning(
|
||||
"device auth recovery committed but row missing for guid=%s",
|
||||
guid.value,
|
||||
)
|
||||
return None
|
||||
|
||||
return self._row_to_record(row)
|
||||
|
||||
def ensure_device_record(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
hostname: str,
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> DeviceRecord:
|
||||
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||
now_ts = int(time.time())
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid.value.upper(),),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if row:
|
||||
stored_fp = (row[4] or "").strip().lower()
|
||||
new_fp = fingerprint.value
|
||||
if not stored_fp:
|
||||
cur.execute(
|
||||
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
||||
(new_fp, now_iso, row[0]),
|
||||
)
|
||||
elif stored_fp != new_fp:
|
||||
token_version = self._coerce_int(row[2], default=1) + 1
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET ssl_key_fingerprint = ?,
|
||||
key_added_at = ?,
|
||||
token_version = ?,
|
||||
status = 'active'
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(new_fp, now_iso, token_version, row[0]),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = ?
|
||||
WHERE guid = ?
|
||||
AND revoked_at IS NULL
|
||||
""",
|
||||
(now_iso, row[0]),
|
||||
)
|
||||
conn.commit()
|
||||
else:
|
||||
resolved_hostname = self._resolve_hostname(cur, hostname, guid)
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO devices (
|
||||
guid,
|
||||
hostname,
|
||||
created_at,
|
||||
last_seen,
|
||||
ssl_key_fingerprint,
|
||||
token_version,
|
||||
status,
|
||||
key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
|
||||
""",
|
||||
(
|
||||
guid.value,
|
||||
resolved_hostname,
|
||||
now_ts,
|
||||
now_ts,
|
||||
fingerprint.value,
|
||||
now_iso,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid.value.upper(),),
|
||||
)
|
||||
latest = cur.fetchone()
|
||||
|
||||
if not latest:
|
||||
raise RuntimeError("device record could not be ensured")
|
||||
|
||||
record = self._row_to_record(latest)
|
||||
if record is None:
|
||||
raise RuntimeError("device record invalid after ensure")
|
||||
return record
|
||||
|
||||
def record_device_key(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
added_at: datetime,
|
||||
) -> None:
|
||||
added_iso = added_at.astimezone(timezone.utc).isoformat()
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(str(uuid.uuid4()), guid.value, fingerprint.value, added_iso),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_keys
|
||||
SET retired_at = ?
|
||||
WHERE guid = ?
|
||||
AND ssl_key_fingerprint != ?
|
||||
AND retired_at IS NULL
|
||||
""",
|
||||
(added_iso, guid.value, fingerprint.value),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def update_device_summary(
|
||||
self,
|
||||
*,
|
||||
hostname: Optional[str],
|
||||
last_seen: Optional[int] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
operating_system: Optional[str] = None,
|
||||
last_user: Optional[str] = None,
|
||||
) -> None:
|
||||
if not hostname:
|
||||
return
|
||||
|
||||
normalized_hostname = (hostname or "").strip()
|
||||
if not normalized_hostname:
|
||||
return
|
||||
|
||||
fields = []
|
||||
params = []
|
||||
|
||||
if last_seen is not None:
|
||||
try:
|
||||
fields.append("last_seen = ?")
|
||||
params.append(int(last_seen))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if agent_id:
|
||||
try:
|
||||
candidate = agent_id.strip()
|
||||
except Exception:
|
||||
candidate = agent_id
|
||||
if candidate:
|
||||
fields.append("agent_id = ?")
|
||||
params.append(candidate)
|
||||
|
||||
if operating_system:
|
||||
try:
|
||||
os_value = operating_system.strip()
|
||||
except Exception:
|
||||
os_value = operating_system
|
||||
if os_value:
|
||||
fields.append("operating_system = ?")
|
||||
params.append(os_value)
|
||||
|
||||
if last_user:
|
||||
try:
|
||||
user_value = last_user.strip()
|
||||
except Exception:
|
||||
user_value = last_user
|
||||
if user_value:
|
||||
fields.append("last_user = ?")
|
||||
params.append(user_value)
|
||||
|
||||
if not fields:
|
||||
return
|
||||
|
||||
params.append(normalized_hostname)
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
f"UPDATE devices SET {', '.join(fields)} WHERE LOWER(hostname) = LOWER(?)",
|
||||
params,
|
||||
)
|
||||
if cur.rowcount == 0 and agent_id:
|
||||
cur.execute(
|
||||
f"UPDATE devices SET {', '.join(fields)} WHERE agent_id = ?",
|
||||
params[:-1] + [agent_id],
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _row_to_record(self, row: tuple) -> Optional[DeviceRecord]:
|
||||
try:
|
||||
guid = DeviceGuid(row[0])
|
||||
fingerprint_value = (row[1] or "").strip()
|
||||
if not fingerprint_value:
|
||||
self._log.warning(
|
||||
"device row %s missing TLS fingerprint; skipping",
|
||||
row[0],
|
||||
)
|
||||
return None
|
||||
fingerprint = DeviceFingerprint(fingerprint_value)
|
||||
except Exception as exc:
|
||||
self._log.warning("invalid device row for guid=%s: %s", row[0], exc)
|
||||
return None
|
||||
|
||||
token_version_raw = row[2]
|
||||
try:
|
||||
token_version = int(token_version_raw or 0)
|
||||
except Exception:
|
||||
token_version = 0
|
||||
|
||||
status = DeviceStatus.from_string(row[3])
|
||||
identity = DeviceIdentity(guid=guid, fingerprint=fingerprint)
|
||||
|
||||
return DeviceRecord(
|
||||
identity=identity,
|
||||
token_version=max(token_version, 1),
|
||||
status=status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: object, *, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
def _resolve_hostname(self, cur: sqlite3.Cursor, hostname: str, guid: DeviceGuid) -> str:
|
||||
base = (hostname or "").strip() or guid.value
|
||||
base = base[:253]
|
||||
candidate = base
|
||||
suffix = 1
|
||||
while True:
|
||||
cur.execute(
|
||||
"SELECT guid FROM devices WHERE hostname = ?",
|
||||
(candidate,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return candidate
|
||||
existing = (row[0] or "").strip().upper()
|
||||
if existing == guid.value:
|
||||
return candidate
|
||||
candidate = f"{base}-{suffix}"
|
||||
suffix += 1
|
||||
if suffix > 50:
|
||||
return guid.value
|
||||
@@ -1,143 +0,0 @@
|
||||
"""SQLite persistence for device list views."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
from contextlib import closing
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from Data.Engine.domain.device_views import DeviceListView
|
||||
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||
|
||||
__all__ = ["SQLiteDeviceViewRepository"]
|
||||
|
||||
|
||||
class SQLiteDeviceViewRepository:
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connections = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.device_views")
|
||||
|
||||
def list_views(self) -> List[DeviceListView]:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT id, name, columns_json, filters_json, created_at, updated_at\n"
|
||||
" FROM device_list_views ORDER BY name COLLATE NOCASE ASC"
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [self._row_to_view(row) for row in rows]
|
||||
|
||||
def get_view(self, view_id: int) -> Optional[DeviceListView]:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT id, name, columns_json, filters_json, created_at, updated_at\n"
|
||||
" FROM device_list_views WHERE id = ?",
|
||||
(view_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return self._row_to_view(row) if row else None
|
||||
|
||||
def create_view(self, name: str, columns: List[str], filters: Dict[str, object]) -> DeviceListView:
|
||||
now = int(time.time())
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"INSERT INTO device_list_views(name, columns_json, filters_json, created_at, updated_at)\n"
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(name, json.dumps(columns), json.dumps(filters), now, now),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise ValueError("duplicate") from exc
|
||||
view_id = cur.lastrowid
|
||||
conn.commit()
|
||||
cur.execute(
|
||||
"SELECT id, name, columns_json, filters_json, created_at, updated_at FROM device_list_views WHERE id = ?",
|
||||
(view_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
raise RuntimeError("view missing after insert")
|
||||
return self._row_to_view(row)
|
||||
|
||||
def update_view(
|
||||
self,
|
||||
view_id: int,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
filters: Optional[Dict[str, object]] = None,
|
||||
) -> DeviceListView:
|
||||
fields: List[str] = []
|
||||
params: List[object] = []
|
||||
if name is not None:
|
||||
fields.append("name = ?")
|
||||
params.append(name)
|
||||
if columns is not None:
|
||||
fields.append("columns_json = ?")
|
||||
params.append(json.dumps(columns))
|
||||
if filters is not None:
|
||||
fields.append("filters_json = ?")
|
||||
params.append(json.dumps(filters))
|
||||
fields.append("updated_at = ?")
|
||||
params.append(int(time.time()))
|
||||
params.append(view_id)
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
f"UPDATE device_list_views SET {', '.join(fields)} WHERE id = ?",
|
||||
params,
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise ValueError("duplicate") from exc
|
||||
if cur.rowcount == 0:
|
||||
raise LookupError("not_found")
|
||||
conn.commit()
|
||||
cur.execute(
|
||||
"SELECT id, name, columns_json, filters_json, created_at, updated_at FROM device_list_views WHERE id = ?",
|
||||
(view_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
raise LookupError("not_found")
|
||||
return self._row_to_view(row)
|
||||
|
||||
def delete_view(self, view_id: int) -> bool:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("DELETE FROM device_list_views WHERE id = ?", (view_id,))
|
||||
deleted = cur.rowcount
|
||||
conn.commit()
|
||||
return bool(deleted)
|
||||
|
||||
def _row_to_view(self, row: Optional[Iterable[object]]) -> DeviceListView:
|
||||
if row is None:
|
||||
raise ValueError("row required")
|
||||
view_id, name, columns_json, filters_json, created_at, updated_at = row
|
||||
try:
|
||||
columns = json.loads(columns_json or "[]")
|
||||
except Exception:
|
||||
columns = []
|
||||
try:
|
||||
filters = json.loads(filters_json or "{}")
|
||||
except Exception:
|
||||
filters = {}
|
||||
return DeviceListView(
|
||||
id=int(view_id),
|
||||
name=str(name or ""),
|
||||
columns=list(columns) if isinstance(columns, list) else [],
|
||||
filters=dict(filters) if isinstance(filters, dict) else {},
|
||||
created_at=int(created_at or 0),
|
||||
updated_at=int(updated_at or 0),
|
||||
)
|
||||
@@ -1,726 +0,0 @@
|
||||
"""SQLite-backed enrollment repository for Engine services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import closing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid, normalize_guid
|
||||
from Data.Engine.domain.device_enrollment import (
|
||||
EnrollmentApproval,
|
||||
EnrollmentApprovalStatus,
|
||||
EnrollmentCode,
|
||||
)
|
||||
from Data.Engine.domain.enrollment_admin import (
|
||||
DeviceApprovalRecord,
|
||||
EnrollmentCodeRecord,
|
||||
HostnameConflict,
|
||||
)
|
||||
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||
|
||||
__all__ = ["SQLiteEnrollmentRepository"]
|
||||
|
||||
|
||||
class SQLiteEnrollmentRepository:
|
||||
"""Persistence adapter that manages enrollment codes and approvals."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connections = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.enrollment")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Enrollment install codes
|
||||
# ------------------------------------------------------------------
|
||||
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]:
|
||||
"""Load an enrollment install code by its public value."""
|
||||
|
||||
code_value = (code or "").strip()
|
||||
if not code_value:
|
||||
return None
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
WHERE code = ?
|
||||
""",
|
||||
(code_value,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
record = {
|
||||
"id": row[0],
|
||||
"code": row[1],
|
||||
"expires_at": row[2],
|
||||
"used_at": row[3],
|
||||
"used_by_guid": row[4],
|
||||
"max_uses": row[5],
|
||||
"use_count": row[6],
|
||||
"last_used_at": row[7],
|
||||
}
|
||||
try:
|
||||
return EnrollmentCode.from_mapping(record)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("invalid enrollment code record for code=%s: %s", code_value, exc)
|
||||
return None
|
||||
|
||||
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]:
|
||||
record_value = (record_id or "").strip()
|
||||
if not record_value:
|
||||
return None
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
WHERE id = ?
|
||||
""",
|
||||
(record_value,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
record = {
|
||||
"id": row[0],
|
||||
"code": row[1],
|
||||
"expires_at": row[2],
|
||||
"used_at": row[3],
|
||||
"used_by_guid": row[4],
|
||||
"max_uses": row[5],
|
||||
"use_count": row[6],
|
||||
"last_used_at": row[7],
|
||||
}
|
||||
|
||||
try:
|
||||
return EnrollmentCode.from_mapping(record)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc)
|
||||
return None
|
||||
|
||||
def list_install_codes(
|
||||
self,
|
||||
*,
|
||||
status: Optional[str] = None,
|
||||
now: Optional[datetime] = None,
|
||||
) -> List[EnrollmentCodeRecord]:
|
||||
reference = now or datetime.now(tz=timezone.utc)
|
||||
status_filter = (status or "").strip().lower()
|
||||
params: List[str] = []
|
||||
|
||||
sql = """
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
created_by_user_id,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
"""
|
||||
|
||||
if status_filter in {"active", "expired", "used"}:
|
||||
sql += " WHERE "
|
||||
if status_filter == "active":
|
||||
sql += "use_count < max_uses AND expires_at > ?"
|
||||
params.append(self._isoformat(reference))
|
||||
elif status_filter == "expired":
|
||||
sql += "use_count < max_uses AND expires_at <= ?"
|
||||
params.append(self._isoformat(reference))
|
||||
else: # used
|
||||
sql += "use_count >= max_uses"
|
||||
|
||||
sql += " ORDER BY expires_at ASC"
|
||||
|
||||
rows: List[EnrollmentCodeRecord] = []
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql, params)
|
||||
for raw in cur.fetchall():
|
||||
record = {
|
||||
"id": raw[0],
|
||||
"code": raw[1],
|
||||
"expires_at": raw[2],
|
||||
"created_by_user_id": raw[3],
|
||||
"used_at": raw[4],
|
||||
"used_by_guid": raw[5],
|
||||
"max_uses": raw[6],
|
||||
"use_count": raw[7],
|
||||
"last_used_at": raw[8],
|
||||
}
|
||||
try:
|
||||
rows.append(EnrollmentCodeRecord.from_row(record))
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("invalid enrollment install code row id=%s: %s", record.get("id"), exc)
|
||||
return rows
|
||||
|
||||
def get_install_code_record(self, record_id: str) -> Optional[EnrollmentCodeRecord]:
|
||||
identifier = (record_id or "").strip()
|
||||
if not identifier:
|
||||
return None
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
created_by_user_id,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
WHERE id = ?
|
||||
""",
|
||||
(identifier,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
payload = {
|
||||
"id": row[0],
|
||||
"code": row[1],
|
||||
"expires_at": row[2],
|
||||
"created_by_user_id": row[3],
|
||||
"used_at": row[4],
|
||||
"used_by_guid": row[5],
|
||||
"max_uses": row[6],
|
||||
"use_count": row[7],
|
||||
"last_used_at": row[8],
|
||||
}
|
||||
|
||||
try:
|
||||
return EnrollmentCodeRecord.from_row(payload)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("invalid enrollment install code record id=%s: %s", identifier, exc)
|
||||
return None
|
||||
|
||||
def insert_install_code(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
code: str,
|
||||
expires_at: datetime,
|
||||
created_by: Optional[str],
|
||||
max_uses: int,
|
||||
) -> EnrollmentCodeRecord:
|
||||
expires_iso = self._isoformat(expires_at)
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO enrollment_install_codes (
|
||||
id,
|
||||
code,
|
||||
expires_at,
|
||||
created_by_user_id,
|
||||
max_uses,
|
||||
use_count
|
||||
) VALUES (?, ?, ?, ?, ?, 0)
|
||||
""",
|
||||
(record_id, code, expires_iso, created_by, max_uses),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
record = self.get_install_code_record(record_id)
|
||||
if record is None:
|
||||
raise RuntimeError("failed to load install code after insert")
|
||||
return record
|
||||
|
||||
def delete_install_code_if_unused(self, record_id: str) -> bool:
|
||||
identifier = (record_id or "").strip()
|
||||
if not identifier:
|
||||
return False
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
|
||||
(identifier,),
|
||||
)
|
||||
deleted = cur.rowcount > 0
|
||||
conn.commit()
|
||||
return deleted
|
||||
|
||||
def update_install_code_usage(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
use_count_increment: int,
|
||||
last_used_at: datetime,
|
||||
used_by_guid: Optional[DeviceGuid] = None,
|
||||
mark_first_use: bool = False,
|
||||
) -> None:
|
||||
"""Increment usage counters and usage metadata for an install code."""
|
||||
|
||||
if use_count_increment <= 0:
|
||||
raise ValueError("use_count_increment must be positive")
|
||||
|
||||
last_used_iso = self._isoformat(last_used_at)
|
||||
guid_value = used_by_guid.value if used_by_guid else ""
|
||||
mark_flag = 1 if mark_first_use else 0
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE enrollment_install_codes
|
||||
SET use_count = use_count + ?,
|
||||
last_used_at = ?,
|
||||
used_by_guid = COALESCE(NULLIF(?, ''), used_by_guid),
|
||||
used_at = CASE WHEN ? = 1 AND used_at IS NULL THEN ? ELSE used_at END
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
use_count_increment,
|
||||
last_used_iso,
|
||||
guid_value,
|
||||
mark_flag,
|
||||
last_used_iso,
|
||||
record_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Device approvals
|
||||
# ------------------------------------------------------------------
|
||||
def list_device_approvals(
|
||||
self,
|
||||
*,
|
||||
status: Optional[str] = None,
|
||||
) -> List[DeviceApprovalRecord]:
|
||||
status_filter = (status or "").strip().lower()
|
||||
params: List[str] = []
|
||||
|
||||
sql = """
|
||||
SELECT
|
||||
da.id,
|
||||
da.approval_reference,
|
||||
da.guid,
|
||||
da.hostname_claimed,
|
||||
da.ssl_key_fingerprint_claimed,
|
||||
da.enrollment_code_id,
|
||||
da.status,
|
||||
da.client_nonce,
|
||||
da.server_nonce,
|
||||
da.created_at,
|
||||
da.updated_at,
|
||||
da.approved_by_user_id,
|
||||
u.username AS approved_by_username
|
||||
FROM device_approvals AS da
|
||||
LEFT JOIN users AS u
|
||||
ON (
|
||||
CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT)
|
||||
OR LOWER(da.approved_by_user_id) = LOWER(u.username)
|
||||
)
|
||||
"""
|
||||
|
||||
if status_filter and status_filter not in {"all", "*"}:
|
||||
sql += " WHERE LOWER(da.status) = ?"
|
||||
params.append(status_filter)
|
||||
|
||||
sql += " ORDER BY da.created_at ASC"
|
||||
|
||||
approvals: List[DeviceApprovalRecord] = []
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql, params)
|
||||
rows = cur.fetchall()
|
||||
|
||||
for raw in rows:
|
||||
record = {
|
||||
"id": raw[0],
|
||||
"approval_reference": raw[1],
|
||||
"guid": raw[2],
|
||||
"hostname_claimed": raw[3],
|
||||
"ssl_key_fingerprint_claimed": raw[4],
|
||||
"enrollment_code_id": raw[5],
|
||||
"status": raw[6],
|
||||
"client_nonce": raw[7],
|
||||
"server_nonce": raw[8],
|
||||
"created_at": raw[9],
|
||||
"updated_at": raw[10],
|
||||
"approved_by_user_id": raw[11],
|
||||
"approved_by_username": raw[12],
|
||||
}
|
||||
|
||||
conflict, fingerprint_match, requires_prompt = self._compute_hostname_conflict(
|
||||
conn,
|
||||
record.get("hostname_claimed"),
|
||||
record.get("guid"),
|
||||
record.get("ssl_key_fingerprint_claimed") or "",
|
||||
)
|
||||
|
||||
alternate = None
|
||||
if conflict and requires_prompt:
|
||||
alternate = self._suggest_alternate_hostname(
|
||||
conn,
|
||||
record.get("hostname_claimed"),
|
||||
record.get("guid"),
|
||||
)
|
||||
|
||||
try:
|
||||
approvals.append(
|
||||
DeviceApprovalRecord.from_row(
|
||||
record,
|
||||
conflict=conflict,
|
||||
alternate_hostname=alternate,
|
||||
fingerprint_match=fingerprint_match,
|
||||
requires_prompt=requires_prompt,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning(
|
||||
"invalid device approval record id=%s: %s",
|
||||
record.get("id"),
|
||||
exc,
|
||||
)
|
||||
|
||||
return approvals
|
||||
|
||||
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]:
|
||||
"""Load a device approval using its operator-visible reference."""
|
||||
|
||||
ref_value = (reference or "").strip()
|
||||
if not ref_value:
|
||||
return None
|
||||
return self._fetch_device_approval("approval_reference = ?", (ref_value,))
|
||||
|
||||
def fetch_device_approval(self, record_id: str) -> Optional[EnrollmentApproval]:
|
||||
record_value = (record_id or "").strip()
|
||||
if not record_value:
|
||||
return None
|
||||
return self._fetch_device_approval("id = ?", (record_value,))
|
||||
|
||||
def fetch_pending_approval_by_fingerprint(
|
||||
self, fingerprint: DeviceFingerprint
|
||||
) -> Optional[EnrollmentApproval]:
|
||||
return self._fetch_device_approval(
|
||||
"ssl_key_fingerprint_claimed = ? AND status = 'pending'",
|
||||
(fingerprint.value,),
|
||||
)
|
||||
|
||||
def update_pending_approval(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
hostname: str,
|
||||
guid: Optional[DeviceGuid],
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
updated_at: datetime,
|
||||
) -> None:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET hostname_claimed = ?,
|
||||
guid = ?,
|
||||
enrollment_code_id = ?,
|
||||
client_nonce = ?,
|
||||
server_nonce = ?,
|
||||
agent_pubkey_der = ?,
|
||||
updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
hostname,
|
||||
guid.value if guid else None,
|
||||
enrollment_code_id,
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
self._isoformat(updated_at),
|
||||
record_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def create_device_approval(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
reference: str,
|
||||
claimed_hostname: str,
|
||||
claimed_fingerprint: DeviceFingerprint,
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
created_at: datetime,
|
||||
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> EnrollmentApproval:
|
||||
created_iso = self._isoformat(created_at)
|
||||
guid_value = guid.value if guid else None
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO device_approvals (
|
||||
id,
|
||||
approval_reference,
|
||||
guid,
|
||||
hostname_claimed,
|
||||
ssl_key_fingerprint_claimed,
|
||||
enrollment_code_id,
|
||||
status,
|
||||
created_at,
|
||||
updated_at,
|
||||
client_nonce,
|
||||
server_nonce,
|
||||
agent_pubkey_der
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
record_id,
|
||||
reference,
|
||||
guid_value,
|
||||
claimed_hostname,
|
||||
claimed_fingerprint.value,
|
||||
enrollment_code_id,
|
||||
status.value,
|
||||
created_iso,
|
||||
created_iso,
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
approval = self.fetch_device_approval(record_id)
|
||||
if approval is None:
|
||||
raise RuntimeError("failed to load device approval after insert")
|
||||
return approval
|
||||
|
||||
def update_device_approval_status(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
status: EnrollmentApprovalStatus,
|
||||
updated_at: datetime,
|
||||
approved_by: Optional[str] = None,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> None:
|
||||
"""Transition an approval to a new status."""
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET status = ?,
|
||||
updated_at = ?,
|
||||
guid = COALESCE(?, guid),
|
||||
approved_by_user_id = COALESCE(?, approved_by_user_id)
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
status.value,
|
||||
self._isoformat(updated_at),
|
||||
guid.value if guid else None,
|
||||
approved_by,
|
||||
record_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _fetch_device_approval(self, where: str, params: tuple) -> Optional[EnrollmentApproval]:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT id,
|
||||
approval_reference,
|
||||
guid,
|
||||
hostname_claimed,
|
||||
ssl_key_fingerprint_claimed,
|
||||
enrollment_code_id,
|
||||
created_at,
|
||||
updated_at,
|
||||
status,
|
||||
approved_by_user_id,
|
||||
client_nonce,
|
||||
server_nonce,
|
||||
agent_pubkey_der
|
||||
FROM device_approvals
|
||||
WHERE {where}
|
||||
""",
|
||||
params,
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
record = {
|
||||
"id": row[0],
|
||||
"approval_reference": row[1],
|
||||
"guid": row[2],
|
||||
"hostname_claimed": row[3],
|
||||
"ssl_key_fingerprint_claimed": row[4],
|
||||
"enrollment_code_id": row[5],
|
||||
"created_at": row[6],
|
||||
"updated_at": row[7],
|
||||
"status": row[8],
|
||||
"approved_by_user_id": row[9],
|
||||
"client_nonce": row[10],
|
||||
"server_nonce": row[11],
|
||||
"agent_pubkey_der": row[12],
|
||||
}
|
||||
|
||||
try:
|
||||
return EnrollmentApproval.from_mapping(record)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning(
|
||||
"invalid device approval record id=%s reference=%s: %s",
|
||||
row[0],
|
||||
row[1],
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
def _compute_hostname_conflict(
|
||||
self,
|
||||
conn,
|
||||
hostname: Optional[str],
|
||||
pending_guid: Optional[str],
|
||||
claimed_fp: str,
|
||||
) -> Tuple[Optional[HostnameConflict], bool, bool]:
|
||||
normalized_host = (hostname or "").strip()
|
||||
if not normalized_host:
|
||||
return None, False, False
|
||||
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT d.guid,
|
||||
d.ssl_key_fingerprint,
|
||||
ds.site_id,
|
||||
s.name
|
||||
FROM devices AS d
|
||||
LEFT JOIN device_sites AS ds ON ds.device_hostname = d.hostname
|
||||
LEFT JOIN sites AS s ON s.id = ds.site_id
|
||||
WHERE d.hostname = ?
|
||||
""",
|
||||
(normalized_host,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("failed to inspect hostname conflict for %s: %s", normalized_host, exc)
|
||||
return None, False, False
|
||||
|
||||
if not row:
|
||||
return None, False, False
|
||||
|
||||
existing_guid = normalize_guid(row[0])
|
||||
pending_norm = normalize_guid(pending_guid)
|
||||
if existing_guid and pending_norm and existing_guid == pending_norm:
|
||||
return None, False, False
|
||||
|
||||
stored_fp = (row[1] or "").strip().lower()
|
||||
claimed_fp_normalized = (claimed_fp or "").strip().lower()
|
||||
fingerprint_match = bool(stored_fp and claimed_fp_normalized and stored_fp == claimed_fp_normalized)
|
||||
|
||||
site_id = None
|
||||
if row[2] is not None:
|
||||
try:
|
||||
site_id = int(row[2])
|
||||
except (TypeError, ValueError): # pragma: no cover - defensive
|
||||
site_id = None
|
||||
|
||||
site_name = str(row[3] or "").strip()
|
||||
requires_prompt = not fingerprint_match
|
||||
|
||||
conflict = HostnameConflict(
|
||||
guid=existing_guid or None,
|
||||
ssl_key_fingerprint=stored_fp or None,
|
||||
site_id=site_id,
|
||||
site_name=site_name,
|
||||
fingerprint_match=fingerprint_match,
|
||||
requires_prompt=requires_prompt,
|
||||
)
|
||||
|
||||
return conflict, fingerprint_match, requires_prompt
|
||||
|
||||
def _suggest_alternate_hostname(
|
||||
self,
|
||||
conn,
|
||||
hostname: Optional[str],
|
||||
pending_guid: Optional[str],
|
||||
) -> Optional[str]:
|
||||
base = (hostname or "").strip()
|
||||
if not base:
|
||||
return None
|
||||
base = base[:253]
|
||||
candidate = base
|
||||
pending_norm = normalize_guid(pending_guid)
|
||||
suffix = 1
|
||||
|
||||
cur = conn.cursor()
|
||||
while True:
|
||||
cur.execute("SELECT guid FROM devices WHERE hostname = ?", (candidate,))
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return candidate
|
||||
existing_guid = normalize_guid(row[0])
|
||||
if pending_norm and existing_guid == pending_norm:
|
||||
return candidate
|
||||
candidate = f"{base}-{suffix}"
|
||||
suffix += 1
|
||||
if suffix > 50:
|
||||
return pending_norm or candidate
|
||||
|
||||
@staticmethod
|
||||
def _isoformat(value: datetime) -> str:
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc).isoformat()
|
||||
@@ -1,53 +0,0 @@
|
||||
"""SQLite-backed GitHub token persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import closing
|
||||
from typing import Optional
|
||||
|
||||
from .connection import SQLiteConnectionFactory
|
||||
|
||||
__all__ = ["SQLiteGitHubRepository"]
|
||||
|
||||
|
||||
class SQLiteGitHubRepository:
|
||||
"""Store and retrieve GitHub API tokens for the Engine."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connections = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.github")
|
||||
|
||||
def load_token(self) -> Optional[str]:
|
||||
"""Return the stored GitHub token if one exists."""
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT token FROM github_token LIMIT 1")
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
token = (row[0] or "").strip()
|
||||
return token or None
|
||||
|
||||
def store_token(self, token: Optional[str]) -> None:
|
||||
"""Persist *token*, replacing any prior value."""
|
||||
|
||||
normalized = (token or "").strip()
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("DELETE FROM github_token")
|
||||
if normalized:
|
||||
cur.execute("INSERT INTO github_token (token) VALUES (?)", (normalized,))
|
||||
conn.commit()
|
||||
|
||||
self._log.info("stored-token has_token=%s", bool(normalized))
|
||||
|
||||
@@ -1,355 +0,0 @@
|
||||
"""SQLite-backed persistence for Engine job scheduling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterable, Optional, Sequence
|
||||
|
||||
import sqlite3
|
||||
|
||||
from .connection import SQLiteConnectionFactory
|
||||
|
||||
__all__ = [
|
||||
"ScheduledJobRecord",
|
||||
"ScheduledJobRunRecord",
|
||||
"SQLiteJobRepository",
|
||||
]
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def _json_dumps(value: Any) -> str:
|
||||
try:
|
||||
return json.dumps(value or [])
|
||||
except Exception:
|
||||
return "[]"
|
||||
|
||||
|
||||
def _json_loads(value: Optional[str]) -> list[Any]:
|
||||
if not value:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(value)
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
return []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ScheduledJobRecord:
|
||||
id: int
|
||||
name: str
|
||||
components: list[dict[str, Any]]
|
||||
targets: list[str]
|
||||
schedule_type: str
|
||||
start_ts: Optional[int]
|
||||
duration_stop_enabled: bool
|
||||
expiration: Optional[str]
|
||||
execution_context: str
|
||||
credential_id: Optional[int]
|
||||
use_service_account: bool
|
||||
enabled: bool
|
||||
created_at: Optional[int]
|
||||
updated_at: Optional[int]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ScheduledJobRunRecord:
|
||||
id: int
|
||||
job_id: int
|
||||
scheduled_ts: Optional[int]
|
||||
started_ts: Optional[int]
|
||||
finished_ts: Optional[int]
|
||||
status: Optional[str]
|
||||
error: Optional[str]
|
||||
target_hostname: Optional[str]
|
||||
created_at: Optional[int]
|
||||
updated_at: Optional[int]
|
||||
|
||||
|
||||
class SQLiteJobRepository:
|
||||
"""Persistence adapter for Engine job scheduling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._factory = factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.jobs")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Job CRUD
|
||||
# ------------------------------------------------------------------
|
||||
def list_jobs(self) -> list[ScheduledJobRecord]:
|
||||
query = (
|
||||
"SELECT id, name, components_json, targets_json, schedule_type, start_ts, "
|
||||
"duration_stop_enabled, expiration, execution_context, credential_id, "
|
||||
"use_service_account, enabled, created_at, updated_at FROM scheduled_jobs "
|
||||
"ORDER BY id ASC"
|
||||
)
|
||||
return [self._row_to_job(row) for row in self._fetchall(query)]
|
||||
|
||||
def list_enabled_jobs(self) -> list[ScheduledJobRecord]:
|
||||
query = (
|
||||
"SELECT id, name, components_json, targets_json, schedule_type, start_ts, "
|
||||
"duration_stop_enabled, expiration, execution_context, credential_id, "
|
||||
"use_service_account, enabled, created_at, updated_at FROM scheduled_jobs "
|
||||
"WHERE enabled=1 ORDER BY id ASC"
|
||||
)
|
||||
return [self._row_to_job(row) for row in self._fetchall(query)]
|
||||
|
||||
def fetch_job(self, job_id: int) -> Optional[ScheduledJobRecord]:
|
||||
query = (
|
||||
"SELECT id, name, components_json, targets_json, schedule_type, start_ts, "
|
||||
"duration_stop_enabled, expiration, execution_context, credential_id, "
|
||||
"use_service_account, enabled, created_at, updated_at FROM scheduled_jobs "
|
||||
"WHERE id=?"
|
||||
)
|
||||
rows = self._fetchall(query, (job_id,))
|
||||
return self._row_to_job(rows[0]) if rows else None
|
||||
|
||||
def create_job(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
components: Sequence[dict[str, Any]],
|
||||
targets: Sequence[Any],
|
||||
schedule_type: str,
|
||||
start_ts: Optional[int],
|
||||
duration_stop_enabled: bool,
|
||||
expiration: Optional[str],
|
||||
execution_context: str,
|
||||
credential_id: Optional[int],
|
||||
use_service_account: bool,
|
||||
enabled: bool = True,
|
||||
) -> ScheduledJobRecord:
|
||||
now = _now_ts()
|
||||
payload = (
|
||||
name,
|
||||
_json_dumps(list(components)),
|
||||
_json_dumps(list(targets)),
|
||||
schedule_type,
|
||||
start_ts,
|
||||
1 if duration_stop_enabled else 0,
|
||||
expiration,
|
||||
execution_context,
|
||||
credential_id,
|
||||
1 if use_service_account else 0,
|
||||
1 if enabled else 0,
|
||||
now,
|
||||
now,
|
||||
)
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO scheduled_jobs
|
||||
(name, components_json, targets_json, schedule_type, start_ts,
|
||||
duration_stop_enabled, expiration, execution_context, credential_id,
|
||||
use_service_account, enabled, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
payload,
|
||||
)
|
||||
job_id = cur.lastrowid
|
||||
conn.commit()
|
||||
record = self.fetch_job(int(job_id))
|
||||
if record is None:
|
||||
raise RuntimeError("failed to create scheduled job")
|
||||
return record
|
||||
|
||||
def update_job(
|
||||
self,
|
||||
job_id: int,
|
||||
*,
|
||||
name: str,
|
||||
components: Sequence[dict[str, Any]],
|
||||
targets: Sequence[Any],
|
||||
schedule_type: str,
|
||||
start_ts: Optional[int],
|
||||
duration_stop_enabled: bool,
|
||||
expiration: Optional[str],
|
||||
execution_context: str,
|
||||
credential_id: Optional[int],
|
||||
use_service_account: bool,
|
||||
) -> Optional[ScheduledJobRecord]:
|
||||
now = _now_ts()
|
||||
payload = (
|
||||
name,
|
||||
_json_dumps(list(components)),
|
||||
_json_dumps(list(targets)),
|
||||
schedule_type,
|
||||
start_ts,
|
||||
1 if duration_stop_enabled else 0,
|
||||
expiration,
|
||||
execution_context,
|
||||
credential_id,
|
||||
1 if use_service_account else 0,
|
||||
now,
|
||||
job_id,
|
||||
)
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE scheduled_jobs
|
||||
SET name=?, components_json=?, targets_json=?, schedule_type=?,
|
||||
start_ts=?, duration_stop_enabled=?, expiration=?, execution_context=?,
|
||||
credential_id=?, use_service_account=?, updated_at=?
|
||||
WHERE id=?
|
||||
""",
|
||||
payload,
|
||||
)
|
||||
conn.commit()
|
||||
return self.fetch_job(job_id)
|
||||
|
||||
def set_enabled(self, job_id: int, enabled: bool) -> None:
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"UPDATE scheduled_jobs SET enabled=?, updated_at=? WHERE id=?",
|
||||
(1 if enabled else 0, _now_ts(), job_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def delete_job(self, job_id: int) -> None:
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("DELETE FROM scheduled_jobs WHERE id=?", (job_id,))
|
||||
conn.commit()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Run history
|
||||
# ------------------------------------------------------------------
|
||||
def list_runs(self, job_id: int, *, days: Optional[int] = None) -> list[ScheduledJobRunRecord]:
|
||||
params: list[Any] = [job_id]
|
||||
where = "WHERE job_id=?"
|
||||
if days is not None and days > 0:
|
||||
cutoff = _now_ts() - (days * 86400)
|
||||
where += " AND COALESCE(finished_ts, scheduled_ts, started_ts, 0) >= ?"
|
||||
params.append(cutoff)
|
||||
|
||||
query = (
|
||||
"SELECT id, job_id, scheduled_ts, started_ts, finished_ts, status, error, "
|
||||
"target_hostname, created_at, updated_at FROM scheduled_job_runs "
|
||||
f"{where} ORDER BY COALESCE(scheduled_ts, created_at, id) DESC"
|
||||
)
|
||||
return [self._row_to_run(row) for row in self._fetchall(query, tuple(params))]
|
||||
|
||||
def fetch_last_run(self, job_id: int) -> Optional[ScheduledJobRunRecord]:
|
||||
query = (
|
||||
"SELECT id, job_id, scheduled_ts, started_ts, finished_ts, status, error, "
|
||||
"target_hostname, created_at, updated_at FROM scheduled_job_runs "
|
||||
"WHERE job_id=? ORDER BY COALESCE(started_ts, scheduled_ts, created_at, id) DESC LIMIT 1"
|
||||
)
|
||||
rows = self._fetchall(query, (job_id,))
|
||||
return self._row_to_run(rows[0]) if rows else None
|
||||
|
||||
def purge_runs(self, job_id: int) -> None:
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("DELETE FROM scheduled_job_run_activity WHERE run_id IN (SELECT id FROM scheduled_job_runs WHERE job_id=?)", (job_id,))
|
||||
cur.execute("DELETE FROM scheduled_job_runs WHERE job_id=?", (job_id,))
|
||||
conn.commit()
|
||||
|
||||
def create_run(self, job_id: int, scheduled_ts: int, *, target_hostname: Optional[str] = None) -> int:
|
||||
now = _now_ts()
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO scheduled_job_runs
|
||||
(job_id, scheduled_ts, created_at, updated_at, target_hostname, status)
|
||||
VALUES (?, ?, ?, ?, ?, 'Pending')
|
||||
""",
|
||||
(job_id, scheduled_ts, now, now, target_hostname),
|
||||
)
|
||||
run_id = int(cur.lastrowid)
|
||||
conn.commit()
|
||||
return run_id
|
||||
|
||||
def mark_run_started(self, run_id: int, *, started_ts: Optional[int] = None) -> None:
|
||||
started = started_ts or _now_ts()
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"UPDATE scheduled_job_runs SET started_ts=?, status='Running', updated_at=? WHERE id=?",
|
||||
(started, _now_ts(), run_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def mark_run_finished(
|
||||
self,
|
||||
run_id: int,
|
||||
*,
|
||||
status: str,
|
||||
error: Optional[str] = None,
|
||||
finished_ts: Optional[int] = None,
|
||||
) -> None:
|
||||
finished = finished_ts or _now_ts()
|
||||
with self._connect() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"UPDATE scheduled_job_runs SET finished_ts=?, status=?, error=?, updated_at=? WHERE id=?",
|
||||
(finished, status, error, _now_ts(), run_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
return self._factory()
|
||||
|
||||
def _fetchall(self, query: str, params: Optional[Iterable[Any]] = None) -> list[sqlite3.Row]:
|
||||
with self._connect() as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cur = conn.cursor()
|
||||
cur.execute(query, tuple(params or ()))
|
||||
rows = cur.fetchall()
|
||||
return rows
|
||||
|
||||
def _row_to_job(self, row: sqlite3.Row) -> ScheduledJobRecord:
|
||||
components = _json_loads(row["components_json"])
|
||||
targets_raw = _json_loads(row["targets_json"])
|
||||
targets = [str(t) for t in targets_raw if isinstance(t, (str, int))]
|
||||
credential_id = row["credential_id"]
|
||||
return ScheduledJobRecord(
|
||||
id=int(row["id"]),
|
||||
name=str(row["name"] or ""),
|
||||
components=[c for c in components if isinstance(c, dict)],
|
||||
targets=targets,
|
||||
schedule_type=str(row["schedule_type"] or "immediately"),
|
||||
start_ts=int(row["start_ts"]) if row["start_ts"] is not None else None,
|
||||
duration_stop_enabled=bool(row["duration_stop_enabled"]),
|
||||
expiration=str(row["expiration"]) if row["expiration"] else None,
|
||||
execution_context=str(row["execution_context"] or "system"),
|
||||
credential_id=int(credential_id) if credential_id is not None else None,
|
||||
use_service_account=bool(row["use_service_account"]),
|
||||
enabled=bool(row["enabled"]),
|
||||
created_at=int(row["created_at"]) if row["created_at"] is not None else None,
|
||||
updated_at=int(row["updated_at"]) if row["updated_at"] is not None else None,
|
||||
)
|
||||
|
||||
def _row_to_run(self, row: sqlite3.Row) -> ScheduledJobRunRecord:
|
||||
return ScheduledJobRunRecord(
|
||||
id=int(row["id"]),
|
||||
job_id=int(row["job_id"]),
|
||||
scheduled_ts=int(row["scheduled_ts"]) if row["scheduled_ts"] is not None else None,
|
||||
started_ts=int(row["started_ts"]) if row["started_ts"] is not None else None,
|
||||
finished_ts=int(row["finished_ts"]) if row["finished_ts"] is not None else None,
|
||||
status=str(row["status"]) if row["status"] else None,
|
||||
error=str(row["error"]) if row["error"] else None,
|
||||
target_hostname=str(row["target_hostname"]) if row["target_hostname"] else None,
|
||||
created_at=int(row["created_at"]) if row["created_at"] is not None else None,
|
||||
updated_at=int(row["updated_at"]) if row["updated_at"] is not None else None,
|
||||
)
|
||||
@@ -1,665 +0,0 @@
|
||||
"""SQLite schema migrations for the Borealis Engine.
|
||||
|
||||
This module centralises schema evolution so the Engine and its interfaces can stay
|
||||
focused on request handling. The migration functions are intentionally
|
||||
idempotent — they can run repeatedly without changing state once the schema
|
||||
matches the desired shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
|
||||
DEVICE_TABLE = "devices"
|
||||
_DEFAULT_ADMIN_USERNAME = "admin"
|
||||
_DEFAULT_ADMIN_PASSWORD_SHA512 = (
|
||||
"e6c83b282aeb2e022844595721cc00bbda47cb24537c1779f9bb84f04039e1676e6ba8573e588da1052510e3aa0a32a9e55879ae22b0c2d62136fc0a3e85f8bb"
|
||||
)
|
||||
|
||||
|
||||
def apply_all(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Run all known schema migrations against the provided sqlite3 connection.
|
||||
"""
|
||||
|
||||
_ensure_devices_table(conn)
|
||||
_ensure_device_aux_tables(conn)
|
||||
_ensure_refresh_token_table(conn)
|
||||
_ensure_install_code_table(conn)
|
||||
_ensure_device_approval_table(conn)
|
||||
_ensure_device_list_views_table(conn)
|
||||
_ensure_sites_tables(conn)
|
||||
_ensure_credentials_table(conn)
|
||||
_ensure_github_token_table(conn)
|
||||
_ensure_scheduled_jobs_table(conn)
|
||||
_ensure_scheduled_job_run_tables(conn)
|
||||
_ensure_users_table(conn)
|
||||
_ensure_default_admin(conn)
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _ensure_devices_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
if not _table_exists(cur, DEVICE_TABLE):
|
||||
_create_devices_table(cur)
|
||||
return
|
||||
|
||||
column_info = _table_info(cur, DEVICE_TABLE)
|
||||
col_names = [c[1] for c in column_info]
|
||||
pk_cols = [c[1] for c in column_info if c[5]]
|
||||
|
||||
needs_rebuild = pk_cols != ["guid"]
|
||||
required_columns = {
|
||||
"guid": "TEXT",
|
||||
"hostname": "TEXT",
|
||||
"description": "TEXT",
|
||||
"created_at": "INTEGER",
|
||||
"agent_hash": "TEXT",
|
||||
"memory": "TEXT",
|
||||
"network": "TEXT",
|
||||
"software": "TEXT",
|
||||
"storage": "TEXT",
|
||||
"cpu": "TEXT",
|
||||
"device_type": "TEXT",
|
||||
"domain": "TEXT",
|
||||
"external_ip": "TEXT",
|
||||
"internal_ip": "TEXT",
|
||||
"last_reboot": "TEXT",
|
||||
"last_seen": "INTEGER",
|
||||
"last_user": "TEXT",
|
||||
"operating_system": "TEXT",
|
||||
"uptime": "INTEGER",
|
||||
"agent_id": "TEXT",
|
||||
"ansible_ee_ver": "TEXT",
|
||||
"connection_type": "TEXT",
|
||||
"connection_endpoint": "TEXT",
|
||||
"ssl_key_fingerprint": "TEXT",
|
||||
"token_version": "INTEGER",
|
||||
"status": "TEXT",
|
||||
"key_added_at": "TEXT",
|
||||
}
|
||||
|
||||
missing_columns = [col for col in required_columns if col not in col_names]
|
||||
if missing_columns:
|
||||
needs_rebuild = True
|
||||
|
||||
if needs_rebuild:
|
||||
_rebuild_devices_table(conn, column_info)
|
||||
else:
|
||||
_ensure_column_defaults(cur)
|
||||
|
||||
_ensure_device_indexes(cur)
|
||||
|
||||
|
||||
def _ensure_device_aux_tables(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS device_keys (
|
||||
id TEXT PRIMARY KEY,
|
||||
guid TEXT NOT NULL,
|
||||
ssl_key_fingerprint TEXT NOT NULL,
|
||||
added_at TEXT NOT NULL,
|
||||
retired_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_device_keys_guid_fingerprint
|
||||
ON device_keys(guid, ssl_key_fingerprint)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_device_keys_guid
|
||||
ON device_keys(guid)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_refresh_token_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
id TEXT PRIMARY KEY,
|
||||
guid TEXT NOT NULL,
|
||||
token_hash TEXT NOT NULL,
|
||||
dpop_jkt TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
revoked_at TEXT,
|
||||
last_used_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_guid
|
||||
ON refresh_tokens(guid)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at
|
||||
ON refresh_tokens(expires_at)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS enrollment_install_codes (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL UNIQUE,
|
||||
expires_at TEXT NOT NULL,
|
||||
created_by_user_id TEXT,
|
||||
used_at TEXT,
|
||||
used_by_guid TEXT,
|
||||
max_uses INTEGER NOT NULL DEFAULT 1,
|
||||
use_count INTEGER NOT NULL DEFAULT 0,
|
||||
last_used_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_eic_expires_at
|
||||
ON enrollment_install_codes(expires_at)
|
||||
"""
|
||||
)
|
||||
|
||||
columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")}
|
||||
if "max_uses" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes
|
||||
ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1
|
||||
"""
|
||||
)
|
||||
if "use_count" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes
|
||||
ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0
|
||||
"""
|
||||
)
|
||||
if "last_used_at" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes
|
||||
ADD COLUMN last_used_at TEXT
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS device_approvals (
|
||||
id TEXT PRIMARY KEY,
|
||||
approval_reference TEXT NOT NULL UNIQUE,
|
||||
guid TEXT,
|
||||
hostname_claimed TEXT NOT NULL,
|
||||
ssl_key_fingerprint_claimed TEXT NOT NULL,
|
||||
enrollment_code_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
client_nonce TEXT NOT NULL,
|
||||
server_nonce TEXT NOT NULL,
|
||||
agent_pubkey_der BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
approved_by_user_id TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_da_status
|
||||
ON device_approvals(status)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_da_fp_status
|
||||
ON device_approvals(ssl_key_fingerprint_claimed, status)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_device_list_views_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS device_list_views (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
columns_json TEXT NOT NULL,
|
||||
filters_json TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_sites_tables(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS sites (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
description TEXT,
|
||||
created_at INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS device_sites (
|
||||
device_hostname TEXT UNIQUE NOT NULL,
|
||||
site_id INTEGER NOT NULL,
|
||||
assigned_at INTEGER,
|
||||
FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_credentials_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS credentials (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
description TEXT,
|
||||
site_id INTEGER,
|
||||
credential_type TEXT NOT NULL DEFAULT 'machine',
|
||||
connection_type TEXT NOT NULL DEFAULT 'ssh',
|
||||
username TEXT,
|
||||
password_encrypted BLOB,
|
||||
private_key_encrypted BLOB,
|
||||
private_key_passphrase_encrypted BLOB,
|
||||
become_method TEXT,
|
||||
become_username TEXT,
|
||||
become_password_encrypted BLOB,
|
||||
metadata_json TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE SET NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_github_token_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS github_token (
|
||||
token TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_scheduled_jobs_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS scheduled_jobs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
components_json TEXT NOT NULL,
|
||||
targets_json TEXT NOT NULL,
|
||||
schedule_type TEXT NOT NULL,
|
||||
start_ts INTEGER,
|
||||
duration_stop_enabled INTEGER DEFAULT 0,
|
||||
expiration TEXT,
|
||||
execution_context TEXT NOT NULL,
|
||||
credential_id INTEGER,
|
||||
use_service_account INTEGER NOT NULL DEFAULT 1,
|
||||
enabled INTEGER DEFAULT 1,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
try:
|
||||
columns = {row[1] for row in _table_info(cur, "scheduled_jobs")}
|
||||
if "credential_id" not in columns:
|
||||
cur.execute("ALTER TABLE scheduled_jobs ADD COLUMN credential_id INTEGER")
|
||||
if "use_service_account" not in columns:
|
||||
cur.execute(
|
||||
"ALTER TABLE scheduled_jobs ADD COLUMN use_service_account INTEGER NOT NULL DEFAULT 1"
|
||||
)
|
||||
except Exception:
|
||||
# Legacy deployments may fail the ALTER TABLE calls; ignore silently.
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_scheduled_job_run_tables(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS scheduled_job_runs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
job_id INTEGER NOT NULL,
|
||||
scheduled_ts INTEGER,
|
||||
started_ts INTEGER,
|
||||
finished_ts INTEGER,
|
||||
status TEXT,
|
||||
error TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER,
|
||||
target_hostname TEXT,
|
||||
FOREIGN KEY(job_id) REFERENCES scheduled_jobs(id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
try:
|
||||
cur.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_runs_job_sched_target ON scheduled_job_runs(job_id, scheduled_ts, target_hostname)"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS scheduled_job_run_activity (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id INTEGER NOT NULL,
|
||||
activity_id INTEGER NOT NULL,
|
||||
component_kind TEXT,
|
||||
script_type TEXT,
|
||||
component_path TEXT,
|
||||
component_name TEXT,
|
||||
created_at INTEGER,
|
||||
FOREIGN KEY(run_id) REFERENCES scheduled_job_runs(id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
try:
|
||||
cur.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_run_activity_run ON scheduled_job_run_activity(run_id)"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
cur.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_run_activity_activity ON scheduled_job_run_activity(activity_id)"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _create_devices_table(cur: sqlite3.Cursor) -> None:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE devices (
|
||||
guid TEXT PRIMARY KEY,
|
||||
hostname TEXT,
|
||||
description TEXT,
|
||||
created_at INTEGER,
|
||||
agent_hash TEXT,
|
||||
memory TEXT,
|
||||
network TEXT,
|
||||
software TEXT,
|
||||
storage TEXT,
|
||||
cpu TEXT,
|
||||
device_type TEXT,
|
||||
domain TEXT,
|
||||
external_ip TEXT,
|
||||
internal_ip TEXT,
|
||||
last_reboot TEXT,
|
||||
last_seen INTEGER,
|
||||
last_user TEXT,
|
||||
operating_system TEXT,
|
||||
uptime INTEGER,
|
||||
agent_id TEXT,
|
||||
ansible_ee_ver TEXT,
|
||||
connection_type TEXT,
|
||||
connection_endpoint TEXT,
|
||||
ssl_key_fingerprint TEXT,
|
||||
token_version INTEGER DEFAULT 1,
|
||||
status TEXT DEFAULT 'active',
|
||||
key_added_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
_ensure_device_indexes(cur)
|
||||
|
||||
|
||||
def _ensure_device_indexes(cur: sqlite3.Cursor) -> None:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_devices_hostname
|
||||
ON devices(hostname)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_devices_ssl_key
|
||||
ON devices(ssl_key_fingerprint)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_devices_status
|
||||
ON devices(status)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_column_defaults(cur: sqlite3.Cursor) -> None:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET token_version = COALESCE(token_version, 1)
|
||||
WHERE token_version IS NULL
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET status = COALESCE(status, 'active')
|
||||
WHERE status IS NULL OR status = ''
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _rebuild_devices_table(conn: sqlite3.Connection, column_info: Sequence[Tuple]) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute("PRAGMA foreign_keys=OFF")
|
||||
cur.execute("BEGIN IMMEDIATE")
|
||||
|
||||
cur.execute("ALTER TABLE devices RENAME TO devices_legacy")
|
||||
_create_devices_table(cur)
|
||||
|
||||
legacy_columns = [c[1] for c in column_info]
|
||||
cur.execute(f"SELECT {', '.join(legacy_columns)} FROM devices_legacy")
|
||||
rows = cur.fetchall()
|
||||
|
||||
insert_sql = (
|
||||
"""
|
||||
INSERT OR REPLACE INTO devices (
|
||||
guid, hostname, description, created_at, agent_hash, memory,
|
||||
network, software, storage, cpu, device_type, domain, external_ip,
|
||||
internal_ip, last_reboot, last_seen, last_user, operating_system,
|
||||
uptime, agent_id, ansible_ee_ver, connection_type, connection_endpoint,
|
||||
ssl_key_fingerprint, token_version, status, key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
record = dict(zip(legacy_columns, row))
|
||||
guid = _normalized_guid(record.get("guid"))
|
||||
if not guid:
|
||||
guid = str(uuid.uuid4())
|
||||
hostname = record.get("hostname")
|
||||
created_at = record.get("created_at")
|
||||
key_added_at = record.get("key_added_at")
|
||||
if key_added_at is None:
|
||||
key_added_at = _default_key_added_at(created_at)
|
||||
|
||||
params: Tuple = (
|
||||
guid,
|
||||
hostname,
|
||||
record.get("description"),
|
||||
created_at,
|
||||
record.get("agent_hash"),
|
||||
record.get("memory"),
|
||||
record.get("network"),
|
||||
record.get("software"),
|
||||
record.get("storage"),
|
||||
record.get("cpu"),
|
||||
record.get("device_type"),
|
||||
record.get("domain"),
|
||||
record.get("external_ip"),
|
||||
record.get("internal_ip"),
|
||||
record.get("last_reboot"),
|
||||
record.get("last_seen"),
|
||||
record.get("last_user"),
|
||||
record.get("operating_system"),
|
||||
record.get("uptime"),
|
||||
record.get("agent_id"),
|
||||
record.get("ansible_ee_ver"),
|
||||
record.get("connection_type"),
|
||||
record.get("connection_endpoint"),
|
||||
record.get("ssl_key_fingerprint"),
|
||||
record.get("token_version") or 1,
|
||||
record.get("status") or "active",
|
||||
key_added_at,
|
||||
)
|
||||
cur.execute(insert_sql, params)
|
||||
|
||||
cur.execute("DROP TABLE devices_legacy")
|
||||
cur.execute("COMMIT")
|
||||
cur.execute("PRAGMA foreign_keys=ON")
|
||||
|
||||
|
||||
def _default_key_added_at(created_at: Optional[int]) -> Optional[str]:
|
||||
if created_at:
|
||||
try:
|
||||
dt = datetime.fromtimestamp(int(created_at), tz=timezone.utc)
|
||||
return dt.isoformat()
|
||||
except Exception:
|
||||
pass
|
||||
return datetime.now(tz=timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _table_exists(cur: sqlite3.Cursor, name: str) -> bool:
|
||||
cur.execute(
|
||||
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(name,),
|
||||
)
|
||||
return cur.fetchone() is not None
|
||||
|
||||
|
||||
def _table_info(cur: sqlite3.Cursor, name: str) -> List[Tuple]:
|
||||
cur.execute(f"PRAGMA table_info({name})")
|
||||
return cur.fetchall()
|
||||
|
||||
|
||||
def _normalized_guid(value: Optional[str]) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
return str(value).strip()
|
||||
|
||||
def _ensure_users_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
display_name TEXT,
|
||||
password_sha512 TEXT NOT NULL,
|
||||
role TEXT NOT NULL DEFAULT 'Admin',
|
||||
last_login INTEGER,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER,
|
||||
mfa_enabled INTEGER NOT NULL DEFAULT 0,
|
||||
mfa_secret TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
try:
|
||||
cur.execute("PRAGMA table_info(users)")
|
||||
columns = [row[1] for row in cur.fetchall()]
|
||||
if "mfa_enabled" not in columns:
|
||||
cur.execute("ALTER TABLE users ADD COLUMN mfa_enabled INTEGER NOT NULL DEFAULT 0")
|
||||
if "mfa_secret" not in columns:
|
||||
cur.execute("ALTER TABLE users ADD COLUMN mfa_secret TEXT")
|
||||
except sqlite3.Error:
|
||||
# Aligning the schema is best-effort; older deployments may lack ALTER
|
||||
# TABLE privileges but can continue using existing columns.
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_default_admin(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT COUNT(*) FROM users WHERE LOWER(role)='admin'")
|
||||
row = cur.fetchone()
|
||||
if row and (row[0] or 0):
|
||||
return
|
||||
|
||||
now = int(datetime.now(timezone.utc).timestamp())
|
||||
cur.execute(
|
||||
"SELECT COUNT(*) FROM users WHERE LOWER(username)=LOWER(?)",
|
||||
(_DEFAULT_ADMIN_USERNAME,),
|
||||
)
|
||||
existing = cur.fetchone()
|
||||
if not existing or not (existing[0] or 0):
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO users (
|
||||
username, display_name, password_sha512, role,
|
||||
last_login, created_at, updated_at, mfa_enabled, mfa_secret
|
||||
) VALUES (?, ?, ?, 'Admin', 0, ?, ?, 0, NULL)
|
||||
""",
|
||||
(
|
||||
_DEFAULT_ADMIN_USERNAME,
|
||||
"Administrator",
|
||||
_DEFAULT_ADMIN_PASSWORD_SHA512,
|
||||
now,
|
||||
now,
|
||||
),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET role='Admin',
|
||||
updated_at=?
|
||||
WHERE LOWER(username)=LOWER(?)
|
||||
AND LOWER(role)!='admin'
|
||||
""",
|
||||
(now, _DEFAULT_ADMIN_USERNAME),
|
||||
)
|
||||
|
||||
|
||||
def ensure_default_admin(conn: sqlite3.Connection) -> None:
|
||||
"""Guarantee that at least one admin account exists."""
|
||||
|
||||
_ensure_users_table(conn)
|
||||
_ensure_default_admin(conn)
|
||||
conn.commit()
|
||||
|
||||
|
||||
__all__ = ["apply_all", "ensure_default_admin"]
|
||||
@@ -1,189 +0,0 @@
|
||||
"""SQLite persistence for site management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
from contextlib import closing
|
||||
from typing import Dict, Iterable, List, Optional, Sequence
|
||||
|
||||
from Data.Engine.domain.sites import SiteDeviceMapping, SiteSummary
|
||||
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||
|
||||
__all__ = ["SQLiteSiteRepository"]
|
||||
|
||||
|
||||
class SQLiteSiteRepository:
|
||||
"""Repository exposing site CRUD and device assignment helpers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connections = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.sites")
|
||||
|
||||
def list_sites(self) -> List[SiteSummary]:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT s.id, s.name, s.description, s.created_at,
|
||||
COALESCE(ds.cnt, 0) AS device_count
|
||||
FROM sites s
|
||||
LEFT JOIN (
|
||||
SELECT site_id, COUNT(*) AS cnt
|
||||
FROM device_sites
|
||||
GROUP BY site_id
|
||||
) ds
|
||||
ON ds.site_id = s.id
|
||||
ORDER BY LOWER(s.name) ASC
|
||||
"""
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [self._row_to_site(row) for row in rows]
|
||||
|
||||
def create_site(self, name: str, description: str) -> SiteSummary:
|
||||
now = int(time.time())
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"INSERT INTO sites(name, description, created_at) VALUES (?, ?, ?)",
|
||||
(name, description, now),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise ValueError("duplicate") from exc
|
||||
site_id = cur.lastrowid
|
||||
conn.commit()
|
||||
|
||||
cur.execute(
|
||||
"SELECT id, name, description, created_at, 0 FROM sites WHERE id = ?",
|
||||
(site_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
raise RuntimeError("site not found after insert")
|
||||
return self._row_to_site(row)
|
||||
|
||||
def delete_sites(self, ids: Sequence[int]) -> int:
|
||||
if not ids:
|
||||
return 0
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
placeholders = ",".join("?" for _ in ids)
|
||||
try:
|
||||
cur.execute(
|
||||
f"DELETE FROM device_sites WHERE site_id IN ({placeholders})",
|
||||
tuple(ids),
|
||||
)
|
||||
cur.execute(
|
||||
f"DELETE FROM sites WHERE id IN ({placeholders})",
|
||||
tuple(ids),
|
||||
)
|
||||
except sqlite3.DatabaseError as exc:
|
||||
conn.rollback()
|
||||
raise
|
||||
deleted = cur.rowcount
|
||||
conn.commit()
|
||||
return deleted
|
||||
|
||||
def rename_site(self, site_id: int, new_name: str) -> SiteSummary:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute("UPDATE sites SET name = ? WHERE id = ?", (new_name, site_id))
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise ValueError("duplicate") from exc
|
||||
if cur.rowcount == 0:
|
||||
raise LookupError("not_found")
|
||||
conn.commit()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT s.id, s.name, s.description, s.created_at,
|
||||
COALESCE(ds.cnt, 0) AS device_count
|
||||
FROM sites s
|
||||
LEFT JOIN (
|
||||
SELECT site_id, COUNT(*) AS cnt
|
||||
FROM device_sites
|
||||
GROUP BY site_id
|
||||
) ds
|
||||
ON ds.site_id = s.id
|
||||
WHERE s.id = ?
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
raise LookupError("not_found")
|
||||
return self._row_to_site(row)
|
||||
|
||||
def map_devices(self, hostnames: Optional[Iterable[str]] = None) -> Dict[str, SiteDeviceMapping]:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
if hostnames:
|
||||
normalized = [hn.strip() for hn in hostnames if hn and hn.strip()]
|
||||
if not normalized:
|
||||
return {}
|
||||
placeholders = ",".join("?" for _ in normalized)
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT ds.device_hostname, s.id, s.name
|
||||
FROM device_sites ds
|
||||
INNER JOIN sites s ON s.id = ds.site_id
|
||||
WHERE ds.device_hostname IN ({placeholders})
|
||||
""",
|
||||
tuple(normalized),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT ds.device_hostname, s.id, s.name
|
||||
FROM device_sites ds
|
||||
INNER JOIN sites s ON s.id = ds.site_id
|
||||
"""
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
mapping: Dict[str, SiteDeviceMapping] = {}
|
||||
for hostname, site_id, site_name in rows:
|
||||
mapping[str(hostname)] = SiteDeviceMapping(
|
||||
hostname=str(hostname),
|
||||
site_id=int(site_id) if site_id is not None else None,
|
||||
site_name=str(site_name or ""),
|
||||
)
|
||||
return mapping
|
||||
|
||||
def assign_devices(self, site_id: int, hostnames: Sequence[str]) -> None:
|
||||
now = int(time.time())
|
||||
normalized = [hn.strip() for hn in hostnames if isinstance(hn, str) and hn.strip()]
|
||||
if not normalized:
|
||||
return
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT 1 FROM sites WHERE id = ?", (site_id,))
|
||||
if not cur.fetchone():
|
||||
raise LookupError("not_found")
|
||||
for hostname in normalized:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO device_sites(device_hostname, site_id, assigned_at)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(device_hostname)
|
||||
DO UPDATE SET site_id = excluded.site_id,
|
||||
assigned_at = excluded.assigned_at
|
||||
""",
|
||||
(hostname, site_id, now),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _row_to_site(self, row: Sequence[object]) -> SiteSummary:
|
||||
return SiteSummary(
|
||||
id=int(row[0]),
|
||||
name=str(row[1] or ""),
|
||||
description=str(row[2] or ""),
|
||||
created_at=int(row[3] or 0),
|
||||
device_count=int(row[4] or 0),
|
||||
)
|
||||
@@ -1,153 +0,0 @@
|
||||
"""SQLite-backed refresh token repository for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import closing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from Data.Engine.domain.device_auth import DeviceGuid
|
||||
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
|
||||
from Data.Engine.services.auth.token_service import RefreshTokenRecord
|
||||
|
||||
__all__ = ["SQLiteRefreshTokenRepository"]
|
||||
|
||||
|
||||
class SQLiteRefreshTokenRepository:
|
||||
"""Persistence adapter for refresh token records."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connections = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.tokens")
|
||||
|
||||
def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at
|
||||
FROM refresh_tokens
|
||||
WHERE guid = ?
|
||||
AND token_hash = ?
|
||||
""",
|
||||
(guid.value, token_hash),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return self._row_to_record(row)
|
||||
|
||||
def clear_dpop_binding(self, record_id: str) -> None:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?",
|
||||
(record_id,),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def touch(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
last_used_at: datetime,
|
||||
dpop_jkt: Optional[str],
|
||||
) -> None:
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET last_used_at = ?,
|
||||
dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt)
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
self._isoformat(last_used_at),
|
||||
(dpop_jkt or "").strip(),
|
||||
record_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
guid: DeviceGuid,
|
||||
token_hash: str,
|
||||
created_at: datetime,
|
||||
expires_at: Optional[datetime],
|
||||
) -> None:
|
||||
created_iso = self._isoformat(created_at)
|
||||
expires_iso = self._isoformat(expires_at) if expires_at else None
|
||||
|
||||
with closing(self._connections()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO refresh_tokens (
|
||||
id,
|
||||
guid,
|
||||
token_hash,
|
||||
created_at,
|
||||
expires_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(record_id, guid.value, token_hash, created_iso, expires_iso),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _row_to_record(self, row: tuple) -> Optional[RefreshTokenRecord]:
|
||||
try:
|
||||
guid = DeviceGuid(row[1])
|
||||
except Exception as exc:
|
||||
self._log.warning("invalid refresh token row guid=%s: %s", row[1], exc)
|
||||
return None
|
||||
|
||||
created_at = self._parse_iso(row[4])
|
||||
expires_at = self._parse_iso(row[5])
|
||||
revoked_at = self._parse_iso(row[6])
|
||||
|
||||
if created_at is None:
|
||||
created_at = datetime.now(tz=timezone.utc)
|
||||
|
||||
return RefreshTokenRecord.from_row(
|
||||
record_id=str(row[0]),
|
||||
guid=guid,
|
||||
token_hash=str(row[2]),
|
||||
dpop_jkt=str(row[3]) if row[3] is not None else None,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
revoked_at=revoked_at,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_iso(value: Optional[str]) -> Optional[datetime]:
|
||||
if not value:
|
||||
return None
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
parsed = datetime.fromisoformat(raw)
|
||||
except Exception:
|
||||
return None
|
||||
if parsed.tzinfo is None:
|
||||
return parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed
|
||||
|
||||
@staticmethod
|
||||
def _isoformat(value: datetime) -> str:
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc).isoformat()
|
||||
@@ -1,340 +0,0 @@
|
||||
"""SQLite repository for operator accounts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from Data.Engine.domain import OperatorAccount
|
||||
|
||||
from .connection import SQLiteConnectionFactory
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _UserRow:
|
||||
id: str
|
||||
username: str
|
||||
display_name: str
|
||||
password_sha512: str
|
||||
role: str
|
||||
last_login: int
|
||||
created_at: int
|
||||
updated_at: int
|
||||
mfa_enabled: int
|
||||
mfa_secret: str
|
||||
|
||||
|
||||
class SQLiteUserRepository:
|
||||
"""Expose CRUD helpers for operator accounts stored in SQLite."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_factory: SQLiteConnectionFactory,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._connection_factory = connection_factory
|
||||
self._log = logger or logging.getLogger("borealis.engine.repositories.users")
|
||||
|
||||
def fetch_by_username(self, username: str) -> Optional[OperatorAccount]:
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
id,
|
||||
username,
|
||||
display_name,
|
||||
COALESCE(password_sha512, '') as password_sha512,
|
||||
COALESCE(role, 'User') as role,
|
||||
COALESCE(last_login, 0) as last_login,
|
||||
COALESCE(created_at, 0) as created_at,
|
||||
COALESCE(updated_at, 0) as updated_at,
|
||||
COALESCE(mfa_enabled, 0) as mfa_enabled,
|
||||
COALESCE(mfa_secret, '') as mfa_secret
|
||||
FROM users
|
||||
WHERE LOWER(username) = LOWER(?)
|
||||
""",
|
||||
(username,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
record = _UserRow(*row)
|
||||
return _row_to_account(record)
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.error("failed to load user %s: %s", username, exc)
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def resolve_identifier(self, username: str) -> Optional[str]:
|
||||
normalized = (username or "").strip()
|
||||
if not normalized:
|
||||
return None
|
||||
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT id FROM users WHERE LOWER(username) = LOWER(?)",
|
||||
(normalized,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return str(row[0]) if row[0] is not None else None
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.error("failed to resolve identifier for %s: %s", username, exc)
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def username_for_identifier(self, identifier: str) -> Optional[str]:
|
||||
token = (identifier or "").strip()
|
||||
if not token:
|
||||
return None
|
||||
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT username
|
||||
FROM users
|
||||
WHERE CAST(id AS TEXT) = ?
|
||||
OR LOWER(username) = LOWER(?)
|
||||
LIMIT 1
|
||||
""",
|
||||
(token, token),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
username = str(row[0] or "").strip()
|
||||
return username or None
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.error("failed to resolve username for %s: %s", identifier, exc)
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def list_accounts(self) -> list[OperatorAccount]:
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
id,
|
||||
username,
|
||||
display_name,
|
||||
COALESCE(password_sha512, '') as password_sha512,
|
||||
COALESCE(role, 'User') as role,
|
||||
COALESCE(last_login, 0) as last_login,
|
||||
COALESCE(created_at, 0) as created_at,
|
||||
COALESCE(updated_at, 0) as updated_at,
|
||||
COALESCE(mfa_enabled, 0) as mfa_enabled,
|
||||
COALESCE(mfa_secret, '') as mfa_secret
|
||||
FROM users
|
||||
ORDER BY LOWER(username) ASC
|
||||
"""
|
||||
)
|
||||
rows = [_UserRow(*row) for row in cur.fetchall()]
|
||||
return [_row_to_account(row) for row in rows]
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.error("failed to enumerate users: %s", exc)
|
||||
return []
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def create_account(
|
||||
self,
|
||||
*,
|
||||
username: str,
|
||||
display_name: str,
|
||||
password_sha512: str,
|
||||
role: str,
|
||||
timestamp: int,
|
||||
) -> None:
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO users (
|
||||
username,
|
||||
display_name,
|
||||
password_sha512,
|
||||
role,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(username, display_name, password_sha512, role, timestamp, timestamp),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def delete_account(self, username: str) -> bool:
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute("DELETE FROM users WHERE LOWER(username) = LOWER(?)", (username,))
|
||||
deleted = cur.rowcount > 0
|
||||
conn.commit()
|
||||
return deleted
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.error("failed to delete user %s: %s", username, exc)
|
||||
return False
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_password(self, username: str, password_sha512: str, *, timestamp: int) -> bool:
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET password_sha512 = ?,
|
||||
updated_at = ?
|
||||
WHERE LOWER(username) = LOWER(?)
|
||||
""",
|
||||
(password_sha512, timestamp, username),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.error("failed to update password for %s: %s", username, exc)
|
||||
return False
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_role(self, username: str, role: str, *, timestamp: int) -> bool:
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET role = ?,
|
||||
updated_at = ?
|
||||
WHERE LOWER(username) = LOWER(?)
|
||||
""",
|
||||
(role, timestamp, username),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.error("failed to update role for %s: %s", username, exc)
|
||||
return False
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_mfa(
|
||||
self,
|
||||
username: str,
|
||||
*,
|
||||
enabled: bool,
|
||||
reset_secret: bool,
|
||||
timestamp: int,
|
||||
) -> bool:
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
secret_clause = "mfa_secret = NULL" if reset_secret else None
|
||||
assignments: list[str] = ["mfa_enabled = ?", "updated_at = ?"]
|
||||
params: list[object] = [1 if enabled else 0, timestamp]
|
||||
if secret_clause is not None:
|
||||
assignments.append(secret_clause)
|
||||
query = "UPDATE users SET " + ", ".join(assignments) + " WHERE LOWER(username) = LOWER(?)"
|
||||
params.append(username)
|
||||
cur.execute(query, tuple(params))
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.error("failed to update MFA for %s: %s", username, exc)
|
||||
return False
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def count_accounts(self) -> int:
|
||||
return self._scalar("SELECT COUNT(*) FROM users", ())
|
||||
|
||||
def count_admins(self) -> int:
|
||||
return self._scalar("SELECT COUNT(*) FROM users WHERE LOWER(role) = 'admin'", ())
|
||||
|
||||
def _scalar(self, query: str, params: Iterable[object]) -> int:
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(query, tuple(params))
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return 0
|
||||
return int(row[0] or 0)
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.error("scalar query failed: %s", exc)
|
||||
return 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_last_login(self, username: str, timestamp: int) -> None:
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET last_login = ?,
|
||||
updated_at = ?
|
||||
WHERE LOWER(username) = LOWER(?)
|
||||
""",
|
||||
(timestamp, timestamp, username),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.warning("failed to update last_login for %s: %s", username, exc)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def store_mfa_secret(self, username: str, secret: str, *, timestamp: int) -> None:
|
||||
conn = self._connection_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET mfa_secret = ?,
|
||||
updated_at = ?
|
||||
WHERE LOWER(username) = LOWER(?)
|
||||
""",
|
||||
(secret, timestamp, username),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as exc: # pragma: no cover - defensive
|
||||
self._log.warning("failed to persist MFA secret for %s: %s", username, exc)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
__all__ = ["SQLiteUserRepository"]
|
||||
|
||||
|
||||
def _row_to_account(record: _UserRow) -> OperatorAccount:
|
||||
return OperatorAccount(
|
||||
username=record.username,
|
||||
display_name=record.display_name or record.username,
|
||||
password_sha512=(record.password_sha512 or "").lower(),
|
||||
role=record.role or "User",
|
||||
last_login=int(record.last_login or 0),
|
||||
created_at=int(record.created_at or 0),
|
||||
updated_at=int(record.updated_at or 0),
|
||||
mfa_enabled=bool(record.mfa_enabled),
|
||||
mfa_secret=(record.mfa_secret or "") or None,
|
||||
)
|
||||
@@ -1,13 +0,0 @@
|
||||
#////////// PROJECT FILE SEPARATION LINE ////////// CODE AFTER THIS LINE ARE FROM: <ProjectRoot>/Data/Engine/requirements.txt
|
||||
# Core web stack
|
||||
Flask
|
||||
flask_socketio
|
||||
flask-cors
|
||||
eventlet
|
||||
requests
|
||||
|
||||
# Auth & security
|
||||
PyJWT[crypto]
|
||||
cryptography
|
||||
pyotp
|
||||
qrcode
|
||||
@@ -1,139 +0,0 @@
|
||||
"""Runtime filesystem helpers for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
__all__ = [
|
||||
"agent_certificates_path",
|
||||
"ensure_agent_certificates_dir",
|
||||
"ensure_certificates_dir",
|
||||
"ensure_runtime_dir",
|
||||
"ensure_server_certificates_dir",
|
||||
"project_root",
|
||||
"runtime_path",
|
||||
"server_certificates_path",
|
||||
]
|
||||
|
||||
|
||||
def _env_path(name: str) -> Optional[Path]:
|
||||
value = os.environ.get(name)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return Path(value).expanduser().resolve()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def project_root() -> Path:
|
||||
env = _env_path("BOREALIS_PROJECT_ROOT")
|
||||
if env:
|
||||
return env
|
||||
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "Borealis.ps1").exists() or (parent / ".git").is_dir():
|
||||
return parent
|
||||
|
||||
try:
|
||||
return current.parents[1]
|
||||
except IndexError:
|
||||
return current.parent
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def server_runtime_root() -> Path:
|
||||
env = _env_path("BOREALIS_ENGINE_ROOT") or _env_path("BOREALIS_SERVER_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = project_root() / "Engine"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def runtime_path(*parts: str) -> Path:
|
||||
return server_runtime_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_runtime_dir(*parts: str) -> Path:
|
||||
path = runtime_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def certificates_root() -> Path:
|
||||
env = _env_path("BOREALIS_CERTIFICATES_ROOT") or _env_path("BOREALIS_CERT_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = project_root() / "Certificates"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
for name in ("Server", "Agent"):
|
||||
try:
|
||||
(root / name).mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
return root
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def server_certificates_root() -> Path:
|
||||
env = _env_path("BOREALIS_SERVER_CERT_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = certificates_root() / "Server"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def agent_certificates_root() -> Path:
|
||||
env = _env_path("BOREALIS_AGENT_CERT_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = certificates_root() / "Agent"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def certificates_path(*parts: str) -> Path:
|
||||
return certificates_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_certificates_dir(*parts: str) -> Path:
|
||||
path = certificates_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def server_certificates_path(*parts: str) -> Path:
|
||||
return server_certificates_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_server_certificates_dir(*parts: str) -> Path:
|
||||
path = server_certificates_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def agent_certificates_path(*parts: str) -> Path:
|
||||
return agent_certificates_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_agent_certificates_dir(*parts: str) -> Path:
|
||||
path = agent_certificates_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Flask application factory for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, request, send_from_directory
|
||||
from flask_cors import CORS
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||
|
||||
from .config import EngineSettings
|
||||
|
||||
|
||||
from .repositories.sqlite.connection import (
|
||||
SQLiteConnectionFactory,
|
||||
connection_factory as create_sqlite_connection_factory,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_static_folder(static_root: Path) -> tuple[str, str]:
|
||||
return str(static_root), ""
|
||||
|
||||
|
||||
def _register_spa_routes(app: Flask, assets_root: Path) -> None:
|
||||
"""Serve the Borealis single-page application from *assets_root*.
|
||||
|
||||
The logic mirrors the legacy server by routing any unknown front-end paths
|
||||
back to ``index.html`` so the React router can take over.
|
||||
"""
|
||||
|
||||
static_folder = assets_root
|
||||
|
||||
@app.route("/", defaults={"path": ""})
|
||||
@app.route("/<path:path>")
|
||||
def serve_frontend(path: str) -> object:
|
||||
candidate = (static_folder / path).resolve()
|
||||
if path and candidate.is_file():
|
||||
return send_from_directory(str(static_folder), path)
|
||||
try:
|
||||
return send_from_directory(str(static_folder), "index.html")
|
||||
except Exception as exc: # pragma: no cover - passthrough
|
||||
raise NotFound() from exc
|
||||
|
||||
@app.errorhandler(404)
|
||||
def spa_fallback(error: Exception) -> object: # pragma: no cover - routing
|
||||
request_path = (request.path or "").strip()
|
||||
if request_path.startswith("/api") or request_path.startswith("/socket.io"):
|
||||
return error
|
||||
if "." in Path(request_path).name:
|
||||
return error
|
||||
if request.method not in {"GET", "HEAD"}:
|
||||
return error
|
||||
try:
|
||||
return send_from_directory(str(static_folder), "index.html")
|
||||
except Exception:
|
||||
return error
|
||||
|
||||
|
||||
def create_app(
|
||||
settings: EngineSettings,
|
||||
*,
|
||||
db_factory: Optional[SQLiteConnectionFactory] = None,
|
||||
) -> Flask:
|
||||
"""Create the Flask application instance for the Engine."""
|
||||
|
||||
if db_factory is None:
|
||||
db_factory = create_sqlite_connection_factory(settings.database_path)
|
||||
|
||||
static_folder, static_url_path = _resolve_static_folder(settings.flask.static_root)
|
||||
app = Flask(
|
||||
__name__,
|
||||
static_folder=static_folder,
|
||||
static_url_path=static_url_path,
|
||||
)
|
||||
|
||||
app.config.update(
|
||||
SECRET_KEY=settings.flask.secret_key,
|
||||
JSON_SORT_KEYS=False,
|
||||
SESSION_COOKIE_HTTPONLY=True,
|
||||
SESSION_COOKIE_SECURE=not settings.debug,
|
||||
SESSION_COOKIE_SAMESITE="Lax",
|
||||
ENGINE_DATABASE_PATH=str(settings.database_path),
|
||||
ENGINE_DB_CONN_FACTORY=db_factory,
|
||||
)
|
||||
app.config.setdefault("PREFERRED_URL_SCHEME", "https")
|
||||
|
||||
# Respect upstream proxy headers when Borealis is hosted behind a TLS terminator.
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1) # type: ignore[assignment]
|
||||
|
||||
CORS(
|
||||
app,
|
||||
resources={r"/*": {"origins": list(settings.flask.cors_allowed_origins)}},
|
||||
supports_credentials=True,
|
||||
)
|
||||
|
||||
_register_spa_routes(app, Path(static_folder))
|
||||
|
||||
return app
|
||||
|
||||
|
||||
__all__ = ["create_app"]
|
||||
@@ -1,104 +0,0 @@
|
||||
"""Application services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
"DeviceRecord",
|
||||
"RefreshTokenRecord",
|
||||
"TokenRefreshError",
|
||||
"TokenRefreshErrorCode",
|
||||
"TokenService",
|
||||
"EnrollmentService",
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"EnrollmentValidationError",
|
||||
"PollingResult",
|
||||
"AgentRealtimeService",
|
||||
"AgentRecord",
|
||||
"SchedulerService",
|
||||
"GitHubService",
|
||||
"GitHubTokenPayload",
|
||||
"EnrollmentAdminService",
|
||||
"SiteService",
|
||||
"DeviceInventoryService",
|
||||
"DeviceViewService",
|
||||
"CredentialService",
|
||||
"AssemblyService",
|
||||
"AssemblyListing",
|
||||
"AssemblyLoadResult",
|
||||
"AssemblyMutationResult",
|
||||
]
|
||||
|
||||
_LAZY_TARGETS: Dict[str, Tuple[str, str]] = {
|
||||
"DeviceAuthService": ("Data.Engine.services.auth.device_auth_service", "DeviceAuthService"),
|
||||
"DeviceRecord": ("Data.Engine.services.auth.device_auth_service", "DeviceRecord"),
|
||||
"RefreshTokenRecord": ("Data.Engine.services.auth.device_auth_service", "RefreshTokenRecord"),
|
||||
"TokenService": ("Data.Engine.services.auth.token_service", "TokenService"),
|
||||
"TokenRefreshError": ("Data.Engine.services.auth.token_service", "TokenRefreshError"),
|
||||
"TokenRefreshErrorCode": ("Data.Engine.services.auth.token_service", "TokenRefreshErrorCode"),
|
||||
"EnrollmentService": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentService"),
|
||||
"EnrollmentRequestResult": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentRequestResult"),
|
||||
"EnrollmentStatus": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentStatus"),
|
||||
"EnrollmentTokenBundle": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentTokenBundle"),
|
||||
"PollingResult": ("Data.Engine.services.enrollment.enrollment_service", "PollingResult"),
|
||||
"EnrollmentValidationError": ("Data.Engine.domain.device_enrollment", "EnrollmentValidationError"),
|
||||
"AgentRealtimeService": ("Data.Engine.services.realtime.agent_registry", "AgentRealtimeService"),
|
||||
"AgentRecord": ("Data.Engine.services.realtime.agent_registry", "AgentRecord"),
|
||||
"SchedulerService": ("Data.Engine.services.jobs.scheduler_service", "SchedulerService"),
|
||||
"GitHubService": ("Data.Engine.services.github.github_service", "GitHubService"),
|
||||
"GitHubTokenPayload": ("Data.Engine.services.github.github_service", "GitHubTokenPayload"),
|
||||
"EnrollmentAdminService": (
|
||||
"Data.Engine.services.enrollment.admin_service",
|
||||
"EnrollmentAdminService",
|
||||
),
|
||||
"SiteService": ("Data.Engine.services.sites.site_service", "SiteService"),
|
||||
"DeviceInventoryService": (
|
||||
"Data.Engine.services.devices.device_inventory_service",
|
||||
"DeviceInventoryService",
|
||||
),
|
||||
"DeviceViewService": (
|
||||
"Data.Engine.services.devices.device_view_service",
|
||||
"DeviceViewService",
|
||||
),
|
||||
"CredentialService": (
|
||||
"Data.Engine.services.credentials.credential_service",
|
||||
"CredentialService",
|
||||
),
|
||||
"AssemblyService": (
|
||||
"Data.Engine.services.assemblies.assembly_service",
|
||||
"AssemblyService",
|
||||
),
|
||||
"AssemblyListing": (
|
||||
"Data.Engine.services.assemblies.assembly_service",
|
||||
"AssemblyListing",
|
||||
),
|
||||
"AssemblyLoadResult": (
|
||||
"Data.Engine.services.assemblies.assembly_service",
|
||||
"AssemblyLoadResult",
|
||||
),
|
||||
"AssemblyMutationResult": (
|
||||
"Data.Engine.services.assemblies.assembly_service",
|
||||
"AssemblyMutationResult",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
try:
|
||||
module_name, attribute = _LAZY_TARGETS[name]
|
||||
except KeyError as exc:
|
||||
raise AttributeError(name) from exc
|
||||
|
||||
module = import_module(module_name)
|
||||
value = getattr(module, attribute)
|
||||
globals()[name] = value
|
||||
return value
|
||||
|
||||
|
||||
def __dir__() -> Any: # pragma: no cover - interactive helper
|
||||
return sorted(set(__all__))
|
||||
@@ -1,10 +0,0 @@
|
||||
"""Assembly management services."""
|
||||
|
||||
from .assembly_service import AssemblyService, AssemblyMutationResult, AssemblyLoadResult, AssemblyListing
|
||||
|
||||
__all__ = [
|
||||
"AssemblyService",
|
||||
"AssemblyMutationResult",
|
||||
"AssemblyLoadResult",
|
||||
"AssemblyListing",
|
||||
]
|
||||
@@ -1,715 +0,0 @@
|
||||
"""Filesystem-backed assembly management service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
__all__ = [
|
||||
"AssemblyService",
|
||||
"AssemblyListing",
|
||||
"AssemblyLoadResult",
|
||||
"AssemblyMutationResult",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AssemblyListing:
|
||||
"""Listing payload for an assembly island."""
|
||||
|
||||
root: Path
|
||||
items: List[Dict[str, Any]]
|
||||
folders: List[str]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"root": str(self.root),
|
||||
"items": self.items,
|
||||
"folders": self.folders,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AssemblyLoadResult:
|
||||
"""Container describing a loaded assembly artifact."""
|
||||
|
||||
payload: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return dict(self.payload)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AssemblyMutationResult:
|
||||
"""Mutation acknowledgement for create/edit/rename operations."""
|
||||
|
||||
status: str = "ok"
|
||||
rel_path: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {"status": self.status}
|
||||
if self.rel_path:
|
||||
payload["rel_path"] = self.rel_path
|
||||
return payload
|
||||
|
||||
|
||||
class AssemblyService:
|
||||
"""Provide CRUD helpers for workflow/script/ansible assemblies."""
|
||||
|
||||
_ISLAND_DIR_MAP = {
|
||||
"workflows": "Workflows",
|
||||
"workflow": "Workflows",
|
||||
"scripts": "Scripts",
|
||||
"script": "Scripts",
|
||||
"ansible": "Ansible_Playbooks",
|
||||
"ansible_playbooks": "Ansible_Playbooks",
|
||||
"ansible-playbooks": "Ansible_Playbooks",
|
||||
"playbooks": "Ansible_Playbooks",
|
||||
}
|
||||
|
||||
_SCRIPT_EXTENSIONS = (".json", ".ps1", ".bat", ".sh")
|
||||
_ANSIBLE_EXTENSIONS = (".json", ".yml")
|
||||
|
||||
def __init__(self, *, root: Path, logger: Optional[logging.Logger] = None) -> None:
|
||||
self._root = root.resolve()
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.assemblies")
|
||||
try:
|
||||
self._root.mkdir(parents=True, exist_ok=True)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("failed to ensure assemblies root %s: %s", self._root, exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
def list_items(self, island: str) -> AssemblyListing:
|
||||
root = self._resolve_island_root(island)
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
items: List[Dict[str, Any]] = []
|
||||
folders: List[str] = []
|
||||
|
||||
isl = (island or "").strip().lower()
|
||||
if isl in {"workflows", "workflow"}:
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
rel_root = os.path.relpath(dirpath, root)
|
||||
if rel_root != ".":
|
||||
folders.append(rel_root.replace(os.sep, "/"))
|
||||
for fname in filenames:
|
||||
if not fname.lower().endswith(".json"):
|
||||
continue
|
||||
abs_path = Path(dirpath) / fname
|
||||
rel_path = abs_path.relative_to(root).as_posix()
|
||||
try:
|
||||
mtime = abs_path.stat().st_mtime
|
||||
except OSError:
|
||||
mtime = 0.0
|
||||
obj = self._safe_read_json(abs_path)
|
||||
tab = self._extract_tab_name(obj)
|
||||
items.append(
|
||||
{
|
||||
"file_name": fname,
|
||||
"rel_path": rel_path,
|
||||
"type": "workflow",
|
||||
"tab_name": tab,
|
||||
"last_edited": time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%S", time.localtime(mtime)
|
||||
),
|
||||
"last_edited_epoch": mtime,
|
||||
}
|
||||
)
|
||||
elif isl in {"scripts", "script"}:
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
rel_root = os.path.relpath(dirpath, root)
|
||||
if rel_root != ".":
|
||||
folders.append(rel_root.replace(os.sep, "/"))
|
||||
for fname in filenames:
|
||||
if not fname.lower().endswith(self._SCRIPT_EXTENSIONS):
|
||||
continue
|
||||
abs_path = Path(dirpath) / fname
|
||||
rel_path = abs_path.relative_to(root).as_posix()
|
||||
try:
|
||||
mtime = abs_path.stat().st_mtime
|
||||
except OSError:
|
||||
mtime = 0.0
|
||||
script_type = self._detect_script_type(abs_path)
|
||||
doc = self._load_assembly_document(abs_path, "scripts", script_type)
|
||||
items.append(
|
||||
{
|
||||
"file_name": fname,
|
||||
"rel_path": rel_path,
|
||||
"type": doc.get("type", script_type),
|
||||
"name": doc.get("name"),
|
||||
"category": doc.get("category"),
|
||||
"description": doc.get("description"),
|
||||
"last_edited": time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%S", time.localtime(mtime)
|
||||
),
|
||||
"last_edited_epoch": mtime,
|
||||
}
|
||||
)
|
||||
elif isl in {
|
||||
"ansible",
|
||||
"ansible_playbooks",
|
||||
"ansible-playbooks",
|
||||
"playbooks",
|
||||
}:
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
rel_root = os.path.relpath(dirpath, root)
|
||||
if rel_root != ".":
|
||||
folders.append(rel_root.replace(os.sep, "/"))
|
||||
for fname in filenames:
|
||||
if not fname.lower().endswith(self._ANSIBLE_EXTENSIONS):
|
||||
continue
|
||||
abs_path = Path(dirpath) / fname
|
||||
rel_path = abs_path.relative_to(root).as_posix()
|
||||
try:
|
||||
mtime = abs_path.stat().st_mtime
|
||||
except OSError:
|
||||
mtime = 0.0
|
||||
script_type = self._detect_script_type(abs_path)
|
||||
doc = self._load_assembly_document(abs_path, "ansible", script_type)
|
||||
items.append(
|
||||
{
|
||||
"file_name": fname,
|
||||
"rel_path": rel_path,
|
||||
"type": doc.get("type", "ansible"),
|
||||
"name": doc.get("name"),
|
||||
"category": doc.get("category"),
|
||||
"description": doc.get("description"),
|
||||
"last_edited": time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%S", time.localtime(mtime)
|
||||
),
|
||||
"last_edited_epoch": mtime,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid_island")
|
||||
|
||||
items.sort(key=lambda entry: entry.get("last_edited_epoch", 0.0), reverse=True)
|
||||
return AssemblyListing(root=root, items=items, folders=folders)
|
||||
|
||||
def load_item(self, island: str, rel_path: str) -> AssemblyLoadResult:
|
||||
root, abs_path, _ = self._resolve_assembly_path(island, rel_path)
|
||||
if not abs_path.is_file():
|
||||
raise FileNotFoundError("file_not_found")
|
||||
|
||||
isl = (island or "").strip().lower()
|
||||
if isl in {"workflows", "workflow"}:
|
||||
payload = self._safe_read_json(abs_path)
|
||||
return AssemblyLoadResult(payload=payload)
|
||||
|
||||
doc = self._load_assembly_document(abs_path, island)
|
||||
rel = abs_path.relative_to(root).as_posix()
|
||||
payload = {
|
||||
"file_name": abs_path.name,
|
||||
"rel_path": rel,
|
||||
"type": doc.get("type"),
|
||||
"assembly": doc,
|
||||
"content": doc.get("script"),
|
||||
}
|
||||
return AssemblyLoadResult(payload=payload)
|
||||
|
||||
def create_item(
|
||||
self,
|
||||
island: str,
|
||||
*,
|
||||
kind: str,
|
||||
rel_path: str,
|
||||
content: Any,
|
||||
item_type: Optional[str] = None,
|
||||
) -> AssemblyMutationResult:
|
||||
root, abs_path, rel_norm = self._resolve_assembly_path(island, rel_path)
|
||||
if not rel_norm:
|
||||
raise ValueError("path_required")
|
||||
|
||||
normalized_kind = (kind or "").strip().lower()
|
||||
if normalized_kind == "folder":
|
||||
abs_path.mkdir(parents=True, exist_ok=True)
|
||||
return AssemblyMutationResult()
|
||||
if normalized_kind != "file":
|
||||
raise ValueError("invalid_kind")
|
||||
|
||||
target_path = abs_path
|
||||
if not target_path.suffix:
|
||||
target_path = target_path.with_suffix(
|
||||
self._default_ext_for_island(island, item_type or "")
|
||||
)
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
isl = (island or "").strip().lower()
|
||||
if isl in {"workflows", "workflow"}:
|
||||
payload = self._ensure_workflow_document(content)
|
||||
base_name = target_path.stem
|
||||
payload.setdefault("tab_name", base_name)
|
||||
self._write_json(target_path, payload)
|
||||
else:
|
||||
document = self._normalize_assembly_document(
|
||||
content,
|
||||
self._default_type_for_island(island, item_type or ""),
|
||||
target_path.stem,
|
||||
)
|
||||
self._write_json(target_path, self._prepare_assembly_storage(document))
|
||||
|
||||
rel_new = target_path.relative_to(root).as_posix()
|
||||
return AssemblyMutationResult(rel_path=rel_new)
|
||||
|
||||
def edit_item(
|
||||
self,
|
||||
island: str,
|
||||
*,
|
||||
rel_path: str,
|
||||
content: Any,
|
||||
item_type: Optional[str] = None,
|
||||
) -> AssemblyMutationResult:
|
||||
root, abs_path, _ = self._resolve_assembly_path(island, rel_path)
|
||||
if not abs_path.exists():
|
||||
raise FileNotFoundError("file_not_found")
|
||||
|
||||
target_path = abs_path
|
||||
if not target_path.suffix:
|
||||
target_path = target_path.with_suffix(
|
||||
self._default_ext_for_island(island, item_type or "")
|
||||
)
|
||||
|
||||
isl = (island or "").strip().lower()
|
||||
if isl in {"workflows", "workflow"}:
|
||||
payload = self._ensure_workflow_document(content)
|
||||
self._write_json(target_path, payload)
|
||||
else:
|
||||
document = self._normalize_assembly_document(
|
||||
content,
|
||||
self._default_type_for_island(island, item_type or ""),
|
||||
target_path.stem,
|
||||
)
|
||||
self._write_json(target_path, self._prepare_assembly_storage(document))
|
||||
|
||||
if target_path != abs_path and abs_path.exists():
|
||||
try:
|
||||
abs_path.unlink()
|
||||
except OSError: # pragma: no cover - best effort cleanup
|
||||
pass
|
||||
|
||||
rel_new = target_path.relative_to(root).as_posix()
|
||||
return AssemblyMutationResult(rel_path=rel_new)
|
||||
|
||||
def rename_item(
|
||||
self,
|
||||
island: str,
|
||||
*,
|
||||
kind: str,
|
||||
rel_path: str,
|
||||
new_name: str,
|
||||
item_type: Optional[str] = None,
|
||||
) -> AssemblyMutationResult:
|
||||
root, old_path, _ = self._resolve_assembly_path(island, rel_path)
|
||||
|
||||
normalized_kind = (kind or "").strip().lower()
|
||||
if normalized_kind not in {"file", "folder"}:
|
||||
raise ValueError("invalid_kind")
|
||||
|
||||
if normalized_kind == "folder":
|
||||
if not old_path.is_dir():
|
||||
raise FileNotFoundError("folder_not_found")
|
||||
destination = old_path.parent / new_name
|
||||
else:
|
||||
if not old_path.is_file():
|
||||
raise FileNotFoundError("file_not_found")
|
||||
candidate = Path(new_name)
|
||||
if not candidate.suffix:
|
||||
candidate = candidate.with_suffix(
|
||||
self._default_ext_for_island(island, item_type or "")
|
||||
)
|
||||
destination = old_path.parent / candidate.name
|
||||
|
||||
destination = destination.resolve()
|
||||
if not str(destination).startswith(str(root)):
|
||||
raise ValueError("invalid_destination")
|
||||
|
||||
old_path.rename(destination)
|
||||
|
||||
isl = (island or "").strip().lower()
|
||||
if normalized_kind == "file" and isl in {"workflows", "workflow"}:
|
||||
try:
|
||||
obj = self._safe_read_json(destination)
|
||||
base_name = destination.stem
|
||||
for key in ["tabName", "tab_name", "name", "title"]:
|
||||
if key in obj:
|
||||
obj[key] = base_name
|
||||
obj.setdefault("tab_name", base_name)
|
||||
self._write_json(destination, obj)
|
||||
except Exception: # pragma: no cover - best effort update
|
||||
self._log.debug("failed to normalize workflow metadata for %s", destination)
|
||||
|
||||
rel_new = destination.relative_to(root).as_posix()
|
||||
return AssemblyMutationResult(rel_path=rel_new)
|
||||
|
||||
def move_item(
|
||||
self,
|
||||
island: str,
|
||||
*,
|
||||
rel_path: str,
|
||||
new_path: str,
|
||||
kind: Optional[str] = None,
|
||||
) -> AssemblyMutationResult:
|
||||
root, old_path, _ = self._resolve_assembly_path(island, rel_path)
|
||||
_, dest_path, _ = self._resolve_assembly_path(island, new_path)
|
||||
|
||||
normalized_kind = (kind or "").strip().lower()
|
||||
if normalized_kind == "folder":
|
||||
if not old_path.is_dir():
|
||||
raise FileNotFoundError("folder_not_found")
|
||||
else:
|
||||
if not old_path.exists():
|
||||
raise FileNotFoundError("file_not_found")
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(old_path), str(dest_path))
|
||||
return AssemblyMutationResult()
|
||||
|
||||
def delete_item(
|
||||
self,
|
||||
island: str,
|
||||
*,
|
||||
rel_path: str,
|
||||
kind: str,
|
||||
) -> AssemblyMutationResult:
|
||||
_, abs_path, rel_norm = self._resolve_assembly_path(island, rel_path)
|
||||
if not rel_norm:
|
||||
raise ValueError("cannot_delete_root")
|
||||
|
||||
normalized_kind = (kind or "").strip().lower()
|
||||
if normalized_kind == "folder":
|
||||
if not abs_path.is_dir():
|
||||
raise FileNotFoundError("folder_not_found")
|
||||
shutil.rmtree(abs_path)
|
||||
elif normalized_kind == "file":
|
||||
if not abs_path.is_file():
|
||||
raise FileNotFoundError("file_not_found")
|
||||
abs_path.unlink()
|
||||
else:
|
||||
raise ValueError("invalid_kind")
|
||||
|
||||
return AssemblyMutationResult()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _resolve_island_root(self, island: str) -> Path:
|
||||
key = (island or "").strip().lower()
|
||||
subdir = self._ISLAND_DIR_MAP.get(key)
|
||||
if not subdir:
|
||||
raise ValueError("invalid_island")
|
||||
root = (self._root / subdir).resolve()
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
def _resolve_assembly_path(self, island: str, rel_path: str) -> Tuple[Path, Path, str]:
|
||||
root = self._resolve_island_root(island)
|
||||
rel_norm = self._normalize_relpath(rel_path)
|
||||
abs_path = (root / rel_norm).resolve()
|
||||
if not str(abs_path).startswith(str(root)):
|
||||
raise ValueError("invalid_path")
|
||||
return root, abs_path, rel_norm
|
||||
|
||||
@staticmethod
|
||||
def _normalize_relpath(value: str) -> str:
|
||||
return (value or "").replace("\\", "/").strip("/")
|
||||
|
||||
@staticmethod
|
||||
def _default_ext_for_island(island: str, item_type: str) -> str:
|
||||
isl = (island or "").strip().lower()
|
||||
if isl in {"workflows", "workflow"}:
|
||||
return ".json"
|
||||
if isl in {"ansible", "ansible_playbooks", "ansible-playbooks", "playbooks"}:
|
||||
return ".json"
|
||||
if isl in {"scripts", "script"}:
|
||||
return ".json"
|
||||
typ = (item_type or "").strip().lower()
|
||||
if typ in {"bash", "batch", "powershell"}:
|
||||
return ".json"
|
||||
return ".json"
|
||||
|
||||
@staticmethod
|
||||
def _default_type_for_island(island: str, item_type: str) -> str:
|
||||
isl = (island or "").strip().lower()
|
||||
if isl in {"ansible", "ansible_playbooks", "ansible-playbooks", "playbooks"}:
|
||||
return "ansible"
|
||||
typ = (item_type or "").strip().lower()
|
||||
if typ in {"powershell", "batch", "bash", "ansible"}:
|
||||
return typ
|
||||
return "powershell"
|
||||
|
||||
@staticmethod
|
||||
def _empty_assembly_document(default_type: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"version": 1,
|
||||
"name": "",
|
||||
"description": "",
|
||||
"category": "application" if default_type.lower() == "ansible" else "script",
|
||||
"type": default_type or "powershell",
|
||||
"script": "",
|
||||
"timeout_seconds": 3600,
|
||||
"sites": {"mode": "all", "values": []},
|
||||
"variables": [],
|
||||
"files": [],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64_text(value: Any) -> Optional[str]:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return ""
|
||||
try:
|
||||
cleaned = re.sub(r"\s+", "", stripped)
|
||||
except Exception:
|
||||
cleaned = stripped
|
||||
try:
|
||||
decoded = base64.b64decode(cleaned, validate=True)
|
||||
except Exception:
|
||||
return None
|
||||
try:
|
||||
return decoded.decode("utf-8")
|
||||
except Exception:
|
||||
return decoded.decode("utf-8", errors="replace")
|
||||
|
||||
def _decode_script_content(self, value: Any, encoding_hint: str = "") -> str:
|
||||
encoding = (encoding_hint or "").strip().lower()
|
||||
if isinstance(value, str):
|
||||
if encoding in {"base64", "b64", "base-64"}:
|
||||
decoded = self._decode_base64_text(value)
|
||||
if decoded is not None:
|
||||
return decoded.replace("\r\n", "\n")
|
||||
decoded = self._decode_base64_text(value)
|
||||
if decoded is not None:
|
||||
return decoded.replace("\r\n", "\n")
|
||||
return value.replace("\r\n", "\n")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _encode_script_content(script_text: Any) -> str:
|
||||
if not isinstance(script_text, str):
|
||||
if script_text is None:
|
||||
script_text = ""
|
||||
else:
|
||||
script_text = str(script_text)
|
||||
normalized = script_text.replace("\r\n", "\n")
|
||||
if not normalized:
|
||||
return ""
|
||||
encoded = base64.b64encode(normalized.encode("utf-8"))
|
||||
return encoded.decode("ascii")
|
||||
|
||||
def _prepare_assembly_storage(self, document: Dict[str, Any]) -> Dict[str, Any]:
|
||||
stored: Dict[str, Any] = {}
|
||||
for key, value in (document or {}).items():
|
||||
if key == "script":
|
||||
stored[key] = self._encode_script_content(value)
|
||||
else:
|
||||
stored[key] = value
|
||||
stored["script_encoding"] = "base64"
|
||||
return stored
|
||||
|
||||
def _normalize_assembly_document(
|
||||
self,
|
||||
obj: Any,
|
||||
default_type: str,
|
||||
base_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
doc = self._empty_assembly_document(default_type)
|
||||
if not isinstance(obj, dict):
|
||||
obj = {}
|
||||
base = (base_name or "assembly").strip()
|
||||
doc["name"] = str(obj.get("name") or obj.get("display_name") or base)
|
||||
doc["description"] = str(obj.get("description") or "")
|
||||
category = str(obj.get("category") or doc["category"]).strip().lower()
|
||||
if category in {"script", "application"}:
|
||||
doc["category"] = category
|
||||
typ = str(obj.get("type") or obj.get("script_type") or default_type or "powershell").strip().lower()
|
||||
if typ in {"powershell", "batch", "bash", "ansible"}:
|
||||
doc["type"] = typ
|
||||
script_val = obj.get("script")
|
||||
content_val = obj.get("content")
|
||||
script_lines = obj.get("script_lines")
|
||||
if isinstance(script_lines, list):
|
||||
try:
|
||||
doc["script"] = "\n".join(str(line) for line in script_lines)
|
||||
except Exception:
|
||||
doc["script"] = ""
|
||||
elif isinstance(script_val, str):
|
||||
doc["script"] = script_val
|
||||
elif isinstance(content_val, str):
|
||||
doc["script"] = content_val
|
||||
encoding_hint = str(
|
||||
obj.get("script_encoding") or obj.get("scriptEncoding") or ""
|
||||
).strip().lower()
|
||||
doc["script"] = self._decode_script_content(doc.get("script"), encoding_hint)
|
||||
if encoding_hint in {"base64", "b64", "base-64"}:
|
||||
doc["script_encoding"] = "base64"
|
||||
else:
|
||||
probe_source = ""
|
||||
if isinstance(script_val, str) and script_val:
|
||||
probe_source = script_val
|
||||
elif isinstance(content_val, str) and content_val:
|
||||
probe_source = content_val
|
||||
decoded_probe = self._decode_base64_text(probe_source) if probe_source else None
|
||||
if decoded_probe is not None:
|
||||
doc["script_encoding"] = "base64"
|
||||
doc["script"] = decoded_probe.replace("\r\n", "\n")
|
||||
else:
|
||||
doc["script_encoding"] = "plain"
|
||||
timeout_val = obj.get("timeout_seconds", obj.get("timeout"))
|
||||
if timeout_val is not None:
|
||||
try:
|
||||
doc["timeout_seconds"] = max(0, int(timeout_val))
|
||||
except Exception:
|
||||
pass
|
||||
sites = obj.get("sites") if isinstance(obj.get("sites"), dict) else {}
|
||||
values = sites.get("values") if isinstance(sites.get("values"), list) else []
|
||||
mode = str(sites.get("mode") or ("specific" if values else "all")).strip().lower()
|
||||
if mode not in {"all", "specific"}:
|
||||
mode = "all"
|
||||
doc["sites"] = {
|
||||
"mode": mode,
|
||||
"values": [
|
||||
str(v).strip()
|
||||
for v in values
|
||||
if isinstance(v, (str, int, float)) and str(v).strip()
|
||||
],
|
||||
}
|
||||
vars_in = obj.get("variables") if isinstance(obj.get("variables"), list) else []
|
||||
doc_vars: List[Dict[str, Any]] = []
|
||||
for entry in vars_in:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
name = str(entry.get("name") or entry.get("key") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
vtype = str(entry.get("type") or "string").strip().lower()
|
||||
if vtype not in {"string", "number", "boolean", "credential"}:
|
||||
vtype = "string"
|
||||
default_val = entry.get("default", entry.get("default_value"))
|
||||
doc_vars.append(
|
||||
{
|
||||
"name": name,
|
||||
"label": str(entry.get("label") or ""),
|
||||
"type": vtype,
|
||||
"default": default_val,
|
||||
"required": bool(entry.get("required")),
|
||||
"description": str(entry.get("description") or ""),
|
||||
}
|
||||
)
|
||||
doc["variables"] = doc_vars
|
||||
files_in = obj.get("files") if isinstance(obj.get("files"), list) else []
|
||||
doc_files: List[Dict[str, Any]] = []
|
||||
for record in files_in:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
fname = record.get("file_name") or record.get("name")
|
||||
data = record.get("data")
|
||||
if not fname or not isinstance(data, str):
|
||||
continue
|
||||
size_val = record.get("size")
|
||||
try:
|
||||
size_int = int(size_val)
|
||||
except Exception:
|
||||
size_int = 0
|
||||
doc_files.append(
|
||||
{
|
||||
"file_name": str(fname),
|
||||
"size": size_int,
|
||||
"mime_type": str(record.get("mime_type") or record.get("mimeType") or ""),
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
doc["files"] = doc_files
|
||||
try:
|
||||
doc["version"] = int(obj.get("version") or doc["version"])
|
||||
except Exception:
|
||||
pass
|
||||
return doc
|
||||
|
||||
def _load_assembly_document(
|
||||
self,
|
||||
abs_path: Path,
|
||||
island: str,
|
||||
type_hint: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
base_name = abs_path.stem
|
||||
default_type = self._default_type_for_island(island, type_hint)
|
||||
if abs_path.suffix.lower() == ".json":
|
||||
data = self._safe_read_json(abs_path)
|
||||
return self._normalize_assembly_document(data, default_type, base_name)
|
||||
try:
|
||||
content = abs_path.read_text(encoding="utf-8", errors="replace")
|
||||
except Exception:
|
||||
content = ""
|
||||
document = self._empty_assembly_document(default_type)
|
||||
document["name"] = base_name
|
||||
document["script"] = (content or "").replace("\r\n", "\n")
|
||||
if default_type == "ansible":
|
||||
document["category"] = "application"
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def _safe_read_json(path: Path) -> Dict[str, Any]:
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _extract_tab_name(obj: Dict[str, Any]) -> str:
|
||||
if not isinstance(obj, dict):
|
||||
return ""
|
||||
for key in ["tabName", "tab_name", "name", "title"]:
|
||||
value = obj.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return ""
|
||||
|
||||
def _detect_script_type(self, path: Path) -> str:
|
||||
lower = path.name.lower()
|
||||
if lower.endswith(".json") and path.is_file():
|
||||
obj = self._safe_read_json(path)
|
||||
if isinstance(obj, dict):
|
||||
typ = str(
|
||||
obj.get("type") or obj.get("script_type") or ""
|
||||
).strip().lower()
|
||||
if typ in {"powershell", "batch", "bash", "ansible"}:
|
||||
return typ
|
||||
return "powershell"
|
||||
if lower.endswith(".yml"):
|
||||
return "ansible"
|
||||
if lower.endswith(".ps1"):
|
||||
return "powershell"
|
||||
if lower.endswith(".bat"):
|
||||
return "batch"
|
||||
if lower.endswith(".sh"):
|
||||
return "bash"
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def _ensure_workflow_document(content: Any) -> Dict[str, Any]:
|
||||
payload = content
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
payload = json.loads(payload)
|
||||
except Exception:
|
||||
payload = {}
|
||||
if not isinstance(payload, dict):
|
||||
payload = {}
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _write_json(path: Path, payload: Dict[str, Any]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||
@@ -1,63 +0,0 @@
|
||||
"""Authentication services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .device_auth_service import DeviceAuthService, DeviceRecord
|
||||
from .dpop import DPoPReplayError, DPoPVerificationError, DPoPValidator
|
||||
from .jwt_service import JWTService, load_service as load_jwt_service
|
||||
from .token_service import (
|
||||
RefreshTokenRecord,
|
||||
TokenRefreshError,
|
||||
TokenRefreshErrorCode,
|
||||
TokenService,
|
||||
)
|
||||
from .operator_account_service import (
|
||||
AccountNotFoundError,
|
||||
CannotModifySelfError,
|
||||
InvalidPasswordHashError,
|
||||
InvalidRoleError,
|
||||
LastAdminError,
|
||||
LastUserError,
|
||||
OperatorAccountError,
|
||||
OperatorAccountRecord,
|
||||
OperatorAccountService,
|
||||
UsernameAlreadyExistsError,
|
||||
)
|
||||
from .operator_auth_service import (
|
||||
InvalidCredentialsError,
|
||||
InvalidMFACodeError,
|
||||
MFAUnavailableError,
|
||||
MFASessionError,
|
||||
OperatorAuthError,
|
||||
OperatorAuthService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
"DeviceRecord",
|
||||
"DPoPReplayError",
|
||||
"DPoPVerificationError",
|
||||
"DPoPValidator",
|
||||
"JWTService",
|
||||
"load_jwt_service",
|
||||
"RefreshTokenRecord",
|
||||
"TokenRefreshError",
|
||||
"TokenRefreshErrorCode",
|
||||
"TokenService",
|
||||
"OperatorAccountService",
|
||||
"OperatorAccountError",
|
||||
"OperatorAccountRecord",
|
||||
"UsernameAlreadyExistsError",
|
||||
"AccountNotFoundError",
|
||||
"LastAdminError",
|
||||
"LastUserError",
|
||||
"CannotModifySelfError",
|
||||
"InvalidRoleError",
|
||||
"InvalidPasswordHashError",
|
||||
"OperatorAuthService",
|
||||
"OperatorAuthError",
|
||||
"InvalidCredentialsError",
|
||||
"InvalidMFACodeError",
|
||||
"MFAUnavailableError",
|
||||
"MFASessionError",
|
||||
]
|
||||
@@ -1,237 +0,0 @@
|
||||
"""Device authentication service copied from the legacy server stack."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Mapping, Optional, Protocol
|
||||
import logging
|
||||
|
||||
from Data.Engine.builders.device_auth import DeviceAuthRequest
|
||||
from Data.Engine.domain.device_auth import (
|
||||
AccessTokenClaims,
|
||||
DeviceAuthContext,
|
||||
DeviceAuthErrorCode,
|
||||
DeviceAuthFailure,
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
DeviceIdentity,
|
||||
DeviceStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
"DeviceRecord",
|
||||
"DPoPValidator",
|
||||
"DPoPVerificationError",
|
||||
"DPoPReplayError",
|
||||
"RateLimiter",
|
||||
"RateLimitDecision",
|
||||
"DeviceRepository",
|
||||
]
|
||||
|
||||
|
||||
class RateLimitDecision(Protocol):
|
||||
allowed: bool
|
||||
retry_after: Optional[float]
|
||||
|
||||
|
||||
class RateLimiter(Protocol):
|
||||
def check(self, key: str, max_requests: int, window_seconds: float) -> RateLimitDecision: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class JWTDecoder(Protocol):
|
||||
def decode(self, token: str) -> Mapping[str, object]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class DPoPValidator(Protocol):
|
||||
def verify(
|
||||
self,
|
||||
method: str,
|
||||
htu: str,
|
||||
proof: str,
|
||||
access_token: Optional[str] = None,
|
||||
) -> str: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class DPoPVerificationError(Exception):
|
||||
"""Raised when a DPoP proof fails validation."""
|
||||
|
||||
|
||||
class DPoPReplayError(DPoPVerificationError):
|
||||
"""Raised when a DPoP proof is replayed."""
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceRecord:
|
||||
"""Snapshot of a device record required for authentication."""
|
||||
|
||||
identity: DeviceIdentity
|
||||
token_version: int
|
||||
status: DeviceStatus
|
||||
|
||||
|
||||
class DeviceRepository(Protocol):
|
||||
"""Port that exposes the minimal device persistence operations."""
|
||||
|
||||
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def recover_missing(
|
||||
self,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
token_version: int,
|
||||
service_context: Optional[str],
|
||||
) -> Optional[DeviceRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class DeviceAuthService:
|
||||
"""Authenticate devices using access tokens, repositories, and DPoP proofs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_repository: DeviceRepository,
|
||||
jwt_service: JWTDecoder,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
rate_limiter: Optional[RateLimiter] = None,
|
||||
dpop_validator: Optional[DPoPValidator] = None,
|
||||
) -> None:
|
||||
self._repository = device_repository
|
||||
self._jwt = jwt_service
|
||||
self._log = logger or logging.getLogger("borealis.engine.auth")
|
||||
self._rate_limiter = rate_limiter
|
||||
self._dpop_validator = dpop_validator
|
||||
|
||||
def authenticate(self, request: DeviceAuthRequest, *, path: str) -> DeviceAuthContext:
|
||||
"""Authenticate an access token and return the resulting context."""
|
||||
|
||||
claims = self._decode_claims(request.access_token)
|
||||
rate_limit_key = f"fp:{claims.fingerprint.value}"
|
||||
if self._rate_limiter is not None:
|
||||
decision = self._rate_limiter.check(rate_limit_key, 60, 60.0)
|
||||
if not decision.allowed:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.RATE_LIMITED,
|
||||
http_status=429,
|
||||
retry_after=decision.retry_after,
|
||||
)
|
||||
|
||||
record = self._repository.fetch_by_guid(claims.guid)
|
||||
if record is None:
|
||||
record = self._repository.recover_missing(
|
||||
claims.guid,
|
||||
claims.fingerprint,
|
||||
claims.token_version,
|
||||
request.service_context,
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DEVICE_NOT_FOUND,
|
||||
http_status=403,
|
||||
)
|
||||
|
||||
self._validate_identity(record, claims)
|
||||
|
||||
dpop_jkt = self._validate_dpop(request, record, claims)
|
||||
|
||||
context = DeviceAuthContext(
|
||||
identity=record.identity,
|
||||
access_token=request.access_token,
|
||||
claims=claims,
|
||||
status=record.status,
|
||||
service_context=request.service_context,
|
||||
dpop_jkt=dpop_jkt,
|
||||
)
|
||||
|
||||
if context.is_quarantined:
|
||||
self._log.warning(
|
||||
"device %s is quarantined; limited access for %s",
|
||||
record.identity.guid,
|
||||
path,
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
def _decode_claims(self, token: str) -> AccessTokenClaims:
|
||||
try:
|
||||
raw_claims = self._jwt.decode(token)
|
||||
except Exception as exc: # pragma: no cover - defensive fallback
|
||||
if self._is_expired_signature(exc):
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.TOKEN_EXPIRED) from exc
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN) from exc
|
||||
|
||||
try:
|
||||
return AccessTokenClaims.from_mapping(raw_claims)
|
||||
except Exception as exc:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_CLAIMS) from exc
|
||||
|
||||
@staticmethod
|
||||
def _is_expired_signature(exc: Exception) -> bool:
|
||||
name = exc.__class__.__name__
|
||||
return name == "ExpiredSignatureError"
|
||||
|
||||
def _validate_identity(
|
||||
self,
|
||||
record: DeviceRecord,
|
||||
claims: AccessTokenClaims,
|
||||
) -> None:
|
||||
if record.identity.guid.value != claims.guid.value:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DEVICE_GUID_MISMATCH,
|
||||
http_status=403,
|
||||
)
|
||||
|
||||
if record.identity.fingerprint.value:
|
||||
if record.identity.fingerprint.value != claims.fingerprint.value:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.FINGERPRINT_MISMATCH,
|
||||
http_status=403,
|
||||
)
|
||||
|
||||
if record.token_version > claims.token_version:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.TOKEN_VERSION_REVOKED)
|
||||
|
||||
if not record.status.allows_access:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DEVICE_REVOKED,
|
||||
http_status=403,
|
||||
)
|
||||
|
||||
def _validate_dpop(
|
||||
self,
|
||||
request: DeviceAuthRequest,
|
||||
record: DeviceRecord,
|
||||
claims: AccessTokenClaims,
|
||||
) -> Optional[str]:
|
||||
if not request.dpop_proof:
|
||||
return None
|
||||
|
||||
if self._dpop_validator is None:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DPOP_NOT_SUPPORTED,
|
||||
http_status=400,
|
||||
)
|
||||
|
||||
try:
|
||||
return self._dpop_validator.verify(
|
||||
request.http_method,
|
||||
request.htu,
|
||||
request.dpop_proof,
|
||||
request.access_token,
|
||||
)
|
||||
except DPoPReplayError as exc:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DPOP_REPLAYED,
|
||||
http_status=400,
|
||||
) from exc
|
||||
except DPoPVerificationError as exc:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DPOP_INVALID,
|
||||
http_status=400,
|
||||
) from exc
|
||||
@@ -1,105 +0,0 @@
|
||||
"""DPoP proof validation for Engine services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict, Optional
|
||||
|
||||
import jwt
|
||||
|
||||
__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"]
|
||||
|
||||
|
||||
_DP0P_MAX_SKEW = 300.0
|
||||
|
||||
|
||||
class DPoPVerificationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DPoPReplayError(DPoPVerificationError):
|
||||
pass
|
||||
|
||||
|
||||
class DPoPValidator:
|
||||
def __init__(self) -> None:
|
||||
self._observed_jti: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def verify(
|
||||
self,
|
||||
method: str,
|
||||
htu: str,
|
||||
proof: str,
|
||||
access_token: Optional[str] = None,
|
||||
) -> str:
|
||||
if not proof:
|
||||
raise DPoPVerificationError("DPoP proof missing")
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(proof)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP header") from exc
|
||||
|
||||
jwk = header.get("jwk")
|
||||
alg = header.get("alg")
|
||||
if not jwk or not isinstance(jwk, dict):
|
||||
raise DPoPVerificationError("missing jwk in DPoP header")
|
||||
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
|
||||
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
|
||||
|
||||
try:
|
||||
key = jwt.PyJWK(jwk)
|
||||
public_key = key.key
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
|
||||
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
proof,
|
||||
public_key,
|
||||
algorithms=[alg],
|
||||
options={"require": ["htm", "htu", "jti", "iat"]},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP signature") from exc
|
||||
|
||||
htm = claims.get("htm")
|
||||
proof_htu = claims.get("htu")
|
||||
jti = claims.get("jti")
|
||||
iat = claims.get("iat")
|
||||
ath = claims.get("ath")
|
||||
|
||||
if not isinstance(htm, str) or htm.lower() != method.lower():
|
||||
raise DPoPVerificationError("DPoP htm mismatch")
|
||||
if not isinstance(proof_htu, str) or proof_htu != htu:
|
||||
raise DPoPVerificationError("DPoP htu mismatch")
|
||||
if not isinstance(jti, str):
|
||||
raise DPoPVerificationError("DPoP jti missing")
|
||||
if not isinstance(iat, (int, float)):
|
||||
raise DPoPVerificationError("DPoP iat missing")
|
||||
|
||||
now = time.time()
|
||||
if abs(now - float(iat)) > _DP0P_MAX_SKEW:
|
||||
raise DPoPVerificationError("DPoP proof outside allowed skew")
|
||||
|
||||
if ath and access_token:
|
||||
expected_ath = jwt.utils.base64url_encode(
|
||||
hashlib.sha256(access_token.encode("utf-8")).digest()
|
||||
).decode("ascii")
|
||||
if expected_ath != ath:
|
||||
raise DPoPVerificationError("DPoP ath mismatch")
|
||||
|
||||
with self._lock:
|
||||
expiry = self._observed_jti.get(jti)
|
||||
if expiry and expiry > now:
|
||||
raise DPoPReplayError("DPoP proof replay detected")
|
||||
self._observed_jti[jti] = now + _DP0P_MAX_SKEW
|
||||
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
|
||||
for key in stale:
|
||||
self._observed_jti.pop(key, None)
|
||||
|
||||
thumbprint = jwt.PyJWK(jwk).thumbprint()
|
||||
return thumbprint.decode("ascii")
|
||||
@@ -1,124 +0,0 @@
|
||||
"""JWT issuance utilities for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from Data.Engine.runtime import ensure_runtime_dir, runtime_path
|
||||
|
||||
__all__ = ["JWTService", "load_service"]
|
||||
|
||||
|
||||
_KEY_DIR = runtime_path("auth_keys")
|
||||
_KEY_FILE = _KEY_DIR / "engine-jwt-ed25519.key"
|
||||
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-jwt-ed25519.key"
|
||||
|
||||
|
||||
class JWTService:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str) -> None:
|
||||
self._private_key = private_key
|
||||
self._public_key = private_key.public_key()
|
||||
self._key_id = key_id
|
||||
|
||||
@property
|
||||
def key_id(self) -> str:
|
||||
return self._key_id
|
||||
|
||||
def issue_access_token(
|
||||
self,
|
||||
guid: str,
|
||||
ssl_key_fingerprint: str,
|
||||
token_version: int,
|
||||
expires_in: int = 900,
|
||||
extra_claims: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
now = int(time.time())
|
||||
payload: Dict[str, Any] = {
|
||||
"sub": f"device:{guid}",
|
||||
"guid": guid,
|
||||
"ssl_key_fingerprint": ssl_key_fingerprint,
|
||||
"token_version": int(token_version),
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": now + int(expires_in),
|
||||
}
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
self._private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
),
|
||||
algorithm="EdDSA",
|
||||
headers={"kid": self._key_id},
|
||||
)
|
||||
return token
|
||||
|
||||
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
|
||||
options = {"require": ["exp", "iat", "sub"]}
|
||||
public_pem = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=["EdDSA"],
|
||||
audience=audience,
|
||||
options=options,
|
||||
)
|
||||
|
||||
def public_jwk(self) -> Dict[str, Any]:
|
||||
public_bytes = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
|
||||
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
|
||||
|
||||
|
||||
def load_service() -> JWTService:
|
||||
private_key = _load_or_create_private_key()
|
||||
public_bytes = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
|
||||
return JWTService(private_key, key_id)
|
||||
|
||||
|
||||
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
|
||||
ensure_runtime_dir("auth_keys")
|
||||
|
||||
if _KEY_FILE.exists():
|
||||
with _KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
if _LEGACY_KEY_FILE.exists():
|
||||
with _LEGACY_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
_KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with _KEY_FILE.open("wb") as fh:
|
||||
fh.write(pem)
|
||||
try:
|
||||
if hasattr(_KEY_FILE, "chmod"):
|
||||
_KEY_FILE.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
return private_key
|
||||
@@ -1,211 +0,0 @@
|
||||
"""Operator account management service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from Data.Engine.domain import OperatorAccount
|
||||
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
|
||||
|
||||
|
||||
class OperatorAccountError(Exception):
|
||||
"""Base class for operator account management failures."""
|
||||
|
||||
|
||||
class UsernameAlreadyExistsError(OperatorAccountError):
|
||||
"""Raised when attempting to create an operator with a duplicate username."""
|
||||
|
||||
|
||||
class AccountNotFoundError(OperatorAccountError):
|
||||
"""Raised when the requested operator account cannot be located."""
|
||||
|
||||
|
||||
class LastAdminError(OperatorAccountError):
|
||||
"""Raised when attempting to demote or delete the last remaining admin."""
|
||||
|
||||
|
||||
class LastUserError(OperatorAccountError):
|
||||
"""Raised when attempting to delete the final operator account."""
|
||||
|
||||
|
||||
class CannotModifySelfError(OperatorAccountError):
|
||||
"""Raised when the caller attempts to delete themselves."""
|
||||
|
||||
|
||||
class InvalidRoleError(OperatorAccountError):
|
||||
"""Raised when a role value is invalid."""
|
||||
|
||||
|
||||
class InvalidPasswordHashError(OperatorAccountError):
|
||||
"""Raised when a password hash is malformed."""
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OperatorAccountRecord:
|
||||
username: str
|
||||
display_name: str
|
||||
role: str
|
||||
last_login: int
|
||||
created_at: int
|
||||
updated_at: int
|
||||
mfa_enabled: bool
|
||||
|
||||
|
||||
class OperatorAccountService:
|
||||
"""High-level operations for managing operator accounts."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: SQLiteUserRepository,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.operator_accounts")
|
||||
|
||||
def list_accounts(self) -> list[OperatorAccountRecord]:
|
||||
return [_to_record(account) for account in self._repository.list_accounts()]
|
||||
|
||||
def create_account(
|
||||
self,
|
||||
*,
|
||||
username: str,
|
||||
password_sha512: str,
|
||||
role: str,
|
||||
display_name: Optional[str] = None,
|
||||
) -> OperatorAccountRecord:
|
||||
normalized_role = self._normalize_role(role)
|
||||
username = (username or "").strip()
|
||||
password_sha512 = (password_sha512 or "").strip().lower()
|
||||
display_name = (display_name or username or "").strip()
|
||||
|
||||
if not username or not password_sha512:
|
||||
raise InvalidPasswordHashError("username and password are required")
|
||||
if len(password_sha512) != 128:
|
||||
raise InvalidPasswordHashError("password hash must be 128 hex characters")
|
||||
|
||||
now = int(time.time())
|
||||
try:
|
||||
self._repository.create_account(
|
||||
username=username,
|
||||
display_name=display_name or username,
|
||||
password_sha512=password_sha512,
|
||||
role=normalized_role,
|
||||
timestamp=now,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - sqlite integrity errors are deterministic
|
||||
import sqlite3
|
||||
|
||||
if isinstance(exc, sqlite3.IntegrityError):
|
||||
raise UsernameAlreadyExistsError("username already exists") from exc
|
||||
raise
|
||||
|
||||
account = self._repository.fetch_by_username(username)
|
||||
if not account: # pragma: no cover - sanity guard
|
||||
raise AccountNotFoundError("account creation failed")
|
||||
return _to_record(account)
|
||||
|
||||
def delete_account(self, username: str, *, actor: Optional[str] = None) -> None:
|
||||
username = (username or "").strip()
|
||||
if not username:
|
||||
raise AccountNotFoundError("invalid username")
|
||||
|
||||
if actor and actor.strip().lower() == username.lower():
|
||||
raise CannotModifySelfError("cannot delete yourself")
|
||||
|
||||
total_accounts = self._repository.count_accounts()
|
||||
if total_accounts <= 1:
|
||||
raise LastUserError("cannot delete the last user")
|
||||
|
||||
target = self._repository.fetch_by_username(username)
|
||||
if not target:
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
if target.role.lower() == "admin" and self._repository.count_admins() <= 1:
|
||||
raise LastAdminError("cannot delete the last admin")
|
||||
|
||||
if not self._repository.delete_account(username):
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
def reset_password(self, username: str, password_sha512: str) -> None:
|
||||
username = (username or "").strip()
|
||||
password_sha512 = (password_sha512 or "").strip().lower()
|
||||
if len(password_sha512) != 128:
|
||||
raise InvalidPasswordHashError("invalid password hash")
|
||||
|
||||
now = int(time.time())
|
||||
if not self._repository.update_password(username, password_sha512, timestamp=now):
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
def change_role(self, username: str, role: str, *, actor: Optional[str] = None) -> OperatorAccountRecord:
|
||||
username = (username or "").strip()
|
||||
normalized_role = self._normalize_role(role)
|
||||
|
||||
account = self._repository.fetch_by_username(username)
|
||||
if not account:
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
if account.role.lower() == "admin" and normalized_role.lower() != "admin":
|
||||
if self._repository.count_admins() <= 1:
|
||||
raise LastAdminError("cannot demote the last admin")
|
||||
|
||||
now = int(time.time())
|
||||
if not self._repository.update_role(username, normalized_role, timestamp=now):
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
updated = self._repository.fetch_by_username(username)
|
||||
if not updated: # pragma: no cover - guard
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
record = _to_record(updated)
|
||||
if actor and actor.strip().lower() == username.lower():
|
||||
self._log.info("actor-role-updated", extra={"username": username, "role": record.role})
|
||||
return record
|
||||
|
||||
def update_mfa(self, username: str, *, enabled: bool, reset_secret: bool) -> None:
|
||||
username = (username or "").strip()
|
||||
if not username:
|
||||
raise AccountNotFoundError("invalid username")
|
||||
|
||||
now = int(time.time())
|
||||
if not self._repository.update_mfa(username, enabled=enabled, reset_secret=reset_secret, timestamp=now):
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
def fetch_account(self, username: str) -> Optional[OperatorAccountRecord]:
|
||||
account = self._repository.fetch_by_username(username)
|
||||
return _to_record(account) if account else None
|
||||
|
||||
def _normalize_role(self, role: str) -> str:
|
||||
normalized = (role or "").strip().title() or "User"
|
||||
if normalized not in {"User", "Admin"}:
|
||||
raise InvalidRoleError("invalid role")
|
||||
return normalized
|
||||
|
||||
|
||||
def _to_record(account: OperatorAccount) -> OperatorAccountRecord:
|
||||
return OperatorAccountRecord(
|
||||
username=account.username,
|
||||
display_name=account.display_name or account.username,
|
||||
role=account.role or "User",
|
||||
last_login=int(account.last_login or 0),
|
||||
created_at=int(account.created_at or 0),
|
||||
updated_at=int(account.updated_at or 0),
|
||||
mfa_enabled=bool(account.mfa_enabled),
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OperatorAccountService",
|
||||
"OperatorAccountError",
|
||||
"UsernameAlreadyExistsError",
|
||||
"AccountNotFoundError",
|
||||
"LastAdminError",
|
||||
"LastUserError",
|
||||
"CannotModifySelfError",
|
||||
"InvalidRoleError",
|
||||
"InvalidPasswordHashError",
|
||||
"OperatorAccountRecord",
|
||||
]
|
||||
@@ -1,236 +0,0 @@
|
||||
"""Operator authentication service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
try: # pragma: no cover - optional dependencies mirror legacy server behaviour
|
||||
import pyotp # type: ignore
|
||||
except Exception: # pragma: no cover - gracefully degrade when unavailable
|
||||
pyotp = None # type: ignore
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import qrcode # type: ignore
|
||||
except Exception: # pragma: no cover - gracefully degrade when unavailable
|
||||
qrcode = None # type: ignore
|
||||
|
||||
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
|
||||
|
||||
from Data.Engine.builders.operator_auth import (
|
||||
OperatorLoginRequest,
|
||||
OperatorMFAVerificationRequest,
|
||||
)
|
||||
from Data.Engine.domain import (
|
||||
OperatorAccount,
|
||||
OperatorLoginSuccess,
|
||||
OperatorMFAChallenge,
|
||||
)
|
||||
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
|
||||
|
||||
|
||||
class OperatorAuthError(Exception):
|
||||
"""Base class for operator authentication errors."""
|
||||
|
||||
|
||||
class InvalidCredentialsError(OperatorAuthError):
|
||||
"""Raised when username/password verification fails."""
|
||||
|
||||
|
||||
class MFAUnavailableError(OperatorAuthError):
|
||||
"""Raised when MFA functionality is requested but dependencies are missing."""
|
||||
|
||||
|
||||
class InvalidMFACodeError(OperatorAuthError):
|
||||
"""Raised when the submitted MFA code is invalid."""
|
||||
|
||||
|
||||
class MFASessionError(OperatorAuthError):
|
||||
"""Raised when the MFA session state cannot be validated."""
|
||||
|
||||
|
||||
class OperatorAuthService:
|
||||
"""Authenticate operator accounts and manage MFA challenges."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: SQLiteUserRepository,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.operator_auth")
|
||||
|
||||
def authenticate(
|
||||
self, request: OperatorLoginRequest
|
||||
) -> OperatorLoginSuccess | OperatorMFAChallenge:
|
||||
account = self._repository.fetch_by_username(request.username)
|
||||
if not account:
|
||||
raise InvalidCredentialsError("invalid username or password")
|
||||
|
||||
if not self._password_matches(account, request.password_sha512):
|
||||
raise InvalidCredentialsError("invalid username or password")
|
||||
|
||||
if not account.mfa_enabled:
|
||||
return self._finalize_login(account)
|
||||
|
||||
stage = "verify" if account.mfa_secret else "setup"
|
||||
return self._build_mfa_challenge(account, stage)
|
||||
|
||||
def verify_mfa(
|
||||
self,
|
||||
challenge: OperatorMFAChallenge,
|
||||
request: OperatorMFAVerificationRequest,
|
||||
) -> OperatorLoginSuccess:
|
||||
now = int(time.time())
|
||||
if challenge.pending_token != request.pending_token:
|
||||
raise MFASessionError("invalid_session")
|
||||
if challenge.expires_at < now:
|
||||
raise MFASessionError("expired")
|
||||
|
||||
if challenge.stage == "setup":
|
||||
secret = (challenge.secret or "").strip()
|
||||
if not secret:
|
||||
raise MFASessionError("mfa_not_configured")
|
||||
totp = self._totp_for_secret(secret)
|
||||
if not totp.verify(request.code, valid_window=1):
|
||||
raise InvalidMFACodeError("invalid_code")
|
||||
self._repository.store_mfa_secret(challenge.username, secret, timestamp=now)
|
||||
else:
|
||||
account = self._repository.fetch_by_username(challenge.username)
|
||||
if not account or not account.mfa_secret:
|
||||
raise MFASessionError("mfa_not_configured")
|
||||
totp = self._totp_for_secret(account.mfa_secret)
|
||||
if not totp.verify(request.code, valid_window=1):
|
||||
raise InvalidMFACodeError("invalid_code")
|
||||
|
||||
account = self._repository.fetch_by_username(challenge.username)
|
||||
if not account:
|
||||
raise InvalidCredentialsError("invalid username or password")
|
||||
return self._finalize_login(account)
|
||||
|
||||
def issue_token(self, username: str, role: str) -> str:
|
||||
serializer = self._token_serializer()
|
||||
payload = {"u": username, "r": role or "User", "ts": int(time.time())}
|
||||
return serializer.dumps(payload)
|
||||
|
||||
def resolve_token(self, token: str, *, max_age: int = 30 * 24 * 3600) -> Optional[OperatorAccount]:
|
||||
"""Return the account associated with *token* if it is valid."""
|
||||
|
||||
token = (token or "").strip()
|
||||
if not token:
|
||||
return None
|
||||
|
||||
serializer = self._token_serializer()
|
||||
try:
|
||||
payload = serializer.loads(token, max_age=max_age)
|
||||
except (BadSignature, SignatureExpired):
|
||||
return None
|
||||
|
||||
username = str(payload.get("u") or "").strip()
|
||||
if not username:
|
||||
return None
|
||||
|
||||
return self._repository.fetch_by_username(username)
|
||||
|
||||
def fetch_account(self, username: str) -> Optional[OperatorAccount]:
|
||||
"""Return the operator account for *username* if it exists."""
|
||||
|
||||
username = (username or "").strip()
|
||||
if not username:
|
||||
return None
|
||||
return self._repository.fetch_by_username(username)
|
||||
|
||||
def _finalize_login(self, account: OperatorAccount) -> OperatorLoginSuccess:
|
||||
now = int(time.time())
|
||||
self._repository.update_last_login(account.username, now)
|
||||
token = self.issue_token(account.username, account.role)
|
||||
return OperatorLoginSuccess(username=account.username, role=account.role, token=token)
|
||||
|
||||
def _password_matches(self, account: OperatorAccount, provided_hash: str) -> bool:
|
||||
expected = (account.password_sha512 or "").strip().lower()
|
||||
candidate = (provided_hash or "").strip().lower()
|
||||
return bool(expected and candidate and expected == candidate)
|
||||
|
||||
def _build_mfa_challenge(
|
||||
self,
|
||||
account: OperatorAccount,
|
||||
stage: str,
|
||||
) -> OperatorMFAChallenge:
|
||||
now = int(time.time())
|
||||
pending_token = uuid.uuid4().hex
|
||||
secret = None
|
||||
otpauth_url = None
|
||||
qr_image = None
|
||||
|
||||
if stage == "setup":
|
||||
secret = self._generate_totp_secret()
|
||||
otpauth_url = self._totp_provisioning_uri(secret, account.username)
|
||||
qr_image = self._totp_qr_data_uri(otpauth_url) if otpauth_url else None
|
||||
|
||||
return OperatorMFAChallenge(
|
||||
username=account.username,
|
||||
role=account.role,
|
||||
stage="verify" if stage == "verify" else "setup",
|
||||
pending_token=pending_token,
|
||||
expires_at=now + 300,
|
||||
secret=secret,
|
||||
otpauth_url=otpauth_url,
|
||||
qr_image=qr_image,
|
||||
)
|
||||
|
||||
def _token_serializer(self) -> URLSafeTimedSerializer:
|
||||
secret = os.getenv("BOREALIS_FLASK_SECRET_KEY") or "change-me"
|
||||
return URLSafeTimedSerializer(secret, salt="borealis-auth")
|
||||
|
||||
def _generate_totp_secret(self) -> str:
|
||||
if not pyotp:
|
||||
raise MFAUnavailableError("pyotp is not installed; MFA unavailable")
|
||||
return pyotp.random_base32() # type: ignore[no-any-return]
|
||||
|
||||
def _totp_for_secret(self, secret: str):
|
||||
if not pyotp:
|
||||
raise MFAUnavailableError("pyotp is not installed; MFA unavailable")
|
||||
normalized = secret.replace(" ", "").strip().upper()
|
||||
if not normalized:
|
||||
raise MFASessionError("mfa_not_configured")
|
||||
return pyotp.TOTP(normalized, digits=6, interval=30)
|
||||
|
||||
def _totp_provisioning_uri(self, secret: str, username: str) -> Optional[str]:
|
||||
try:
|
||||
totp = self._totp_for_secret(secret)
|
||||
except OperatorAuthError:
|
||||
return None
|
||||
issuer = os.getenv("BOREALIS_MFA_ISSUER", "Borealis")
|
||||
try:
|
||||
return totp.provisioning_uri(name=username, issuer_name=issuer)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return None
|
||||
|
||||
def _totp_qr_data_uri(self, payload: str) -> Optional[str]:
|
||||
if not payload or qrcode is None:
|
||||
return None
|
||||
try:
|
||||
img = qrcode.make(payload, box_size=6, border=4)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
encoded = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{encoded}"
|
||||
except Exception: # pragma: no cover - defensive
|
||||
self._log.warning("failed to generate MFA QR code", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OperatorAuthService",
|
||||
"OperatorAuthError",
|
||||
"InvalidCredentialsError",
|
||||
"MFAUnavailableError",
|
||||
"InvalidMFACodeError",
|
||||
"MFASessionError",
|
||||
]
|
||||
@@ -1,190 +0,0 @@
|
||||
"""Token refresh service extracted from the legacy blueprint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Protocol
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from Data.Engine.builders.device_auth import RefreshTokenRequest
|
||||
from Data.Engine.domain.device_auth import DeviceGuid
|
||||
|
||||
from .device_auth_service import (
|
||||
DeviceRecord,
|
||||
DeviceRepository,
|
||||
DPoPReplayError,
|
||||
DPoPVerificationError,
|
||||
DPoPValidator,
|
||||
)
|
||||
|
||||
__all__ = ["RefreshTokenRecord", "TokenService", "TokenRefreshError", "TokenRefreshErrorCode"]
|
||||
|
||||
|
||||
class JWTIssuer(Protocol):
|
||||
def issue_access_token(self, guid: str, fingerprint: str, token_version: int) -> str: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class TokenRefreshErrorCode(str):
|
||||
INVALID_REFRESH_TOKEN = "invalid_refresh_token"
|
||||
REFRESH_TOKEN_REVOKED = "refresh_token_revoked"
|
||||
REFRESH_TOKEN_EXPIRED = "refresh_token_expired"
|
||||
DEVICE_NOT_FOUND = "device_not_found"
|
||||
DEVICE_REVOKED = "device_revoked"
|
||||
DPOP_REPLAYED = "dpop_replayed"
|
||||
DPOP_INVALID = "dpop_invalid"
|
||||
|
||||
|
||||
class TokenRefreshError(Exception):
|
||||
def __init__(self, code: str, *, http_status: int = 400) -> None:
|
||||
self.code = code
|
||||
self.http_status = http_status
|
||||
super().__init__(code)
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
return {"error": self.code}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RefreshTokenRecord:
|
||||
record_id: str
|
||||
guid: DeviceGuid
|
||||
token_hash: str
|
||||
dpop_jkt: Optional[str]
|
||||
created_at: datetime
|
||||
expires_at: Optional[datetime]
|
||||
revoked_at: Optional[datetime]
|
||||
|
||||
@classmethod
|
||||
def from_row(
|
||||
cls,
|
||||
*,
|
||||
record_id: str,
|
||||
guid: DeviceGuid,
|
||||
token_hash: str,
|
||||
dpop_jkt: Optional[str],
|
||||
created_at: datetime,
|
||||
expires_at: Optional[datetime],
|
||||
revoked_at: Optional[datetime],
|
||||
) -> "RefreshTokenRecord":
|
||||
return cls(
|
||||
record_id=record_id,
|
||||
guid=guid,
|
||||
token_hash=token_hash,
|
||||
dpop_jkt=dpop_jkt,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
revoked_at=revoked_at,
|
||||
)
|
||||
|
||||
|
||||
class RefreshTokenRepository(Protocol):
|
||||
def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def clear_dpop_binding(self, record_id: str) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def touch(self, record_id: str, *, last_used_at: datetime, dpop_jkt: Optional[str]) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AccessTokenResponse:
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
refresh_token_repository: RefreshTokenRepository,
|
||||
device_repository: DeviceRepository,
|
||||
jwt_service: JWTIssuer,
|
||||
dpop_validator: Optional[DPoPValidator] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._refresh_tokens = refresh_token_repository
|
||||
self._devices = device_repository
|
||||
self._jwt = jwt_service
|
||||
self._dpop_validator = dpop_validator
|
||||
self._log = logger or logging.getLogger("borealis.engine.auth")
|
||||
|
||||
def refresh_access_token(
|
||||
self,
|
||||
request: RefreshTokenRequest,
|
||||
) -> AccessTokenResponse:
|
||||
record = self._refresh_tokens.fetch(
|
||||
request.guid,
|
||||
self._hash_token(request.refresh_token),
|
||||
)
|
||||
if record is None:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.INVALID_REFRESH_TOKEN, http_status=401)
|
||||
|
||||
if record.guid.value != request.guid.value:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.INVALID_REFRESH_TOKEN, http_status=401)
|
||||
|
||||
if record.revoked_at is not None:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.REFRESH_TOKEN_REVOKED, http_status=401)
|
||||
|
||||
if record.expires_at is not None and record.expires_at <= self._now():
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.REFRESH_TOKEN_EXPIRED, http_status=401)
|
||||
|
||||
device = self._devices.fetch_by_guid(request.guid)
|
||||
if device is None:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DEVICE_NOT_FOUND, http_status=404)
|
||||
|
||||
if not device.status.allows_access:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DEVICE_REVOKED, http_status=403)
|
||||
|
||||
dpop_jkt = record.dpop_jkt or ""
|
||||
if request.dpop_proof:
|
||||
if self._dpop_validator is None:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DPOP_INVALID)
|
||||
try:
|
||||
dpop_jkt = self._dpop_validator.verify(
|
||||
request.http_method,
|
||||
request.htu,
|
||||
request.dpop_proof,
|
||||
None,
|
||||
)
|
||||
except DPoPReplayError as exc:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DPOP_REPLAYED) from exc
|
||||
except DPoPVerificationError as exc:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DPOP_INVALID) from exc
|
||||
elif record.dpop_jkt:
|
||||
self._log.warning(
|
||||
"Clearing stored DPoP binding for guid=%s due to missing proof",
|
||||
request.guid.value,
|
||||
)
|
||||
self._refresh_tokens.clear_dpop_binding(record.record_id)
|
||||
|
||||
access_token = self._jwt.issue_access_token(
|
||||
request.guid.value,
|
||||
device.identity.fingerprint.value,
|
||||
max(device.token_version, 1),
|
||||
)
|
||||
|
||||
self._refresh_tokens.touch(
|
||||
record.record_id,
|
||||
last_used_at=self._now(),
|
||||
dpop_jkt=dpop_jkt or None,
|
||||
)
|
||||
|
||||
return AccessTokenResponse(
|
||||
access_token=access_token,
|
||||
expires_in=900,
|
||||
token_type="Bearer",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
@@ -1,238 +0,0 @@
|
||||
"""Service container assembly for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from Data.Engine.config import EngineSettings
|
||||
from Data.Engine.integrations.github import GitHubArtifactProvider
|
||||
from Data.Engine.repositories.sqlite import (
|
||||
SQLiteConnectionFactory,
|
||||
SQLiteDeviceRepository,
|
||||
SQLiteDeviceInventoryRepository,
|
||||
SQLiteDeviceViewRepository,
|
||||
SQLiteCredentialRepository,
|
||||
SQLiteEnrollmentRepository,
|
||||
SQLiteGitHubRepository,
|
||||
SQLiteJobRepository,
|
||||
SQLiteRefreshTokenRepository,
|
||||
SQLiteSiteRepository,
|
||||
SQLiteUserRepository,
|
||||
)
|
||||
from Data.Engine.services.auth import (
|
||||
DeviceAuthService,
|
||||
DPoPValidator,
|
||||
OperatorAccountService,
|
||||
OperatorAuthService,
|
||||
JWTService,
|
||||
TokenService,
|
||||
load_jwt_service,
|
||||
)
|
||||
from Data.Engine.services.crypto.signing import ScriptSigner, load_signer
|
||||
from Data.Engine.services.enrollment import EnrollmentService
|
||||
from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService
|
||||
from Data.Engine.services.enrollment.nonce_cache import NonceCache
|
||||
from Data.Engine.services.devices import DeviceInventoryService
|
||||
from Data.Engine.services.devices import DeviceViewService
|
||||
from Data.Engine.services.credentials import CredentialService
|
||||
from Data.Engine.services.github import GitHubService
|
||||
from Data.Engine.services.jobs import SchedulerService
|
||||
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
|
||||
from Data.Engine.services.realtime import AgentRealtimeService
|
||||
from Data.Engine.services.sites import SiteService
|
||||
from Data.Engine.services.assemblies import AssemblyService
|
||||
|
||||
__all__ = ["EngineServiceContainer", "build_service_container"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EngineServiceContainer:
|
||||
device_auth: DeviceAuthService
|
||||
device_inventory: DeviceInventoryService
|
||||
device_view_service: DeviceViewService
|
||||
credential_service: CredentialService
|
||||
token_service: TokenService
|
||||
enrollment_service: EnrollmentService
|
||||
enrollment_admin_service: EnrollmentAdminService
|
||||
site_service: SiteService
|
||||
jwt_service: JWTService
|
||||
dpop_validator: DPoPValidator
|
||||
agent_realtime: AgentRealtimeService
|
||||
scheduler_service: SchedulerService
|
||||
github_service: GitHubService
|
||||
operator_auth_service: OperatorAuthService
|
||||
operator_account_service: OperatorAccountService
|
||||
assembly_service: AssemblyService
|
||||
script_signer: Optional[ScriptSigner]
|
||||
|
||||
|
||||
def build_service_container(
|
||||
settings: EngineSettings,
|
||||
*,
|
||||
db_factory: SQLiteConnectionFactory,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> EngineServiceContainer:
|
||||
log = logger or logging.getLogger("borealis.engine.services")
|
||||
|
||||
device_repo = SQLiteDeviceRepository(db_factory, logger=log.getChild("devices"))
|
||||
device_inventory_repo = SQLiteDeviceInventoryRepository(
|
||||
db_factory, logger=log.getChild("devices.inventory")
|
||||
)
|
||||
device_view_repo = SQLiteDeviceViewRepository(
|
||||
db_factory, logger=log.getChild("devices.views")
|
||||
)
|
||||
credential_repo = SQLiteCredentialRepository(
|
||||
db_factory, logger=log.getChild("credentials.repo")
|
||||
)
|
||||
token_repo = SQLiteRefreshTokenRepository(db_factory, logger=log.getChild("tokens"))
|
||||
enrollment_repo = SQLiteEnrollmentRepository(db_factory, logger=log.getChild("enrollment"))
|
||||
job_repo = SQLiteJobRepository(db_factory, logger=log.getChild("jobs"))
|
||||
github_repo = SQLiteGitHubRepository(db_factory, logger=log.getChild("github_repo"))
|
||||
site_repo = SQLiteSiteRepository(db_factory, logger=log.getChild("sites.repo"))
|
||||
user_repo = SQLiteUserRepository(db_factory, logger=log.getChild("users"))
|
||||
|
||||
jwt_service = load_jwt_service()
|
||||
dpop_validator = DPoPValidator()
|
||||
rate_limiter = SlidingWindowRateLimiter()
|
||||
|
||||
token_service = TokenService(
|
||||
refresh_token_repository=token_repo,
|
||||
device_repository=device_repo,
|
||||
jwt_service=jwt_service,
|
||||
dpop_validator=dpop_validator,
|
||||
logger=log.getChild("token_service"),
|
||||
)
|
||||
|
||||
script_signer = _load_script_signer(log)
|
||||
|
||||
enrollment_service = EnrollmentService(
|
||||
device_repository=device_repo,
|
||||
enrollment_repository=enrollment_repo,
|
||||
token_repository=token_repo,
|
||||
jwt_service=jwt_service,
|
||||
tls_bundle_loader=_tls_bundle_loader(settings),
|
||||
ip_rate_limiter=SlidingWindowRateLimiter(),
|
||||
fingerprint_rate_limiter=SlidingWindowRateLimiter(),
|
||||
nonce_cache=NonceCache(),
|
||||
script_signer=script_signer,
|
||||
logger=log.getChild("enrollment"),
|
||||
)
|
||||
|
||||
enrollment_admin_service = EnrollmentAdminService(
|
||||
repository=enrollment_repo,
|
||||
user_repository=user_repo,
|
||||
logger=log.getChild("enrollment_admin"),
|
||||
)
|
||||
|
||||
device_auth = DeviceAuthService(
|
||||
device_repository=device_repo,
|
||||
jwt_service=jwt_service,
|
||||
logger=log.getChild("device_auth"),
|
||||
rate_limiter=rate_limiter,
|
||||
dpop_validator=dpop_validator,
|
||||
)
|
||||
|
||||
agent_realtime = AgentRealtimeService(
|
||||
device_repository=device_repo,
|
||||
logger=log.getChild("agent_realtime"),
|
||||
)
|
||||
|
||||
scheduler_service = SchedulerService(
|
||||
job_repository=job_repo,
|
||||
assemblies_root=settings.project_root / "Assemblies",
|
||||
logger=log.getChild("scheduler"),
|
||||
)
|
||||
|
||||
operator_auth_service = OperatorAuthService(
|
||||
repository=user_repo,
|
||||
logger=log.getChild("operator_auth"),
|
||||
)
|
||||
operator_account_service = OperatorAccountService(
|
||||
repository=user_repo,
|
||||
logger=log.getChild("operator_accounts"),
|
||||
)
|
||||
device_inventory = DeviceInventoryService(
|
||||
repository=device_inventory_repo,
|
||||
logger=log.getChild("device_inventory"),
|
||||
)
|
||||
device_view_service = DeviceViewService(
|
||||
repository=device_view_repo,
|
||||
logger=log.getChild("device_views"),
|
||||
)
|
||||
credential_service = CredentialService(
|
||||
repository=credential_repo,
|
||||
logger=log.getChild("credentials"),
|
||||
)
|
||||
site_service = SiteService(
|
||||
repository=site_repo,
|
||||
logger=log.getChild("sites"),
|
||||
)
|
||||
|
||||
assembly_service = AssemblyService(
|
||||
root=settings.project_root / "Assemblies",
|
||||
logger=log.getChild("assemblies"),
|
||||
)
|
||||
|
||||
github_provider = GitHubArtifactProvider(
|
||||
cache_file=settings.github.cache_file,
|
||||
default_repo=settings.github.default_repo,
|
||||
default_branch=settings.github.default_branch,
|
||||
refresh_interval=settings.github.refresh_interval_seconds,
|
||||
logger=log.getChild("github.provider"),
|
||||
)
|
||||
github_service = GitHubService(
|
||||
repository=github_repo,
|
||||
provider=github_provider,
|
||||
logger=log.getChild("github"),
|
||||
)
|
||||
github_service.start_background_refresh()
|
||||
|
||||
return EngineServiceContainer(
|
||||
device_auth=device_auth,
|
||||
token_service=token_service,
|
||||
enrollment_service=enrollment_service,
|
||||
enrollment_admin_service=enrollment_admin_service,
|
||||
jwt_service=jwt_service,
|
||||
dpop_validator=dpop_validator,
|
||||
agent_realtime=agent_realtime,
|
||||
scheduler_service=scheduler_service,
|
||||
github_service=github_service,
|
||||
operator_auth_service=operator_auth_service,
|
||||
operator_account_service=operator_account_service,
|
||||
device_inventory=device_inventory,
|
||||
device_view_service=device_view_service,
|
||||
credential_service=credential_service,
|
||||
site_service=site_service,
|
||||
assembly_service=assembly_service,
|
||||
script_signer=script_signer,
|
||||
)
|
||||
|
||||
|
||||
def _tls_bundle_loader(settings: EngineSettings) -> Callable[[], str]:
|
||||
candidates = [
|
||||
Path(os.getenv("BOREALIS_TLS_BUNDLE", "")),
|
||||
settings.project_root / "Certificates" / "Server" / "borealis-server-bundle.pem",
|
||||
]
|
||||
|
||||
def loader() -> str:
|
||||
for candidate in candidates:
|
||||
if candidate and candidate.is_file():
|
||||
try:
|
||||
return candidate.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
return ""
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def _load_script_signer(logger: logging.Logger) -> Optional[ScriptSigner]:
|
||||
try:
|
||||
return load_signer()
|
||||
except Exception as exc:
|
||||
logger.warning("script signer unavailable: %s", exc)
|
||||
return None
|
||||
@@ -1,3 +0,0 @@
|
||||
from .credential_service import CredentialService
|
||||
|
||||
__all__ = ["CredentialService"]
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Expose read access to stored credentials."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from Data.Engine.repositories.sqlite.credential_repository import SQLiteCredentialRepository
|
||||
|
||||
__all__ = ["CredentialService"]
|
||||
|
||||
|
||||
class CredentialService:
|
||||
def __init__(
|
||||
self,
|
||||
repository: SQLiteCredentialRepository,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._repo = repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.credentials")
|
||||
|
||||
def list_credentials(
|
||||
self,
|
||||
*,
|
||||
site_id: Optional[int] = None,
|
||||
connection_type: Optional[str] = None,
|
||||
) -> List[dict]:
|
||||
return self._repo.list_credentials(site_id=site_id, connection_type=connection_type)
|
||||
@@ -1,366 +0,0 @@
|
||||
"""Server TLS certificate management for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import ssl
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
|
||||
|
||||
from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path
|
||||
|
||||
__all__ = [
|
||||
"build_ssl_context",
|
||||
"certificate_paths",
|
||||
"ensure_certificate",
|
||||
]
|
||||
|
||||
_CERT_DIR = server_certificates_path()
|
||||
_CERT_FILE = _CERT_DIR / "borealis-server-cert.pem"
|
||||
_KEY_FILE = _CERT_DIR / "borealis-server-key.pem"
|
||||
_BUNDLE_FILE = _CERT_DIR / "borealis-server-bundle.pem"
|
||||
_CA_KEY_FILE = _CERT_DIR / "borealis-root-ca-key.pem"
|
||||
_CA_CERT_FILE = _CERT_DIR / "borealis-root-ca.pem"
|
||||
|
||||
_LEGACY_CERT_DIR = runtime_path("certs")
|
||||
_LEGACY_CERT_FILE = _LEGACY_CERT_DIR / "borealis-server-cert.pem"
|
||||
_LEGACY_KEY_FILE = _LEGACY_CERT_DIR / "borealis-server-key.pem"
|
||||
_LEGACY_BUNDLE_FILE = _LEGACY_CERT_DIR / "borealis-server-bundle.pem"
|
||||
|
||||
_ROOT_COMMON_NAME = "Borealis Root CA"
|
||||
_ORG_NAME = "Borealis"
|
||||
_ROOT_VALIDITY = timedelta(days=365 * 100)
|
||||
_SERVER_VALIDITY = timedelta(days=365 * 5)
|
||||
|
||||
|
||||
def ensure_certificate(common_name: str = "Borealis Engine") -> Tuple[Path, Path, Path]:
|
||||
"""Ensure the root CA, server certificate, and bundle exist on disk."""
|
||||
|
||||
ensure_server_certificates_dir()
|
||||
_migrate_legacy_material_if_present()
|
||||
|
||||
ca_key, ca_cert, ca_regenerated = _ensure_root_ca()
|
||||
|
||||
server_cert = _load_certificate(_CERT_FILE)
|
||||
needs_regen = ca_regenerated or _server_certificate_needs_regeneration(server_cert, ca_cert)
|
||||
if needs_regen:
|
||||
server_cert = _generate_server_certificate(common_name, ca_key, ca_cert)
|
||||
|
||||
if server_cert is None:
|
||||
server_cert = _generate_server_certificate(common_name, ca_key, ca_cert)
|
||||
|
||||
_write_bundle(server_cert, ca_cert)
|
||||
|
||||
return _CERT_FILE, _KEY_FILE, _BUNDLE_FILE
|
||||
|
||||
|
||||
def _migrate_legacy_material_if_present() -> None:
|
||||
if not _CERT_FILE.exists() or not _KEY_FILE.exists():
|
||||
legacy_cert = _LEGACY_CERT_FILE
|
||||
legacy_key = _LEGACY_KEY_FILE
|
||||
if legacy_cert.exists() and legacy_key.exists():
|
||||
try:
|
||||
ensure_server_certificates_dir()
|
||||
if not _CERT_FILE.exists():
|
||||
_safe_copy(legacy_cert, _CERT_FILE)
|
||||
if not _KEY_FILE.exists():
|
||||
_safe_copy(legacy_key, _KEY_FILE)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_root_ca() -> Tuple[ec.EllipticCurvePrivateKey, x509.Certificate, bool]:
|
||||
regenerated = False
|
||||
|
||||
ca_key: Optional[ec.EllipticCurvePrivateKey] = None
|
||||
ca_cert: Optional[x509.Certificate] = None
|
||||
|
||||
if _CA_KEY_FILE.exists() and _CA_CERT_FILE.exists():
|
||||
try:
|
||||
ca_key = _load_private_key(_CA_KEY_FILE)
|
||||
ca_cert = _load_certificate(_CA_CERT_FILE)
|
||||
if ca_cert is not None and ca_key is not None:
|
||||
expiry = _cert_not_after(ca_cert)
|
||||
subject = ca_cert.subject
|
||||
subject_cn = ""
|
||||
try:
|
||||
subject_cn = subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value # type: ignore[index]
|
||||
except Exception:
|
||||
subject_cn = ""
|
||||
try:
|
||||
basic = ca_cert.extensions.get_extension_for_class(x509.BasicConstraints).value # type: ignore[attr-defined]
|
||||
is_ca = bool(basic.ca)
|
||||
except Exception:
|
||||
is_ca = False
|
||||
if (
|
||||
expiry <= datetime.now(tz=timezone.utc)
|
||||
or not is_ca
|
||||
or subject_cn != _ROOT_COMMON_NAME
|
||||
):
|
||||
regenerated = True
|
||||
else:
|
||||
regenerated = True
|
||||
except Exception:
|
||||
regenerated = True
|
||||
else:
|
||||
regenerated = True
|
||||
|
||||
if regenerated or ca_key is None or ca_cert is None:
|
||||
ca_key = ec.generate_private_key(ec.SECP384R1())
|
||||
public_key = ca_key.public_key()
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
builder = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, _ROOT_COMMON_NAME),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
|
||||
]
|
||||
)
|
||||
)
|
||||
.issuer_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, _ROOT_COMMON_NAME),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
|
||||
]
|
||||
)
|
||||
)
|
||||
.public_key(public_key)
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now - timedelta(minutes=5))
|
||||
.not_valid_after(now + _ROOT_VALIDITY)
|
||||
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=True,
|
||||
crl_sign=True,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
)
|
||||
|
||||
builder = builder.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
|
||||
ca_cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA384())
|
||||
|
||||
_CA_KEY_FILE.write_bytes(
|
||||
ca_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
_CA_CERT_FILE.write_bytes(ca_cert.public_bytes(serialization.Encoding.PEM))
|
||||
|
||||
_tighten_permissions(_CA_KEY_FILE)
|
||||
_tighten_permissions(_CA_CERT_FILE)
|
||||
else:
|
||||
regenerated = False
|
||||
|
||||
return ca_key, ca_cert, regenerated
|
||||
|
||||
|
||||
def _server_certificate_needs_regeneration(
|
||||
server_cert: Optional[x509.Certificate],
|
||||
ca_cert: x509.Certificate,
|
||||
) -> bool:
|
||||
if server_cert is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
if server_cert.issuer != ca_cert.subject:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
try:
|
||||
expiry = _cert_not_after(server_cert)
|
||||
if expiry <= datetime.now(tz=timezone.utc):
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
try:
|
||||
basic = server_cert.extensions.get_extension_for_class(x509.BasicConstraints).value # type: ignore[attr-defined]
|
||||
if basic.ca:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
try:
|
||||
eku = server_cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value # type: ignore[attr-defined]
|
||||
if ExtendedKeyUsageOID.SERVER_AUTH not in eku:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _generate_server_certificate(
|
||||
common_name: str,
|
||||
ca_key: ec.EllipticCurvePrivateKey,
|
||||
ca_cert: x509.Certificate,
|
||||
) -> x509.Certificate:
|
||||
private_key = ec.generate_private_key(ec.SECP384R1())
|
||||
public_key = private_key.public_key()
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
ca_expiry = _cert_not_after(ca_cert)
|
||||
candidate_expiry = now + _SERVER_VALIDITY
|
||||
not_after = min(ca_expiry - timedelta(days=1), candidate_expiry)
|
||||
|
||||
builder = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
|
||||
]
|
||||
)
|
||||
)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(public_key)
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now - timedelta(minutes=5))
|
||||
.not_valid_after(not_after)
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName(
|
||||
[
|
||||
x509.DNSName("localhost"),
|
||||
x509.DNSName("127.0.0.1"),
|
||||
x509.DNSName("::1"),
|
||||
]
|
||||
),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=False,
|
||||
crl_sign=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_key.public_key()),
|
||||
critical=False,
|
||||
)
|
||||
)
|
||||
|
||||
certificate = builder.sign(private_key=ca_key, algorithm=hashes.SHA384())
|
||||
|
||||
_KEY_FILE.write_bytes(
|
||||
private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
_CERT_FILE.write_bytes(certificate.public_bytes(serialization.Encoding.PEM))
|
||||
|
||||
_tighten_permissions(_KEY_FILE)
|
||||
_tighten_permissions(_CERT_FILE)
|
||||
|
||||
return certificate
|
||||
|
||||
|
||||
def _write_bundle(server_cert: x509.Certificate, ca_cert: x509.Certificate) -> None:
|
||||
try:
|
||||
server_pem = server_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip()
|
||||
ca_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
bundle = f"{server_pem}\n{ca_pem}\n"
|
||||
_BUNDLE_FILE.write_text(bundle, encoding="utf-8")
|
||||
_tighten_permissions(_BUNDLE_FILE)
|
||||
|
||||
|
||||
def _safe_copy(src: Path, dst: Path) -> None:
|
||||
try:
|
||||
dst.write_bytes(src.read_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _tighten_permissions(path: Path) -> None:
|
||||
try:
|
||||
if os.name == "posix":
|
||||
path.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _load_private_key(path: Path) -> ec.EllipticCurvePrivateKey:
|
||||
with path.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
|
||||
def _load_certificate(path: Path) -> Optional[x509.Certificate]:
|
||||
try:
|
||||
return x509.load_pem_x509_certificate(path.read_bytes())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _cert_not_after(cert: x509.Certificate) -> datetime:
|
||||
try:
|
||||
return cert.not_valid_after_utc # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
value = cert.not_valid_after
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value
|
||||
|
||||
|
||||
def build_ssl_context() -> ssl.SSLContext:
|
||||
cert_path, key_path, bundle_path = ensure_certificate()
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||
context.load_cert_chain(certfile=str(bundle_path), keyfile=str(key_path))
|
||||
return context
|
||||
|
||||
|
||||
def certificate_paths() -> Tuple[str, str, str]:
|
||||
cert_path, key_path, bundle_path = ensure_certificate()
|
||||
return str(cert_path), str(key_path), str(bundle_path)
|
||||
@@ -1,75 +0,0 @@
|
||||
"""Script signing utilities for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from Data.Engine.integrations.crypto.keys import base64_from_spki_der
|
||||
from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path
|
||||
|
||||
__all__ = ["ScriptSigner", "load_signer"]
|
||||
|
||||
|
||||
_KEY_DIR = server_certificates_path("Code-Signing")
|
||||
_SIGNING_KEY_FILE = _KEY_DIR / "engine-script-ed25519.key"
|
||||
_SIGNING_PUB_FILE = _KEY_DIR / "engine-script-ed25519.pub"
|
||||
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-script-ed25519.key"
|
||||
_LEGACY_PUB_FILE = runtime_path("keys") / "borealis-script-ed25519.pub"
|
||||
|
||||
|
||||
class ScriptSigner:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey) -> None:
|
||||
self._private = private_key
|
||||
self._public = private_key.public_key()
|
||||
|
||||
def sign(self, payload: bytes) -> bytes:
|
||||
return self._private.sign(payload)
|
||||
|
||||
def public_spki_der(self) -> bytes:
|
||||
return self._public.public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def public_base64_spki(self) -> str:
|
||||
return base64_from_spki_der(self.public_spki_der())
|
||||
|
||||
|
||||
def load_signer() -> ScriptSigner:
|
||||
private_key = _load_or_create()
|
||||
return ScriptSigner(private_key)
|
||||
|
||||
|
||||
def _load_or_create() -> ed25519.Ed25519PrivateKey:
|
||||
ensure_server_certificates_dir("Code-Signing")
|
||||
|
||||
if _SIGNING_KEY_FILE.exists():
|
||||
with _SIGNING_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
if _LEGACY_KEY_FILE.exists():
|
||||
with _LEGACY_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
_KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_SIGNING_KEY_FILE.write_bytes(pem)
|
||||
try:
|
||||
if hasattr(_SIGNING_KEY_FILE, "chmod"):
|
||||
_SIGNING_KEY_FILE.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
pub_der = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
_SIGNING_PUB_FILE.write_bytes(pub_der)
|
||||
|
||||
return private_key
|
||||
@@ -1,15 +0,0 @@
|
||||
from .device_inventory_service import (
|
||||
DeviceDescriptionError,
|
||||
DeviceDetailsError,
|
||||
DeviceInventoryService,
|
||||
RemoteDeviceError,
|
||||
)
|
||||
from .device_view_service import DeviceViewService
|
||||
|
||||
__all__ = [
|
||||
"DeviceInventoryService",
|
||||
"RemoteDeviceError",
|
||||
"DeviceViewService",
|
||||
"DeviceDetailsError",
|
||||
"DeviceDescriptionError",
|
||||
]
|
||||
@@ -1,575 +0,0 @@
|
||||
"""Mirrors the legacy device inventory HTTP behaviour."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from Data.Engine.repositories.sqlite.device_inventory_repository import (
|
||||
SQLiteDeviceInventoryRepository,
|
||||
)
|
||||
from Data.Engine.domain.device_auth import DeviceAuthContext, normalize_guid
|
||||
from Data.Engine.domain.devices import clean_device_str, coerce_int, ts_to_human
|
||||
|
||||
__all__ = [
|
||||
"DeviceInventoryService",
|
||||
"RemoteDeviceError",
|
||||
"DeviceHeartbeatError",
|
||||
"DeviceDetailsError",
|
||||
"DeviceDescriptionError",
|
||||
]
|
||||
|
||||
|
||||
class RemoteDeviceError(Exception):
|
||||
def __init__(self, code: str, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or code)
|
||||
self.code = code
|
||||
|
||||
|
||||
class DeviceHeartbeatError(Exception):
|
||||
def __init__(self, code: str, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or code)
|
||||
self.code = code
|
||||
|
||||
|
||||
class DeviceDetailsError(Exception):
|
||||
def __init__(self, code: str, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or code)
|
||||
self.code = code
|
||||
|
||||
|
||||
class DeviceDescriptionError(Exception):
|
||||
def __init__(self, code: str, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or code)
|
||||
self.code = code
|
||||
|
||||
|
||||
class DeviceInventoryService:
|
||||
def __init__(
|
||||
self,
|
||||
repository: SQLiteDeviceInventoryRepository,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._repo = repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.devices")
|
||||
|
||||
def list_devices(self) -> List[Dict[str, object]]:
|
||||
return self._repo.fetch_devices()
|
||||
|
||||
def list_agent_devices(self) -> List[Dict[str, object]]:
|
||||
return self._repo.fetch_devices(only_agents=True)
|
||||
|
||||
def list_remote_devices(self, connection_type: str) -> List[Dict[str, object]]:
|
||||
return self._repo.fetch_devices(connection_type=connection_type)
|
||||
|
||||
def get_device_by_guid(self, guid: str) -> Optional[Dict[str, object]]:
|
||||
snapshot = self._repo.load_snapshot(guid=guid)
|
||||
if not snapshot:
|
||||
return None
|
||||
devices = self._repo.fetch_devices(hostname=snapshot.get("hostname"))
|
||||
return devices[0] if devices else None
|
||||
|
||||
def get_device_details(self, hostname: str) -> Dict[str, object]:
|
||||
normalized_host = clean_device_str(hostname)
|
||||
if not normalized_host:
|
||||
return {}
|
||||
|
||||
snapshot = self._repo.load_snapshot(hostname=normalized_host)
|
||||
if not snapshot:
|
||||
return {}
|
||||
|
||||
summary = dict(snapshot.get("summary") or {})
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"details": snapshot.get("details", {}),
|
||||
"summary": summary,
|
||||
"description": snapshot.get("description")
|
||||
or summary.get("description")
|
||||
or "",
|
||||
"created_at": snapshot.get("created_at") or 0,
|
||||
"agent_hash": snapshot.get("agent_hash")
|
||||
or summary.get("agent_hash")
|
||||
or "",
|
||||
"agent_guid": snapshot.get("agent_guid")
|
||||
or summary.get("agent_guid")
|
||||
or "",
|
||||
"memory": snapshot.get("memory", []),
|
||||
"network": snapshot.get("network", []),
|
||||
"software": snapshot.get("software", []),
|
||||
"storage": snapshot.get("storage", []),
|
||||
"cpu": snapshot.get("cpu", {}),
|
||||
"device_type": snapshot.get("device_type")
|
||||
or summary.get("device_type")
|
||||
or "",
|
||||
"domain": snapshot.get("domain")
|
||||
or summary.get("domain")
|
||||
or "",
|
||||
"external_ip": snapshot.get("external_ip")
|
||||
or summary.get("external_ip")
|
||||
or "",
|
||||
"internal_ip": snapshot.get("internal_ip")
|
||||
or summary.get("internal_ip")
|
||||
or "",
|
||||
"last_reboot": snapshot.get("last_reboot")
|
||||
or summary.get("last_reboot")
|
||||
or "",
|
||||
"last_seen": snapshot.get("last_seen")
|
||||
or summary.get("last_seen")
|
||||
or 0,
|
||||
"last_user": snapshot.get("last_user")
|
||||
or summary.get("last_user")
|
||||
or "",
|
||||
"operating_system": snapshot.get("operating_system")
|
||||
or summary.get("operating_system")
|
||||
or summary.get("agent_operating_system")
|
||||
or "",
|
||||
"uptime": snapshot.get("uptime")
|
||||
or summary.get("uptime")
|
||||
or 0,
|
||||
"agent_id": snapshot.get("agent_id")
|
||||
or summary.get("agent_id")
|
||||
or "",
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
def collect_agent_hash_records(self) -> List[Dict[str, object]]:
|
||||
records: List[Dict[str, object]] = []
|
||||
key_to_index: Dict[str, int] = {}
|
||||
|
||||
for device in self._repo.fetch_devices():
|
||||
summary = device.get("summary", {}) if isinstance(device, dict) else {}
|
||||
agent_id = (summary.get("agent_id") or "").strip()
|
||||
agent_guid = (summary.get("agent_guid") or "").strip()
|
||||
hostname = (summary.get("hostname") or device.get("hostname") or "").strip()
|
||||
agent_hash = (summary.get("agent_hash") or device.get("agent_hash") or "").strip()
|
||||
|
||||
keys: List[str] = []
|
||||
if agent_id:
|
||||
keys.append(f"id:{agent_id.lower()}")
|
||||
if agent_guid:
|
||||
keys.append(f"guid:{agent_guid.lower()}")
|
||||
if hostname:
|
||||
keys.append(f"host:{hostname.lower()}")
|
||||
|
||||
payload = {
|
||||
"agent_id": agent_id or None,
|
||||
"agent_guid": agent_guid or None,
|
||||
"hostname": hostname or None,
|
||||
"agent_hash": agent_hash or None,
|
||||
"source": "database",
|
||||
}
|
||||
|
||||
if not keys:
|
||||
records.append(payload)
|
||||
continue
|
||||
|
||||
existing_index = None
|
||||
for key in keys:
|
||||
if key in key_to_index:
|
||||
existing_index = key_to_index[key]
|
||||
break
|
||||
|
||||
if existing_index is None:
|
||||
existing_index = len(records)
|
||||
records.append(payload)
|
||||
for key in keys:
|
||||
key_to_index[key] = existing_index
|
||||
continue
|
||||
|
||||
merged = records[existing_index]
|
||||
for key in ("agent_id", "agent_guid", "hostname", "agent_hash"):
|
||||
if not merged.get(key) and payload.get(key):
|
||||
merged[key] = payload[key]
|
||||
|
||||
return records
|
||||
|
||||
def upsert_remote_device(
|
||||
self,
|
||||
connection_type: str,
|
||||
hostname: str,
|
||||
address: Optional[str],
|
||||
description: Optional[str],
|
||||
os_hint: Optional[str],
|
||||
*,
|
||||
ensure_existing_type: Optional[str],
|
||||
) -> Dict[str, object]:
|
||||
normalized_type = (connection_type or "").strip().lower()
|
||||
if not normalized_type:
|
||||
raise RemoteDeviceError("invalid_type", "connection type required")
|
||||
normalized_host = (hostname or "").strip()
|
||||
if not normalized_host:
|
||||
raise RemoteDeviceError("invalid_hostname", "hostname is required")
|
||||
|
||||
existing = self._repo.load_snapshot(hostname=normalized_host)
|
||||
existing_type = (existing or {}).get("summary", {}).get("connection_type") or ""
|
||||
existing_type = existing_type.strip().lower()
|
||||
|
||||
if ensure_existing_type and existing_type != ensure_existing_type.lower():
|
||||
raise RemoteDeviceError("not_found", "device not found")
|
||||
if ensure_existing_type is None and existing_type and existing_type != normalized_type:
|
||||
raise RemoteDeviceError("conflict", "device already exists with different connection type")
|
||||
|
||||
created_ts = None
|
||||
if existing:
|
||||
created_ts = existing.get("summary", {}).get("created_at")
|
||||
|
||||
endpoint = (address or "").strip() or (existing or {}).get("summary", {}).get("connection_endpoint") or ""
|
||||
if not endpoint:
|
||||
raise RemoteDeviceError("address_required", "address is required")
|
||||
|
||||
description_val = description if description is not None else (existing or {}).get("summary", {}).get("description")
|
||||
os_value = os_hint or (existing or {}).get("summary", {}).get("operating_system")
|
||||
os_value = (os_value or "").strip()
|
||||
|
||||
device_type_label = "SSH Remote" if normalized_type == "ssh" else "WinRM Remote"
|
||||
|
||||
summary_payload = {
|
||||
"connection_type": normalized_type,
|
||||
"connection_endpoint": endpoint,
|
||||
"internal_ip": endpoint,
|
||||
"external_ip": endpoint,
|
||||
"device_type": device_type_label,
|
||||
"operating_system": os_value or "",
|
||||
"last_seen": 0,
|
||||
"description": (description_val or ""),
|
||||
}
|
||||
|
||||
try:
|
||||
self._repo.upsert_device(
|
||||
normalized_host,
|
||||
description_val,
|
||||
{"summary": summary_payload},
|
||||
created_ts,
|
||||
)
|
||||
except sqlite3.DatabaseError as exc: # type: ignore[name-defined]
|
||||
raise RemoteDeviceError("storage_error", str(exc)) from exc
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
raise RemoteDeviceError("storage_error", str(exc)) from exc
|
||||
|
||||
devices = self._repo.fetch_devices(hostname=normalized_host)
|
||||
if not devices:
|
||||
raise RemoteDeviceError("reload_failed", "failed to load device after upsert")
|
||||
return devices[0]
|
||||
|
||||
def delete_remote_device(self, connection_type: str, hostname: str) -> None:
|
||||
normalized_host = (hostname or "").strip()
|
||||
if not normalized_host:
|
||||
raise RemoteDeviceError("invalid_hostname", "invalid hostname")
|
||||
existing = self._repo.load_snapshot(hostname=normalized_host)
|
||||
if not existing:
|
||||
raise RemoteDeviceError("not_found", "device not found")
|
||||
existing_type = (existing.get("summary", {}) or {}).get("connection_type") or ""
|
||||
if (existing_type or "").strip().lower() != (connection_type or "").strip().lower():
|
||||
raise RemoteDeviceError("not_found", "device not found")
|
||||
self._repo.delete_device_by_hostname(normalized_host)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent heartbeats
|
||||
# ------------------------------------------------------------------
|
||||
def record_heartbeat(
|
||||
self,
|
||||
*,
|
||||
context: DeviceAuthContext,
|
||||
payload: Mapping[str, Any],
|
||||
) -> None:
|
||||
guid = context.identity.guid.value
|
||||
snapshot = self._repo.load_snapshot(guid=guid)
|
||||
if not snapshot:
|
||||
raise DeviceHeartbeatError("device_not_registered", "device not registered")
|
||||
|
||||
summary = dict(snapshot.get("summary") or {})
|
||||
details = dict(snapshot.get("details") or {})
|
||||
|
||||
now_ts = int(time.time())
|
||||
summary["last_seen"] = now_ts
|
||||
summary["agent_guid"] = guid
|
||||
|
||||
existing_hostname = clean_device_str(summary.get("hostname")) or clean_device_str(
|
||||
snapshot.get("hostname")
|
||||
)
|
||||
incoming_hostname = clean_device_str(payload.get("hostname"))
|
||||
raw_metrics = payload.get("metrics")
|
||||
metrics = raw_metrics if isinstance(raw_metrics, Mapping) else {}
|
||||
metrics_hostname = clean_device_str(metrics.get("hostname")) if metrics else None
|
||||
hostname = incoming_hostname or metrics_hostname or existing_hostname
|
||||
if not hostname:
|
||||
hostname = f"RECOVERED-{guid[:12]}"
|
||||
summary["hostname"] = hostname
|
||||
|
||||
if metrics:
|
||||
last_user = metrics.get("last_user") or metrics.get("username") or metrics.get("user")
|
||||
if last_user:
|
||||
cleaned_user = clean_device_str(last_user)
|
||||
if cleaned_user:
|
||||
summary["last_user"] = cleaned_user
|
||||
operating_system = metrics.get("operating_system")
|
||||
if operating_system:
|
||||
cleaned_os = clean_device_str(operating_system)
|
||||
if cleaned_os:
|
||||
summary["operating_system"] = cleaned_os
|
||||
uptime = metrics.get("uptime")
|
||||
if uptime is not None:
|
||||
coerced = coerce_int(uptime)
|
||||
if coerced is not None:
|
||||
summary["uptime"] = coerced
|
||||
agent_id = metrics.get("agent_id")
|
||||
if agent_id:
|
||||
cleaned_agent = clean_device_str(agent_id)
|
||||
if cleaned_agent:
|
||||
summary["agent_id"] = cleaned_agent
|
||||
|
||||
for field in ("external_ip", "internal_ip", "device_type"):
|
||||
value = payload.get(field)
|
||||
cleaned = clean_device_str(value)
|
||||
if cleaned:
|
||||
summary[field] = cleaned
|
||||
|
||||
summary.setdefault("description", summary.get("description") or "")
|
||||
created_at = coerce_int(summary.get("created_at"))
|
||||
if created_at is None:
|
||||
created_at = coerce_int(snapshot.get("created_at"))
|
||||
if created_at is None:
|
||||
created_at = now_ts
|
||||
summary["created_at"] = created_at
|
||||
|
||||
raw_inventory = payload.get("inventory")
|
||||
inventory = raw_inventory if isinstance(raw_inventory, Mapping) else {}
|
||||
memory = inventory.get("memory") if isinstance(inventory.get("memory"), list) else details.get("memory")
|
||||
network = inventory.get("network") if isinstance(inventory.get("network"), list) else details.get("network")
|
||||
software = (
|
||||
inventory.get("software") if isinstance(inventory.get("software"), list) else details.get("software")
|
||||
)
|
||||
storage = inventory.get("storage") if isinstance(inventory.get("storage"), list) else details.get("storage")
|
||||
cpu = inventory.get("cpu") if isinstance(inventory.get("cpu"), Mapping) else details.get("cpu")
|
||||
|
||||
merged_details: Dict[str, Any] = {
|
||||
"summary": summary,
|
||||
"memory": memory,
|
||||
"network": network,
|
||||
"software": software,
|
||||
"storage": storage,
|
||||
"cpu": cpu,
|
||||
}
|
||||
|
||||
try:
|
||||
self._repo.upsert_device(
|
||||
summary["hostname"],
|
||||
summary.get("description"),
|
||||
merged_details,
|
||||
summary.get("created_at"),
|
||||
agent_hash=clean_device_str(summary.get("agent_hash")),
|
||||
guid=guid,
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
self._log.warning(
|
||||
"device-heartbeat-conflict guid=%s hostname=%s error=%s",
|
||||
guid,
|
||||
summary["hostname"],
|
||||
exc,
|
||||
)
|
||||
raise DeviceHeartbeatError("storage_conflict", str(exc)) from exc
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
self._log.exception(
|
||||
"device-heartbeat-failure guid=%s hostname=%s",
|
||||
guid,
|
||||
summary["hostname"],
|
||||
exc_info=exc,
|
||||
)
|
||||
raise DeviceHeartbeatError("storage_error", "failed to persist heartbeat") from exc
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent details
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _is_empty(value: Any) -> bool:
|
||||
return value in (None, "", [], {})
|
||||
|
||||
@classmethod
|
||||
def _deep_merge_preserve(cls, prev: Dict[str, Any], incoming: Dict[str, Any]) -> Dict[str, Any]:
|
||||
merged: Dict[str, Any] = dict(prev or {})
|
||||
for key, value in (incoming or {}).items():
|
||||
if isinstance(value, Mapping):
|
||||
existing = merged.get(key)
|
||||
if not isinstance(existing, Mapping):
|
||||
existing = {}
|
||||
merged[key] = cls._deep_merge_preserve(dict(existing), dict(value))
|
||||
elif isinstance(value, list):
|
||||
if value:
|
||||
merged[key] = value
|
||||
else:
|
||||
if cls._is_empty(value):
|
||||
continue
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
def save_agent_details(
|
||||
self,
|
||||
*,
|
||||
context: DeviceAuthContext,
|
||||
payload: Mapping[str, Any],
|
||||
) -> None:
|
||||
hostname = clean_device_str(payload.get("hostname"))
|
||||
details_raw = payload.get("details")
|
||||
agent_id = clean_device_str(payload.get("agent_id"))
|
||||
agent_hash = clean_device_str(payload.get("agent_hash"))
|
||||
|
||||
if not isinstance(details_raw, Mapping):
|
||||
raise DeviceDetailsError("invalid_payload", "details object required")
|
||||
|
||||
details_dict: Dict[str, Any]
|
||||
try:
|
||||
details_dict = json.loads(json.dumps(details_raw))
|
||||
except Exception:
|
||||
details_dict = dict(details_raw)
|
||||
|
||||
incoming_summary = dict(details_dict.get("summary") or {})
|
||||
if not hostname:
|
||||
hostname = clean_device_str(incoming_summary.get("hostname"))
|
||||
if not hostname:
|
||||
raise DeviceDetailsError("invalid_payload", "hostname required")
|
||||
|
||||
snapshot = self._repo.load_snapshot(hostname=hostname)
|
||||
if not snapshot:
|
||||
snapshot = {}
|
||||
|
||||
previous_details = snapshot.get("details")
|
||||
if isinstance(previous_details, Mapping):
|
||||
try:
|
||||
prev_details = json.loads(json.dumps(previous_details))
|
||||
except Exception:
|
||||
prev_details = dict(previous_details)
|
||||
else:
|
||||
prev_details = {}
|
||||
|
||||
prev_summary = dict(prev_details.get("summary") or {})
|
||||
|
||||
existing_guid = clean_device_str(snapshot.get("guid") or snapshot.get("summary", {}).get("agent_guid"))
|
||||
normalized_existing_guid = normalize_guid(existing_guid)
|
||||
auth_guid = context.identity.guid.value
|
||||
|
||||
if normalized_existing_guid and normalized_existing_guid != auth_guid:
|
||||
raise DeviceDetailsError("guid_mismatch", "device guid mismatch")
|
||||
|
||||
fingerprint = context.identity.fingerprint.value.lower()
|
||||
stored_fp = clean_device_str(snapshot.get("summary", {}).get("ssl_key_fingerprint"))
|
||||
if stored_fp and stored_fp.lower() != fingerprint:
|
||||
raise DeviceDetailsError("fingerprint_mismatch", "device fingerprint mismatch")
|
||||
|
||||
incoming_summary.setdefault("hostname", hostname)
|
||||
if agent_id and not incoming_summary.get("agent_id"):
|
||||
incoming_summary["agent_id"] = agent_id
|
||||
if agent_hash:
|
||||
incoming_summary["agent_hash"] = agent_hash
|
||||
incoming_summary["agent_guid"] = auth_guid
|
||||
if fingerprint:
|
||||
incoming_summary["ssl_key_fingerprint"] = fingerprint
|
||||
if not incoming_summary.get("last_seen") and prev_summary.get("last_seen"):
|
||||
incoming_summary["last_seen"] = prev_summary.get("last_seen")
|
||||
|
||||
details_dict["summary"] = incoming_summary
|
||||
merged_details = self._deep_merge_preserve(prev_details, details_dict)
|
||||
merged_summary = merged_details.setdefault("summary", {})
|
||||
|
||||
if not merged_summary.get("last_user") and prev_summary.get("last_user"):
|
||||
merged_summary["last_user"] = prev_summary.get("last_user")
|
||||
|
||||
created_at = coerce_int(merged_summary.get("created_at"))
|
||||
if created_at is None:
|
||||
created_at = coerce_int(snapshot.get("created_at"))
|
||||
if created_at is None:
|
||||
created_at = int(time.time())
|
||||
merged_summary["created_at"] = created_at
|
||||
if not merged_summary.get("created"):
|
||||
merged_summary["created"] = ts_to_human(created_at)
|
||||
|
||||
if fingerprint:
|
||||
merged_summary["ssl_key_fingerprint"] = fingerprint
|
||||
if not merged_summary.get("key_added_at"):
|
||||
merged_summary["key_added_at"] = datetime.now(timezone.utc).isoformat()
|
||||
if merged_summary.get("token_version") is None:
|
||||
merged_summary["token_version"] = 1
|
||||
if not merged_summary.get("status") and snapshot.get("summary", {}).get("status"):
|
||||
merged_summary["status"] = snapshot.get("summary", {}).get("status")
|
||||
uptime_val = merged_summary.get("uptime")
|
||||
if merged_summary.get("uptime_sec") is None and uptime_val is not None:
|
||||
coerced = coerce_int(uptime_val)
|
||||
if coerced is not None:
|
||||
merged_summary["uptime_sec"] = coerced
|
||||
merged_summary.setdefault("uptime_seconds", coerced)
|
||||
if merged_summary.get("uptime_seconds") is None and merged_summary.get("uptime_sec") is not None:
|
||||
merged_summary["uptime_seconds"] = merged_summary.get("uptime_sec")
|
||||
|
||||
description = clean_device_str(merged_summary.get("description"))
|
||||
existing_description = snapshot.get("description") if snapshot else ""
|
||||
description_to_store = description if description is not None else (existing_description or "")
|
||||
|
||||
existing_hash = clean_device_str(snapshot.get("agent_hash") or snapshot.get("summary", {}).get("agent_hash"))
|
||||
effective_hash = agent_hash or existing_hash
|
||||
|
||||
try:
|
||||
self._repo.upsert_device(
|
||||
hostname,
|
||||
description_to_store,
|
||||
merged_details,
|
||||
created_at,
|
||||
agent_hash=effective_hash,
|
||||
guid=auth_guid,
|
||||
)
|
||||
except sqlite3.DatabaseError as exc:
|
||||
raise DeviceDetailsError("storage_error", str(exc)) from exc
|
||||
|
||||
added_at = merged_summary.get("key_added_at") or datetime.now(timezone.utc).isoformat()
|
||||
self._repo.record_device_fingerprint(auth_guid, fingerprint, added_at)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Description management
|
||||
# ------------------------------------------------------------------
|
||||
def update_device_description(self, hostname: str, description: Optional[str]) -> None:
|
||||
normalized_host = clean_device_str(hostname)
|
||||
if not normalized_host:
|
||||
raise DeviceDescriptionError("invalid_hostname", "invalid hostname")
|
||||
|
||||
snapshot = self._repo.load_snapshot(hostname=normalized_host)
|
||||
if not snapshot:
|
||||
raise DeviceDescriptionError("not_found", "device not found")
|
||||
|
||||
details = snapshot.get("details")
|
||||
if isinstance(details, Mapping):
|
||||
try:
|
||||
existing = json.loads(json.dumps(details))
|
||||
except Exception:
|
||||
existing = dict(details)
|
||||
else:
|
||||
existing = {}
|
||||
|
||||
summary = dict(existing.get("summary") or {})
|
||||
summary["description"] = description or ""
|
||||
existing["summary"] = summary
|
||||
|
||||
created_at = coerce_int(summary.get("created_at"))
|
||||
if created_at is None:
|
||||
created_at = coerce_int(snapshot.get("created_at"))
|
||||
if created_at is None:
|
||||
created_at = int(time.time())
|
||||
|
||||
agent_hash = clean_device_str(summary.get("agent_hash") or snapshot.get("agent_hash"))
|
||||
guid = clean_device_str(summary.get("agent_guid") or snapshot.get("guid"))
|
||||
|
||||
try:
|
||||
self._repo.upsert_device(
|
||||
normalized_host,
|
||||
description or (snapshot.get("description") or ""),
|
||||
existing,
|
||||
created_at,
|
||||
agent_hash=agent_hash,
|
||||
guid=guid,
|
||||
)
|
||||
except sqlite3.DatabaseError as exc:
|
||||
raise DeviceDescriptionError("storage_error", str(exc)) from exc
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Service exposing CRUD for saved device list views."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from Data.Engine.domain.device_views import DeviceListView
|
||||
from Data.Engine.repositories.sqlite.device_view_repository import SQLiteDeviceViewRepository
|
||||
|
||||
__all__ = ["DeviceViewService"]
|
||||
|
||||
|
||||
class DeviceViewService:
|
||||
def __init__(
|
||||
self,
|
||||
repository: SQLiteDeviceViewRepository,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._repo = repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.device_views")
|
||||
|
||||
def list_views(self) -> List[DeviceListView]:
|
||||
return self._repo.list_views()
|
||||
|
||||
def get_view(self, view_id: int) -> Optional[DeviceListView]:
|
||||
return self._repo.get_view(view_id)
|
||||
|
||||
def create_view(self, name: str, columns: List[str], filters: dict) -> DeviceListView:
|
||||
normalized_name = (name or "").strip()
|
||||
if not normalized_name:
|
||||
raise ValueError("missing_name")
|
||||
if normalized_name.lower() == "default view":
|
||||
raise ValueError("reserved")
|
||||
return self._repo.create_view(normalized_name, list(columns), dict(filters))
|
||||
|
||||
def update_view(
|
||||
self,
|
||||
view_id: int,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
filters: Optional[dict] = None,
|
||||
) -> DeviceListView:
|
||||
updates: dict = {}
|
||||
if name is not None:
|
||||
normalized = (name or "").strip()
|
||||
if not normalized:
|
||||
raise ValueError("missing_name")
|
||||
if normalized.lower() == "default view":
|
||||
raise ValueError("reserved")
|
||||
updates["name"] = normalized
|
||||
if columns is not None:
|
||||
if not isinstance(columns, list) or not all(isinstance(col, str) for col in columns):
|
||||
raise ValueError("invalid_columns")
|
||||
updates["columns"] = list(columns)
|
||||
if filters is not None:
|
||||
if not isinstance(filters, dict):
|
||||
raise ValueError("invalid_filters")
|
||||
updates["filters"] = dict(filters)
|
||||
if not updates:
|
||||
raise ValueError("no_fields")
|
||||
return self._repo.update_view(
|
||||
view_id,
|
||||
name=updates.get("name"),
|
||||
columns=updates.get("columns"),
|
||||
filters=updates.get("filters"),
|
||||
)
|
||||
|
||||
def delete_view(self, view_id: int) -> bool:
|
||||
return self._repo.delete_view(view_id)
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
"""Enrollment services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentService",
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"PollingResult",
|
||||
"EnrollmentValidationError",
|
||||
"EnrollmentAdminService",
|
||||
]
|
||||
|
||||
_LAZY: dict[str, tuple[str, str]] = {
|
||||
"EnrollmentService": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentService"),
|
||||
"EnrollmentRequestResult": (
|
||||
"Data.Engine.services.enrollment.enrollment_service",
|
||||
"EnrollmentRequestResult",
|
||||
),
|
||||
"EnrollmentStatus": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentStatus"),
|
||||
"EnrollmentTokenBundle": (
|
||||
"Data.Engine.services.enrollment.enrollment_service",
|
||||
"EnrollmentTokenBundle",
|
||||
),
|
||||
"PollingResult": ("Data.Engine.services.enrollment.enrollment_service", "PollingResult"),
|
||||
"EnrollmentValidationError": (
|
||||
"Data.Engine.domain.device_enrollment",
|
||||
"EnrollmentValidationError",
|
||||
),
|
||||
"EnrollmentAdminService": (
|
||||
"Data.Engine.services.enrollment.admin_service",
|
||||
"EnrollmentAdminService",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
try:
|
||||
module_name, attribute = _LAZY[name]
|
||||
except KeyError as exc: # pragma: no cover - defensive
|
||||
raise AttributeError(name) from exc
|
||||
|
||||
module = import_module(module_name)
|
||||
value = getattr(module, attribute)
|
||||
globals()[name] = value
|
||||
return value
|
||||
|
||||
|
||||
def __dir__() -> list[str]: # pragma: no cover - interactive helper
|
||||
return sorted(set(__all__))
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
"""Administrative helpers for enrollment workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from Data.Engine.domain.device_auth import DeviceGuid, normalize_guid
|
||||
from Data.Engine.domain.device_enrollment import EnrollmentApprovalStatus
|
||||
from Data.Engine.domain.enrollment_admin import DeviceApprovalRecord, EnrollmentCodeRecord
|
||||
from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository
|
||||
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
|
||||
|
||||
__all__ = ["EnrollmentAdminService", "DeviceApprovalActionResult"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceApprovalActionResult:
|
||||
"""Outcome metadata returned after mutating an approval."""
|
||||
|
||||
status: str
|
||||
conflict_resolution: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
payload = {"status": self.status}
|
||||
if self.conflict_resolution:
|
||||
payload["conflict_resolution"] = self.conflict_resolution
|
||||
return payload
|
||||
|
||||
|
||||
class EnrollmentAdminService:
|
||||
"""Expose administrative enrollment operations."""
|
||||
|
||||
_VALID_TTL_HOURS = {1, 3, 6, 12, 24}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: SQLiteEnrollmentRepository,
|
||||
user_repository: SQLiteUserRepository,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
clock: Optional[Callable[[], datetime]] = None,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._users = user_repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.enrollment_admin")
|
||||
self._clock = clock or (lambda: datetime.now(tz=timezone.utc))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Enrollment install codes
|
||||
# ------------------------------------------------------------------
|
||||
def list_install_codes(self, *, status: Optional[str] = None) -> List[EnrollmentCodeRecord]:
|
||||
return self._repository.list_install_codes(status=status, now=self._clock())
|
||||
|
||||
def create_install_code(
|
||||
self,
|
||||
*,
|
||||
ttl_hours: int,
|
||||
max_uses: int,
|
||||
created_by: Optional[str],
|
||||
) -> EnrollmentCodeRecord:
|
||||
if ttl_hours not in self._VALID_TTL_HOURS:
|
||||
raise ValueError("invalid_ttl")
|
||||
|
||||
normalized_max = self._normalize_max_uses(max_uses)
|
||||
|
||||
now = self._clock()
|
||||
expires_at = now + timedelta(hours=ttl_hours)
|
||||
record_id = str(uuid.uuid4())
|
||||
code = self._generate_install_code()
|
||||
|
||||
created_by_identifier = None
|
||||
if created_by:
|
||||
created_by_identifier = self._users.resolve_identifier(created_by)
|
||||
if not created_by_identifier:
|
||||
created_by_identifier = created_by.strip() or None
|
||||
|
||||
record = self._repository.insert_install_code(
|
||||
record_id=record_id,
|
||||
code=code,
|
||||
expires_at=expires_at,
|
||||
created_by=created_by_identifier,
|
||||
max_uses=normalized_max,
|
||||
)
|
||||
|
||||
self._log.info(
|
||||
"install code created id=%s ttl=%sh max_uses=%s",
|
||||
record.record_id,
|
||||
ttl_hours,
|
||||
normalized_max,
|
||||
)
|
||||
|
||||
return record
|
||||
|
||||
def delete_install_code(self, record_id: str) -> bool:
|
||||
deleted = self._repository.delete_install_code_if_unused(record_id)
|
||||
if deleted:
|
||||
self._log.info("install code deleted id=%s", record_id)
|
||||
return deleted
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Device approvals
|
||||
# ------------------------------------------------------------------
|
||||
def list_device_approvals(self, *, status: Optional[str] = None) -> List[DeviceApprovalRecord]:
|
||||
return self._repository.list_device_approvals(status=status)
|
||||
|
||||
def approve_device_approval(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
actor: Optional[str],
|
||||
guid: Optional[str] = None,
|
||||
conflict_resolution: Optional[str] = None,
|
||||
) -> DeviceApprovalActionResult:
|
||||
return self._set_device_approval_status(
|
||||
record_id,
|
||||
EnrollmentApprovalStatus.APPROVED,
|
||||
actor=actor,
|
||||
guid=guid,
|
||||
conflict_resolution=conflict_resolution,
|
||||
)
|
||||
|
||||
def deny_device_approval(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
actor: Optional[str],
|
||||
) -> DeviceApprovalActionResult:
|
||||
return self._set_device_approval_status(
|
||||
record_id,
|
||||
EnrollmentApprovalStatus.DENIED,
|
||||
actor=actor,
|
||||
guid=None,
|
||||
conflict_resolution=None,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _generate_install_code() -> str:
|
||||
raw = secrets.token_hex(16).upper()
|
||||
return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))
|
||||
|
||||
@staticmethod
|
||||
def _normalize_max_uses(value: int) -> int:
|
||||
try:
|
||||
count = int(value)
|
||||
except Exception:
|
||||
count = 2
|
||||
if count < 1:
|
||||
return 1
|
||||
if count > 10:
|
||||
return 10
|
||||
return count
|
||||
|
||||
def _set_device_approval_status(
|
||||
self,
|
||||
record_id: str,
|
||||
status: EnrollmentApprovalStatus,
|
||||
*,
|
||||
actor: Optional[str],
|
||||
guid: Optional[str],
|
||||
conflict_resolution: Optional[str],
|
||||
) -> DeviceApprovalActionResult:
|
||||
approval = self._repository.fetch_device_approval(record_id)
|
||||
if approval is None:
|
||||
raise LookupError("not_found")
|
||||
|
||||
if approval.status is not EnrollmentApprovalStatus.PENDING:
|
||||
raise ValueError("approval_not_pending")
|
||||
|
||||
normalized_guid = normalize_guid(guid) or (approval.guid.value if approval.guid else "")
|
||||
resolution_normalized = (conflict_resolution or "").strip().lower() or None
|
||||
|
||||
fingerprint_match = False
|
||||
conflict_guid: Optional[str] = None
|
||||
|
||||
if status is EnrollmentApprovalStatus.APPROVED:
|
||||
pending_records = self._repository.list_device_approvals(status="pending")
|
||||
current_record = next(
|
||||
(record for record in pending_records if record.record_id == approval.record_id),
|
||||
None,
|
||||
)
|
||||
|
||||
conflict = current_record.hostname_conflict if current_record else None
|
||||
if conflict:
|
||||
conflict_guid = normalize_guid(conflict.guid)
|
||||
fingerprint_match = bool(conflict.fingerprint_match)
|
||||
|
||||
if fingerprint_match:
|
||||
normalized_guid = conflict_guid or normalized_guid or ""
|
||||
if resolution_normalized is None:
|
||||
resolution_normalized = "auto_merge_fingerprint"
|
||||
elif resolution_normalized == "overwrite":
|
||||
normalized_guid = conflict_guid or normalized_guid or ""
|
||||
elif resolution_normalized == "coexist":
|
||||
pass
|
||||
else:
|
||||
raise ValueError("conflict_resolution_required")
|
||||
|
||||
if normalized_guid:
|
||||
try:
|
||||
guid_value = DeviceGuid(normalized_guid)
|
||||
except ValueError as exc:
|
||||
raise ValueError("invalid_guid") from exc
|
||||
else:
|
||||
guid_value = None
|
||||
|
||||
actor_identifier = None
|
||||
if actor:
|
||||
actor_identifier = self._users.resolve_identifier(actor)
|
||||
if not actor_identifier:
|
||||
actor_identifier = actor.strip() or None
|
||||
if not actor_identifier:
|
||||
actor_identifier = "system"
|
||||
|
||||
self._repository.update_device_approval_status(
|
||||
approval.record_id,
|
||||
status=status,
|
||||
updated_at=self._clock(),
|
||||
approved_by=actor_identifier,
|
||||
guid=guid_value,
|
||||
)
|
||||
|
||||
if status is EnrollmentApprovalStatus.APPROVED:
|
||||
self._log.info(
|
||||
"device approval %s approved resolution=%s guid=%s",
|
||||
approval.record_id,
|
||||
resolution_normalized or "",
|
||||
guid_value.value if guid_value else normalized_guid or "",
|
||||
)
|
||||
else:
|
||||
self._log.info("device approval %s denied", approval.record_id)
|
||||
|
||||
return DeviceApprovalActionResult(
|
||||
status=status.value,
|
||||
conflict_resolution=resolution_normalized,
|
||||
)
|
||||
|
||||
@@ -1,487 +0,0 @@
|
||||
"""Enrollment workflow orchestration for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable, Optional, Protocol
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
from Data.Engine.builders.device_enrollment import EnrollmentRequestInput
|
||||
from Data.Engine.domain.device_auth import (
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
sanitize_service_context,
|
||||
)
|
||||
from Data.Engine.domain.device_enrollment import (
|
||||
EnrollmentApproval,
|
||||
EnrollmentApprovalStatus,
|
||||
EnrollmentCode,
|
||||
EnrollmentValidationError,
|
||||
)
|
||||
from Data.Engine.services.auth.device_auth_service import DeviceRecord
|
||||
from Data.Engine.services.auth.token_service import JWTIssuer
|
||||
from Data.Engine.services.enrollment.nonce_cache import NonceCache
|
||||
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentService",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"PollingResult",
|
||||
]
|
||||
|
||||
|
||||
class DeviceRepository(Protocol):
|
||||
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def ensure_device_record(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
hostname: str,
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> DeviceRecord: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def record_device_key(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
added_at: datetime,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class EnrollmentRepository(Protocol):
|
||||
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_install_code_usage(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
use_count_increment: int,
|
||||
last_used_at: datetime,
|
||||
used_by_guid: Optional[DeviceGuid] = None,
|
||||
mark_first_use: bool = False,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_pending_approval_by_fingerprint(
|
||||
self, fingerprint: DeviceFingerprint
|
||||
) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_pending_approval(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
hostname: str,
|
||||
guid: Optional[DeviceGuid],
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
updated_at: datetime,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def create_device_approval(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
reference: str,
|
||||
claimed_hostname: str,
|
||||
claimed_fingerprint: DeviceFingerprint,
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
created_at: datetime,
|
||||
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> EnrollmentApproval: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_device_approval_status(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
status: EnrollmentApprovalStatus,
|
||||
updated_at: datetime,
|
||||
approved_by: Optional[str] = None,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class RefreshTokenRepository(Protocol):
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
guid: DeviceGuid,
|
||||
token_hash: str,
|
||||
created_at: datetime,
|
||||
expires_at: Optional[datetime],
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class ScriptSigner(Protocol):
|
||||
def public_base64_spki(self) -> str: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class EnrollmentStatus(str):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
EXPIRED = "expired"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentTokenBundle:
|
||||
guid: DeviceGuid
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
token_type: str = "Bearer"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentRequestResult:
|
||||
status: EnrollmentStatus
|
||||
server_certificate: str
|
||||
signing_key: str
|
||||
approval_reference: Optional[str] = None
|
||||
server_nonce: Optional[str] = None
|
||||
poll_after_ms: Optional[int] = None
|
||||
http_status: int = 200
|
||||
retry_after: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PollingResult:
|
||||
status: EnrollmentStatus
|
||||
http_status: int
|
||||
poll_after_ms: Optional[int] = None
|
||||
reason: Optional[str] = None
|
||||
detail: Optional[str] = None
|
||||
tokens: Optional[EnrollmentTokenBundle] = None
|
||||
server_certificate: Optional[str] = None
|
||||
signing_key: Optional[str] = None
|
||||
|
||||
|
||||
class EnrollmentService:
|
||||
"""Coordinate the Borealis device enrollment handshake."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_repository: DeviceRepository,
|
||||
enrollment_repository: EnrollmentRepository,
|
||||
token_repository: RefreshTokenRepository,
|
||||
jwt_service: JWTIssuer,
|
||||
tls_bundle_loader: Callable[[], str],
|
||||
ip_rate_limiter: SlidingWindowRateLimiter,
|
||||
fingerprint_rate_limiter: SlidingWindowRateLimiter,
|
||||
nonce_cache: NonceCache,
|
||||
script_signer: Optional[ScriptSigner] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._devices = device_repository
|
||||
self._enrollment = enrollment_repository
|
||||
self._tokens = token_repository
|
||||
self._jwt = jwt_service
|
||||
self._load_tls_bundle = tls_bundle_loader
|
||||
self._ip_rate_limiter = ip_rate_limiter
|
||||
self._fp_rate_limiter = fingerprint_rate_limiter
|
||||
self._nonce_cache = nonce_cache
|
||||
self._signer = script_signer
|
||||
self._log = logger or logging.getLogger("borealis.engine.enrollment")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
def request_enrollment(
|
||||
self,
|
||||
payload: EnrollmentRequestInput,
|
||||
*,
|
||||
remote_addr: str,
|
||||
) -> EnrollmentRequestResult:
|
||||
context_hint = sanitize_service_context(payload.service_context)
|
||||
self._log.info(
|
||||
"enrollment-request ip=%s host=%s code_mask=%s", remote_addr, payload.hostname, self._mask_code(payload.enrollment_code)
|
||||
)
|
||||
|
||||
self._enforce_rate_limit(self._ip_rate_limiter, f"ip:{remote_addr}")
|
||||
self._enforce_rate_limit(self._fp_rate_limiter, f"fp:{payload.fingerprint.value}")
|
||||
|
||||
install_code = self._enrollment.fetch_install_code(payload.enrollment_code)
|
||||
reuse_guid = self._determine_reuse_guid(install_code, payload.fingerprint)
|
||||
|
||||
server_nonce_bytes = secrets.token_bytes(32)
|
||||
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
|
||||
|
||||
now = self._now()
|
||||
approval = self._enrollment.fetch_pending_approval_by_fingerprint(payload.fingerprint)
|
||||
if approval:
|
||||
self._enrollment.update_pending_approval(
|
||||
approval.record_id,
|
||||
hostname=payload.hostname,
|
||||
guid=reuse_guid,
|
||||
enrollment_code_id=install_code.identifier if install_code else None,
|
||||
client_nonce_b64=payload.client_nonce_b64,
|
||||
server_nonce_b64=server_nonce_b64,
|
||||
agent_pubkey_der=payload.agent_public_key_der,
|
||||
updated_at=now,
|
||||
)
|
||||
approval_reference = approval.reference
|
||||
else:
|
||||
record_id = str(uuid.uuid4())
|
||||
approval_reference = str(uuid.uuid4())
|
||||
approval = self._enrollment.create_device_approval(
|
||||
record_id=record_id,
|
||||
reference=approval_reference,
|
||||
claimed_hostname=payload.hostname,
|
||||
claimed_fingerprint=payload.fingerprint,
|
||||
enrollment_code_id=install_code.identifier if install_code else None,
|
||||
client_nonce_b64=payload.client_nonce_b64,
|
||||
server_nonce_b64=server_nonce_b64,
|
||||
agent_pubkey_der=payload.agent_public_key_der,
|
||||
created_at=now,
|
||||
guid=reuse_guid,
|
||||
)
|
||||
|
||||
signing_key = self._signer.public_base64_spki() if self._signer else ""
|
||||
certificate = self._load_tls_bundle()
|
||||
|
||||
return EnrollmentRequestResult(
|
||||
status=EnrollmentStatus.PENDING,
|
||||
approval_reference=approval.reference,
|
||||
server_nonce=server_nonce_b64,
|
||||
poll_after_ms=3000,
|
||||
server_certificate=certificate,
|
||||
signing_key=signing_key,
|
||||
)
|
||||
|
||||
def poll_enrollment(
|
||||
self,
|
||||
*,
|
||||
approval_reference: str,
|
||||
client_nonce_b64: str,
|
||||
proof_signature_b64: str,
|
||||
) -> PollingResult:
|
||||
if not approval_reference:
|
||||
raise EnrollmentValidationError("approval_reference_required")
|
||||
if not client_nonce_b64:
|
||||
raise EnrollmentValidationError("client_nonce_required")
|
||||
if not proof_signature_b64:
|
||||
raise EnrollmentValidationError("proof_sig_required")
|
||||
|
||||
approval = self._enrollment.fetch_device_approval_by_reference(approval_reference)
|
||||
if approval is None:
|
||||
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=404)
|
||||
|
||||
client_nonce = self._decode_base64(client_nonce_b64, "invalid_client_nonce")
|
||||
server_nonce = self._decode_base64(approval.server_nonce_b64, "server_nonce_invalid")
|
||||
proof_sig = self._decode_base64(proof_signature_b64, "invalid_proof_sig")
|
||||
|
||||
if approval.client_nonce_b64 != client_nonce_b64:
|
||||
raise EnrollmentValidationError("nonce_mismatch")
|
||||
|
||||
self._verify_proof_signature(
|
||||
approval=approval,
|
||||
client_nonce=client_nonce,
|
||||
server_nonce=server_nonce,
|
||||
signature=proof_sig,
|
||||
)
|
||||
|
||||
status = approval.status
|
||||
if status is EnrollmentApprovalStatus.PENDING:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.PENDING,
|
||||
http_status=200,
|
||||
poll_after_ms=5000,
|
||||
)
|
||||
if status is EnrollmentApprovalStatus.DENIED:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.DENIED,
|
||||
http_status=200,
|
||||
reason="operator_denied",
|
||||
)
|
||||
if status is EnrollmentApprovalStatus.EXPIRED:
|
||||
return PollingResult(status=EnrollmentStatus.EXPIRED, http_status=200)
|
||||
if status is EnrollmentApprovalStatus.COMPLETED:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.APPROVED,
|
||||
http_status=200,
|
||||
detail="finalized",
|
||||
)
|
||||
if status is not EnrollmentApprovalStatus.APPROVED:
|
||||
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=400)
|
||||
|
||||
nonce_key = f"{approval.reference}:{proof_signature_b64}"
|
||||
if not self._nonce_cache.consume(nonce_key):
|
||||
raise EnrollmentValidationError("proof_replayed", http_status=409)
|
||||
|
||||
token_bundle = self._finalize_approval(approval)
|
||||
signing_key = self._signer.public_base64_spki() if self._signer else ""
|
||||
certificate = self._load_tls_bundle()
|
||||
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.APPROVED,
|
||||
http_status=200,
|
||||
tokens=token_bundle,
|
||||
server_certificate=certificate,
|
||||
signing_key=signing_key,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
def _enforce_rate_limit(
|
||||
self,
|
||||
limiter: SlidingWindowRateLimiter,
|
||||
key: str,
|
||||
*,
|
||||
limit: int = 60,
|
||||
window_seconds: float = 60.0,
|
||||
) -> None:
|
||||
decision = limiter.check(key, limit, window_seconds)
|
||||
if not decision.allowed:
|
||||
raise EnrollmentValidationError(
|
||||
"rate_limited", http_status=429, retry_after=max(decision.retry_after, 1.0)
|
||||
)
|
||||
|
||||
def _determine_reuse_guid(
|
||||
self,
|
||||
install_code: Optional[EnrollmentCode],
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> Optional[DeviceGuid]:
|
||||
if install_code is None:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
if install_code.is_expired:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
if not install_code.is_exhausted:
|
||||
return None
|
||||
if not install_code.used_by_guid:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
|
||||
existing = self._devices.fetch_by_guid(install_code.used_by_guid)
|
||||
if existing and existing.identity.fingerprint.value == fingerprint.value:
|
||||
return install_code.used_by_guid
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
|
||||
def _finalize_approval(self, approval: EnrollmentApproval) -> EnrollmentTokenBundle:
|
||||
now = self._now()
|
||||
effective_guid = approval.guid or DeviceGuid(str(uuid.uuid4()))
|
||||
device_record = self._devices.ensure_device_record(
|
||||
guid=effective_guid,
|
||||
hostname=approval.claimed_hostname,
|
||||
fingerprint=approval.claimed_fingerprint,
|
||||
)
|
||||
self._devices.record_device_key(
|
||||
guid=effective_guid,
|
||||
fingerprint=approval.claimed_fingerprint,
|
||||
added_at=now,
|
||||
)
|
||||
|
||||
if approval.enrollment_code_id:
|
||||
code = self._enrollment.fetch_install_code_by_id(approval.enrollment_code_id)
|
||||
if code is not None:
|
||||
mark_first = code.used_at is None
|
||||
self._enrollment.update_install_code_usage(
|
||||
approval.enrollment_code_id,
|
||||
use_count_increment=1,
|
||||
last_used_at=now,
|
||||
used_by_guid=effective_guid,
|
||||
mark_first_use=mark_first,
|
||||
)
|
||||
|
||||
refresh_token = secrets.token_urlsafe(48)
|
||||
refresh_id = str(uuid.uuid4())
|
||||
expires_at = now + timedelta(days=30)
|
||||
token_hash = hashlib.sha256(refresh_token.encode("utf-8")).hexdigest()
|
||||
self._tokens.create(
|
||||
record_id=refresh_id,
|
||||
guid=effective_guid,
|
||||
token_hash=token_hash,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
access_token = self._jwt.issue_access_token(
|
||||
effective_guid.value,
|
||||
device_record.identity.fingerprint.value,
|
||||
max(device_record.token_version, 1),
|
||||
)
|
||||
|
||||
self._enrollment.update_device_approval_status(
|
||||
approval.record_id,
|
||||
status=EnrollmentApprovalStatus.COMPLETED,
|
||||
updated_at=now,
|
||||
guid=effective_guid,
|
||||
)
|
||||
|
||||
return EnrollmentTokenBundle(
|
||||
guid=effective_guid,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=900,
|
||||
)
|
||||
|
||||
def _verify_proof_signature(
|
||||
self,
|
||||
*,
|
||||
approval: EnrollmentApproval,
|
||||
client_nonce: bytes,
|
||||
server_nonce: bytes,
|
||||
signature: bytes,
|
||||
) -> None:
|
||||
message = server_nonce + approval.reference.encode("utf-8") + client_nonce
|
||||
try:
|
||||
public_key = serialization.load_der_public_key(approval.agent_pubkey_der)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError("agent_pubkey_invalid") from exc
|
||||
|
||||
try:
|
||||
public_key.verify(signature, message)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError("invalid_proof") from exc
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64(value: str, error_code: str) -> bytes:
|
||||
try:
|
||||
return base64.b64decode(value, validate=True)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError(error_code) from exc
|
||||
|
||||
@staticmethod
|
||||
def _mask_code(code: str) -> str:
|
||||
trimmed = (code or "").strip()
|
||||
if len(trimmed) <= 6:
|
||||
return "***"
|
||||
return f"{trimmed[:3]}***{trimmed[-3:]}"
|
||||
|
||||
@staticmethod
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
@@ -1,32 +0,0 @@
|
||||
"""Nonce replay protection for enrollment workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict
|
||||
|
||||
__all__ = ["NonceCache"]
|
||||
|
||||
|
||||
class NonceCache:
|
||||
"""Track recently observed nonces to prevent replay."""
|
||||
|
||||
def __init__(self, ttl_seconds: float = 300.0) -> None:
|
||||
self._ttl = ttl_seconds
|
||||
self._entries: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def consume(self, key: str) -> bool:
|
||||
"""Consume *key* if it has not been seen recently."""
|
||||
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
expiry = self._entries.get(key)
|
||||
if expiry and expiry > now:
|
||||
return False
|
||||
self._entries[key] = now + self._ttl
|
||||
stale = [nonce for nonce, ttl in self._entries.items() if ttl <= now]
|
||||
for nonce in stale:
|
||||
self._entries.pop(nonce, None)
|
||||
return True
|
||||
@@ -1,8 +0,0 @@
|
||||
"""GitHub-oriented services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .github_service import GitHubService, GitHubTokenPayload
|
||||
|
||||
__all__ = ["GitHubService", "GitHubTokenPayload"]
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
"""GitHub service layer bridging repositories and integrations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
from Data.Engine.domain.github import GitHubRepoRef, GitHubTokenStatus, RepoHeadSnapshot
|
||||
from Data.Engine.integrations.github import GitHubArtifactProvider
|
||||
from Data.Engine.repositories.sqlite.github_repository import SQLiteGitHubRepository
|
||||
|
||||
__all__ = ["GitHubService", "GitHubTokenPayload"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GitHubTokenPayload:
|
||||
token: Optional[str]
|
||||
status: GitHubTokenStatus
|
||||
checked_at: int
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
payload = self.status.to_dict()
|
||||
payload.update(
|
||||
{
|
||||
"token": self.token or "",
|
||||
"checked_at": self.checked_at,
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
class GitHubService:
|
||||
"""Coordinate GitHub caching, verification, and persistence."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: SQLiteGitHubRepository,
|
||||
provider: GitHubArtifactProvider,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
clock: Optional[Callable[[], float]] = None,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._provider = provider
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.github")
|
||||
self._clock = clock or time.time
|
||||
self._token_cache: Optional[str] = None
|
||||
self._token_loaded_at: float = 0.0
|
||||
|
||||
initial_token = self._repository.load_token()
|
||||
self._apply_token(initial_token)
|
||||
|
||||
def get_repo_head(
|
||||
self,
|
||||
owner_repo: Optional[str],
|
||||
branch: Optional[str],
|
||||
*,
|
||||
ttl_seconds: int,
|
||||
force_refresh: bool = False,
|
||||
) -> RepoHeadSnapshot:
|
||||
repo_str = (owner_repo or self._provider.default_repo).strip()
|
||||
branch_name = (branch or self._provider.default_branch).strip()
|
||||
repo = GitHubRepoRef.parse(repo_str, branch_name)
|
||||
ttl = max(30, min(ttl_seconds, 3600))
|
||||
return self._provider.fetch_repo_head(repo, ttl_seconds=ttl, force_refresh=force_refresh)
|
||||
|
||||
def refresh_default_repo(self, *, force: bool = False) -> RepoHeadSnapshot:
|
||||
return self._provider.refresh_default_repo_head(force=force)
|
||||
|
||||
def get_token_status(self, *, force_refresh: bool = False) -> GitHubTokenPayload:
|
||||
token = self._load_token(force_refresh=force_refresh)
|
||||
status = self._provider.verify_token(token)
|
||||
return GitHubTokenPayload(token=token, status=status, checked_at=int(self._clock()))
|
||||
|
||||
def update_token(self, token: Optional[str]) -> GitHubTokenPayload:
|
||||
normalized = (token or "").strip()
|
||||
self._repository.store_token(normalized)
|
||||
self._apply_token(normalized)
|
||||
status = self._provider.verify_token(normalized)
|
||||
self._provider.start_background_refresh()
|
||||
self._log.info("github-token updated valid=%s", status.valid)
|
||||
return GitHubTokenPayload(token=normalized or None, status=status, checked_at=int(self._clock()))
|
||||
|
||||
def start_background_refresh(self) -> None:
|
||||
self._provider.start_background_refresh()
|
||||
|
||||
@property
|
||||
def default_refresh_interval(self) -> int:
|
||||
return self._provider.refresh_interval
|
||||
|
||||
def _load_token(self, *, force_refresh: bool = False) -> Optional[str]:
|
||||
now = self._clock()
|
||||
if not force_refresh and self._token_cache is not None and (now - self._token_loaded_at) < 15.0:
|
||||
return self._token_cache
|
||||
|
||||
token = self._repository.load_token()
|
||||
self._apply_token(token)
|
||||
return token
|
||||
|
||||
def _apply_token(self, token: Optional[str]) -> None:
|
||||
self._token_cache = (token or "").strip() or None
|
||||
self._token_loaded_at = self._clock()
|
||||
self._provider.set_token(self._token_cache)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Job-related services for the Borealis Engine."""
|
||||
|
||||
from .scheduler_service import SchedulerService
|
||||
|
||||
__all__ = ["SchedulerService"]
|
||||
@@ -1,373 +0,0 @@
|
||||
"""Background scheduler service for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import calendar
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Mapping, Optional
|
||||
|
||||
from Data.Engine.builders.job_fabricator import JobFabricator, JobManifest
|
||||
from Data.Engine.repositories.sqlite.job_repository import (
|
||||
ScheduledJobRecord,
|
||||
ScheduledJobRunRecord,
|
||||
SQLiteJobRepository,
|
||||
)
|
||||
|
||||
__all__ = ["SchedulerService", "SchedulerRuntime"]
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def _floor_minute(ts: int) -> int:
|
||||
ts = int(ts or 0)
|
||||
return ts - (ts % 60)
|
||||
|
||||
|
||||
def _parse_ts(val: Any) -> Optional[int]:
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, (int, float)):
|
||||
return int(val)
|
||||
try:
|
||||
s = str(val).strip().replace("Z", "+00:00")
|
||||
return int(datetime.fromisoformat(s).timestamp())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_expiration(raw: Optional[str]) -> Optional[int]:
|
||||
if not raw or raw == "no_expire":
|
||||
return None
|
||||
try:
|
||||
s = raw.strip().lower()
|
||||
unit = s[-1]
|
||||
value = int(s[:-1])
|
||||
if unit == "m":
|
||||
return value * 60
|
||||
if unit == "h":
|
||||
return value * 3600
|
||||
if unit == "d":
|
||||
return value * 86400
|
||||
return int(s) * 60
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _add_months(dt: datetime, months: int) -> datetime:
|
||||
year = dt.year + (dt.month - 1 + months) // 12
|
||||
month = ((dt.month - 1 + months) % 12) + 1
|
||||
last_day = calendar.monthrange(year, month)[1]
|
||||
day = min(dt.day, last_day)
|
||||
return dt.replace(year=year, month=month, day=day)
|
||||
|
||||
|
||||
def _add_years(dt: datetime, years: int) -> datetime:
|
||||
year = dt.year + years
|
||||
month = dt.month
|
||||
last_day = calendar.monthrange(year, month)[1]
|
||||
day = min(dt.day, last_day)
|
||||
return dt.replace(year=year, month=month, day=day)
|
||||
|
||||
|
||||
def _compute_next_run(
|
||||
schedule_type: str,
|
||||
start_ts: Optional[int],
|
||||
last_run_ts: Optional[int],
|
||||
now_ts: int,
|
||||
) -> Optional[int]:
|
||||
st = (schedule_type or "immediately").strip().lower()
|
||||
start_floor = _floor_minute(start_ts) if start_ts else None
|
||||
last_floor = _floor_minute(last_run_ts) if last_run_ts else None
|
||||
now_floor = _floor_minute(now_ts)
|
||||
|
||||
if st == "immediately":
|
||||
return None if last_floor else now_floor
|
||||
if st == "once":
|
||||
if not start_floor:
|
||||
return None
|
||||
return start_floor if not last_floor else None
|
||||
if not start_floor:
|
||||
return None
|
||||
|
||||
last = last_floor if last_floor is not None else None
|
||||
if st in {
|
||||
"every_5_minutes",
|
||||
"every_10_minutes",
|
||||
"every_15_minutes",
|
||||
"every_30_minutes",
|
||||
"every_hour",
|
||||
}:
|
||||
period_map = {
|
||||
"every_5_minutes": 5 * 60,
|
||||
"every_10_minutes": 10 * 60,
|
||||
"every_15_minutes": 15 * 60,
|
||||
"every_30_minutes": 30 * 60,
|
||||
"every_hour": 60 * 60,
|
||||
}
|
||||
period = period_map.get(st)
|
||||
candidate = (last + period) if last else start_floor
|
||||
while candidate is not None and candidate <= now_floor - 1:
|
||||
candidate += period
|
||||
return candidate
|
||||
if st == "daily":
|
||||
period = 86400
|
||||
candidate = (last + period) if last else start_floor
|
||||
while candidate is not None and candidate <= now_floor - 1:
|
||||
candidate += period
|
||||
return candidate
|
||||
if st == "weekly":
|
||||
period = 7 * 86400
|
||||
candidate = (last + period) if last else start_floor
|
||||
while candidate is not None and candidate <= now_floor - 1:
|
||||
candidate += period
|
||||
return candidate
|
||||
if st == "monthly":
|
||||
base = datetime.utcfromtimestamp(last) if last else datetime.utcfromtimestamp(start_floor)
|
||||
candidate = _add_months(base, 1 if last else 0)
|
||||
while int(candidate.timestamp()) <= now_floor - 1:
|
||||
candidate = _add_months(candidate, 1)
|
||||
return int(candidate.timestamp())
|
||||
if st == "yearly":
|
||||
base = datetime.utcfromtimestamp(last) if last else datetime.utcfromtimestamp(start_floor)
|
||||
candidate = _add_years(base, 1 if last else 0)
|
||||
while int(candidate.timestamp()) <= now_floor - 1:
|
||||
candidate = _add_years(candidate, 1)
|
||||
return int(candidate.timestamp())
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerRuntime:
|
||||
thread: Optional[threading.Thread]
|
||||
stop_event: threading.Event
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""Evaluate and dispatch scheduled jobs using Engine repositories."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
job_repository: SQLiteJobRepository,
|
||||
assemblies_root: Path,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
poll_interval: int = 30,
|
||||
) -> None:
|
||||
self._jobs = job_repository
|
||||
self._fabricator = JobFabricator(assemblies_root=assemblies_root, logger=logger)
|
||||
self._log = logger or logging.getLogger("borealis.engine.scheduler")
|
||||
self._poll_interval = max(5, poll_interval)
|
||||
self._socketio: Optional[Any] = None
|
||||
self._runtime = SchedulerRuntime(thread=None, stop_event=threading.Event())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
def start(self, socketio: Optional[Any] = None) -> None:
|
||||
self._socketio = socketio
|
||||
if self._runtime.thread and self._runtime.thread.is_alive():
|
||||
return
|
||||
self._runtime.stop_event.clear()
|
||||
thread = threading.Thread(target=self._run_loop, name="borealis-engine-scheduler", daemon=True)
|
||||
thread.start()
|
||||
self._runtime.thread = thread
|
||||
self._log.info("scheduler-started")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._runtime.stop_event.set()
|
||||
thread = self._runtime.thread
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=5)
|
||||
self._log.info("scheduler-stopped")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP orchestration helpers
|
||||
# ------------------------------------------------------------------
|
||||
def list_jobs(self) -> list[dict[str, Any]]:
|
||||
return [self._serialize_job(job) for job in self._jobs.list_jobs()]
|
||||
|
||||
def get_job(self, job_id: int) -> Optional[dict[str, Any]]:
|
||||
record = self._jobs.fetch_job(job_id)
|
||||
return self._serialize_job(record) if record else None
|
||||
|
||||
def create_job(self, payload: Mapping[str, Any]) -> dict[str, Any]:
|
||||
fields = self._normalize_payload(payload)
|
||||
record = self._jobs.create_job(**fields)
|
||||
return self._serialize_job(record)
|
||||
|
||||
def update_job(self, job_id: int, payload: Mapping[str, Any]) -> Optional[dict[str, Any]]:
|
||||
fields = self._normalize_payload(payload)
|
||||
record = self._jobs.update_job(job_id, **fields)
|
||||
return self._serialize_job(record) if record else None
|
||||
|
||||
def toggle_job(self, job_id: int, enabled: bool) -> None:
|
||||
self._jobs.set_enabled(job_id, enabled)
|
||||
|
||||
def delete_job(self, job_id: int) -> None:
|
||||
self._jobs.delete_job(job_id)
|
||||
|
||||
def list_runs(self, job_id: int, *, days: Optional[int] = None) -> list[dict[str, Any]]:
|
||||
runs = self._jobs.list_runs(job_id, days=days)
|
||||
return [self._serialize_run(run) for run in runs]
|
||||
|
||||
def purge_runs(self, job_id: int) -> None:
|
||||
self._jobs.purge_runs(job_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scheduling loop
|
||||
# ------------------------------------------------------------------
|
||||
def tick(self, *, now_ts: Optional[int] = None) -> None:
|
||||
self._evaluate_jobs(now_ts=now_ts or _now_ts())
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
stop_event = self._runtime.stop_event
|
||||
while not stop_event.wait(timeout=self._poll_interval):
|
||||
try:
|
||||
self._evaluate_jobs(now_ts=_now_ts())
|
||||
except Exception as exc: # pragma: no cover - safety net
|
||||
self._log.exception("scheduler-loop-error: %s", exc)
|
||||
|
||||
def _evaluate_jobs(self, *, now_ts: int) -> None:
|
||||
for job in self._jobs.list_enabled_jobs():
|
||||
try:
|
||||
self._evaluate_job(job, now_ts=now_ts)
|
||||
except Exception as exc:
|
||||
self._log.exception("job-evaluation-error job_id=%s error=%s", job.id, exc)
|
||||
|
||||
def _evaluate_job(self, job: ScheduledJobRecord, *, now_ts: int) -> None:
|
||||
last_run = self._jobs.fetch_last_run(job.id)
|
||||
last_ts = None
|
||||
if last_run:
|
||||
last_ts = last_run.scheduled_ts or last_run.started_ts or last_run.created_at
|
||||
next_run = _compute_next_run(job.schedule_type, job.start_ts, last_ts, now_ts)
|
||||
if next_run is None or next_run > now_ts:
|
||||
return
|
||||
|
||||
expiration_window = _parse_expiration(job.expiration)
|
||||
if expiration_window and job.start_ts:
|
||||
if job.start_ts + expiration_window <= now_ts:
|
||||
self._log.info(
|
||||
"job-expired",
|
||||
extra={"job_id": job.id, "start_ts": job.start_ts, "expiration": job.expiration},
|
||||
)
|
||||
return
|
||||
|
||||
manifest = self._fabricator.build(job, occurrence_ts=next_run)
|
||||
targets = manifest.targets or ("<unassigned>",)
|
||||
for target in targets:
|
||||
run_id = self._jobs.create_run(job.id, next_run, target_hostname=None if target == "<unassigned>" else target)
|
||||
self._jobs.mark_run_started(run_id, started_ts=now_ts)
|
||||
self._emit_run_event("job_run_started", job, run_id, target, manifest)
|
||||
self._jobs.mark_run_finished(run_id, status="Success", finished_ts=now_ts)
|
||||
self._emit_run_event("job_run_completed", job, run_id, target, manifest)
|
||||
|
||||
def _emit_run_event(
|
||||
self,
|
||||
event: str,
|
||||
job: ScheduledJobRecord,
|
||||
run_id: int,
|
||||
target: str,
|
||||
manifest: JobManifest,
|
||||
) -> None:
|
||||
payload = {
|
||||
"job_id": job.id,
|
||||
"run_id": run_id,
|
||||
"target": target,
|
||||
"schedule_type": job.schedule_type,
|
||||
"occurrence_ts": manifest.occurrence_ts,
|
||||
}
|
||||
if self._socketio is not None:
|
||||
try:
|
||||
self._socketio.emit(event, payload) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
self._log.debug("socketio-emit-failed event=%s payload=%s", event, payload)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _serialize_job(self, job: ScheduledJobRecord) -> dict[str, Any]:
|
||||
return {
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"components": job.components,
|
||||
"targets": job.targets,
|
||||
"schedule": {
|
||||
"type": job.schedule_type,
|
||||
"start": job.start_ts,
|
||||
},
|
||||
"schedule_type": job.schedule_type,
|
||||
"start_ts": job.start_ts,
|
||||
"duration_stop_enabled": job.duration_stop_enabled,
|
||||
"expiration": job.expiration or "no_expire",
|
||||
"execution_context": job.execution_context,
|
||||
"credential_id": job.credential_id,
|
||||
"use_service_account": job.use_service_account,
|
||||
"enabled": job.enabled,
|
||||
"created_at": job.created_at,
|
||||
"updated_at": job.updated_at,
|
||||
}
|
||||
|
||||
def _serialize_run(self, run: ScheduledJobRunRecord) -> dict[str, Any]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"job_id": run.job_id,
|
||||
"scheduled_ts": run.scheduled_ts,
|
||||
"started_ts": run.started_ts,
|
||||
"finished_ts": run.finished_ts,
|
||||
"status": run.status,
|
||||
"error": run.error,
|
||||
"target_hostname": run.target_hostname,
|
||||
"created_at": run.created_at,
|
||||
"updated_at": run.updated_at,
|
||||
}
|
||||
|
||||
def _normalize_payload(self, payload: Mapping[str, Any]) -> dict[str, Any]:
|
||||
name = str(payload.get("name") or "").strip()
|
||||
components = payload.get("components") or []
|
||||
targets = payload.get("targets") or []
|
||||
schedule_block = payload.get("schedule") if isinstance(payload.get("schedule"), Mapping) else {}
|
||||
schedule_type = str(schedule_block.get("type") or payload.get("schedule_type") or "immediately").strip().lower()
|
||||
start_value = schedule_block.get("start") if isinstance(schedule_block, Mapping) else None
|
||||
if start_value is None:
|
||||
start_value = payload.get("start")
|
||||
start_ts = _parse_ts(start_value)
|
||||
duration_block = payload.get("duration") if isinstance(payload.get("duration"), Mapping) else {}
|
||||
duration_stop = bool(duration_block.get("stopAfterEnabled") or payload.get("duration_stop_enabled"))
|
||||
expiration = str(duration_block.get("expiration") or payload.get("expiration") or "no_expire").strip()
|
||||
execution_context = str(payload.get("execution_context") or "system").strip().lower()
|
||||
credential_id = payload.get("credential_id")
|
||||
try:
|
||||
credential_id = int(credential_id) if credential_id is not None else None
|
||||
except Exception:
|
||||
credential_id = None
|
||||
use_service_account_raw = payload.get("use_service_account")
|
||||
use_service_account = bool(use_service_account_raw) if execution_context == "winrm" else False
|
||||
enabled = bool(payload.get("enabled", True))
|
||||
|
||||
if not name:
|
||||
raise ValueError("job name is required")
|
||||
if not isinstance(components, Iterable) or not list(components):
|
||||
raise ValueError("at least one component is required")
|
||||
if not isinstance(targets, Iterable) or not list(targets):
|
||||
raise ValueError("at least one target is required")
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"components": list(components),
|
||||
"targets": list(targets),
|
||||
"schedule_type": schedule_type,
|
||||
"start_ts": start_ts,
|
||||
"duration_stop_enabled": duration_stop,
|
||||
"expiration": expiration,
|
||||
"execution_context": execution_context,
|
||||
"credential_id": credential_id,
|
||||
"use_service_account": use_service_account,
|
||||
"enabled": enabled,
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
"""In-process rate limiting utilities for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Deque, Dict
|
||||
|
||||
__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RateLimitDecision:
|
||||
"""Result of a rate limit check."""
|
||||
|
||||
allowed: bool
|
||||
retry_after: float
|
||||
|
||||
|
||||
class SlidingWindowRateLimiter:
|
||||
"""Tiny in-memory sliding window limiter suitable for single-process use."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buckets: Dict[str, Deque[float]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
bucket = self._buckets.get(key)
|
||||
if bucket is None:
|
||||
bucket = deque()
|
||||
self._buckets[key] = bucket
|
||||
|
||||
while bucket and now - bucket[0] > window_seconds:
|
||||
bucket.popleft()
|
||||
|
||||
if len(bucket) >= limit:
|
||||
retry_after = max(0.0, window_seconds - (now - bucket[0]))
|
||||
return RateLimitDecision(False, retry_after)
|
||||
|
||||
bucket.append(now)
|
||||
return RateLimitDecision(True, 0.0)
|
||||
@@ -1,10 +0,0 @@
|
||||
"""Realtime coordination services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .agent_registry import AgentRealtimeService, AgentRecord
|
||||
|
||||
__all__ = [
|
||||
"AgentRealtimeService",
|
||||
"AgentRecord",
|
||||
]
|
||||
@@ -1,301 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple
|
||||
|
||||
from Data.Engine.repositories.sqlite import SQLiteDeviceRepository
|
||||
|
||||
__all__ = ["AgentRealtimeService", "AgentRecord"]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentRecord:
|
||||
"""In-memory representation of a connected agent."""
|
||||
|
||||
agent_id: str
|
||||
hostname: str = "unknown"
|
||||
agent_operating_system: str = "-"
|
||||
last_seen: int = 0
|
||||
status: str = "orphaned"
|
||||
service_mode: str = "currentuser"
|
||||
is_script_agent: bool = False
|
||||
collector_active_ts: Optional[float] = None
|
||||
|
||||
|
||||
class AgentRealtimeService:
|
||||
"""Track realtime agent presence and provide persistence hooks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_repository: SQLiteDeviceRepository,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._device_repository = device_repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.realtime.agents")
|
||||
self._agents: Dict[str, AgentRecord] = {}
|
||||
self._configs: Dict[str, Dict[str, Any]] = {}
|
||||
self._screenshots: Dict[str, Dict[str, Any]] = {}
|
||||
self._task_screenshots: Dict[Tuple[str, str], Dict[str, Any]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent presence management
|
||||
# ------------------------------------------------------------------
|
||||
def register_connection(self, agent_id: str, service_mode: Optional[str]) -> AgentRecord:
|
||||
record = self._agents.get(agent_id) or AgentRecord(agent_id=agent_id)
|
||||
mode = self.normalize_service_mode(service_mode, agent_id)
|
||||
now = int(time.time())
|
||||
|
||||
record.service_mode = mode
|
||||
record.is_script_agent = self._is_script_agent(agent_id)
|
||||
record.last_seen = now
|
||||
record.status = "provisioned" if agent_id in self._configs else "orphaned"
|
||||
self._agents[agent_id] = record
|
||||
|
||||
self._persist_activity(
|
||||
hostname=record.hostname,
|
||||
last_seen=record.last_seen,
|
||||
agent_id=agent_id,
|
||||
operating_system=record.agent_operating_system,
|
||||
)
|
||||
return record
|
||||
|
||||
def heartbeat(self, payload: Mapping[str, Any]) -> Optional[AgentRecord]:
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
agent_id = payload.get("agent_id")
|
||||
if not agent_id:
|
||||
return None
|
||||
|
||||
hostname = payload.get("hostname") or ""
|
||||
mode = self.normalize_service_mode(payload.get("service_mode"), agent_id)
|
||||
is_script_agent = self._is_script_agent(agent_id)
|
||||
last_seen = self._coerce_int(payload.get("last_seen"), default=int(time.time()))
|
||||
operating_system = (payload.get("agent_operating_system") or "-").strip() or "-"
|
||||
|
||||
if hostname:
|
||||
self._reconcile_hostname_collisions(
|
||||
hostname=hostname,
|
||||
agent_id=agent_id,
|
||||
incoming_mode=mode,
|
||||
is_script_agent=is_script_agent,
|
||||
last_seen=last_seen,
|
||||
)
|
||||
|
||||
record = self._agents.get(agent_id) or AgentRecord(agent_id=agent_id)
|
||||
if hostname:
|
||||
record.hostname = hostname
|
||||
record.agent_operating_system = operating_system
|
||||
record.last_seen = last_seen
|
||||
record.service_mode = mode
|
||||
record.is_script_agent = is_script_agent
|
||||
record.status = "provisioned" if agent_id in self._configs else record.status or "orphaned"
|
||||
self._agents[agent_id] = record
|
||||
|
||||
self._persist_activity(
|
||||
hostname=record.hostname or hostname,
|
||||
last_seen=record.last_seen,
|
||||
agent_id=agent_id,
|
||||
operating_system=record.agent_operating_system,
|
||||
)
|
||||
return record
|
||||
|
||||
def collector_status(self, payload: Mapping[str, Any]) -> None:
|
||||
if not payload:
|
||||
return
|
||||
|
||||
agent_id = payload.get("agent_id")
|
||||
if not agent_id:
|
||||
return
|
||||
|
||||
hostname = payload.get("hostname") or ""
|
||||
mode = self.normalize_service_mode(payload.get("service_mode"), agent_id)
|
||||
active = bool(payload.get("active"))
|
||||
last_user = (payload.get("last_user") or "").strip()
|
||||
|
||||
record = self._agents.get(agent_id) or AgentRecord(agent_id=agent_id)
|
||||
if hostname:
|
||||
record.hostname = hostname
|
||||
if mode:
|
||||
record.service_mode = mode
|
||||
record.is_script_agent = self._is_script_agent(agent_id) or record.is_script_agent
|
||||
if active:
|
||||
record.collector_active_ts = time.time()
|
||||
self._agents[agent_id] = record
|
||||
|
||||
if (
|
||||
last_user
|
||||
and hostname
|
||||
and self._is_valid_interactive_user(last_user)
|
||||
and not self._is_system_service_agent(agent_id, record.is_script_agent)
|
||||
):
|
||||
self._persist_activity(
|
||||
hostname=hostname,
|
||||
last_seen=int(time.time()),
|
||||
agent_id=agent_id,
|
||||
operating_system=record.agent_operating_system,
|
||||
last_user=last_user,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Configuration management
|
||||
# ------------------------------------------------------------------
|
||||
def set_agent_config(self, agent_id: str, config: Mapping[str, Any]) -> None:
|
||||
self._configs[agent_id] = dict(config)
|
||||
record = self._agents.get(agent_id)
|
||||
if record:
|
||||
record.status = "provisioned"
|
||||
|
||||
def get_agent_config(self, agent_id: str) -> Optional[Dict[str, Any]]:
|
||||
config = self._configs.get(agent_id)
|
||||
if config is None:
|
||||
return None
|
||||
return dict(config)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Screenshot caches
|
||||
# ------------------------------------------------------------------
|
||||
def store_agent_screenshot(self, agent_id: str, image_base64: str) -> None:
|
||||
self._screenshots[agent_id] = {
|
||||
"image_base64": image_base64,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
def store_task_screenshot(self, agent_id: str, node_id: str, image_base64: str) -> None:
|
||||
self._task_screenshots[(agent_id, node_id)] = {
|
||||
"image_base64": image_base64,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def normalize_service_mode(value: Optional[str], agent_id: Optional[str] = None) -> str:
|
||||
text = ""
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
text = value.strip().lower()
|
||||
except Exception:
|
||||
text = ""
|
||||
|
||||
if not text and agent_id:
|
||||
try:
|
||||
lowered = agent_id.lower()
|
||||
if "-svc-" in lowered or lowered.endswith("-svc"):
|
||||
return "system"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if text in {"system", "svc", "service", "system_service"}:
|
||||
return "system"
|
||||
if text in {"interactive", "currentuser", "user", "current_user"}:
|
||||
return "currentuser"
|
||||
return "currentuser"
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, *, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _is_script_agent(agent_id: Optional[str]) -> bool:
|
||||
try:
|
||||
return bool(isinstance(agent_id, str) and agent_id.lower().endswith("-script"))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_interactive_user(candidate: Optional[str]) -> bool:
|
||||
if not candidate:
|
||||
return False
|
||||
try:
|
||||
text = str(candidate).strip()
|
||||
except Exception:
|
||||
return False
|
||||
if not text:
|
||||
return False
|
||||
upper = text.upper()
|
||||
if text.endswith("$"):
|
||||
return False
|
||||
if "NT AUTHORITY\\" in upper or "NT SERVICE\\" in upper:
|
||||
return False
|
||||
if upper.endswith("\\SYSTEM"):
|
||||
return False
|
||||
if upper.endswith("\\LOCAL SERVICE"):
|
||||
return False
|
||||
if upper.endswith("\\NETWORK SERVICE"):
|
||||
return False
|
||||
if upper == "ANONYMOUS LOGON":
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _is_system_service_agent(agent_id: str, is_script_agent: bool) -> bool:
|
||||
try:
|
||||
lowered = agent_id.lower()
|
||||
except Exception:
|
||||
lowered = ""
|
||||
if is_script_agent:
|
||||
return False
|
||||
return "-svc-" in lowered or lowered.endswith("-svc")
|
||||
|
||||
def _reconcile_hostname_collisions(
|
||||
self,
|
||||
*,
|
||||
hostname: str,
|
||||
agent_id: str,
|
||||
incoming_mode: str,
|
||||
is_script_agent: bool,
|
||||
last_seen: int,
|
||||
) -> None:
|
||||
transferred_config = False
|
||||
for existing_id, info in list(self._agents.items()):
|
||||
if existing_id == agent_id:
|
||||
continue
|
||||
if info.hostname != hostname:
|
||||
continue
|
||||
existing_mode = self.normalize_service_mode(info.service_mode, existing_id)
|
||||
if existing_mode != incoming_mode:
|
||||
continue
|
||||
if is_script_agent and not info.is_script_agent:
|
||||
self._persist_activity(
|
||||
hostname=hostname,
|
||||
last_seen=last_seen,
|
||||
agent_id=existing_id,
|
||||
operating_system=info.agent_operating_system,
|
||||
)
|
||||
return
|
||||
if not transferred_config and existing_id in self._configs and agent_id not in self._configs:
|
||||
self._configs[agent_id] = dict(self._configs[existing_id])
|
||||
transferred_config = True
|
||||
self._agents.pop(existing_id, None)
|
||||
if existing_id != agent_id:
|
||||
self._configs.pop(existing_id, None)
|
||||
|
||||
def _persist_activity(
|
||||
self,
|
||||
*,
|
||||
hostname: Optional[str],
|
||||
last_seen: Optional[int],
|
||||
agent_id: Optional[str],
|
||||
operating_system: Optional[str],
|
||||
last_user: Optional[str] = None,
|
||||
) -> None:
|
||||
if not hostname:
|
||||
return
|
||||
try:
|
||||
self._device_repository.update_device_summary(
|
||||
hostname=hostname,
|
||||
last_seen=last_seen,
|
||||
agent_id=agent_id,
|
||||
operating_system=operating_system,
|
||||
last_user=last_user,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.debug("failed to persist device activity for %s: %s", hostname, exc)
|
||||
@@ -1,3 +0,0 @@
|
||||
from .site_service import SiteService
|
||||
|
||||
__all__ = ["SiteService"]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Site management service that mirrors the legacy Flask behaviour."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from Data.Engine.domain.sites import SiteDeviceMapping, SiteSummary
|
||||
from Data.Engine.repositories.sqlite.site_repository import SQLiteSiteRepository
|
||||
|
||||
__all__ = ["SiteService"]
|
||||
|
||||
|
||||
class SiteService:
|
||||
def __init__(self, repository: SQLiteSiteRepository, *, logger: Optional[logging.Logger] = None) -> None:
|
||||
self._repo = repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.sites")
|
||||
|
||||
def list_sites(self) -> List[SiteSummary]:
|
||||
return self._repo.list_sites()
|
||||
|
||||
def create_site(self, name: str, description: str) -> SiteSummary:
|
||||
normalized_name = (name or "").strip()
|
||||
normalized_description = (description or "").strip()
|
||||
if not normalized_name:
|
||||
raise ValueError("missing_name")
|
||||
try:
|
||||
return self._repo.create_site(normalized_name, normalized_description)
|
||||
except ValueError as exc:
|
||||
if str(exc) == "duplicate":
|
||||
raise ValueError("duplicate") from exc
|
||||
raise
|
||||
|
||||
def delete_sites(self, ids: Iterable[int]) -> int:
|
||||
normalized = []
|
||||
for value in ids:
|
||||
try:
|
||||
normalized.append(int(value))
|
||||
except Exception:
|
||||
continue
|
||||
if not normalized:
|
||||
return 0
|
||||
return self._repo.delete_sites(tuple(normalized))
|
||||
|
||||
def rename_site(self, site_id: int, new_name: str) -> SiteSummary:
|
||||
normalized_name = (new_name or "").strip()
|
||||
if not normalized_name:
|
||||
raise ValueError("missing_name")
|
||||
try:
|
||||
return self._repo.rename_site(int(site_id), normalized_name)
|
||||
except ValueError as exc:
|
||||
if str(exc) == "duplicate":
|
||||
raise ValueError("duplicate") from exc
|
||||
raise
|
||||
|
||||
def map_devices(self, hostnames: Optional[Iterable[str]] = None) -> Dict[str, SiteDeviceMapping]:
|
||||
return self._repo.map_devices(hostnames)
|
||||
|
||||
def assign_devices(self, site_id: int, hostnames: Iterable[str]) -> None:
|
||||
try:
|
||||
numeric_id = int(site_id)
|
||||
except Exception as exc:
|
||||
raise ValueError("invalid_site_id") from exc
|
||||
normalized = [hn for hn in hostnames if isinstance(hn, str) and hn.strip()]
|
||||
if not normalized:
|
||||
raise ValueError("invalid_hostnames")
|
||||
try:
|
||||
self._repo.assign_devices(numeric_id, normalized)
|
||||
except LookupError as exc:
|
||||
if str(exc) == "not_found":
|
||||
raise LookupError("not_found") from exc
|
||||
raise
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Test suite for the Borealis Engine."""
|
||||
@@ -1,72 +0,0 @@
|
||||
"""Shared pytest fixtures for Engine HTTP interface tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from Data.Engine.config.environment import (
|
||||
DatabaseSettings,
|
||||
EngineSettings,
|
||||
FlaskSettings,
|
||||
GitHubSettings,
|
||||
ServerSettings,
|
||||
SocketIOSettings,
|
||||
)
|
||||
from Data.Engine.interfaces.http import register_http_interfaces
|
||||
from Data.Engine.repositories.sqlite import connection as sqlite_connection
|
||||
from Data.Engine.repositories.sqlite import migrations as sqlite_migrations
|
||||
from Data.Engine.server import create_app
|
||||
from Data.Engine.services.container import build_service_container
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def engine_settings(tmp_path: Path) -> EngineSettings:
|
||||
"""Provision an EngineSettings instance backed by a temporary project root."""
|
||||
|
||||
project_root = tmp_path
|
||||
static_root = project_root / "static"
|
||||
static_root.mkdir()
|
||||
(static_root / "index.html").write_text("<html></html>", encoding="utf-8")
|
||||
|
||||
database_path = project_root / "database.db"
|
||||
|
||||
return EngineSettings(
|
||||
project_root=project_root,
|
||||
debug=False,
|
||||
database=DatabaseSettings(path=database_path, apply_migrations=False),
|
||||
flask=FlaskSettings(
|
||||
secret_key="test-key",
|
||||
static_root=static_root,
|
||||
cors_allowed_origins=("https://localhost",),
|
||||
),
|
||||
socketio=SocketIOSettings(cors_allowed_origins=("https://localhost",)),
|
||||
server=ServerSettings(host="127.0.0.1", port=5000),
|
||||
github=GitHubSettings(
|
||||
default_repo="owner/repo",
|
||||
default_branch="main",
|
||||
refresh_interval_seconds=60,
|
||||
cache_root=project_root / "cache",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def prepared_app(engine_settings: EngineSettings):
|
||||
"""Create a Flask app instance with registered Engine interfaces."""
|
||||
|
||||
settings = engine_settings
|
||||
settings.github.cache_root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
db_factory = sqlite_connection.connection_factory(settings.database.path)
|
||||
with sqlite_connection.connection_scope(settings.database.path) as conn:
|
||||
sqlite_migrations.apply_all(conn)
|
||||
|
||||
app = create_app(settings, db_factory=db_factory)
|
||||
services = build_service_container(settings, db_factory=db_factory)
|
||||
app.extensions["engine_services"] = services
|
||||
register_http_interfaces(app, services)
|
||||
app.config.update(TESTING=True)
|
||||
return app
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user