# ====================================================== # Data\Engine\Unit_Tests\test_enrollment_api.py # Description: Covers device enrollment request and poll flows including cryptographic proof handling. # # API Endpoints (if applicable): None # ====================================================== from __future__ import annotations import base64 import os import sqlite3 import uuid from datetime import datetime, timedelta, timezone from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ed25519 from flask.testing import FlaskClient from Data.Engine.crypto import keys as crypto_keys from Data.Engine.database import initialise_engine_database from .conftest import EngineTestHarness def _now() -> datetime: return datetime.now(tz=timezone.utc) def _iso(dt: datetime) -> str: return dt.astimezone(timezone.utc).isoformat() def _seed_install_code(db_path: os.PathLike[str], code: str, site_id: int = 1) -> str: record_id = str(uuid.uuid4()) baseline = _now() issued_at = _iso(baseline) expires_at = _iso(baseline + timedelta(days=1)) with sqlite3.connect(str(db_path)) as conn: columns = {row[1] for row in conn.execute("PRAGMA table_info(sites)")} if "enrollment_code_id" not in columns: conn.execute("ALTER TABLE sites ADD COLUMN enrollment_code_id TEXT") conn.execute( """ INSERT OR IGNORE INTO sites (id, name, description, created_at, enrollment_code_id) VALUES (?, ?, ?, ?, ?) """, (site_id, f"Test Site {site_id}", "Seeded site", int(baseline.timestamp()), record_id), ) conn.execute( """ INSERT INTO enrollment_install_codes ( id, code, expires_at, used_at, used_by_guid, max_uses, use_count, last_used_at, site_id ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, (record_id, code, expires_at, None, None, 1, 0, None, site_id), ) conn.execute( """ INSERT INTO enrollment_install_codes_persistent ( id, code, created_at, expires_at, created_by_user_id, used_at, used_by_guid, max_uses, last_known_use_count, last_used_at, is_active, archived_at, consumed_at, site_id ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( record_id, code, issued_at, expires_at, "test-suite", None, None, 1, 0, None, 1, None, None, site_id, ), ) conn.execute( """ UPDATE sites SET enrollment_code_id = ? WHERE id = ? """, (record_id, site_id), ) conn.commit() return record_id def _generate_agent_material() -> tuple[ed25519.Ed25519PrivateKey, bytes, str]: private_key = ed25519.Ed25519PrivateKey.generate() public_der = private_key.public_key().public_bytes( encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) public_b64 = base64.b64encode(public_der).decode("ascii") return private_key, public_der, public_b64 def test_enrollment_request_creates_pending_approval(engine_harness: EngineTestHarness) -> None: harness = engine_harness client: FlaskClient = harness.app.test_client() install_code = "INSTALL-CODE-001" install_code_id = _seed_install_code(harness.db_path, install_code) private_key, public_der, public_b64 = _generate_agent_material() client_nonce_bytes = os.urandom(32) client_nonce_b64 = base64.b64encode(client_nonce_bytes).decode("ascii") response = client.post( "/api/agent/enroll/request", json={ "hostname": "agent-node-01", "enrollment_code": install_code, "agent_pubkey": public_b64, "client_nonce": client_nonce_b64, }, headers={"X-Borealis-Agent-Context": "interactive"}, ) assert response.status_code == 200 payload = response.get_json() assert payload["status"] == "pending" assert payload["server_certificate"] == harness.bundle_contents approval_reference = payload["approval_reference"] with sqlite3.connect(str(harness.db_path)) as conn: cur = conn.cursor() cur.execute( """ SELECT hostname_claimed, ssl_key_fingerprint_claimed, client_nonce, status, enrollment_code_id, site_id FROM device_approvals WHERE approval_reference = ? """, (approval_reference,), ) row = cur.fetchone() assert row is not None hostname_claimed, fingerprint, stored_client_nonce, status, stored_code_id, stored_site_id = row assert hostname_claimed == "agent-node-01" assert stored_client_nonce == client_nonce_b64 assert status == "pending" assert stored_code_id == install_code_id assert stored_site_id == 1 expected_fingerprint = crypto_keys.fingerprint_from_spki_der(public_der) assert fingerprint == expected_fingerprint def test_enrollment_poll_finalizes_when_approved(engine_harness: EngineTestHarness) -> None: harness = engine_harness client: FlaskClient = harness.app.test_client() install_code = "INSTALL-CODE-002" install_code_id = _seed_install_code(harness.db_path, install_code) private_key, public_der, public_b64 = _generate_agent_material() client_nonce_bytes = os.urandom(32) client_nonce_b64 = base64.b64encode(client_nonce_bytes).decode("ascii") request_response = client.post( "/api/agent/enroll/request", json={ "hostname": "agent-node-02", "enrollment_code": install_code, "agent_pubkey": public_b64, "client_nonce": client_nonce_b64, }, headers={"X-Borealis-Agent-Context": "system"}, ) assert request_response.status_code == 200 request_payload = request_response.get_json() approval_reference = request_payload["approval_reference"] server_nonce_b64 = request_payload["server_nonce"] approved_at = _iso(_now()) with sqlite3.connect(str(harness.db_path)) as conn: conn.execute( """ UPDATE device_approvals SET status = 'approved', updated_at = ?, approved_by_user_id = 'operator' WHERE approval_reference = ? """, (approved_at, approval_reference), ) conn.commit() message = base64.b64decode(server_nonce_b64, validate=True) + approval_reference.encode("utf-8") + client_nonce_bytes proof_sig = private_key.sign(message) proof_sig_b64 = base64.b64encode(proof_sig).decode("ascii") poll_response = client.post( "/api/agent/enroll/poll", json={ "approval_reference": approval_reference, "client_nonce": client_nonce_b64, "proof_sig": proof_sig_b64, }, ) assert poll_response.status_code == 200 poll_payload = poll_response.get_json() assert poll_payload["status"] == "approved" assert poll_payload["token_type"] == "Bearer" assert poll_payload["server_certificate"] == harness.bundle_contents final_guid = poll_payload["guid"] assert isinstance(final_guid, str) and len(final_guid) == 36 with sqlite3.connect(str(harness.db_path)) as conn: cur = conn.cursor() cur.execute( "SELECT guid, status, site_id FROM device_approvals WHERE approval_reference = ?", (approval_reference,), ) approval_row = cur.fetchone() cur.execute( "SELECT hostname, ssl_key_fingerprint, token_version FROM devices WHERE guid = ?", (final_guid,), ) device_row = cur.fetchone() cur.execute( "SELECT site_id FROM device_sites WHERE device_hostname = ?", (device_row[0] if device_row else None,), ) site_row = cur.fetchone() cur.execute( "SELECT COUNT(*) FROM refresh_tokens WHERE guid = ?", (final_guid,), ) refresh_count = cur.fetchone()[0] cur.execute( "SELECT use_count, used_by_guid FROM enrollment_install_codes WHERE id = ?", (install_code_id,), ) install_row = cur.fetchone() cur.execute( "SELECT COUNT(*) FROM device_keys WHERE guid = ?", (final_guid,), ) key_count = cur.fetchone()[0] cur.execute( """ SELECT is_active, last_known_use_count, used_by_guid, consumed_at FROM enrollment_install_codes_persistent WHERE id = ? """, (install_code_id,), ) persistent_row = cur.fetchone() assert approval_row is not None approval_guid, approval_status, approval_site_id = approval_row assert approval_status == "completed" assert approval_guid == final_guid assert approval_site_id == 1 assert device_row is not None hostname, fingerprint, token_version = device_row assert hostname == "agent-node-02" assert fingerprint == crypto_keys.fingerprint_from_spki_der(public_der) assert token_version >= 1 assert site_row is not None assert site_row[0] == 1 assert refresh_count == 1 assert install_row is not None use_count, used_by_guid = install_row assert use_count == 1 assert used_by_guid == final_guid assert key_count == 1 assert persistent_row is not None is_active, last_known_use_count, persistent_guid, consumed_at = persistent_row assert is_active == 0 assert last_known_use_count == 1 assert persistent_guid == final_guid assert consumed_at is not None def test_persistent_enrollment_codes_restore_active_table(engine_harness: EngineTestHarness) -> None: harness = engine_harness code_id = str(uuid.uuid4()) baseline = _now() issued_iso = _iso(baseline) expires_iso = _iso(baseline + timedelta(hours=4)) with sqlite3.connect(str(harness.db_path)) as conn: conn.execute("DELETE FROM enrollment_install_codes") conn.execute( """ INSERT OR REPLACE INTO enrollment_install_codes_persistent ( id, code, created_at, expires_at, created_by_user_id, used_at, used_by_guid, max_uses, last_known_use_count, last_used_at, is_active, archived_at, consumed_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( code_id, "RESTORE-CODE-001", issued_iso, expires_iso, "restorer", None, None, 3, 0, None, 1, None, None, ), ) conn.commit() initialise_engine_database(str(harness.db_path)) with sqlite3.connect(str(harness.db_path)) as conn: cur = conn.execute( """ SELECT code, expires_at, max_uses, use_count FROM enrollment_install_codes WHERE id = ? """, (code_id,), ) row = cur.fetchone() assert row is not None restored_code, restored_expires, restored_max_uses, restored_use_count = row assert restored_code == "RESTORE-CODE-001" assert restored_expires == expires_iso assert restored_max_uses == 3 assert restored_use_count == 0