From a72bff5e8eba00aa60b7b25323e243e7d6fd886d Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Fri, 17 Oct 2025 17:15:02 -0600 Subject: [PATCH] feat: secure agent auth and heartbeat endpoints --- Data/Server/Modules/agents/__init__.py | 1 + Data/Server/Modules/agents/routes.py | 114 ++++++++++++++++++ Data/Server/Modules/auth/device_auth.py | 148 ++++++++++++++++++++++++ Data/Server/Modules/tokens/__init__.py | 1 + Data/Server/Modules/tokens/routes.py | 125 ++++++++++++++++++++ Data/Server/server.py | 32 ++++- 6 files changed, 420 insertions(+), 1 deletion(-) create mode 100644 Data/Server/Modules/agents/__init__.py create mode 100644 Data/Server/Modules/agents/routes.py create mode 100644 Data/Server/Modules/auth/device_auth.py create mode 100644 Data/Server/Modules/tokens/__init__.py create mode 100644 Data/Server/Modules/tokens/routes.py diff --git a/Data/Server/Modules/agents/__init__.py b/Data/Server/Modules/agents/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/Data/Server/Modules/agents/__init__.py @@ -0,0 +1 @@ + diff --git a/Data/Server/Modules/agents/routes.py b/Data/Server/Modules/agents/routes.py new file mode 100644 index 0000000..e312d6c --- /dev/null +++ b/Data/Server/Modules/agents/routes.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import json +import time +from typing import Any, Callable, Dict, Optional + +from flask import Blueprint, jsonify, request, g + +from Modules.auth.device_auth import DeviceAuthManager, require_device_auth +from Modules.crypto.signing import ScriptSigner + + +def register( + app, + *, + db_conn_factory: Callable[[], Any], + auth_manager: DeviceAuthManager, + log: Callable[[str, str], None], + script_signer: ScriptSigner, +) -> None: + blueprint = Blueprint("agents", __name__) + + def _json_or_none(value) -> Optional[str]: + if value is None: + return None + try: + return json.dumps(value) + except Exception: + return None + + @blueprint.route("/api/agent/heartbeat", methods=["POST"]) + @require_device_auth(auth_manager) + def heartbeat(): + ctx = getattr(g, "device_auth") + payload = request.get_json(force=True, silent=True) or {} + + now_ts = int(time.time()) + updates: Dict[str, Optional[str]] = {"last_seen": now_ts} + + hostname = payload.get("hostname") + if isinstance(hostname, str) and hostname.strip(): + updates["hostname"] = hostname.strip() + + inventory = payload.get("inventory") if isinstance(payload.get("inventory"), dict) else {} + for key in ("memory", "network", "software", "storage", "cpu"): + if key in inventory and inventory[key] is not None: + encoded = _json_or_none(inventory[key]) + if encoded is not None: + updates[key] = encoded + + metrics = payload.get("metrics") if isinstance(payload.get("metrics"), dict) else {} + def _maybe_str(field: str) -> Optional[str]: + val = metrics.get(field) + if isinstance(val, str): + return val.strip() + return None + + if "last_user" in metrics and metrics["last_user"]: + updates["last_user"] = str(metrics["last_user"]) + if "operating_system" in metrics and metrics["operating_system"]: + updates["operating_system"] = str(metrics["operating_system"]) + if "uptime" in metrics and metrics["uptime"] is not None: + try: + updates["uptime"] = int(metrics["uptime"]) + except Exception: + pass + for field in ("external_ip", "internal_ip", "device_type"): + if field in payload and payload[field]: + updates[field] = str(payload[field]) + + conn = db_conn_factory() + try: + cur = conn.cursor() + columns = ", ".join(f"{col} = ?" for col in updates.keys()) + params = list(updates.values()) + params.append(ctx.guid) + cur.execute( + f"UPDATE devices SET {columns} WHERE guid = ?", + params, + ) + if cur.rowcount == 0: + log("server", f"heartbeat missing device record guid={ctx.guid}") + return jsonify({"error": "device_not_registered"}), 404 + conn.commit() + finally: + conn.close() + + return jsonify({"status": "ok", "poll_after_ms": 15000}) + + @blueprint.route("/api/agent/script/request", methods=["POST"]) + @require_device_auth(auth_manager) + def script_request(): + ctx = getattr(g, "device_auth") + if ctx.status != "active": + return jsonify( + { + "status": "quarantined", + "poll_after_ms": 60000, + "sig_alg": "ed25519", + "signing_key": script_signer.public_base64_spki(), + } + ) + + # Placeholder: actual dispatch logic will integrate with job scheduler. + return jsonify( + { + "status": "idle", + "poll_after_ms": 30000, + "sig_alg": "ed25519", + "signing_key": script_signer.public_base64_spki(), + } + ) + + app.register_blueprint(blueprint) diff --git a/Data/Server/Modules/auth/device_auth.py b/Data/Server/Modules/auth/device_auth.py new file mode 100644 index 0000000..ce45660 --- /dev/null +++ b/Data/Server/Modules/auth/device_auth.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import functools +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional + +import jwt +from flask import g, jsonify, request + +from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError + + +@dataclass +class DeviceAuthContext: + guid: str + ssl_key_fingerprint: str + token_version: int + access_token: str + claims: Dict[str, Any] + dpop_jkt: Optional[str] + status: str + + +class DeviceAuthError(Exception): + status_code = 401 + error_code = "unauthorized" + + def __init__(self, message: str = "unauthorized", *, status_code: Optional[int] = None): + super().__init__(message) + if status_code is not None: + self.status_code = status_code + self.message = message + + +class DeviceAuthManager: + def __init__( + self, + *, + db_conn_factory: Callable[[], Any], + jwt_service, + dpop_validator: Optional[DPoPValidator], + log: Callable[[str, str], None], + ) -> None: + self._db_conn_factory = db_conn_factory + self._jwt_service = jwt_service + self._dpop_validator = dpop_validator + self._log = log + + def authenticate(self) -> DeviceAuthContext: + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + raise DeviceAuthError("missing_authorization") + token = auth_header[len("Bearer ") :].strip() + if not token: + raise DeviceAuthError("missing_authorization") + + try: + claims = self._jwt_service.decode(token) + except jwt.ExpiredSignatureError: + raise DeviceAuthError("token_expired") + except Exception: + raise DeviceAuthError("invalid_token") + + guid = str(claims.get("guid") or "").strip() + fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip() + token_version = int(claims.get("token_version") or 0) + if not guid or not fingerprint or token_version <= 0: + raise DeviceAuthError("invalid_claims") + + conn = self._db_conn_factory() + try: + cur = conn.cursor() + cur.execute( + """ + SELECT guid, ssl_key_fingerprint, token_version, status + FROM devices + WHERE guid = ? + """, + (guid,), + ) + row = cur.fetchone() + finally: + conn.close() + + if not row: + raise DeviceAuthError("device_not_found", status_code=403) + + db_guid, db_fp, db_token_version, status = row + + if str(db_guid or "").lower() != guid.lower(): + raise DeviceAuthError("device_guid_mismatch", status_code=403) + + db_fp = (db_fp or "").lower().strip() + if db_fp and db_fp != fingerprint: + raise DeviceAuthError("fingerprint_mismatch", status_code=403) + + if db_token_version and db_token_version > token_version: + raise DeviceAuthError("token_version_revoked", status_code=401) + + status_normalized = (status or "active").strip().lower() + allowed_statuses = {"active", "quarantined"} + if status_normalized not in allowed_statuses: + raise DeviceAuthError("device_revoked", status_code=403) + if status_normalized == "quarantined": + self._log("server", f"device {guid} is quarantined; limited access for {request.path}") + + dpop_jkt: Optional[str] = None + dpop_proof = request.headers.get("DPoP") + if dpop_proof: + if not self._dpop_validator: + raise DeviceAuthError("dpop_not_supported", status_code=400) + try: + htu = request.url + dpop_jkt = self._dpop_validator.verify(request.method, htu, dpop_proof, token) + except DPoPReplayError: + raise DeviceAuthError("dpop_replayed", status_code=400) + except DPoPVerificationError: + raise DeviceAuthError("dpop_invalid", status_code=400) + + ctx = DeviceAuthContext( + guid=guid, + ssl_key_fingerprint=fingerprint, + token_version=token_version, + access_token=token, + claims=claims, + dpop_jkt=dpop_jkt, + status=status_normalized, + ) + return ctx + + +def require_device_auth(manager: DeviceAuthManager): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + ctx = manager.authenticate() + except DeviceAuthError as exc: + response = jsonify({"error": exc.message}) + response.status_code = exc.status_code + return response + + g.device_auth = ctx + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/Data/Server/Modules/tokens/__init__.py b/Data/Server/Modules/tokens/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/Data/Server/Modules/tokens/__init__.py @@ -0,0 +1 @@ + diff --git a/Data/Server/Modules/tokens/routes.py b/Data/Server/Modules/tokens/routes.py new file mode 100644 index 0000000..1e69d9d --- /dev/null +++ b/Data/Server/Modules/tokens/routes.py @@ -0,0 +1,125 @@ + +from __future__ import annotations + +import hashlib +import sqlite3 +from datetime import datetime, timezone +from typing import Callable + +from flask import Blueprint, jsonify, request + +from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError + + +def register( + app, + *, + db_conn_factory: Callable[[], sqlite3.Connection], + jwt_service, + dpop_validator: DPoPValidator, +) -> None: + blueprint = Blueprint("tokens", __name__) + + def _hash_token(token: str) -> str: + return hashlib.sha256(token.encode("utf-8")).hexdigest() + + def _iso_now() -> str: + return datetime.now(tz=timezone.utc).isoformat() + + def _parse_iso(ts: str) -> datetime: + return datetime.fromisoformat(ts) + + @blueprint.route("/api/agent/token/refresh", methods=["POST"]) + def refresh(): + payload = request.get_json(force=True, silent=True) or {} + guid = str(payload.get("guid") or "").strip() + refresh_token = str(payload.get("refresh_token") or "").strip() + + if not guid or not refresh_token: + return jsonify({"error": "invalid_request"}), 400 + + conn = db_conn_factory() + try: + cur = conn.cursor() + cur.execute( + """ + SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at + FROM refresh_tokens + WHERE guid = ? + AND token_hash = ? + """, + (guid, _hash_token(refresh_token)), + ) + row = cur.fetchone() + if not row: + return jsonify({"error": "invalid_refresh_token"}), 401 + + record_id, row_guid, _token_hash, stored_jkt, created_at, expires_at, revoked_at = row + if row_guid != guid: + return jsonify({"error": "invalid_refresh_token"}), 401 + if revoked_at: + return jsonify({"error": "refresh_token_revoked"}), 401 + if expires_at: + try: + if _parse_iso(expires_at) <= datetime.now(tz=timezone.utc): + return jsonify({"error": "refresh_token_expired"}), 401 + except Exception: + pass + + cur.execute( + """ + SELECT guid, ssl_key_fingerprint, token_version, status + FROM devices + WHERE guid = ? + """, + (guid,), + ) + device_row = cur.fetchone() + if not device_row: + return jsonify({"error": "device_not_found"}), 404 + + device_guid, fingerprint, token_version, status = device_row + status_norm = (status or "active").strip().lower() + if status_norm in {"revoked", "decommissioned"}: + return jsonify({"error": "device_revoked"}), 403 + + dpop_proof = request.headers.get("DPoP") + jkt = stored_jkt or "" + if dpop_proof: + try: + jkt = dpop_validator.verify(request.method, request.url, dpop_proof, access_token=None) + except DPoPReplayError: + return jsonify({"error": "dpop_replayed"}), 400 + except DPoPVerificationError: + return jsonify({"error": "dpop_invalid"}), 400 + elif stored_jkt: + return jsonify({"error": "dpop_required"}), 400 + + new_access_token = jwt_service.issue_access_token( + guid, + fingerprint or "", + token_version or 1, + ) + + cur.execute( + """ + UPDATE refresh_tokens + SET last_used_at = ?, + dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt) + WHERE id = ? + """, + (_iso_now(), jkt, record_id), + ) + conn.commit() + finally: + conn.close() + + return jsonify( + { + "access_token": new_access_token, + "expires_in": 900, + "token_type": "Bearer", + } + ) + + app.register_blueprint(blueprint) diff --git a/Data/Server/server.py b/Data/Server/server.py index 690c5db..98c2d88 100644 --- a/Data/Server/server.py +++ b/Data/Server/server.py @@ -50,10 +50,14 @@ from datetime import datetime, timezone from Modules import db_migrations from Modules.auth import jwt_service as jwt_service_module +from Modules.auth.dpop import DPoPValidator +from Modules.auth.device_auth import DeviceAuthManager from Modules.auth.rate_limit import SlidingWindowRateLimiter -from Modules.crypto import certificates +from Modules.agents import routes as agent_routes +from Modules.crypto import certificates, signing from Modules.enrollment import routes as enrollment_routes from Modules.enrollment.nonce_store import NonceCache +from Modules.tokens import routes as token_routes try: from cryptography.fernet import Fernet # type: ignore @@ -149,9 +153,12 @@ os.environ.setdefault("BOREALIS_TLS_KEY", TLS_KEY_PATH) os.environ.setdefault("BOREALIS_TLS_BUNDLE", TLS_BUNDLE_PATH) JWT_SERVICE = jwt_service_module.load_service() +SCRIPT_SIGNER = signing.load_signer() IP_RATE_LIMITER = SlidingWindowRateLimiter() FP_RATE_LIMITER = SlidingWindowRateLimiter() ENROLLMENT_NONCE_CACHE = NonceCache() +DPOP_VALIDATOR = DPoPValidator() +DEVICE_AUTH_MANAGER: Optional[DeviceAuthManager] = None def _set_cached_github_token(token: Optional[str]) -> None: @@ -1248,6 +1255,14 @@ def _db_conn(): return conn +if DEVICE_AUTH_MANAGER is None: + DEVICE_AUTH_MANAGER = DeviceAuthManager( + db_conn_factory=_db_conn, + jwt_service=JWT_SERVICE, + dpop_validator=DPOP_VALIDATOR, + log=_write_service_log, + ) + def _update_last_login(username: str) -> None: if not username: return @@ -4836,6 +4851,21 @@ enrollment_routes.register( nonce_cache=ENROLLMENT_NONCE_CACHE, ) +token_routes.register( + app, + db_conn_factory=_db_conn, + jwt_service=JWT_SERVICE, + dpop_validator=DPOP_VALIDATOR, +) + +agent_routes.register( + app, + db_conn_factory=_db_conn, + auth_manager=DEVICE_AUTH_MANAGER, + log=_write_service_log, + script_signer=SCRIPT_SIGNER, +) + def ensure_default_admin(): """Ensure at least one admin user exists.