Simplified & Reworked Enrollment Code System to be Site-Specific

This commit is contained in:
2025-11-16 17:40:24 -07:00
parent 65bee703e9
commit b2120d7385
13 changed files with 649 additions and 492 deletions

View File

@@ -94,7 +94,8 @@ CREATE TABLE IF NOT EXISTS enrollment_install_codes (
used_by_guid TEXT,
max_uses INTEGER,
use_count INTEGER,
last_used_at TEXT
last_used_at TEXT,
site_id INTEGER
);
CREATE TABLE IF NOT EXISTS enrollment_install_codes_persistent (
id TEXT PRIMARY KEY,
@@ -109,7 +110,8 @@ CREATE TABLE IF NOT EXISTS enrollment_install_codes_persistent (
last_used_at TEXT,
is_active INTEGER NOT NULL DEFAULT 1,
archived_at TEXT,
consumed_at TEXT
consumed_at TEXT,
site_id INTEGER
);
CREATE TABLE IF NOT EXISTS device_approvals (
id TEXT PRIMARY KEY,
@@ -118,6 +120,7 @@ CREATE TABLE IF NOT EXISTS device_approvals (
hostname_claimed TEXT,
ssl_key_fingerprint_claimed TEXT,
enrollment_code_id TEXT,
site_id INTEGER,
status TEXT,
client_nonce TEXT,
server_nonce TEXT,
@@ -145,7 +148,8 @@ CREATE TABLE IF NOT EXISTS sites (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
description TEXT,
created_at INTEGER
created_at INTEGER,
enrollment_code_id TEXT
);
CREATE TABLE IF NOT EXISTS device_sites (
device_hostname TEXT PRIMARY KEY,
@@ -270,9 +274,51 @@ def engine_harness(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Iterator[
"2025-10-01T00:00:00Z",
),
)
site_code_id = "SITE-CODE-0001"
site_code_value = "SITE-MAIN-CODE"
site_code_created = "2025-01-01T00:00:00Z"
site_code_expires = "2030-01-01T00:00:00Z"
cur.execute(
"INSERT INTO sites (id, name, description, created_at) VALUES (?, ?, ?, ?)",
(1, "Main Lab", "Primary integration site", 1_700_000_000),
"""
INSERT INTO enrollment_install_codes (
id,
code,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at,
site_id
) VALUES (?, ?, ?, ?, NULL, NULL, 0, 0, NULL, ?)
""",
(site_code_id, site_code_value, site_code_expires, "admin", 1),
)
cur.execute(
"""
INSERT INTO enrollment_install_codes_persistent (
id,
code,
created_at,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
last_known_use_count,
last_used_at,
is_active,
archived_at,
consumed_at,
site_id
) VALUES (?, ?, ?, ?, ?, NULL, NULL, 0, 0, NULL, 1, NULL, NULL, ?)
""",
(site_code_id, site_code_value, site_code_created, site_code_expires, "admin", 1),
)
cur.execute(
"INSERT INTO sites (id, name, description, created_at, enrollment_code_id) VALUES (?, ?, ?, ?, ?)",
(1, "Main Lab", "Primary integration site", 1_700_000_000, site_code_id),
)
cur.execute(
"INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)",
@@ -294,6 +340,7 @@ def engine_harness(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Iterator[
hostname_claimed,
ssl_key_fingerprint_claimed,
enrollment_code_id,
site_id,
status,
client_nonce,
server_nonce,
@@ -302,7 +349,7 @@ def engine_harness(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Iterator[
updated_at,
approved_by_user_id
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
"approval-1",
@@ -310,7 +357,8 @@ def engine_harness(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Iterator[
None,
"pending-device",
"aa:bb:cc:dd",
None,
site_code_id,
1,
"pending",
"client-nonce",
"server-nonce",

View File

@@ -57,7 +57,7 @@ def _patch_repo_call(monkeypatch: pytest.MonkeyPatch, calls: dict) -> None:
def test_list_devices(engine_harness: EngineTestHarness) -> None:
client = engine_harness.app.test_client()
client = _client_with_admin_session(engine_harness)
response = client.get("/api/devices")
assert response.status_code == 200
payload = response.get_json()
@@ -70,7 +70,7 @@ def test_list_devices(engine_harness: EngineTestHarness) -> None:
def test_list_agents(engine_harness: EngineTestHarness) -> None:
client = engine_harness.app.test_client()
client = _client_with_admin_session(engine_harness)
response = client.get("/api/agents")
assert response.status_code == 200
payload = response.get_json()
@@ -82,7 +82,7 @@ def test_list_agents(engine_harness: EngineTestHarness) -> None:
def test_device_details(engine_harness: EngineTestHarness) -> None:
client = engine_harness.app.test_client()
client = _client_with_admin_session(engine_harness)
response = client.get("/api/device/details/test-device")
assert response.status_code == 200
payload = response.get_json()
@@ -165,7 +165,7 @@ def test_repo_current_hash_allows_device_token(engine_harness: EngineTestHarness
def test_agent_hash_list_permissions(engine_harness: EngineTestHarness) -> None:
client = engine_harness.app.test_client()
client = _client_with_admin_session(engine_harness)
forbidden = client.get("/api/agent/hash_list", environ_base={"REMOTE_ADDR": "192.0.2.10"})
assert forbidden.status_code == 403
allowed = client.get("/api/agent/hash_list", environ_base={"REMOTE_ADDR": "127.0.0.1"})
@@ -208,21 +208,20 @@ def test_sites_lifecycle(engine_harness: EngineTestHarness) -> None:
assert delete_resp.status_code == 200
def test_admin_enrollment_code_flow(engine_harness: EngineTestHarness) -> None:
def test_site_enrollment_code_rotation(engine_harness: EngineTestHarness) -> None:
client = _client_with_admin_session(engine_harness)
create_resp = client.post(
"/api/admin/enrollment-codes",
json={"ttl_hours": 1, "max_uses": 2},
)
assert create_resp.status_code == 201
code_id = create_resp.get_json()["id"]
sites_resp = client.get("/api/sites")
assert sites_resp.status_code == 200
sites = sites_resp.get_json()["sites"]
assert sites and sites[0]["enrollment_code"]
site_id = sites[0]["id"]
original_code = sites[0]["enrollment_code"]
list_resp = client.get("/api/admin/enrollment-codes")
codes = list_resp.get_json()["codes"]
assert any(code["id"] == code_id for code in codes)
delete_resp = client.delete(f"/api/admin/enrollment-codes/{code_id}")
assert delete_resp.status_code == 200
rotate_resp = client.post("/api/sites/rotate_code", json={"site_id": site_id})
assert rotate_resp.status_code == 200
rotated = rotate_resp.get_json()
assert rotated["id"] == site_id
assert rotated["enrollment_code"] and rotated["enrollment_code"] != original_code
def test_admin_device_approvals(engine_harness: EngineTestHarness) -> None:

View File

@@ -31,19 +31,29 @@ def _iso(dt: datetime) -> str:
return dt.astimezone(timezone.utc).isoformat()
def _seed_install_code(db_path: os.PathLike[str], code: str) -> str:
def _seed_install_code(db_path: os.PathLike[str], code: str, site_id: int = 1) -> str:
record_id = str(uuid.uuid4())
baseline = _now()
issued_at = _iso(baseline)
expires_at = _iso(baseline + timedelta(days=1))
with sqlite3.connect(str(db_path)) as conn:
columns = {row[1] for row in conn.execute("PRAGMA table_info(sites)")}
if "enrollment_code_id" not in columns:
conn.execute("ALTER TABLE sites ADD COLUMN enrollment_code_id TEXT")
conn.execute(
"""
INSERT OR IGNORE INTO sites (id, name, description, created_at, enrollment_code_id)
VALUES (?, ?, ?, ?, ?)
""",
(site_id, f"Test Site {site_id}", "Seeded site", int(baseline.timestamp()), record_id),
)
conn.execute(
"""
INSERT INTO enrollment_install_codes (
id, code, expires_at, used_at, used_by_guid, max_uses, use_count, last_used_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
id, code, expires_at, used_at, used_by_guid, max_uses, use_count, last_used_at, site_id
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(record_id, code, expires_at, None, None, 1, 0, None),
(record_id, code, expires_at, None, None, 1, 0, None, site_id),
)
conn.execute(
"""
@@ -60,8 +70,9 @@ def _seed_install_code(db_path: os.PathLike[str], code: str) -> str:
last_used_at,
is_active,
archived_at,
consumed_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
consumed_at,
site_id
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
record_id,
@@ -77,8 +88,17 @@ def _seed_install_code(db_path: os.PathLike[str], code: str) -> str:
1,
None,
None,
site_id,
),
)
conn.execute(
"""
UPDATE sites
SET enrollment_code_id = ?
WHERE id = ?
""",
(record_id, site_id),
)
conn.commit()
return record_id
@@ -124,7 +144,7 @@ def test_enrollment_request_creates_pending_approval(engine_harness: EngineTestH
cur = conn.cursor()
cur.execute(
"""
SELECT hostname_claimed, ssl_key_fingerprint_claimed, client_nonce, status, enrollment_code_id
SELECT hostname_claimed, ssl_key_fingerprint_claimed, client_nonce, status, enrollment_code_id, site_id
FROM device_approvals
WHERE approval_reference = ?
""",
@@ -133,11 +153,12 @@ def test_enrollment_request_creates_pending_approval(engine_harness: EngineTestH
row = cur.fetchone()
assert row is not None
hostname_claimed, fingerprint, stored_client_nonce, status, stored_code_id = row
hostname_claimed, fingerprint, stored_client_nonce, status, stored_code_id, stored_site_id = row
assert hostname_claimed == "agent-node-01"
assert stored_client_nonce == client_nonce_b64
assert status == "pending"
assert stored_code_id == install_code_id
assert stored_site_id == 1
expected_fingerprint = crypto_keys.fingerprint_from_spki_der(public_der)
assert fingerprint == expected_fingerprint
@@ -206,7 +227,7 @@ def test_enrollment_poll_finalizes_when_approved(engine_harness: EngineTestHarne
with sqlite3.connect(str(harness.db_path)) as conn:
cur = conn.cursor()
cur.execute(
"SELECT guid, status FROM device_approvals WHERE approval_reference = ?",
"SELECT guid, status, site_id FROM device_approvals WHERE approval_reference = ?",
(approval_reference,),
)
approval_row = cur.fetchone()
@@ -215,6 +236,11 @@ def test_enrollment_poll_finalizes_when_approved(engine_harness: EngineTestHarne
(final_guid,),
)
device_row = cur.fetchone()
cur.execute(
"SELECT site_id FROM device_sites WHERE device_hostname = ?",
(device_row[0] if device_row else None,),
)
site_row = cur.fetchone()
cur.execute(
"SELECT COUNT(*) FROM refresh_tokens WHERE guid = ?",
(final_guid,),
@@ -241,15 +267,18 @@ def test_enrollment_poll_finalizes_when_approved(engine_harness: EngineTestHarne
persistent_row = cur.fetchone()
assert approval_row is not None
approval_guid, approval_status = approval_row
approval_guid, approval_status, approval_site_id = approval_row
assert approval_status == "completed"
assert approval_guid == final_guid
assert approval_site_id == 1
assert device_row is not None
hostname, fingerprint, token_version = device_row
assert hostname == "agent-node-02"
assert fingerprint == crypto_keys.fingerprint_from_spki_der(public_der)
assert token_version >= 1
assert site_row is not None
assert site_row[0] == 1
assert refresh_count == 1
assert install_row is not None