Simplified & Reworked Enrollment Code System to be Site-Specific

This commit is contained in:
2025-11-16 17:40:24 -07:00
parent 65bee703e9
commit b2120d7385
13 changed files with 649 additions and 492 deletions

View File

@@ -103,7 +103,8 @@ def register(
used_by_guid,
max_uses,
use_count,
last_used_at
last_used_at,
site_id
FROM enrollment_install_codes
WHERE code = ?
""",
@@ -121,6 +122,7 @@ def register(
"max_uses",
"use_count",
"last_used_at",
"site_id",
]
record = dict(zip(keys, row))
return record
@@ -140,16 +142,16 @@ def register(
if expiry <= _now():
return False, None
try:
max_uses = int(record.get("max_uses") or 1)
max_uses_raw = record.get("max_uses")
max_uses = int(max_uses_raw) if max_uses_raw is not None else 0
except Exception:
max_uses = 1
if max_uses < 1:
max_uses = 1
max_uses = 0
unlimited = max_uses <= 0
try:
use_count = int(record.get("use_count") or 0)
except Exception:
use_count = 0
if use_count < max_uses:
if unlimited or use_count < max_uses:
return True, None
guid = normalize_guid(record.get("used_by_guid"))
@@ -392,6 +394,24 @@ def register(
try:
cur = conn.cursor()
install_code = _load_install_code(cur, enrollment_code)
site_id = install_code.get("site_id") if install_code else None
if site_id is None:
log(
"server",
"enrollment request rejected missing_site_binding "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
)
return jsonify({"error": "invalid_enrollment_code"}), 400
cur.execute("SELECT 1 FROM sites WHERE id = ?", (site_id,))
if cur.fetchone() is None:
log(
"server",
"enrollment request rejected missing_site_owner "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
)
return jsonify({"error": "invalid_enrollment_code"}), 400
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
if not valid_code:
log(
@@ -427,6 +447,7 @@ def register(
SET hostname_claimed = ?,
guid = ?,
enrollment_code_id = ?,
site_id = ?,
client_nonce = ?,
server_nonce = ?,
agent_pubkey_der = ?,
@@ -437,6 +458,7 @@ def register(
hostname,
reuse_guid,
install_code["id"],
install_code.get("site_id"),
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
@@ -451,11 +473,11 @@ def register(
"""
INSERT INTO device_approvals (
id, approval_reference, guid, hostname_claimed,
ssl_key_fingerprint_claimed, enrollment_code_id,
ssl_key_fingerprint_claimed, enrollment_code_id, site_id,
status, client_nonce, server_nonce, agent_pubkey_der,
created_at, updated_at
)
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
""",
(
record_id,
@@ -464,6 +486,7 @@ def register(
hostname,
fingerprint,
install_code["id"],
install_code.get("site_id"),
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
@@ -535,7 +558,7 @@ def register(
cur.execute(
"""
SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed,
enrollment_code_id, status, client_nonce, server_nonce,
enrollment_code_id, site_id, status, client_nonce, server_nonce,
agent_pubkey_der, created_at, updated_at, approved_by_user_id
FROM device_approvals
WHERE approval_reference = ?
@@ -553,6 +576,7 @@ def register(
hostname_claimed,
fingerprint,
enrollment_code_id,
site_id,
status,
client_nonce_stored,
server_nonce_b64,
@@ -643,6 +667,17 @@ def register(
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
_store_device_key(cur, effective_guid, fingerprint)
if site_id:
assigned_at = int(time.time())
cur.execute(
"""
INSERT INTO device_sites(device_hostname, site_id, assigned_at)
VALUES (?, ?, ?)
ON CONFLICT(device_hostname)
DO UPDATE SET site_id=excluded.site_id, assigned_at=excluded.assigned_at
""",
(device_record.get("hostname"), site_id, assigned_at),
)
# Mark install code used
if enrollment_code_id:
@@ -656,13 +691,14 @@ def register(
except Exception:
prior_count = 0
try:
allowed_uses = int(usage_row[1]) if usage_row else 1
allowed_uses = int(usage_row[1]) if usage_row else 0
except Exception:
allowed_uses = 1
allowed_uses = 0
unlimited = allowed_uses <= 0
if allowed_uses < 1:
allowed_uses = 1
new_count = prior_count + 1
consumed = new_count >= allowed_uses
consumed = False if unlimited else new_count >= allowed_uses
cur.execute(
"""
UPDATE enrollment_install_codes
@@ -767,4 +803,3 @@ def _mask_code(code: str) -> str:
if len(trimmed) <= 6:
return "***"
return f"{trimmed[:3]}***{trimmed[-3:]}"