Add agent REST endpoints and heartbeat handling

This commit is contained in:
2025-10-23 01:01:15 -06:00
parent 82210408ca
commit fddf0230e2
5 changed files with 480 additions and 3 deletions

View File

@@ -8,6 +8,7 @@ from Data.Engine.services.container import EngineServiceContainer
from . import (
admin,
agent,
agents,
auth,
enrollment,
@@ -25,6 +26,7 @@ from . import (
_REGISTRARS = (
health.register,
agent.register,
agents.register,
enrollment.register,
tokens.register,

View File

@@ -0,0 +1,113 @@
"""Agent REST endpoints for device communication."""
from __future__ import annotations
import math
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, cast
from flask import Blueprint, Flask, current_app, g, jsonify, request
from Data.Engine.builders.device_auth import DeviceAuthRequestBuilder
from Data.Engine.domain.device_auth import DeviceAuthContext, DeviceAuthFailure
from Data.Engine.services.container import EngineServiceContainer
from Data.Engine.services.devices.device_inventory_service import DeviceHeartbeatError
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
blueprint = Blueprint("engine_agent", __name__)
F = TypeVar("F", bound=Callable[..., Any])
def _services() -> EngineServiceContainer:
return cast(EngineServiceContainer, current_app.extensions["engine_services"])
def require_device_auth(func: F) -> F:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any):
services = _services()
builder = (
DeviceAuthRequestBuilder()
.with_authorization(request.headers.get("Authorization"))
.with_http_method(request.method)
.with_htu(request.url)
.with_service_context(request.headers.get(AGENT_CONTEXT_HEADER))
.with_dpop_proof(request.headers.get("DPoP"))
)
try:
auth_request = builder.build()
context = services.device_auth.authenticate(auth_request, path=request.path)
except DeviceAuthFailure as exc:
payload = exc.to_dict()
response = jsonify(payload)
if exc.retry_after is not None:
response.headers["Retry-After"] = str(int(math.ceil(exc.retry_after)))
return response, exc.http_status
g.device_auth = context
try:
return func(*args, **kwargs)
finally:
g.pop("device_auth", None)
return cast(F, wrapper)
def register(app: Flask, _services: EngineServiceContainer) -> None:
if "engine_agent" not in app.blueprints:
app.register_blueprint(blueprint)
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
@require_device_auth
def heartbeat() -> Any:
services = _services()
payload = request.get_json(force=True, silent=True) or {}
context = cast(DeviceAuthContext, g.device_auth)
try:
services.device_inventory.record_heartbeat(context=context, payload=payload)
except DeviceHeartbeatError as exc:
error_payload = {"error": exc.code}
if exc.code == "device_not_registered":
return jsonify(error_payload), 404
if exc.code == "storage_conflict":
return jsonify(error_payload), 409
current_app.logger.exception(
"device-heartbeat-error guid=%s code=%s", context.identity.guid.value, exc.code
)
return jsonify(error_payload), 500
return jsonify({"status": "ok", "poll_after_ms": 15000})
@blueprint.route("/api/agent/script/request", methods=["POST"])
@require_device_auth
def script_request() -> Any:
services = _services()
context = cast(DeviceAuthContext, g.device_auth)
signing_key: Optional[str] = None
signer = services.script_signer
if signer is not None:
try:
signing_key = signer.public_base64_spki()
except Exception as exc: # pragma: no cover - defensive logging
current_app.logger.warning("script-signer-unavailable: %s", exc)
status = "quarantined" if context.is_quarantined else "idle"
poll_after = 60000 if context.is_quarantined else 30000
response = {
"status": status,
"poll_after_ms": poll_after,
"sig_alg": "ed25519",
}
if signing_key:
response["signing_key"] = signing_key
return jsonify(response)
__all__ = ["register", "blueprint", "heartbeat", "script_request", "require_device_auth"]

View File

@@ -67,6 +67,7 @@ class EngineServiceContainer:
operator_auth_service: OperatorAuthService
operator_account_service: OperatorAccountService
assembly_service: AssemblyService
script_signer: Optional[ScriptSigner]
def build_service_container(
@@ -106,6 +107,8 @@ def build_service_container(
logger=log.getChild("token_service"),
)
script_signer = _load_script_signer(log)
enrollment_service = EnrollmentService(
device_repository=device_repo,
enrollment_repository=enrollment_repo,
@@ -115,7 +118,7 @@ def build_service_container(
ip_rate_limiter=SlidingWindowRateLimiter(),
fingerprint_rate_limiter=SlidingWindowRateLimiter(),
nonce_cache=NonceCache(),
script_signer=_load_script_signer(log),
script_signer=script_signer,
logger=log.getChild("enrollment"),
)
@@ -205,6 +208,7 @@ def build_service_container(
credential_service=credential_service,
site_service=site_service,
assembly_service=assembly_service,
script_signer=script_signer,
)

View File

@@ -4,13 +4,17 @@ from __future__ import annotations
import logging
import sqlite3
from typing import Dict, List, Optional
import time
from collections.abc import Mapping
from typing import Any, Dict, List, Optional
from Data.Engine.repositories.sqlite.device_inventory_repository import (
SQLiteDeviceInventoryRepository,
)
from Data.Engine.domain.device_auth import DeviceAuthContext
from Data.Engine.domain.devices import clean_device_str, coerce_int
__all__ = ["DeviceInventoryService", "RemoteDeviceError"]
__all__ = ["DeviceInventoryService", "RemoteDeviceError", "DeviceHeartbeatError"]
class RemoteDeviceError(Exception):
@@ -19,6 +23,12 @@ class RemoteDeviceError(Exception):
self.code = code
class DeviceHeartbeatError(Exception):
def __init__(self, code: str, message: Optional[str] = None) -> None:
super().__init__(message or code)
self.code = code
class DeviceInventoryService:
def __init__(
self,
@@ -176,3 +186,117 @@ class DeviceInventoryService:
raise RemoteDeviceError("not_found", "device not found")
self._repo.delete_device_by_hostname(normalized_host)
# ------------------------------------------------------------------
# Agent heartbeats
# ------------------------------------------------------------------
def record_heartbeat(
self,
*,
context: DeviceAuthContext,
payload: Mapping[str, Any],
) -> None:
guid = context.identity.guid.value
snapshot = self._repo.load_snapshot(guid=guid)
if not snapshot:
raise DeviceHeartbeatError("device_not_registered", "device not registered")
summary = dict(snapshot.get("summary") or {})
details = dict(snapshot.get("details") or {})
now_ts = int(time.time())
summary["last_seen"] = now_ts
summary["agent_guid"] = guid
existing_hostname = clean_device_str(summary.get("hostname")) or clean_device_str(
snapshot.get("hostname")
)
incoming_hostname = clean_device_str(payload.get("hostname"))
raw_metrics = payload.get("metrics")
metrics = raw_metrics if isinstance(raw_metrics, Mapping) else {}
metrics_hostname = clean_device_str(metrics.get("hostname")) if metrics else None
hostname = incoming_hostname or metrics_hostname or existing_hostname
if not hostname:
hostname = f"RECOVERED-{guid[:12]}"
summary["hostname"] = hostname
if metrics:
last_user = metrics.get("last_user") or metrics.get("username")
if last_user:
cleaned_user = clean_device_str(last_user)
if cleaned_user:
summary["last_user"] = cleaned_user
operating_system = metrics.get("operating_system")
if operating_system:
cleaned_os = clean_device_str(operating_system)
if cleaned_os:
summary["operating_system"] = cleaned_os
uptime = metrics.get("uptime")
if uptime is not None:
coerced = coerce_int(uptime)
if coerced is not None:
summary["uptime"] = coerced
agent_id = metrics.get("agent_id")
if agent_id:
cleaned_agent = clean_device_str(agent_id)
if cleaned_agent:
summary["agent_id"] = cleaned_agent
for field in ("external_ip", "internal_ip", "device_type"):
value = payload.get(field)
cleaned = clean_device_str(value)
if cleaned:
summary[field] = cleaned
summary.setdefault("description", summary.get("description") or "")
created_at = coerce_int(summary.get("created_at"))
if created_at is None:
created_at = coerce_int(snapshot.get("created_at"))
if created_at is None:
created_at = now_ts
summary["created_at"] = created_at
raw_inventory = payload.get("inventory")
inventory = raw_inventory if isinstance(raw_inventory, Mapping) else {}
memory = inventory.get("memory") if isinstance(inventory.get("memory"), list) else details.get("memory")
network = inventory.get("network") if isinstance(inventory.get("network"), list) else details.get("network")
software = (
inventory.get("software") if isinstance(inventory.get("software"), list) else details.get("software")
)
storage = inventory.get("storage") if isinstance(inventory.get("storage"), list) else details.get("storage")
cpu = inventory.get("cpu") if isinstance(inventory.get("cpu"), Mapping) else details.get("cpu")
merged_details: Dict[str, Any] = {
"summary": summary,
"memory": memory,
"network": network,
"software": software,
"storage": storage,
"cpu": cpu,
}
try:
self._repo.upsert_device(
summary["hostname"],
summary.get("description"),
merged_details,
summary.get("created_at"),
agent_hash=clean_device_str(summary.get("agent_hash")),
guid=guid,
)
except sqlite3.IntegrityError as exc:
self._log.warning(
"device-heartbeat-conflict guid=%s hostname=%s error=%s",
guid,
summary["hostname"],
exc,
)
raise DeviceHeartbeatError("storage_conflict", str(exc)) from exc
except Exception as exc: # pragma: no cover - defensive
self._log.exception(
"device-heartbeat-failure guid=%s hostname=%s",
guid,
summary["hostname"],
exc_info=exc,
)
raise DeviceHeartbeatError("storage_error", "failed to persist heartbeat") from exc

View File

@@ -0,0 +1,234 @@
import pytest
pytest.importorskip("jwt")
import json
import sqlite3
import time
from datetime import datetime, timezone
from pathlib import Path
from Data.Engine.config.environment import (
DatabaseSettings,
EngineSettings,
FlaskSettings,
GitHubSettings,
ServerSettings,
SocketIOSettings,
)
from Data.Engine.domain.device_auth import (
AccessTokenClaims,
DeviceAuthContext,
DeviceFingerprint,
DeviceGuid,
DeviceIdentity,
DeviceStatus,
)
from Data.Engine.interfaces.http import register_http_interfaces
from Data.Engine.repositories.sqlite import connection as sqlite_connection
from Data.Engine.repositories.sqlite import migrations as sqlite_migrations
from Data.Engine.server import create_app
from Data.Engine.services.container import build_service_container
@pytest.fixture()
def engine_settings(tmp_path: Path) -> EngineSettings:
project_root = tmp_path
static_root = project_root / "static"
static_root.mkdir()
(static_root / "index.html").write_text("<html></html>", encoding="utf-8")
database_path = project_root / "database.db"
return EngineSettings(
project_root=project_root,
debug=False,
database=DatabaseSettings(path=database_path, apply_migrations=False),
flask=FlaskSettings(
secret_key="test-key",
static_root=static_root,
cors_allowed_origins=("https://localhost",),
),
socketio=SocketIOSettings(cors_allowed_origins=("https://localhost",)),
server=ServerSettings(host="127.0.0.1", port=5000),
github=GitHubSettings(
default_repo="owner/repo",
default_branch="main",
refresh_interval_seconds=60,
cache_root=project_root / "cache",
),
)
@pytest.fixture()
def prepared_app(engine_settings: EngineSettings):
settings = engine_settings
settings.github.cache_root.mkdir(exist_ok=True, parents=True)
db_factory = sqlite_connection.connection_factory(settings.database.path)
with sqlite_connection.connection_scope(settings.database.path) as conn:
sqlite_migrations.apply_all(conn)
app = create_app(settings, db_factory=db_factory)
services = build_service_container(settings, db_factory=db_factory)
app.extensions["engine_services"] = services
register_http_interfaces(app, services)
app.config.update(TESTING=True)
return app
def _insert_device(app, guid: str, fingerprint: str, hostname: str) -> None:
db_path = Path(app.config["ENGINE_DATABASE_PATH"])
now = int(time.time())
with sqlite3.connect(db_path) as conn:
conn.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
token_version,
status,
key_added_at
) VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
""",
(
guid,
hostname,
now,
now,
fingerprint.lower(),
1,
datetime.now(timezone.utc).isoformat(),
),
)
conn.commit()
def _build_context(guid: str, fingerprint: str, *, status: DeviceStatus = DeviceStatus.ACTIVE) -> DeviceAuthContext:
now = int(time.time())
claims = AccessTokenClaims(
subject="device",
guid=DeviceGuid(guid),
fingerprint=DeviceFingerprint(fingerprint),
token_version=1,
issued_at=now,
not_before=now,
expires_at=now + 600,
raw={"sub": "device"},
)
identity = DeviceIdentity(DeviceGuid(guid), DeviceFingerprint(fingerprint))
return DeviceAuthContext(
identity=identity,
access_token="token",
claims=claims,
status=status,
service_context="SYSTEM",
)
def test_heartbeat_updates_device(prepared_app, monkeypatch):
client = prepared_app.test_client()
guid = "DE305D54-75B4-431B-ADB2-EB6B9E546014"
fingerprint = "aa:bb:cc"
hostname = "device-heartbeat"
_insert_device(prepared_app, guid, fingerprint, hostname)
services = prepared_app.extensions["engine_services"]
context = _build_context(guid, fingerprint)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context)
payload = {
"hostname": hostname,
"inventory": {"memory": [{"total": "16GB"}], "cpu": {"cores": 8}},
"metrics": {"operating_system": "Windows", "last_user": "Admin", "uptime": 120},
"external_ip": "1.2.3.4",
}
start = int(time.time())
resp = client.post(
"/api/agent/heartbeat",
json=payload,
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 200
body = resp.get_json()
assert body == {"status": "ok", "poll_after_ms": 15000}
db_path = Path(prepared_app.config["ENGINE_DATABASE_PATH"])
with sqlite3.connect(db_path) as conn:
row = conn.execute(
"SELECT last_seen, external_ip, memory, cpu FROM devices WHERE guid = ?",
(guid,),
).fetchone()
assert row is not None
last_seen, external_ip, memory_json, cpu_json = row
assert last_seen >= start
assert external_ip == "1.2.3.4"
assert json.loads(memory_json)[0]["total"] == "16GB"
assert json.loads(cpu_json)["cores"] == 8
def test_heartbeat_returns_404_when_device_missing(prepared_app, monkeypatch):
client = prepared_app.test_client()
guid = "9E295C27-8339-40C8-AD1A-6ED95C164A4A"
fingerprint = "11:22:33"
services = prepared_app.extensions["engine_services"]
context = _build_context(guid, fingerprint)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context)
resp = client.post(
"/api/agent/heartbeat",
json={"hostname": "missing-device"},
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 404
assert resp.get_json() == {"error": "device_not_registered"}
def test_script_request_reports_status_and_signing_key(prepared_app, monkeypatch):
client = prepared_app.test_client()
guid = "2F8D76C0-38D4-4700-B247-3E90C03A67D7"
fingerprint = "44:55:66"
hostname = "device-script"
_insert_device(prepared_app, guid, fingerprint, hostname)
services = prepared_app.extensions["engine_services"]
context = _build_context(guid, fingerprint)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context)
class DummySigner:
def public_base64_spki(self) -> str:
return "PUBKEY"
object.__setattr__(services, "script_signer", DummySigner())
resp = client.post(
"/api/agent/script/request",
json={"guid": guid},
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 200
body = resp.get_json()
assert body == {
"status": "idle",
"poll_after_ms": 30000,
"sig_alg": "ed25519",
"signing_key": "PUBKEY",
}
quarantined_context = _build_context(guid, fingerprint, status=DeviceStatus.QUARANTINED)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: quarantined_context)
resp = client.post(
"/api/agent/script/request",
json={},
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 200
assert resp.get_json()["status"] == "quarantined"
assert resp.get_json()["poll_after_ms"] == 60000