From 78a5d3d7f9dc83e690f68860d4113a1a8b1475e0 Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Fri, 17 Oct 2025 17:00:11 -0600 Subject: [PATCH] feat: add agent enrollment endpoints and nonce protections --- Data/Server/Modules/enrollment/nonce_store.py | 35 ++ Data/Server/Modules/enrollment/routes.py | 478 ++++++++++++++++++ Data/Server/server.py | 17 + 3 files changed, 530 insertions(+) create mode 100644 Data/Server/Modules/enrollment/nonce_store.py create mode 100644 Data/Server/Modules/enrollment/routes.py diff --git a/Data/Server/Modules/enrollment/nonce_store.py b/Data/Server/Modules/enrollment/nonce_store.py new file mode 100644 index 0000000..bcdb962 --- /dev/null +++ b/Data/Server/Modules/enrollment/nonce_store.py @@ -0,0 +1,35 @@ +""" +Short-lived nonce cache to defend against replay attacks during enrollment. +""" + +from __future__ import annotations + +import time +from threading import Lock +from typing import Dict + + +class NonceCache: + def __init__(self, ttl_seconds: float = 300.0) -> None: + self._ttl = ttl_seconds + self._entries: Dict[str, float] = {} + self._lock = Lock() + + def consume(self, key: str) -> bool: + """ + Attempt to consume the nonce identified by `key`. + + Returns True on first use within TTL, False if already consumed. + """ + + now = time.monotonic() + with self._lock: + expire_at = self._entries.get(key) + if expire_at and expire_at > now: + return False + self._entries[key] = now + self._ttl + # Opportunistic cleanup to keep the dict small + stale = [nonce for nonce, expiry in self._entries.items() if expiry <= now] + for nonce in stale: + self._entries.pop(nonce, None) + return True diff --git a/Data/Server/Modules/enrollment/routes.py b/Data/Server/Modules/enrollment/routes.py new file mode 100644 index 0000000..359c5ea --- /dev/null +++ b/Data/Server/Modules/enrollment/routes.py @@ -0,0 +1,478 @@ +from __future__ import annotations + +import base64 +import secrets +import sqlite3 +import uuid +from datetime import datetime, timezone, timedelta +import time +from typing import Any, Callable, Dict, Optional + +from flask import Blueprint, jsonify, request + +from Modules.auth.rate_limit import SlidingWindowRateLimiter +from Modules.crypto import keys as crypto_keys +from Modules.enrollment.nonce_store import NonceCache +from cryptography.hazmat.primitives import serialization + + +def register( + app, + *, + db_conn_factory: Callable[[], sqlite3.Connection], + log: Callable[[str, str], None], + jwt_service, + tls_bundle_path: str, + ip_rate_limiter: SlidingWindowRateLimiter, + fp_rate_limiter: SlidingWindowRateLimiter, + nonce_cache: NonceCache, +) -> None: + blueprint = Blueprint("enrollment", __name__) + + def _now() -> datetime: + return datetime.now(tz=timezone.utc) + + def _iso(dt: datetime) -> str: + return dt.isoformat() + + def _remote_addr() -> str: + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + addr = request.remote_addr or "unknown" + return addr.strip() + + def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float): + decision = limiter.check(key, limit, window_s) + if not decision.allowed: + response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after}) + response.status_code = 429 + response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}" + return response + return None + + def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]: + cur.execute( + "SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?", + (code_value,), + ) + row = cur.fetchone() + if not row: + return None + keys = ["id", "code", "expires_at", "used_at"] + record = dict(zip(keys, row)) + return record + + def _install_code_valid(record: Dict[str, Any]) -> bool: + if not record: + return False + expires_at = record.get("expires_at") + if not isinstance(expires_at, str): + return False + try: + expiry = datetime.fromisoformat(expires_at) + except Exception: + return False + if expiry <= _now(): + return False + if record.get("used_at"): + return False + return True + + def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str: + base = (hostname or "").strip() or guid + base = base[:253] + candidate = base + suffix = 1 + while True: + cur.execute( + "SELECT guid FROM devices WHERE hostname = ? AND guid != ?", + (candidate, guid), + ) + row = cur.fetchone() + if not row: + return candidate + candidate = f"{base}-{suffix}" + suffix += 1 + if suffix > 50: + return f"{guid}" + + def _store_device_key(cur: sqlite3.Cursor, guid: str, fingerprint: str) -> None: + added_at = _iso(_now()) + cur.execute( + """ + INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at) + VALUES (?, ?, ?, ?) + """, + (str(uuid.uuid4()), guid, fingerprint, added_at), + ) + cur.execute( + """ + UPDATE device_keys + SET retired_at = ? + WHERE guid = ? + AND ssl_key_fingerprint != ? + AND retired_at IS NULL + """, + (_iso(_now()), guid, fingerprint), + ) + + def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]: + cur.execute( + "SELECT guid, hostname, token_version, status, ssl_key_fingerprint FROM devices WHERE guid = ?", + (guid,), + ) + row = cur.fetchone() + if row: + keys = ["guid", "hostname", "token_version", "status", "ssl_key_fingerprint"] + record = dict(zip(keys, row)) + if not record.get("ssl_key_fingerprint"): + cur.execute( + "UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?", + (fingerprint, _iso(_now()), guid), + ) + record["ssl_key_fingerprint"] = fingerprint + return record + + resolved_hostname = _normalize_host(hostname, guid, cur) + created_at = int(time.time()) + key_added_at = _iso(_now()) + cur.execute( + """ + INSERT INTO devices ( + guid, hostname, created_at, last_seen, ssl_key_fingerprint, + token_version, status, key_added_at + ) + VALUES (?, ?, ?, ?, ?, 1, 'active', ?) + """, + ( + guid, + resolved_hostname, + created_at, + created_at, + fingerprint, + key_added_at, + ), + ) + return { + "guid": guid, + "hostname": resolved_hostname, + "token_version": 1, + "status": "active", + "ssl_key_fingerprint": fingerprint, + } + + def _hash_refresh_token(token: str) -> str: + import hashlib + + return hashlib.sha256(token.encode("utf-8")).hexdigest() + + def _issue_refresh_token(cur: sqlite3.Cursor, guid: str) -> Dict[str, Any]: + token = secrets.token_urlsafe(48) + now = _now() + expires_at = now.replace(microsecond=0) + timedelta(days=30) + cur.execute( + """ + INSERT INTO refresh_tokens (id, guid, token_hash, created_at, expires_at) + VALUES (?, ?, ?, ?, ?) + """, + ( + str(uuid.uuid4()), + guid, + _hash_refresh_token(token), + _iso(now), + _iso(expires_at), + ), + ) + return {"token": token, "expires_at": expires_at} + + @blueprint.route("/api/agent/enroll/request", methods=["POST"]) + def enrollment_request(): + remote = _remote_addr() + rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 10, 60.0) + if rate_error: + return rate_error + + payload = request.get_json(force=True, silent=True) or {} + hostname = str(payload.get("hostname") or "").strip() + enrollment_code = str(payload.get("enrollment_code") or "").strip() + agent_pubkey_b64 = payload.get("agent_pubkey") + client_nonce_b64 = payload.get("client_nonce") + + if not hostname: + return jsonify({"error": "hostname_required"}), 400 + if not enrollment_code: + return jsonify({"error": "enrollment_code_required"}), 400 + if not isinstance(agent_pubkey_b64, str): + return jsonify({"error": "agent_pubkey_required"}), 400 + if not isinstance(client_nonce_b64, str): + return jsonify({"error": "client_nonce_required"}), 400 + + try: + agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64) + except Exception: + return jsonify({"error": "invalid_agent_pubkey"}), 400 + + if len(agent_pubkey_der) < 10: + return jsonify({"error": "invalid_agent_pubkey"}), 400 + + try: + client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True) + except Exception: + return jsonify({"error": "invalid_client_nonce"}), 400 + if len(client_nonce_bytes) < 16: + return jsonify({"error": "invalid_client_nonce"}), 400 + + fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der) + rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 3, 60.0) + if rate_error: + return rate_error + + conn = db_conn_factory() + try: + cur = conn.cursor() + install_code = _load_install_code(cur, enrollment_code) + if not _install_code_valid(install_code): + return jsonify({"error": "invalid_enrollment_code"}), 400 + + approval_reference: str + record_id: str + server_nonce_bytes = secrets.token_bytes(32) + server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii") + now = _iso(_now()) + + cur.execute( + """ + SELECT id, approval_reference + FROM device_approvals + WHERE ssl_key_fingerprint_claimed = ? + AND status = 'pending' + """, + (fingerprint,), + ) + existing = cur.fetchone() + if existing: + record_id = existing[0] + approval_reference = existing[1] + cur.execute( + """ + UPDATE device_approvals + SET hostname_claimed = ?, + enrollment_code_id = ?, + client_nonce = ?, + server_nonce = ?, + agent_pubkey_der = ?, + updated_at = ? + WHERE id = ? + """, + ( + hostname, + install_code["id"], + client_nonce_b64, + server_nonce_b64, + agent_pubkey_der, + now, + record_id, + ), + ) + else: + record_id = str(uuid.uuid4()) + approval_reference = str(uuid.uuid4()) + cur.execute( + """ + INSERT INTO device_approvals ( + id, approval_reference, guid, hostname_claimed, + ssl_key_fingerprint_claimed, enrollment_code_id, + status, client_nonce, server_nonce, agent_pubkey_der, + created_at, updated_at + ) + VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?) + """, + ( + record_id, + approval_reference, + hostname, + fingerprint, + install_code["id"], + client_nonce_b64, + server_nonce_b64, + agent_pubkey_der, + now, + now, + ), + ) + + conn.commit() + finally: + conn.close() + + response = { + "status": "pending", + "approval_reference": approval_reference, + "server_nonce": server_nonce_b64, + "poll_after_ms": 3000, + "server_certificate": _load_tls_bundle(tls_bundle_path), + } + log("server", f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}") + return jsonify(response) + + @blueprint.route("/api/agent/enroll/poll", methods=["POST"]) + def enrollment_poll(): + payload = request.get_json(force=True, silent=True) or {} + approval_reference = payload.get("approval_reference") + client_nonce_b64 = payload.get("client_nonce") + proof_sig_b64 = payload.get("proof_sig") + + if not isinstance(approval_reference, str) or not approval_reference: + return jsonify({"error": "approval_reference_required"}), 400 + if not isinstance(client_nonce_b64, str): + return jsonify({"error": "client_nonce_required"}), 400 + if not isinstance(proof_sig_b64, str): + return jsonify({"error": "proof_sig_required"}), 400 + + try: + client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True) + except Exception: + return jsonify({"error": "invalid_client_nonce"}), 400 + + try: + proof_sig = base64.b64decode(proof_sig_b64, validate=True) + except Exception: + return jsonify({"error": "invalid_proof_sig"}), 400 + + conn = db_conn_factory() + try: + cur = conn.cursor() + cur.execute( + """ + SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed, + enrollment_code_id, status, client_nonce, server_nonce, + agent_pubkey_der, created_at, updated_at, approved_by_user_id + FROM device_approvals + WHERE approval_reference = ? + """, + (approval_reference,), + ) + row = cur.fetchone() + if not row: + return jsonify({"status": "unknown"}), 404 + + ( + record_id, + guid, + hostname_claimed, + fingerprint, + enrollment_code_id, + status, + client_nonce_stored, + server_nonce_b64, + agent_pubkey_der, + created_at, + updated_at, + approved_by, + ) = row + + if client_nonce_stored != client_nonce_b64: + return jsonify({"error": "nonce_mismatch"}), 400 + + try: + server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True) + except Exception: + return jsonify({"error": "server_nonce_invalid"}), 400 + + message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes + + try: + public_key = serialization.load_der_public_key(agent_pubkey_der) + except Exception: + public_key = None + + if public_key is None: + return jsonify({"error": "agent_pubkey_invalid"}), 400 + + try: + public_key.verify(proof_sig, message) + except Exception: + return jsonify({"error": "invalid_proof"}), 400 + + if status == "pending": + return jsonify({"status": "pending", "poll_after_ms": 5000}) + if status == "denied": + return jsonify({"status": "denied", "reason": "operator_denied"}) + if status == "expired": + return jsonify({"status": "expired"}) + + if status != "approved": + return jsonify({"status": status or "unknown"}), 400 + + nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}" + if not nonce_cache.consume(nonce_key): + return jsonify({"error": "proof_replayed"}), 409 + + # Finalize enrollment + effective_guid = guid or str(uuid.uuid4()) + now_iso = _iso(_now()) + + device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint) + _store_device_key(cur, effective_guid, fingerprint) + + # Mark install code used + if enrollment_code_id: + cur.execute( + """ + UPDATE enrollment_install_codes + SET used_at = ?, used_by_guid = ? + WHERE id = ? + AND used_at IS NULL + """, + (now_iso, effective_guid, enrollment_code_id), + ) + + # Update approval record with final state + cur.execute( + """ + UPDATE device_approvals + SET guid = ?, + status = 'completed', + updated_at = ? + WHERE id = ? + """, + (effective_guid, now_iso, record_id), + ) + + refresh_info = _issue_refresh_token(cur, effective_guid) + access_token = jwt_service.issue_access_token( + effective_guid, + fingerprint, + device_record.get("token_version") or 1, + ) + + conn.commit() + finally: + conn.close() + + log( + "server", + f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}", + ) + return jsonify( + { + "status": "approved", + "guid": effective_guid, + "access_token": access_token, + "expires_in": 900, + "refresh_token": refresh_info["token"], + "token_type": "Bearer", + "server_certificate": _load_tls_bundle(tls_bundle_path), + } + ) + + app.register_blueprint(blueprint) + + +def _load_tls_bundle(path: str) -> str: + try: + with open(path, "r", encoding="utf-8") as fh: + return fh.read() + except Exception: + return "" diff --git a/Data/Server/server.py b/Data/Server/server.py index d8ef38b..690c5db 100644 --- a/Data/Server/server.py +++ b/Data/Server/server.py @@ -50,7 +50,10 @@ from datetime import datetime, timezone from Modules import db_migrations from Modules.auth import jwt_service as jwt_service_module +from Modules.auth.rate_limit import SlidingWindowRateLimiter from Modules.crypto import certificates +from Modules.enrollment import routes as enrollment_routes +from Modules.enrollment.nonce_store import NonceCache try: from cryptography.fernet import Fernet # type: ignore @@ -146,6 +149,9 @@ os.environ.setdefault("BOREALIS_TLS_KEY", TLS_KEY_PATH) os.environ.setdefault("BOREALIS_TLS_BUNDLE", TLS_BUNDLE_PATH) JWT_SERVICE = jwt_service_module.load_service() +IP_RATE_LIMITER = SlidingWindowRateLimiter() +FP_RATE_LIMITER = SlidingWindowRateLimiter() +ENROLLMENT_NONCE_CACHE = NonceCache() def _set_cached_github_token(token: Optional[str]) -> None: @@ -4819,6 +4825,17 @@ def init_db(): init_db() +enrollment_routes.register( + app, + db_conn_factory=_db_conn, + log=_write_service_log, + jwt_service=JWT_SERVICE, + tls_bundle_path=TLS_BUNDLE_PATH, + ip_rate_limiter=IP_RATE_LIMITER, + fp_rate_limiter=FP_RATE_LIMITER, + nonce_cache=ENROLLMENT_NONCE_CACHE, +) + def ensure_default_admin(): """Ensure at least one admin user exists.