Simplified & Reworked Enrollment Code System to be Site-Specific

This commit is contained in:
2025-11-16 17:40:24 -07:00
parent 65bee703e9
commit b2120d7385
13 changed files with 649 additions and 492 deletions

View File

@@ -31,19 +31,29 @@ def _iso(dt: datetime) -> str:
return dt.astimezone(timezone.utc).isoformat()
def _seed_install_code(db_path: os.PathLike[str], code: str) -> str:
def _seed_install_code(db_path: os.PathLike[str], code: str, site_id: int = 1) -> str:
record_id = str(uuid.uuid4())
baseline = _now()
issued_at = _iso(baseline)
expires_at = _iso(baseline + timedelta(days=1))
with sqlite3.connect(str(db_path)) as conn:
columns = {row[1] for row in conn.execute("PRAGMA table_info(sites)")}
if "enrollment_code_id" not in columns:
conn.execute("ALTER TABLE sites ADD COLUMN enrollment_code_id TEXT")
conn.execute(
"""
INSERT OR IGNORE INTO sites (id, name, description, created_at, enrollment_code_id)
VALUES (?, ?, ?, ?, ?)
""",
(site_id, f"Test Site {site_id}", "Seeded site", int(baseline.timestamp()), record_id),
)
conn.execute(
"""
INSERT INTO enrollment_install_codes (
id, code, expires_at, used_at, used_by_guid, max_uses, use_count, last_used_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
id, code, expires_at, used_at, used_by_guid, max_uses, use_count, last_used_at, site_id
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(record_id, code, expires_at, None, None, 1, 0, None),
(record_id, code, expires_at, None, None, 1, 0, None, site_id),
)
conn.execute(
"""
@@ -60,8 +70,9 @@ def _seed_install_code(db_path: os.PathLike[str], code: str) -> str:
last_used_at,
is_active,
archived_at,
consumed_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
consumed_at,
site_id
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
record_id,
@@ -77,8 +88,17 @@ def _seed_install_code(db_path: os.PathLike[str], code: str) -> str:
1,
None,
None,
site_id,
),
)
conn.execute(
"""
UPDATE sites
SET enrollment_code_id = ?
WHERE id = ?
""",
(record_id, site_id),
)
conn.commit()
return record_id
@@ -124,7 +144,7 @@ def test_enrollment_request_creates_pending_approval(engine_harness: EngineTestH
cur = conn.cursor()
cur.execute(
"""
SELECT hostname_claimed, ssl_key_fingerprint_claimed, client_nonce, status, enrollment_code_id
SELECT hostname_claimed, ssl_key_fingerprint_claimed, client_nonce, status, enrollment_code_id, site_id
FROM device_approvals
WHERE approval_reference = ?
""",
@@ -133,11 +153,12 @@ def test_enrollment_request_creates_pending_approval(engine_harness: EngineTestH
row = cur.fetchone()
assert row is not None
hostname_claimed, fingerprint, stored_client_nonce, status, stored_code_id = row
hostname_claimed, fingerprint, stored_client_nonce, status, stored_code_id, stored_site_id = row
assert hostname_claimed == "agent-node-01"
assert stored_client_nonce == client_nonce_b64
assert status == "pending"
assert stored_code_id == install_code_id
assert stored_site_id == 1
expected_fingerprint = crypto_keys.fingerprint_from_spki_der(public_der)
assert fingerprint == expected_fingerprint
@@ -206,7 +227,7 @@ def test_enrollment_poll_finalizes_when_approved(engine_harness: EngineTestHarne
with sqlite3.connect(str(harness.db_path)) as conn:
cur = conn.cursor()
cur.execute(
"SELECT guid, status FROM device_approvals WHERE approval_reference = ?",
"SELECT guid, status, site_id FROM device_approvals WHERE approval_reference = ?",
(approval_reference,),
)
approval_row = cur.fetchone()
@@ -215,6 +236,11 @@ def test_enrollment_poll_finalizes_when_approved(engine_harness: EngineTestHarne
(final_guid,),
)
device_row = cur.fetchone()
cur.execute(
"SELECT site_id FROM device_sites WHERE device_hostname = ?",
(device_row[0] if device_row else None,),
)
site_row = cur.fetchone()
cur.execute(
"SELECT COUNT(*) FROM refresh_tokens WHERE guid = ?",
(final_guid,),
@@ -241,15 +267,18 @@ def test_enrollment_poll_finalizes_when_approved(engine_harness: EngineTestHarne
persistent_row = cur.fetchone()
assert approval_row is not None
approval_guid, approval_status = approval_row
approval_guid, approval_status, approval_site_id = approval_row
assert approval_status == "completed"
assert approval_guid == final_guid
assert approval_site_id == 1
assert device_row is not None
hostname, fingerprint, token_version = device_row
assert hostname == "agent-node-02"
assert fingerprint == crypto_keys.fingerprint_from_spki_der(public_der)
assert token_version >= 1
assert site_row is not None
assert site_row[0] == 1
assert refresh_count == 1
assert install_row is not None