Add agent REST endpoints and heartbeat handling

This commit is contained in:
2025-10-23 01:01:15 -06:00
parent 82210408ca
commit fddf0230e2
5 changed files with 480 additions and 3 deletions

View File

@@ -8,6 +8,7 @@ from Data.Engine.services.container import EngineServiceContainer
from . import ( from . import (
admin, admin,
agent,
agents, agents,
auth, auth,
enrollment, enrollment,
@@ -25,6 +26,7 @@ from . import (
_REGISTRARS = ( _REGISTRARS = (
health.register, health.register,
agent.register,
agents.register, agents.register,
enrollment.register, enrollment.register,
tokens.register, tokens.register,

View File

@@ -0,0 +1,113 @@
"""Agent REST endpoints for device communication."""
from __future__ import annotations
import math
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, cast
from flask import Blueprint, Flask, current_app, g, jsonify, request
from Data.Engine.builders.device_auth import DeviceAuthRequestBuilder
from Data.Engine.domain.device_auth import DeviceAuthContext, DeviceAuthFailure
from Data.Engine.services.container import EngineServiceContainer
from Data.Engine.services.devices.device_inventory_service import DeviceHeartbeatError
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
blueprint = Blueprint("engine_agent", __name__)
F = TypeVar("F", bound=Callable[..., Any])
def _services() -> EngineServiceContainer:
return cast(EngineServiceContainer, current_app.extensions["engine_services"])
def require_device_auth(func: F) -> F:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any):
services = _services()
builder = (
DeviceAuthRequestBuilder()
.with_authorization(request.headers.get("Authorization"))
.with_http_method(request.method)
.with_htu(request.url)
.with_service_context(request.headers.get(AGENT_CONTEXT_HEADER))
.with_dpop_proof(request.headers.get("DPoP"))
)
try:
auth_request = builder.build()
context = services.device_auth.authenticate(auth_request, path=request.path)
except DeviceAuthFailure as exc:
payload = exc.to_dict()
response = jsonify(payload)
if exc.retry_after is not None:
response.headers["Retry-After"] = str(int(math.ceil(exc.retry_after)))
return response, exc.http_status
g.device_auth = context
try:
return func(*args, **kwargs)
finally:
g.pop("device_auth", None)
return cast(F, wrapper)
def register(app: Flask, _services: EngineServiceContainer) -> None:
if "engine_agent" not in app.blueprints:
app.register_blueprint(blueprint)
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
@require_device_auth
def heartbeat() -> Any:
services = _services()
payload = request.get_json(force=True, silent=True) or {}
context = cast(DeviceAuthContext, g.device_auth)
try:
services.device_inventory.record_heartbeat(context=context, payload=payload)
except DeviceHeartbeatError as exc:
error_payload = {"error": exc.code}
if exc.code == "device_not_registered":
return jsonify(error_payload), 404
if exc.code == "storage_conflict":
return jsonify(error_payload), 409
current_app.logger.exception(
"device-heartbeat-error guid=%s code=%s", context.identity.guid.value, exc.code
)
return jsonify(error_payload), 500
return jsonify({"status": "ok", "poll_after_ms": 15000})
@blueprint.route("/api/agent/script/request", methods=["POST"])
@require_device_auth
def script_request() -> Any:
services = _services()
context = cast(DeviceAuthContext, g.device_auth)
signing_key: Optional[str] = None
signer = services.script_signer
if signer is not None:
try:
signing_key = signer.public_base64_spki()
except Exception as exc: # pragma: no cover - defensive logging
current_app.logger.warning("script-signer-unavailable: %s", exc)
status = "quarantined" if context.is_quarantined else "idle"
poll_after = 60000 if context.is_quarantined else 30000
response = {
"status": status,
"poll_after_ms": poll_after,
"sig_alg": "ed25519",
}
if signing_key:
response["signing_key"] = signing_key
return jsonify(response)
__all__ = ["register", "blueprint", "heartbeat", "script_request", "require_device_auth"]

View File

@@ -67,6 +67,7 @@ class EngineServiceContainer:
operator_auth_service: OperatorAuthService operator_auth_service: OperatorAuthService
operator_account_service: OperatorAccountService operator_account_service: OperatorAccountService
assembly_service: AssemblyService assembly_service: AssemblyService
script_signer: Optional[ScriptSigner]
def build_service_container( def build_service_container(
@@ -106,6 +107,8 @@ def build_service_container(
logger=log.getChild("token_service"), logger=log.getChild("token_service"),
) )
script_signer = _load_script_signer(log)
enrollment_service = EnrollmentService( enrollment_service = EnrollmentService(
device_repository=device_repo, device_repository=device_repo,
enrollment_repository=enrollment_repo, enrollment_repository=enrollment_repo,
@@ -115,7 +118,7 @@ def build_service_container(
ip_rate_limiter=SlidingWindowRateLimiter(), ip_rate_limiter=SlidingWindowRateLimiter(),
fingerprint_rate_limiter=SlidingWindowRateLimiter(), fingerprint_rate_limiter=SlidingWindowRateLimiter(),
nonce_cache=NonceCache(), nonce_cache=NonceCache(),
script_signer=_load_script_signer(log), script_signer=script_signer,
logger=log.getChild("enrollment"), logger=log.getChild("enrollment"),
) )
@@ -205,6 +208,7 @@ def build_service_container(
credential_service=credential_service, credential_service=credential_service,
site_service=site_service, site_service=site_service,
assembly_service=assembly_service, assembly_service=assembly_service,
script_signer=script_signer,
) )

View File

@@ -4,13 +4,17 @@ from __future__ import annotations
import logging import logging
import sqlite3 import sqlite3
from typing import Dict, List, Optional import time
from collections.abc import Mapping
from typing import Any, Dict, List, Optional
from Data.Engine.repositories.sqlite.device_inventory_repository import ( from Data.Engine.repositories.sqlite.device_inventory_repository import (
SQLiteDeviceInventoryRepository, SQLiteDeviceInventoryRepository,
) )
from Data.Engine.domain.device_auth import DeviceAuthContext
from Data.Engine.domain.devices import clean_device_str, coerce_int
__all__ = ["DeviceInventoryService", "RemoteDeviceError"] __all__ = ["DeviceInventoryService", "RemoteDeviceError", "DeviceHeartbeatError"]
class RemoteDeviceError(Exception): class RemoteDeviceError(Exception):
@@ -19,6 +23,12 @@ class RemoteDeviceError(Exception):
self.code = code self.code = code
class DeviceHeartbeatError(Exception):
def __init__(self, code: str, message: Optional[str] = None) -> None:
super().__init__(message or code)
self.code = code
class DeviceInventoryService: class DeviceInventoryService:
def __init__( def __init__(
self, self,
@@ -176,3 +186,117 @@ class DeviceInventoryService:
raise RemoteDeviceError("not_found", "device not found") raise RemoteDeviceError("not_found", "device not found")
self._repo.delete_device_by_hostname(normalized_host) self._repo.delete_device_by_hostname(normalized_host)
# ------------------------------------------------------------------
# Agent heartbeats
# ------------------------------------------------------------------
def record_heartbeat(
self,
*,
context: DeviceAuthContext,
payload: Mapping[str, Any],
) -> None:
guid = context.identity.guid.value
snapshot = self._repo.load_snapshot(guid=guid)
if not snapshot:
raise DeviceHeartbeatError("device_not_registered", "device not registered")
summary = dict(snapshot.get("summary") or {})
details = dict(snapshot.get("details") or {})
now_ts = int(time.time())
summary["last_seen"] = now_ts
summary["agent_guid"] = guid
existing_hostname = clean_device_str(summary.get("hostname")) or clean_device_str(
snapshot.get("hostname")
)
incoming_hostname = clean_device_str(payload.get("hostname"))
raw_metrics = payload.get("metrics")
metrics = raw_metrics if isinstance(raw_metrics, Mapping) else {}
metrics_hostname = clean_device_str(metrics.get("hostname")) if metrics else None
hostname = incoming_hostname or metrics_hostname or existing_hostname
if not hostname:
hostname = f"RECOVERED-{guid[:12]}"
summary["hostname"] = hostname
if metrics:
last_user = metrics.get("last_user") or metrics.get("username")
if last_user:
cleaned_user = clean_device_str(last_user)
if cleaned_user:
summary["last_user"] = cleaned_user
operating_system = metrics.get("operating_system")
if operating_system:
cleaned_os = clean_device_str(operating_system)
if cleaned_os:
summary["operating_system"] = cleaned_os
uptime = metrics.get("uptime")
if uptime is not None:
coerced = coerce_int(uptime)
if coerced is not None:
summary["uptime"] = coerced
agent_id = metrics.get("agent_id")
if agent_id:
cleaned_agent = clean_device_str(agent_id)
if cleaned_agent:
summary["agent_id"] = cleaned_agent
for field in ("external_ip", "internal_ip", "device_type"):
value = payload.get(field)
cleaned = clean_device_str(value)
if cleaned:
summary[field] = cleaned
summary.setdefault("description", summary.get("description") or "")
created_at = coerce_int(summary.get("created_at"))
if created_at is None:
created_at = coerce_int(snapshot.get("created_at"))
if created_at is None:
created_at = now_ts
summary["created_at"] = created_at
raw_inventory = payload.get("inventory")
inventory = raw_inventory if isinstance(raw_inventory, Mapping) else {}
memory = inventory.get("memory") if isinstance(inventory.get("memory"), list) else details.get("memory")
network = inventory.get("network") if isinstance(inventory.get("network"), list) else details.get("network")
software = (
inventory.get("software") if isinstance(inventory.get("software"), list) else details.get("software")
)
storage = inventory.get("storage") if isinstance(inventory.get("storage"), list) else details.get("storage")
cpu = inventory.get("cpu") if isinstance(inventory.get("cpu"), Mapping) else details.get("cpu")
merged_details: Dict[str, Any] = {
"summary": summary,
"memory": memory,
"network": network,
"software": software,
"storage": storage,
"cpu": cpu,
}
try:
self._repo.upsert_device(
summary["hostname"],
summary.get("description"),
merged_details,
summary.get("created_at"),
agent_hash=clean_device_str(summary.get("agent_hash")),
guid=guid,
)
except sqlite3.IntegrityError as exc:
self._log.warning(
"device-heartbeat-conflict guid=%s hostname=%s error=%s",
guid,
summary["hostname"],
exc,
)
raise DeviceHeartbeatError("storage_conflict", str(exc)) from exc
except Exception as exc: # pragma: no cover - defensive
self._log.exception(
"device-heartbeat-failure guid=%s hostname=%s",
guid,
summary["hostname"],
exc_info=exc,
)
raise DeviceHeartbeatError("storage_error", "failed to persist heartbeat") from exc

View File

@@ -0,0 +1,234 @@
import pytest
pytest.importorskip("jwt")
import json
import sqlite3
import time
from datetime import datetime, timezone
from pathlib import Path
from Data.Engine.config.environment import (
DatabaseSettings,
EngineSettings,
FlaskSettings,
GitHubSettings,
ServerSettings,
SocketIOSettings,
)
from Data.Engine.domain.device_auth import (
AccessTokenClaims,
DeviceAuthContext,
DeviceFingerprint,
DeviceGuid,
DeviceIdentity,
DeviceStatus,
)
from Data.Engine.interfaces.http import register_http_interfaces
from Data.Engine.repositories.sqlite import connection as sqlite_connection
from Data.Engine.repositories.sqlite import migrations as sqlite_migrations
from Data.Engine.server import create_app
from Data.Engine.services.container import build_service_container
@pytest.fixture()
def engine_settings(tmp_path: Path) -> EngineSettings:
project_root = tmp_path
static_root = project_root / "static"
static_root.mkdir()
(static_root / "index.html").write_text("<html></html>", encoding="utf-8")
database_path = project_root / "database.db"
return EngineSettings(
project_root=project_root,
debug=False,
database=DatabaseSettings(path=database_path, apply_migrations=False),
flask=FlaskSettings(
secret_key="test-key",
static_root=static_root,
cors_allowed_origins=("https://localhost",),
),
socketio=SocketIOSettings(cors_allowed_origins=("https://localhost",)),
server=ServerSettings(host="127.0.0.1", port=5000),
github=GitHubSettings(
default_repo="owner/repo",
default_branch="main",
refresh_interval_seconds=60,
cache_root=project_root / "cache",
),
)
@pytest.fixture()
def prepared_app(engine_settings: EngineSettings):
settings = engine_settings
settings.github.cache_root.mkdir(exist_ok=True, parents=True)
db_factory = sqlite_connection.connection_factory(settings.database.path)
with sqlite_connection.connection_scope(settings.database.path) as conn:
sqlite_migrations.apply_all(conn)
app = create_app(settings, db_factory=db_factory)
services = build_service_container(settings, db_factory=db_factory)
app.extensions["engine_services"] = services
register_http_interfaces(app, services)
app.config.update(TESTING=True)
return app
def _insert_device(app, guid: str, fingerprint: str, hostname: str) -> None:
db_path = Path(app.config["ENGINE_DATABASE_PATH"])
now = int(time.time())
with sqlite3.connect(db_path) as conn:
conn.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
token_version,
status,
key_added_at
) VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
""",
(
guid,
hostname,
now,
now,
fingerprint.lower(),
1,
datetime.now(timezone.utc).isoformat(),
),
)
conn.commit()
def _build_context(guid: str, fingerprint: str, *, status: DeviceStatus = DeviceStatus.ACTIVE) -> DeviceAuthContext:
now = int(time.time())
claims = AccessTokenClaims(
subject="device",
guid=DeviceGuid(guid),
fingerprint=DeviceFingerprint(fingerprint),
token_version=1,
issued_at=now,
not_before=now,
expires_at=now + 600,
raw={"sub": "device"},
)
identity = DeviceIdentity(DeviceGuid(guid), DeviceFingerprint(fingerprint))
return DeviceAuthContext(
identity=identity,
access_token="token",
claims=claims,
status=status,
service_context="SYSTEM",
)
def test_heartbeat_updates_device(prepared_app, monkeypatch):
client = prepared_app.test_client()
guid = "DE305D54-75B4-431B-ADB2-EB6B9E546014"
fingerprint = "aa:bb:cc"
hostname = "device-heartbeat"
_insert_device(prepared_app, guid, fingerprint, hostname)
services = prepared_app.extensions["engine_services"]
context = _build_context(guid, fingerprint)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context)
payload = {
"hostname": hostname,
"inventory": {"memory": [{"total": "16GB"}], "cpu": {"cores": 8}},
"metrics": {"operating_system": "Windows", "last_user": "Admin", "uptime": 120},
"external_ip": "1.2.3.4",
}
start = int(time.time())
resp = client.post(
"/api/agent/heartbeat",
json=payload,
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 200
body = resp.get_json()
assert body == {"status": "ok", "poll_after_ms": 15000}
db_path = Path(prepared_app.config["ENGINE_DATABASE_PATH"])
with sqlite3.connect(db_path) as conn:
row = conn.execute(
"SELECT last_seen, external_ip, memory, cpu FROM devices WHERE guid = ?",
(guid,),
).fetchone()
assert row is not None
last_seen, external_ip, memory_json, cpu_json = row
assert last_seen >= start
assert external_ip == "1.2.3.4"
assert json.loads(memory_json)[0]["total"] == "16GB"
assert json.loads(cpu_json)["cores"] == 8
def test_heartbeat_returns_404_when_device_missing(prepared_app, monkeypatch):
client = prepared_app.test_client()
guid = "9E295C27-8339-40C8-AD1A-6ED95C164A4A"
fingerprint = "11:22:33"
services = prepared_app.extensions["engine_services"]
context = _build_context(guid, fingerprint)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context)
resp = client.post(
"/api/agent/heartbeat",
json={"hostname": "missing-device"},
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 404
assert resp.get_json() == {"error": "device_not_registered"}
def test_script_request_reports_status_and_signing_key(prepared_app, monkeypatch):
client = prepared_app.test_client()
guid = "2F8D76C0-38D4-4700-B247-3E90C03A67D7"
fingerprint = "44:55:66"
hostname = "device-script"
_insert_device(prepared_app, guid, fingerprint, hostname)
services = prepared_app.extensions["engine_services"]
context = _build_context(guid, fingerprint)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context)
class DummySigner:
def public_base64_spki(self) -> str:
return "PUBKEY"
object.__setattr__(services, "script_signer", DummySigner())
resp = client.post(
"/api/agent/script/request",
json={"guid": guid},
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 200
body = resp.get_json()
assert body == {
"status": "idle",
"poll_after_ms": 30000,
"sig_alg": "ed25519",
"signing_key": "PUBKEY",
}
quarantined_context = _build_context(guid, fingerprint, status=DeviceStatus.QUARANTINED)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: quarantined_context)
resp = client.post(
"/api/agent/script/request",
json={},
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 200
assert resp.get_json()["status"] == "quarantined"
assert resp.get_json()["poll_after_ms"] == 60000