Support multi-use installer codes and reuse

This commit is contained in:
2025-10-18 03:19:26 -06:00
parent 775d365512
commit 8177cc0892
6 changed files with 416 additions and 36 deletions

View File

@@ -54,18 +54,27 @@ def register(
try:
cur = conn.cursor()
sql = """
SELECT id, code, expires_at, created_by_user_id, used_at, used_by_guid
SELECT id,
code,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
"""
params: List[str] = []
now_iso = _iso(_now())
if status_filter == "active":
sql += " WHERE used_at IS NULL AND expires_at > ?"
params.append(_iso(_now()))
sql += " WHERE use_count < max_uses AND expires_at > ?"
params.append(now_iso)
elif status_filter == "expired":
sql += " WHERE used_at IS NULL AND expires_at <= ?"
params.append(_iso(_now()))
sql += " WHERE use_count < max_uses AND expires_at <= ?"
params.append(now_iso)
elif status_filter == "used":
sql += " WHERE used_at IS NOT NULL"
sql += " WHERE use_count >= max_uses"
sql += " ORDER BY expires_at ASC"
cur.execute(sql, params)
rows = cur.fetchall()
@@ -82,6 +91,9 @@ def register(
"created_by_user_id": row[3],
"used_at": row[4],
"used_by_guid": row[5],
"max_uses": row[6],
"use_count": row[7],
"last_used_at": row[8],
}
)
return jsonify({"codes": records})
@@ -93,6 +105,18 @@ def register(
if ttl_hours not in VALID_TTL_HOURS:
return jsonify({"error": "invalid_ttl"}), 400
max_uses_value = payload.get("max_uses")
if max_uses_value is None:
max_uses_value = payload.get("allowed_uses")
try:
max_uses = int(max_uses_value)
except Exception:
max_uses = 2
if max_uses < 1:
max_uses = 1
if max_uses > 10:
max_uses = 10
user = current_user() or {}
username = user.get("username") or ""
@@ -106,22 +130,28 @@ def register(
cur.execute(
"""
INSERT INTO enrollment_install_codes (
id, code, expires_at, created_by_user_id
id, code, expires_at, created_by_user_id, max_uses, use_count
)
VALUES (?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, 0)
""",
(record_id, code_value, _iso(expires_at), created_by),
(record_id, code_value, _iso(expires_at), created_by, max_uses),
)
conn.commit()
finally:
conn.close()
log("server", f"installer code created id={record_id} by={username} ttl={ttl_hours}h")
log(
"server",
f"installer code created id={record_id} by={username} ttl={ttl_hours}h max_uses={max_uses}",
)
return jsonify(
{
"id": record_id,
"code": code_value,
"expires_at": _iso(expires_at),
"max_uses": max_uses,
"use_count": 0,
"last_used_at": None,
}
)
@@ -131,7 +161,7 @@ def register(
try:
cur = conn.cursor()
cur.execute(
"DELETE FROM enrollment_install_codes WHERE id = ? AND used_at IS NULL",
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
(code_id,),
)
deleted = cur.rowcount

View File

@@ -152,7 +152,10 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
expires_at TEXT NOT NULL,
created_by_user_id TEXT,
used_at TEXT,
used_by_guid TEXT
used_by_guid TEXT,
max_uses INTEGER NOT NULL DEFAULT 1,
use_count INTEGER NOT NULL DEFAULT 0,
last_used_at TEXT
)
"""
)
@@ -163,6 +166,29 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
"""
)
columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")}
if "max_uses" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1
"""
)
if "use_count" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0
"""
)
if "last_used_at" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN last_used_at TEXT
"""
)
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()

View File

@@ -6,7 +6,7 @@ import sqlite3
import uuid
from datetime import datetime, timezone, timedelta
import time
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Tuple
from flask import Blueprint, jsonify, request
@@ -66,31 +66,79 @@ def register(
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
cur.execute(
"SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?",
"""
SELECT id,
code,
expires_at,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
WHERE code = ?
""",
(code_value,),
)
row = cur.fetchone()
if not row:
return None
keys = ["id", "code", "expires_at", "used_at"]
keys = [
"id",
"code",
"expires_at",
"used_at",
"used_by_guid",
"max_uses",
"use_count",
"last_used_at",
]
record = dict(zip(keys, row))
return record
def _install_code_valid(record: Dict[str, Any]) -> bool:
def _install_code_valid(
record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor
) -> Tuple[bool, Optional[str]]:
if not record:
return False
return False, None
expires_at = record.get("expires_at")
if not isinstance(expires_at, str):
return False
return False, None
try:
expiry = datetime.fromisoformat(expires_at)
except Exception:
return False
return False, None
if expiry <= _now():
return False
if record.get("used_at"):
return False
return True
return False, None
try:
max_uses = int(record.get("max_uses") or 1)
except Exception:
max_uses = 1
if max_uses < 1:
max_uses = 1
try:
use_count = int(record.get("use_count") or 0)
except Exception:
use_count = 0
if use_count < max_uses:
return True, None
guid = str(record.get("used_by_guid") or "").strip()
if not guid:
return False, None
cur.execute(
"SELECT ssl_key_fingerprint FROM devices WHERE guid = ?",
(guid,),
)
row = cur.fetchone()
if not row:
return False, None
stored_fp = (row[0] or "").strip().lower()
if not stored_fp:
return False, None
if stored_fp == (fingerprint or "").strip().lower():
return True, guid
return False, None
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
base = (hostname or "").strip() or guid
@@ -305,7 +353,13 @@ def register(
try:
cur = conn.cursor()
install_code = _load_install_code(cur, enrollment_code)
if not _install_code_valid(install_code):
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
if not valid_code:
log(
"server",
"enrollment request invalid_code "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
)
return jsonify({"error": "invalid_enrollment_code"}), 400
approval_reference: str
@@ -331,6 +385,7 @@ def register(
"""
UPDATE device_approvals
SET hostname_claimed = ?,
guid = ?,
enrollment_code_id = ?,
client_nonce = ?,
server_nonce = ?,
@@ -340,6 +395,7 @@ def register(
""",
(
hostname,
reuse_guid,
install_code["id"],
client_nonce_b64,
server_nonce_b64,
@@ -359,11 +415,12 @@ def register(
status, client_nonce, server_nonce, agent_pubkey_der,
created_at, updated_at
)
VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
""",
(
record_id,
approval_reference,
reuse_guid,
hostname,
fingerprint,
install_code["id"],
@@ -537,14 +594,40 @@ def register(
# Mark install code used
if enrollment_code_id:
cur.execute(
"SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?",
(enrollment_code_id,),
)
usage_row = cur.fetchone()
try:
prior_count = int(usage_row[0]) if usage_row else 0
except Exception:
prior_count = 0
try:
allowed_uses = int(usage_row[1]) if usage_row else 1
except Exception:
allowed_uses = 1
if allowed_uses < 1:
allowed_uses = 1
new_count = prior_count + 1
consumed = new_count >= allowed_uses
cur.execute(
"""
UPDATE enrollment_install_codes
SET used_at = ?, used_by_guid = ?
SET use_count = ?,
used_by_guid = ?,
last_used_at = ?,
used_at = CASE WHEN ? THEN ? ELSE used_at END
WHERE id = ?
AND used_at IS NULL
""",
(now_iso, effective_guid, enrollment_code_id),
(
new_count,
effective_guid,
now_iso,
1 if consumed else 0,
now_iso,
enrollment_code_id,
),
)
# Update approval record with final state

View File

@@ -34,7 +34,7 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
cur.execute(
"""
DELETE FROM enrollment_install_codes
WHERE used_at IS NULL
WHERE use_count = 0
AND expires_at < ?
""",
(now_iso,),
@@ -52,7 +52,10 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
SELECT 1
FROM enrollment_install_codes c
WHERE c.id = device_approvals.enrollment_code_id
AND c.expires_at < ?
AND (
c.expires_at < ?
OR c.use_count >= c.max_uses
)
)
OR created_at < ?
)