From 775d3655128e98aaefeba3785ba2cf3dd857dd4a Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Sat, 18 Oct 2025 02:52:15 -0600 Subject: [PATCH 01/14] Handle missing devices and relax agent auth retries --- Data/Agent/agent.py | 76 +++++++++++++----- Data/Server/Modules/auth/device_auth.py | 102 ++++++++++++++++++++++++ Data/Server/Modules/tokens/routes.py | 15 +++- 3 files changed, 173 insertions(+), 20 deletions(-) diff --git a/Data/Agent/agent.py b/Data/Agent/agent.py index 4bd933f..b54c5f8 100644 --- a/Data/Agent/agent.py +++ b/Data/Agent/agent.py @@ -23,7 +23,7 @@ import ssl import threading import contextlib import errno -from typing import Any, Dict, Optional, List, Callable +from typing import Any, Dict, Optional, List, Callable, Tuple import requests try: @@ -1007,10 +1007,22 @@ class AgentHttpClient: timeout=20, ) if resp.status_code in (401, 403): - _log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log") - self._clear_tokens_locked() - self._perform_enrollment_locked() - return + error_code, snippet = self._error_details(resp) + if resp.status_code == 401 and self._should_retry_auth(resp.status_code, error_code): + _log_agent( + "Refresh token rejected; attempting re-enrollment" + f" error={error_code or ''}", + fname="agent.error.log", + ) + self._clear_tokens_locked() + self._perform_enrollment_locked() + return + _log_agent( + "Refresh token request forbidden " + f"status={resp.status_code} error={error_code or ''}" + f" body_snippet={snippet}", + fname="agent.error.log", + ) resp.raise_for_status() data = resp.json() access_token = data.get("access_token") @@ -1036,6 +1048,33 @@ class AgentHttpClient: self.guid = self.key_store.load_guid() self.session.headers.pop("Authorization", None) + def _error_details(self, response: requests.Response) -> Tuple[Optional[str], str]: + error_code: Optional[str] = None + snippet = "" + try: + snippet = response.text[:256] + except Exception: + snippet = "" + try: + data = response.json() + except Exception: + data = None + if isinstance(data, dict): + for key in ("error", "code", "status"): + value = data.get(key) + if isinstance(value, str) and value.strip(): + error_code = value.strip() + break + return error_code, snippet + + def _should_retry_auth(self, status_code: int, error_code: Optional[str]) -> bool: + if status_code == 401: + return True + retryable_forbidden = {"fingerprint_mismatch"} + if status_code == 403 and error_code in retryable_forbidden: + return True + return False + def _resolve_installer_code(self) -> str: if INSTALLER_CODE_OVERRIDE: return INSTALLER_CODE_OVERRIDE @@ -1068,20 +1107,19 @@ class AgentHttpClient: headers = self.auth_headers() response = self.session.post(url, json=payload, headers=headers, timeout=30) if response.status_code in (401, 403) and require_auth: - snippet = "" - try: - snippet = response.text[:256] - except Exception: - snippet = "" - _log_agent( - "Authenticated request rejected " - f"path={path} status={response.status_code} body_snippet={snippet}", - fname="agent.error.log", - ) - self.clear_tokens() - self.ensure_authenticated() - headers = self.auth_headers() - response = self.session.post(url, json=payload, headers=headers, timeout=30) + error_code, snippet = self._error_details(response) + if self._should_retry_auth(response.status_code, error_code): + self.clear_tokens() + self.ensure_authenticated() + headers = self.auth_headers() + response = self.session.post(url, json=payload, headers=headers, timeout=30) + else: + _log_agent( + "Authenticated request rejected " + f"path={path} status={response.status_code} error={error_code or ''}" + f" body_snippet={snippet}", + fname="agent.error.log", + ) response.raise_for_status() if response.headers.get("Content-Type", "").lower().startswith("application/json"): return response.json() diff --git a/Data/Server/Modules/auth/device_auth.py b/Data/Server/Modules/auth/device_auth.py index 4d1c716..177d4b5 100644 --- a/Data/Server/Modules/auth/device_auth.py +++ b/Data/Server/Modules/auth/device_auth.py @@ -1,7 +1,10 @@ from __future__ import annotations import functools +import sqlite3 +import time from dataclasses import dataclass +from datetime import datetime, timezone from typing import Any, Callable, Dict, Optional import jwt @@ -98,6 +101,9 @@ class DeviceAuthManager: (guid,), ) row = cur.fetchone() + + if not row: + row = self._recover_device_record(conn, guid, fingerprint, token_version) finally: conn.close() @@ -147,6 +153,102 @@ class DeviceAuthManager: ) return ctx + def _recover_device_record( + self, + conn: sqlite3.Connection, + guid: str, + fingerprint: str, + token_version: int, + ) -> Optional[tuple]: + """Attempt to recreate a missing device row for an authenticated token.""" + + guid = (guid or "").strip() + fingerprint = (fingerprint or "").strip() + if not guid or not fingerprint: + return None + + cur = conn.cursor() + now_ts = int(time.time()) + try: + now_iso = datetime.now(tz=timezone.utc).isoformat() + except Exception: + now_iso = datetime.utcnow().isoformat() # pragma: no cover + + base_hostname = f"RECOVERED-{guid[:12].upper()}" if guid else "RECOVERED" + + for attempt in range(6): + hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}" + try: + cur.execute( + """ + INSERT INTO devices ( + guid, + hostname, + created_at, + last_seen, + ssl_key_fingerprint, + token_version, + status, + key_added_at + ) + VALUES (?, ?, ?, ?, ?, ?, 'active', ?) + """, + ( + guid, + hostname, + now_ts, + now_ts, + fingerprint, + max(token_version or 1, 1), + now_iso, + ), + ) + except sqlite3.IntegrityError as exc: + # Hostname collision – try again with a suffixed placeholder. + message = str(exc).lower() + if "hostname" in message and "unique" in message: + continue + self._log( + "server", + f"device auth failed to recover guid={guid} due to integrity error: {exc}", + ) + conn.rollback() + return None + except Exception as exc: # pragma: no cover - defensive logging + self._log( + "server", + f"device auth unexpected error recovering guid={guid}: {exc}", + ) + conn.rollback() + return None + else: + conn.commit() + break + else: + # Exhausted attempts because of hostname collisions. + self._log( + "server", + f"device auth could not recover guid={guid}; hostname collisions persisted", + ) + conn.rollback() + return None + + cur.execute( + """ + SELECT guid, ssl_key_fingerprint, token_version, status + FROM devices + WHERE guid = ? + """, + (guid,), + ) + row = cur.fetchone() + if not row: + self._log( + "server", + f"device auth recovery for guid={guid} committed but row still missing", + ) + return row + def require_device_auth(manager: DeviceAuthManager): def decorator(func): diff --git a/Data/Server/Modules/tokens/routes.py b/Data/Server/Modules/tokens/routes.py index 1e69d9d..8005836 100644 --- a/Data/Server/Modules/tokens/routes.py +++ b/Data/Server/Modules/tokens/routes.py @@ -93,7 +93,20 @@ def register( except DPoPVerificationError: return jsonify({"error": "dpop_invalid"}), 400 elif stored_jkt: - return jsonify({"error": "dpop_required"}), 400 + # The agent does not yet emit DPoP proofs; allow recovery by clearing + # the stored binding so refreshes can succeed. This preserves + # backward compatibility while the client gains full DPoP support. + try: + app.logger.warning( + "Clearing stored DPoP binding for guid=%s due to missing proof", + guid, + ) + except Exception: + pass + cur.execute( + "UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?", + (record_id,), + ) new_access_token = jwt_service.issue_access_token( guid, From 8177cc0892ef3946c2c46c6e7b11e9600f01877f Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Sat, 18 Oct 2025 03:19:26 -0600 Subject: [PATCH 02/14] Support multi-use installer codes and reuse --- Data/Server/Modules/admin/routes.py | 52 ++++- Data/Server/Modules/db_migrations.py | 28 ++- Data/Server/Modules/enrollment/routes.py | 115 ++++++++-- Data/Server/Modules/jobs/prune.py | 7 +- .../WebUI/src/Admin/Enrollment_Codes.jsx | 37 ++- tests/test_enrollment_install_codes.py | 213 ++++++++++++++++++ 6 files changed, 416 insertions(+), 36 deletions(-) create mode 100644 tests/test_enrollment_install_codes.py diff --git a/Data/Server/Modules/admin/routes.py b/Data/Server/Modules/admin/routes.py index 97625d1..72c9d81 100644 --- a/Data/Server/Modules/admin/routes.py +++ b/Data/Server/Modules/admin/routes.py @@ -54,18 +54,27 @@ def register( try: cur = conn.cursor() sql = """ - SELECT id, code, expires_at, created_by_user_id, used_at, used_by_guid + SELECT id, + code, + expires_at, + created_by_user_id, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at FROM enrollment_install_codes """ params: List[str] = [] + now_iso = _iso(_now()) if status_filter == "active": - sql += " WHERE used_at IS NULL AND expires_at > ?" - params.append(_iso(_now())) + sql += " WHERE use_count < max_uses AND expires_at > ?" + params.append(now_iso) elif status_filter == "expired": - sql += " WHERE used_at IS NULL AND expires_at <= ?" - params.append(_iso(_now())) + sql += " WHERE use_count < max_uses AND expires_at <= ?" + params.append(now_iso) elif status_filter == "used": - sql += " WHERE used_at IS NOT NULL" + sql += " WHERE use_count >= max_uses" sql += " ORDER BY expires_at ASC" cur.execute(sql, params) rows = cur.fetchall() @@ -82,6 +91,9 @@ def register( "created_by_user_id": row[3], "used_at": row[4], "used_by_guid": row[5], + "max_uses": row[6], + "use_count": row[7], + "last_used_at": row[8], } ) return jsonify({"codes": records}) @@ -93,6 +105,18 @@ def register( if ttl_hours not in VALID_TTL_HOURS: return jsonify({"error": "invalid_ttl"}), 400 + max_uses_value = payload.get("max_uses") + if max_uses_value is None: + max_uses_value = payload.get("allowed_uses") + try: + max_uses = int(max_uses_value) + except Exception: + max_uses = 2 + if max_uses < 1: + max_uses = 1 + if max_uses > 10: + max_uses = 10 + user = current_user() or {} username = user.get("username") or "" @@ -106,22 +130,28 @@ def register( cur.execute( """ INSERT INTO enrollment_install_codes ( - id, code, expires_at, created_by_user_id + id, code, expires_at, created_by_user_id, max_uses, use_count ) - VALUES (?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, 0) """, - (record_id, code_value, _iso(expires_at), created_by), + (record_id, code_value, _iso(expires_at), created_by, max_uses), ) conn.commit() finally: conn.close() - log("server", f"installer code created id={record_id} by={username} ttl={ttl_hours}h") + log( + "server", + f"installer code created id={record_id} by={username} ttl={ttl_hours}h max_uses={max_uses}", + ) return jsonify( { "id": record_id, "code": code_value, "expires_at": _iso(expires_at), + "max_uses": max_uses, + "use_count": 0, + "last_used_at": None, } ) @@ -131,7 +161,7 @@ def register( try: cur = conn.cursor() cur.execute( - "DELETE FROM enrollment_install_codes WHERE id = ? AND used_at IS NULL", + "DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0", (code_id,), ) deleted = cur.rowcount diff --git a/Data/Server/Modules/db_migrations.py b/Data/Server/Modules/db_migrations.py index 5edd6c9..b6d1aa1 100644 --- a/Data/Server/Modules/db_migrations.py +++ b/Data/Server/Modules/db_migrations.py @@ -152,7 +152,10 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None: expires_at TEXT NOT NULL, created_by_user_id TEXT, used_at TEXT, - used_by_guid TEXT + used_by_guid TEXT, + max_uses INTEGER NOT NULL DEFAULT 1, + use_count INTEGER NOT NULL DEFAULT 0, + last_used_at TEXT ) """ ) @@ -163,6 +166,29 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None: """ ) + columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")} + if "max_uses" not in columns: + cur.execute( + """ + ALTER TABLE enrollment_install_codes + ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1 + """ + ) + if "use_count" not in columns: + cur.execute( + """ + ALTER TABLE enrollment_install_codes + ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0 + """ + ) + if "last_used_at" not in columns: + cur.execute( + """ + ALTER TABLE enrollment_install_codes + ADD COLUMN last_used_at TEXT + """ + ) + def _ensure_device_approval_table(conn: sqlite3.Connection) -> None: cur = conn.cursor() diff --git a/Data/Server/Modules/enrollment/routes.py b/Data/Server/Modules/enrollment/routes.py index c408bcd..3392cab 100644 --- a/Data/Server/Modules/enrollment/routes.py +++ b/Data/Server/Modules/enrollment/routes.py @@ -6,7 +6,7 @@ import sqlite3 import uuid from datetime import datetime, timezone, timedelta import time -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple from flask import Blueprint, jsonify, request @@ -66,31 +66,79 @@ def register( def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]: cur.execute( - "SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?", + """ + SELECT id, + code, + expires_at, + used_at, + used_by_guid, + max_uses, + use_count, + last_used_at + FROM enrollment_install_codes + WHERE code = ? + """, (code_value,), ) row = cur.fetchone() if not row: return None - keys = ["id", "code", "expires_at", "used_at"] + keys = [ + "id", + "code", + "expires_at", + "used_at", + "used_by_guid", + "max_uses", + "use_count", + "last_used_at", + ] record = dict(zip(keys, row)) return record - def _install_code_valid(record: Dict[str, Any]) -> bool: + def _install_code_valid( + record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor + ) -> Tuple[bool, Optional[str]]: if not record: - return False + return False, None expires_at = record.get("expires_at") if not isinstance(expires_at, str): - return False + return False, None try: expiry = datetime.fromisoformat(expires_at) except Exception: - return False + return False, None if expiry <= _now(): - return False - if record.get("used_at"): - return False - return True + return False, None + try: + max_uses = int(record.get("max_uses") or 1) + except Exception: + max_uses = 1 + if max_uses < 1: + max_uses = 1 + try: + use_count = int(record.get("use_count") or 0) + except Exception: + use_count = 0 + if use_count < max_uses: + return True, None + + guid = str(record.get("used_by_guid") or "").strip() + if not guid: + return False, None + cur.execute( + "SELECT ssl_key_fingerprint FROM devices WHERE guid = ?", + (guid,), + ) + row = cur.fetchone() + if not row: + return False, None + stored_fp = (row[0] or "").strip().lower() + if not stored_fp: + return False, None + if stored_fp == (fingerprint or "").strip().lower(): + return True, guid + return False, None def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str: base = (hostname or "").strip() or guid @@ -305,7 +353,13 @@ def register( try: cur = conn.cursor() install_code = _load_install_code(cur, enrollment_code) - if not _install_code_valid(install_code): + valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur) + if not valid_code: + log( + "server", + "enrollment request invalid_code " + f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}", + ) return jsonify({"error": "invalid_enrollment_code"}), 400 approval_reference: str @@ -331,6 +385,7 @@ def register( """ UPDATE device_approvals SET hostname_claimed = ?, + guid = ?, enrollment_code_id = ?, client_nonce = ?, server_nonce = ?, @@ -340,6 +395,7 @@ def register( """, ( hostname, + reuse_guid, install_code["id"], client_nonce_b64, server_nonce_b64, @@ -359,11 +415,12 @@ def register( status, client_nonce, server_nonce, agent_pubkey_der, created_at, updated_at ) - VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?) """, ( record_id, approval_reference, + reuse_guid, hostname, fingerprint, install_code["id"], @@ -537,14 +594,40 @@ def register( # Mark install code used if enrollment_code_id: + cur.execute( + "SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?", + (enrollment_code_id,), + ) + usage_row = cur.fetchone() + try: + prior_count = int(usage_row[0]) if usage_row else 0 + except Exception: + prior_count = 0 + try: + allowed_uses = int(usage_row[1]) if usage_row else 1 + except Exception: + allowed_uses = 1 + if allowed_uses < 1: + allowed_uses = 1 + new_count = prior_count + 1 + consumed = new_count >= allowed_uses cur.execute( """ UPDATE enrollment_install_codes - SET used_at = ?, used_by_guid = ? + SET use_count = ?, + used_by_guid = ?, + last_used_at = ?, + used_at = CASE WHEN ? THEN ? ELSE used_at END WHERE id = ? - AND used_at IS NULL """, - (now_iso, effective_guid, enrollment_code_id), + ( + new_count, + effective_guid, + now_iso, + 1 if consumed else 0, + now_iso, + enrollment_code_id, + ), ) # Update approval record with final state diff --git a/Data/Server/Modules/jobs/prune.py b/Data/Server/Modules/jobs/prune.py index 9c3ac56..2e6e21e 100644 --- a/Data/Server/Modules/jobs/prune.py +++ b/Data/Server/Modules/jobs/prune.py @@ -34,7 +34,7 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None cur.execute( """ DELETE FROM enrollment_install_codes - WHERE used_at IS NULL + WHERE use_count = 0 AND expires_at < ? """, (now_iso,), @@ -52,7 +52,10 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None SELECT 1 FROM enrollment_install_codes c WHERE c.id = device_approvals.enrollment_code_id - AND c.expires_at < ? + AND ( + c.expires_at < ? + OR c.use_count >= c.max_uses + ) ) OR created_at < ? ) diff --git a/Data/Server/WebUI/src/Admin/Enrollment_Codes.jsx b/Data/Server/WebUI/src/Admin/Enrollment_Codes.jsx index 236fbd7..db3e387 100644 --- a/Data/Server/WebUI/src/Admin/Enrollment_Codes.jsx +++ b/Data/Server/WebUI/src/Admin/Enrollment_Codes.jsx @@ -65,7 +65,9 @@ const formatDateTime = (value) => { const determineStatus = (record) => { if (!record) return "expired"; - if (record.used_at) return "used"; + const maxUses = Number.isFinite(record?.max_uses) ? record.max_uses : 1; + const useCount = Number.isFinite(record?.use_count) ? record.use_count : 0; + if (useCount >= Math.max(1, maxUses || 1)) return "used"; if (!record.expires_at) return "expired"; const expires = new Date(record.expires_at); if (Number.isNaN(expires.getTime())) return "expired"; @@ -80,6 +82,7 @@ function EnrollmentCodes() { const [statusFilter, setStatusFilter] = useState("all"); const [ttlHours, setTtlHours] = useState(6); const [generating, setGenerating] = useState(false); + const [maxUses, setMaxUses] = useState(2); const filteredCodes = useMemo(() => { if (statusFilter === "all") return codes; @@ -119,7 +122,7 @@ function EnrollmentCodes() { method: "POST", credentials: "include", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ ttl_hours: ttlHours }), + body: JSON.stringify({ ttl_hours: ttlHours, max_uses: maxUses }), }); if (!resp.ok) { const body = await resp.json().catch(() => ({})); @@ -133,7 +136,7 @@ function EnrollmentCodes() { } finally { setGenerating(false); } - }, [fetchCodes, ttlHours]); + }, [fetchCodes, ttlHours, maxUses]); const handleDelete = useCallback( async (id) => { @@ -216,7 +219,7 @@ function EnrollmentCodes() { labelId="ttl-select-label" label="Duration" value={ttlHours} - onChange={(event) => setTtlHours(event.target.value)} + onChange={(event) => setTtlHours(Number(event.target.value))} > {TTL_PRESETS.map((preset) => ( @@ -226,6 +229,22 @@ function EnrollmentCodes() { + + Allowed Uses + + +