diff --git a/Data/Server/Modules/admin/routes.py b/Data/Server/Modules/admin/routes.py index 97625d1..72c9d81 100644 --- a/Data/Server/Modules/admin/routes.py +++ b/Data/Server/Modules/admin/routes.py @@ -54,18 +54,27 @@ def register( try: cur = conn.cursor() sql = """ - SELECT id, code, expires_at, created_by_user_id, used_at, used_by_guid + SELECT id, + code, + expires_at, + created_by_user_id, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at FROM enrollment_install_codes """ params: List[str] = [] + now_iso = _iso(_now()) if status_filter == "active": - sql += " WHERE used_at IS NULL AND expires_at > ?" - params.append(_iso(_now())) + sql += " WHERE use_count < max_uses AND expires_at > ?" + params.append(now_iso) elif status_filter == "expired": - sql += " WHERE used_at IS NULL AND expires_at <= ?" - params.append(_iso(_now())) + sql += " WHERE use_count < max_uses AND expires_at <= ?" + params.append(now_iso) elif status_filter == "used": - sql += " WHERE used_at IS NOT NULL" + sql += " WHERE use_count >= max_uses" sql += " ORDER BY expires_at ASC" cur.execute(sql, params) rows = cur.fetchall() @@ -82,6 +91,9 @@ def register( "created_by_user_id": row[3], "used_at": row[4], "used_by_guid": row[5], + "max_uses": row[6], + "use_count": row[7], + "last_used_at": row[8], } ) return jsonify({"codes": records}) @@ -93,6 +105,18 @@ def register( if ttl_hours not in VALID_TTL_HOURS: return jsonify({"error": "invalid_ttl"}), 400 + max_uses_value = payload.get("max_uses") + if max_uses_value is None: + max_uses_value = payload.get("allowed_uses") + try: + max_uses = int(max_uses_value) + except Exception: + max_uses = 2 + if max_uses < 1: + max_uses = 1 + if max_uses > 10: + max_uses = 10 + user = current_user() or {} username = user.get("username") or "" @@ -106,22 +130,28 @@ def register( cur.execute( """ INSERT INTO enrollment_install_codes ( - id, code, expires_at, created_by_user_id + id, code, expires_at, created_by_user_id, max_uses, use_count ) - VALUES (?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, 0) """, - (record_id, code_value, _iso(expires_at), created_by), + (record_id, code_value, _iso(expires_at), created_by, max_uses), ) conn.commit() finally: conn.close() - log("server", f"installer code created id={record_id} by={username} ttl={ttl_hours}h") + log( + "server", + f"installer code created id={record_id} by={username} ttl={ttl_hours}h max_uses={max_uses}", + ) return jsonify( { "id": record_id, "code": code_value, "expires_at": _iso(expires_at), + "max_uses": max_uses, + "use_count": 0, + "last_used_at": None, } ) @@ -131,7 +161,7 @@ def register( try: cur = conn.cursor() cur.execute( - "DELETE FROM enrollment_install_codes WHERE id = ? AND used_at IS NULL", + "DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0", (code_id,), ) deleted = cur.rowcount diff --git a/Data/Server/Modules/db_migrations.py b/Data/Server/Modules/db_migrations.py index 5edd6c9..b6d1aa1 100644 --- a/Data/Server/Modules/db_migrations.py +++ b/Data/Server/Modules/db_migrations.py @@ -152,7 +152,10 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None: expires_at TEXT NOT NULL, created_by_user_id TEXT, used_at TEXT, - used_by_guid TEXT + used_by_guid TEXT, + max_uses INTEGER NOT NULL DEFAULT 1, + use_count INTEGER NOT NULL DEFAULT 0, + last_used_at TEXT ) """ ) @@ -163,6 +166,29 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None: """ ) + columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")} + if "max_uses" not in columns: + cur.execute( + """ + ALTER TABLE enrollment_install_codes + ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1 + """ + ) + if "use_count" not in columns: + cur.execute( + """ + ALTER TABLE enrollment_install_codes + ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0 + """ + ) + if "last_used_at" not in columns: + cur.execute( + """ + ALTER TABLE enrollment_install_codes + ADD COLUMN last_used_at TEXT + """ + ) + def _ensure_device_approval_table(conn: sqlite3.Connection) -> None: cur = conn.cursor() diff --git a/Data/Server/Modules/enrollment/routes.py b/Data/Server/Modules/enrollment/routes.py index c408bcd..3392cab 100644 --- a/Data/Server/Modules/enrollment/routes.py +++ b/Data/Server/Modules/enrollment/routes.py @@ -6,7 +6,7 @@ import sqlite3 import uuid from datetime import datetime, timezone, timedelta import time -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple from flask import Blueprint, jsonify, request @@ -66,31 +66,79 @@ def register( def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]: cur.execute( - "SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?", + """ + SELECT id, + code, + expires_at, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at + FROM enrollment_install_codes + WHERE code = ? + """, (code_value,), ) row = cur.fetchone() if not row: return None - keys = ["id", "code", "expires_at", "used_at"] + keys = [ + "id", + "code", + "expires_at", + "used_at", + "used_by_guid", + "max_uses", + "use_count", + "last_used_at", + ] record = dict(zip(keys, row)) return record - def _install_code_valid(record: Dict[str, Any]) -> bool: + def _install_code_valid( + record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor + ) -> Tuple[bool, Optional[str]]: if not record: - return False + return False, None expires_at = record.get("expires_at") if not isinstance(expires_at, str): - return False + return False, None try: expiry = datetime.fromisoformat(expires_at) except Exception: - return False + return False, None if expiry <= _now(): - return False - if record.get("used_at"): - return False - return True + return False, None + try: + max_uses = int(record.get("max_uses") or 1) + except Exception: + max_uses = 1 + if max_uses < 1: + max_uses = 1 + try: + use_count = int(record.get("use_count") or 0) + except Exception: + use_count = 0 + if use_count < max_uses: + return True, None + + guid = str(record.get("used_by_guid") or "").strip() + if not guid: + return False, None + cur.execute( + "SELECT ssl_key_fingerprint FROM devices WHERE guid = ?", + (guid,), + ) + row = cur.fetchone() + if not row: + return False, None + stored_fp = (row[0] or "").strip().lower() + if not stored_fp: + return False, None + if stored_fp == (fingerprint or "").strip().lower(): + return True, guid + return False, None def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str: base = (hostname or "").strip() or guid @@ -305,7 +353,13 @@ def register( try: cur = conn.cursor() install_code = _load_install_code(cur, enrollment_code) - if not _install_code_valid(install_code): + valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur) + if not valid_code: + log( + "server", + "enrollment request invalid_code " + f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}", + ) return jsonify({"error": "invalid_enrollment_code"}), 400 approval_reference: str @@ -331,6 +385,7 @@ def register( """ UPDATE device_approvals SET hostname_claimed = ?, + guid = ?, enrollment_code_id = ?, client_nonce = ?, server_nonce = ?, @@ -340,6 +395,7 @@ def register( """, ( hostname, + reuse_guid, install_code["id"], client_nonce_b64, server_nonce_b64, @@ -359,11 +415,12 @@ def register( status, client_nonce, server_nonce, agent_pubkey_der, created_at, updated_at ) - VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?) """, ( record_id, approval_reference, + reuse_guid, hostname, fingerprint, install_code["id"], @@ -537,14 +594,40 @@ def register( # Mark install code used if enrollment_code_id: + cur.execute( + "SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?", + (enrollment_code_id,), + ) + usage_row = cur.fetchone() + try: + prior_count = int(usage_row[0]) if usage_row else 0 + except Exception: + prior_count = 0 + try: + allowed_uses = int(usage_row[1]) if usage_row else 1 + except Exception: + allowed_uses = 1 + if allowed_uses < 1: + allowed_uses = 1 + new_count = prior_count + 1 + consumed = new_count >= allowed_uses cur.execute( """ UPDATE enrollment_install_codes - SET used_at = ?, used_by_guid = ? + SET use_count = ?, + used_by_guid = ?, + last_used_at = ?, + used_at = CASE WHEN ? THEN ? ELSE used_at END WHERE id = ? - AND used_at IS NULL """, - (now_iso, effective_guid, enrollment_code_id), + ( + new_count, + effective_guid, + now_iso, + 1 if consumed else 0, + now_iso, + enrollment_code_id, + ), ) # Update approval record with final state diff --git a/Data/Server/Modules/jobs/prune.py b/Data/Server/Modules/jobs/prune.py index 9c3ac56..2e6e21e 100644 --- a/Data/Server/Modules/jobs/prune.py +++ b/Data/Server/Modules/jobs/prune.py @@ -34,7 +34,7 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None cur.execute( """ DELETE FROM enrollment_install_codes - WHERE used_at IS NULL + WHERE use_count = 0 AND expires_at < ? """, (now_iso,), @@ -52,7 +52,10 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None SELECT 1 FROM enrollment_install_codes c WHERE c.id = device_approvals.enrollment_code_id - AND c.expires_at < ? + AND ( + c.expires_at < ? + OR c.use_count >= c.max_uses + ) ) OR created_at < ? ) diff --git a/Data/Server/WebUI/src/Admin/Enrollment_Codes.jsx b/Data/Server/WebUI/src/Admin/Enrollment_Codes.jsx index 236fbd7..db3e387 100644 --- a/Data/Server/WebUI/src/Admin/Enrollment_Codes.jsx +++ b/Data/Server/WebUI/src/Admin/Enrollment_Codes.jsx @@ -65,7 +65,9 @@ const formatDateTime = (value) => { const determineStatus = (record) => { if (!record) return "expired"; - if (record.used_at) return "used"; + const maxUses = Number.isFinite(record?.max_uses) ? record.max_uses : 1; + const useCount = Number.isFinite(record?.use_count) ? record.use_count : 0; + if (useCount >= Math.max(1, maxUses || 1)) return "used"; if (!record.expires_at) return "expired"; const expires = new Date(record.expires_at); if (Number.isNaN(expires.getTime())) return "expired"; @@ -80,6 +82,7 @@ function EnrollmentCodes() { const [statusFilter, setStatusFilter] = useState("all"); const [ttlHours, setTtlHours] = useState(6); const [generating, setGenerating] = useState(false); + const [maxUses, setMaxUses] = useState(2); const filteredCodes = useMemo(() => { if (statusFilter === "all") return codes; @@ -119,7 +122,7 @@ function EnrollmentCodes() { method: "POST", credentials: "include", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ ttl_hours: ttlHours }), + body: JSON.stringify({ ttl_hours: ttlHours, max_uses: maxUses }), }); if (!resp.ok) { const body = await resp.json().catch(() => ({})); @@ -133,7 +136,7 @@ function EnrollmentCodes() { } finally { setGenerating(false); } - }, [fetchCodes, ttlHours]); + }, [fetchCodes, ttlHours, maxUses]); const handleDelete = useCallback( async (id) => { @@ -216,7 +219,7 @@ function EnrollmentCodes() { labelId="ttl-select-label" label="Duration" value={ttlHours} - onChange={(event) => setTtlHours(event.target.value)} + onChange={(event) => setTtlHours(Number(event.target.value))} > {TTL_PRESETS.map((preset) => ( @@ -226,6 +229,22 @@ function EnrollmentCodes() { + + Allowed Uses + + +