From 7b0e2f48e12abe564e202da11d5b89bc3b4f2c3d Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Fri, 17 Oct 2025 23:30:40 -0600 Subject: [PATCH] Handle device rekey during enrollment --- Data/Server/Modules/enrollment/routes.py | 51 ++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/Data/Server/Modules/enrollment/routes.py b/Data/Server/Modules/enrollment/routes.py index 7a883af..3d3570e 100644 --- a/Data/Server/Modules/enrollment/routes.py +++ b/Data/Server/Modules/enrollment/routes.py @@ -128,19 +128,63 @@ def register( def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]: cur.execute( - "SELECT guid, hostname, token_version, status, ssl_key_fingerprint FROM devices WHERE guid = ?", + """ + SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at + FROM devices + WHERE guid = ? + """, (guid,), ) row = cur.fetchone() if row: - keys = ["guid", "hostname", "token_version", "status", "ssl_key_fingerprint"] + keys = [ + "guid", + "hostname", + "token_version", + "status", + "ssl_key_fingerprint", + "key_added_at", + ] record = dict(zip(keys, row)) - if not record.get("ssl_key_fingerprint"): + stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower() + new_fp = (fingerprint or "").strip().lower() + if not stored_fp and new_fp: cur.execute( "UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?", (fingerprint, _iso(_now()), guid), ) record["ssl_key_fingerprint"] = fingerprint + elif new_fp and stored_fp != new_fp: + now_iso = _iso(_now()) + try: + current_version = int(record.get("token_version") or 1) + except Exception: + current_version = 1 + new_version = max(current_version + 1, 1) + cur.execute( + """ + UPDATE devices + SET ssl_key_fingerprint = ?, + key_added_at = ?, + token_version = ?, + status = 'active' + WHERE guid = ? + """, + (fingerprint, now_iso, new_version, guid), + ) + cur.execute( + """ + UPDATE refresh_tokens + SET revoked_at = ? + WHERE guid = ? + AND revoked_at IS NULL + """, + (now_iso, guid), + ) + record["ssl_key_fingerprint"] = fingerprint + record["token_version"] = new_version + record["status"] = "active" + record["key_added_at"] = now_iso return record resolved_hostname = _normalize_host(hostname, guid, cur) @@ -169,6 +213,7 @@ def register( "token_version": 1, "status": "active", "ssl_key_fingerprint": fingerprint, + "key_added_at": key_added_at, } def _hash_refresh_token(token: str) -> str: