diff --git a/Data/Engine/interfaces/http/__init__.py b/Data/Engine/interfaces/http/__init__.py index 43bfc9c..a428f1b 100644 --- a/Data/Engine/interfaces/http/__init__.py +++ b/Data/Engine/interfaces/http/__init__.py @@ -8,6 +8,7 @@ from Data.Engine.services.container import EngineServiceContainer from . import ( admin, + agent, agents, auth, enrollment, @@ -25,6 +26,7 @@ from . import ( _REGISTRARS = ( health.register, + agent.register, agents.register, enrollment.register, tokens.register, diff --git a/Data/Engine/interfaces/http/agent.py b/Data/Engine/interfaces/http/agent.py new file mode 100644 index 0000000..1d415db --- /dev/null +++ b/Data/Engine/interfaces/http/agent.py @@ -0,0 +1,113 @@ +"""Agent REST endpoints for device communication.""" + +from __future__ import annotations + +import math +from functools import wraps +from typing import Any, Callable, Optional, TypeVar, cast + +from flask import Blueprint, Flask, current_app, g, jsonify, request + +from Data.Engine.builders.device_auth import DeviceAuthRequestBuilder +from Data.Engine.domain.device_auth import DeviceAuthContext, DeviceAuthFailure +from Data.Engine.services.container import EngineServiceContainer +from Data.Engine.services.devices.device_inventory_service import DeviceHeartbeatError + +AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context" + +blueprint = Blueprint("engine_agent", __name__) + +F = TypeVar("F", bound=Callable[..., Any]) + + +def _services() -> EngineServiceContainer: + return cast(EngineServiceContainer, current_app.extensions["engine_services"]) + + +def require_device_auth(func: F) -> F: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any): + services = _services() + builder = ( + DeviceAuthRequestBuilder() + .with_authorization(request.headers.get("Authorization")) + .with_http_method(request.method) + .with_htu(request.url) + .with_service_context(request.headers.get(AGENT_CONTEXT_HEADER)) + .with_dpop_proof(request.headers.get("DPoP")) + ) + try: + auth_request = builder.build() + context = services.device_auth.authenticate(auth_request, path=request.path) + except DeviceAuthFailure as exc: + payload = exc.to_dict() + response = jsonify(payload) + if exc.retry_after is not None: + response.headers["Retry-After"] = str(int(math.ceil(exc.retry_after))) + return response, exc.http_status + + g.device_auth = context + try: + return func(*args, **kwargs) + finally: + g.pop("device_auth", None) + + return cast(F, wrapper) + + +def register(app: Flask, _services: EngineServiceContainer) -> None: + if "engine_agent" not in app.blueprints: + app.register_blueprint(blueprint) + + +@blueprint.route("/api/agent/heartbeat", methods=["POST"]) +@require_device_auth +def heartbeat() -> Any: + services = _services() + payload = request.get_json(force=True, silent=True) or {} + context = cast(DeviceAuthContext, g.device_auth) + + try: + services.device_inventory.record_heartbeat(context=context, payload=payload) + except DeviceHeartbeatError as exc: + error_payload = {"error": exc.code} + if exc.code == "device_not_registered": + return jsonify(error_payload), 404 + if exc.code == "storage_conflict": + return jsonify(error_payload), 409 + current_app.logger.exception( + "device-heartbeat-error guid=%s code=%s", context.identity.guid.value, exc.code + ) + return jsonify(error_payload), 500 + + return jsonify({"status": "ok", "poll_after_ms": 15000}) + + +@blueprint.route("/api/agent/script/request", methods=["POST"]) +@require_device_auth +def script_request() -> Any: + services = _services() + context = cast(DeviceAuthContext, g.device_auth) + + signing_key: Optional[str] = None + signer = services.script_signer + if signer is not None: + try: + signing_key = signer.public_base64_spki() + except Exception as exc: # pragma: no cover - defensive logging + current_app.logger.warning("script-signer-unavailable: %s", exc) + + status = "quarantined" if context.is_quarantined else "idle" + poll_after = 60000 if context.is_quarantined else 30000 + + response = { + "status": status, + "poll_after_ms": poll_after, + "sig_alg": "ed25519", + } + if signing_key: + response["signing_key"] = signing_key + return jsonify(response) + + +__all__ = ["register", "blueprint", "heartbeat", "script_request", "require_device_auth"] diff --git a/Data/Engine/services/container.py b/Data/Engine/services/container.py index 4b8df94..9a0d06b 100644 --- a/Data/Engine/services/container.py +++ b/Data/Engine/services/container.py @@ -67,6 +67,7 @@ class EngineServiceContainer: operator_auth_service: OperatorAuthService operator_account_service: OperatorAccountService assembly_service: AssemblyService + script_signer: Optional[ScriptSigner] def build_service_container( @@ -106,6 +107,8 @@ def build_service_container( logger=log.getChild("token_service"), ) + script_signer = _load_script_signer(log) + enrollment_service = EnrollmentService( device_repository=device_repo, enrollment_repository=enrollment_repo, @@ -115,7 +118,7 @@ def build_service_container( ip_rate_limiter=SlidingWindowRateLimiter(), fingerprint_rate_limiter=SlidingWindowRateLimiter(), nonce_cache=NonceCache(), - script_signer=_load_script_signer(log), + script_signer=script_signer, logger=log.getChild("enrollment"), ) @@ -205,6 +208,7 @@ def build_service_container( credential_service=credential_service, site_service=site_service, assembly_service=assembly_service, + script_signer=script_signer, ) diff --git a/Data/Engine/services/devices/device_inventory_service.py b/Data/Engine/services/devices/device_inventory_service.py index 031e789..ec146f0 100644 --- a/Data/Engine/services/devices/device_inventory_service.py +++ b/Data/Engine/services/devices/device_inventory_service.py @@ -4,13 +4,17 @@ from __future__ import annotations import logging import sqlite3 -from typing import Dict, List, Optional +import time +from collections.abc import Mapping +from typing import Any, Dict, List, Optional from Data.Engine.repositories.sqlite.device_inventory_repository import ( SQLiteDeviceInventoryRepository, ) +from Data.Engine.domain.device_auth import DeviceAuthContext +from Data.Engine.domain.devices import clean_device_str, coerce_int -__all__ = ["DeviceInventoryService", "RemoteDeviceError"] +__all__ = ["DeviceInventoryService", "RemoteDeviceError", "DeviceHeartbeatError"] class RemoteDeviceError(Exception): @@ -19,6 +23,12 @@ class RemoteDeviceError(Exception): self.code = code +class DeviceHeartbeatError(Exception): + def __init__(self, code: str, message: Optional[str] = None) -> None: + super().__init__(message or code) + self.code = code + + class DeviceInventoryService: def __init__( self, @@ -176,3 +186,117 @@ class DeviceInventoryService: raise RemoteDeviceError("not_found", "device not found") self._repo.delete_device_by_hostname(normalized_host) + # ------------------------------------------------------------------ + # Agent heartbeats + # ------------------------------------------------------------------ + def record_heartbeat( + self, + *, + context: DeviceAuthContext, + payload: Mapping[str, Any], + ) -> None: + guid = context.identity.guid.value + snapshot = self._repo.load_snapshot(guid=guid) + if not snapshot: + raise DeviceHeartbeatError("device_not_registered", "device not registered") + + summary = dict(snapshot.get("summary") or {}) + details = dict(snapshot.get("details") or {}) + + now_ts = int(time.time()) + summary["last_seen"] = now_ts + summary["agent_guid"] = guid + + existing_hostname = clean_device_str(summary.get("hostname")) or clean_device_str( + snapshot.get("hostname") + ) + incoming_hostname = clean_device_str(payload.get("hostname")) + raw_metrics = payload.get("metrics") + metrics = raw_metrics if isinstance(raw_metrics, Mapping) else {} + metrics_hostname = clean_device_str(metrics.get("hostname")) if metrics else None + hostname = incoming_hostname or metrics_hostname or existing_hostname + if not hostname: + hostname = f"RECOVERED-{guid[:12]}" + summary["hostname"] = hostname + + if metrics: + last_user = metrics.get("last_user") or metrics.get("username") + if last_user: + cleaned_user = clean_device_str(last_user) + if cleaned_user: + summary["last_user"] = cleaned_user + operating_system = metrics.get("operating_system") + if operating_system: + cleaned_os = clean_device_str(operating_system) + if cleaned_os: + summary["operating_system"] = cleaned_os + uptime = metrics.get("uptime") + if uptime is not None: + coerced = coerce_int(uptime) + if coerced is not None: + summary["uptime"] = coerced + agent_id = metrics.get("agent_id") + if agent_id: + cleaned_agent = clean_device_str(agent_id) + if cleaned_agent: + summary["agent_id"] = cleaned_agent + + for field in ("external_ip", "internal_ip", "device_type"): + value = payload.get(field) + cleaned = clean_device_str(value) + if cleaned: + summary[field] = cleaned + + summary.setdefault("description", summary.get("description") or "") + created_at = coerce_int(summary.get("created_at")) + if created_at is None: + created_at = coerce_int(snapshot.get("created_at")) + if created_at is None: + created_at = now_ts + summary["created_at"] = created_at + + raw_inventory = payload.get("inventory") + inventory = raw_inventory if isinstance(raw_inventory, Mapping) else {} + memory = inventory.get("memory") if isinstance(inventory.get("memory"), list) else details.get("memory") + network = inventory.get("network") if isinstance(inventory.get("network"), list) else details.get("network") + software = ( + inventory.get("software") if isinstance(inventory.get("software"), list) else details.get("software") + ) + storage = inventory.get("storage") if isinstance(inventory.get("storage"), list) else details.get("storage") + cpu = inventory.get("cpu") if isinstance(inventory.get("cpu"), Mapping) else details.get("cpu") + + merged_details: Dict[str, Any] = { + "summary": summary, + "memory": memory, + "network": network, + "software": software, + "storage": storage, + "cpu": cpu, + } + + try: + self._repo.upsert_device( + summary["hostname"], + summary.get("description"), + merged_details, + summary.get("created_at"), + agent_hash=clean_device_str(summary.get("agent_hash")), + guid=guid, + ) + except sqlite3.IntegrityError as exc: + self._log.warning( + "device-heartbeat-conflict guid=%s hostname=%s error=%s", + guid, + summary["hostname"], + exc, + ) + raise DeviceHeartbeatError("storage_conflict", str(exc)) from exc + except Exception as exc: # pragma: no cover - defensive + self._log.exception( + "device-heartbeat-failure guid=%s hostname=%s", + guid, + summary["hostname"], + exc_info=exc, + ) + raise DeviceHeartbeatError("storage_error", "failed to persist heartbeat") from exc + diff --git a/Data/Engine/tests/test_http_agent.py b/Data/Engine/tests/test_http_agent.py new file mode 100644 index 0000000..8ca499e --- /dev/null +++ b/Data/Engine/tests/test_http_agent.py @@ -0,0 +1,234 @@ +import pytest + +pytest.importorskip("jwt") + +import json +import sqlite3 +import time +from datetime import datetime, timezone +from pathlib import Path + +from Data.Engine.config.environment import ( + DatabaseSettings, + EngineSettings, + FlaskSettings, + GitHubSettings, + ServerSettings, + SocketIOSettings, +) +from Data.Engine.domain.device_auth import ( + AccessTokenClaims, + DeviceAuthContext, + DeviceFingerprint, + DeviceGuid, + DeviceIdentity, + DeviceStatus, +) +from Data.Engine.interfaces.http import register_http_interfaces +from Data.Engine.repositories.sqlite import connection as sqlite_connection +from Data.Engine.repositories.sqlite import migrations as sqlite_migrations +from Data.Engine.server import create_app +from Data.Engine.services.container import build_service_container + + +@pytest.fixture() +def engine_settings(tmp_path: Path) -> EngineSettings: + project_root = tmp_path + static_root = project_root / "static" + static_root.mkdir() + (static_root / "index.html").write_text("", encoding="utf-8") + + database_path = project_root / "database.db" + + return EngineSettings( + project_root=project_root, + debug=False, + database=DatabaseSettings(path=database_path, apply_migrations=False), + flask=FlaskSettings( + secret_key="test-key", + static_root=static_root, + cors_allowed_origins=("https://localhost",), + ), + socketio=SocketIOSettings(cors_allowed_origins=("https://localhost",)), + server=ServerSettings(host="127.0.0.1", port=5000), + github=GitHubSettings( + default_repo="owner/repo", + default_branch="main", + refresh_interval_seconds=60, + cache_root=project_root / "cache", + ), + ) + + +@pytest.fixture() +def prepared_app(engine_settings: EngineSettings): + settings = engine_settings + settings.github.cache_root.mkdir(exist_ok=True, parents=True) + + db_factory = sqlite_connection.connection_factory(settings.database.path) + with sqlite_connection.connection_scope(settings.database.path) as conn: + sqlite_migrations.apply_all(conn) + + app = create_app(settings, db_factory=db_factory) + services = build_service_container(settings, db_factory=db_factory) + app.extensions["engine_services"] = services + register_http_interfaces(app, services) + app.config.update(TESTING=True) + return app + + +def _insert_device(app, guid: str, fingerprint: str, hostname: str) -> None: + db_path = Path(app.config["ENGINE_DATABASE_PATH"]) + now = int(time.time()) + with sqlite3.connect(db_path) as conn: + conn.execute( + """ + INSERT INTO devices ( + guid, + hostname, + created_at, + last_seen, + ssl_key_fingerprint, + token_version, + status, + key_added_at + ) VALUES (?, ?, ?, ?, ?, ?, 'active', ?) + """, + ( + guid, + hostname, + now, + now, + fingerprint.lower(), + 1, + datetime.now(timezone.utc).isoformat(), + ), + ) + conn.commit() + + +def _build_context(guid: str, fingerprint: str, *, status: DeviceStatus = DeviceStatus.ACTIVE) -> DeviceAuthContext: + now = int(time.time()) + claims = AccessTokenClaims( + subject="device", + guid=DeviceGuid(guid), + fingerprint=DeviceFingerprint(fingerprint), + token_version=1, + issued_at=now, + not_before=now, + expires_at=now + 600, + raw={"sub": "device"}, + ) + identity = DeviceIdentity(DeviceGuid(guid), DeviceFingerprint(fingerprint)) + return DeviceAuthContext( + identity=identity, + access_token="token", + claims=claims, + status=status, + service_context="SYSTEM", + ) + + +def test_heartbeat_updates_device(prepared_app, monkeypatch): + client = prepared_app.test_client() + guid = "DE305D54-75B4-431B-ADB2-EB6B9E546014" + fingerprint = "aa:bb:cc" + hostname = "device-heartbeat" + _insert_device(prepared_app, guid, fingerprint, hostname) + + services = prepared_app.extensions["engine_services"] + context = _build_context(guid, fingerprint) + monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context) + + payload = { + "hostname": hostname, + "inventory": {"memory": [{"total": "16GB"}], "cpu": {"cores": 8}}, + "metrics": {"operating_system": "Windows", "last_user": "Admin", "uptime": 120}, + "external_ip": "1.2.3.4", + } + + start = int(time.time()) + resp = client.post( + "/api/agent/heartbeat", + json=payload, + headers={"Authorization": "Bearer token"}, + ) + assert resp.status_code == 200 + body = resp.get_json() + assert body == {"status": "ok", "poll_after_ms": 15000} + + db_path = Path(prepared_app.config["ENGINE_DATABASE_PATH"]) + with sqlite3.connect(db_path) as conn: + row = conn.execute( + "SELECT last_seen, external_ip, memory, cpu FROM devices WHERE guid = ?", + (guid,), + ).fetchone() + + assert row is not None + last_seen, external_ip, memory_json, cpu_json = row + assert last_seen >= start + assert external_ip == "1.2.3.4" + assert json.loads(memory_json)[0]["total"] == "16GB" + assert json.loads(cpu_json)["cores"] == 8 + + +def test_heartbeat_returns_404_when_device_missing(prepared_app, monkeypatch): + client = prepared_app.test_client() + guid = "9E295C27-8339-40C8-AD1A-6ED95C164A4A" + fingerprint = "11:22:33" + services = prepared_app.extensions["engine_services"] + context = _build_context(guid, fingerprint) + monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context) + + resp = client.post( + "/api/agent/heartbeat", + json={"hostname": "missing-device"}, + headers={"Authorization": "Bearer token"}, + ) + assert resp.status_code == 404 + assert resp.get_json() == {"error": "device_not_registered"} + + +def test_script_request_reports_status_and_signing_key(prepared_app, monkeypatch): + client = prepared_app.test_client() + guid = "2F8D76C0-38D4-4700-B247-3E90C03A67D7" + fingerprint = "44:55:66" + hostname = "device-script" + _insert_device(prepared_app, guid, fingerprint, hostname) + + services = prepared_app.extensions["engine_services"] + context = _build_context(guid, fingerprint) + monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context) + + class DummySigner: + def public_base64_spki(self) -> str: + return "PUBKEY" + + object.__setattr__(services, "script_signer", DummySigner()) + + resp = client.post( + "/api/agent/script/request", + json={"guid": guid}, + headers={"Authorization": "Bearer token"}, + ) + assert resp.status_code == 200 + body = resp.get_json() + assert body == { + "status": "idle", + "poll_after_ms": 30000, + "sig_alg": "ed25519", + "signing_key": "PUBKEY", + } + + quarantined_context = _build_context(guid, fingerprint, status=DeviceStatus.QUARANTINED) + monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: quarantined_context) + + resp = client.post( + "/api/agent/script/request", + json={}, + headers={"Authorization": "Bearer token"}, + ) + assert resp.status_code == 200 + assert resp.get_json()["status"] == "quarantined" + assert resp.get_json()["poll_after_ms"] == 60000 +