Simplified & Reworked Enrollment Code System to be Site-Specific

This commit is contained in:
2025-11-16 17:40:24 -07:00
parent 65bee703e9
commit b2120d7385
13 changed files with 649 additions and 492 deletions

View File

@@ -20,6 +20,7 @@
# - GET /api/sites/device_map (Token Authenticated) - Provides hostname to site assignment mapping data.
# - POST /api/sites/assign (Token Authenticated (Admin)) - Assigns a set of devices to a given site.
# - POST /api/sites/rename (Token Authenticated (Admin)) - Renames an existing site record.
# - POST /api/sites/rotate_code (Token Authenticated (Admin)) - Rotates the static enrollment code for a site.
# - GET /api/repo/current_hash (Device or Token Authenticated) - Fetches the current agent repository hash (with caching).
# - GET/POST /api/agent/hash (Device Authenticated) - Retrieves or updates an agent hash record bound to the authenticated device.
# - GET /api/agent/hash_list (Token Authenticated (Admin + Loopback)) - Returns stored agent hash metadata for localhost diagnostics.
@@ -31,10 +32,11 @@ from __future__ import annotations
import json
import logging
import os
import secrets
import sqlite3
import time
import uuid
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -118,6 +120,11 @@ def _is_internal_request(remote_addr: Optional[str]) -> bool:
return False
def _generate_install_code() -> str:
raw = secrets.token_hex(16).upper()
return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))
def _row_to_site(row: Tuple[Any, ...]) -> Dict[str, Any]:
return {
"id": row[0],
@@ -125,6 +132,11 @@ def _row_to_site(row: Tuple[Any, ...]) -> Dict[str, Any]:
"description": row[2] or "",
"created_at": row[3] or 0,
"device_count": row[4] or 0,
"enrollment_code_id": row[5],
"enrollment_code": row[6] or "",
"enrollment_code_expires_at": row[7] or "",
"enrollment_code_last_used_at": row[8] or "",
"enrollment_code_use_count": row[9] or 0,
}
@@ -1103,26 +1115,83 @@ class DeviceManagementService:
# Site management helpers
# ------------------------------------------------------------------
def _site_select_sql(self) -> str:
return """
SELECT s.id,
s.name,
s.description,
s.created_at,
COALESCE(ds.cnt, 0) AS device_count,
s.enrollment_code_id,
ic.code,
ic.expires_at,
ic.last_used_at,
ic.use_count
FROM sites AS s
LEFT JOIN (
SELECT site_id, COUNT(*) AS cnt
FROM device_sites
GROUP BY site_id
) AS ds ON ds.site_id = s.id
LEFT JOIN enrollment_install_codes AS ic
ON ic.id = s.enrollment_code_id
"""
def _fetch_site_row(self, cur: sqlite3.Cursor, site_id: int) -> Optional[Tuple[Any, ...]]:
cur.execute(self._site_select_sql() + " WHERE s.id = ?", (site_id,))
return cur.fetchone()
def _issue_site_enrollment_code(self, cur: sqlite3.Cursor, site_id: int, *, creator: str) -> Dict[str, Any]:
now = datetime.now(tz=timezone.utc)
issued_iso = now.isoformat()
expires_iso = (now + timedelta(days=3650)).isoformat()
code_id = str(uuid.uuid4())
code_value = _generate_install_code()
creator_value = creator or "system"
cur.execute(
"""
INSERT INTO enrollment_install_codes (
id, code, expires_at, created_by_user_id, used_at, used_by_guid,
max_uses, use_count, last_used_at, site_id
)
VALUES (?, ?, ?, ?, NULL, NULL, 0, 0, NULL, ?)
""",
(code_id, code_value, expires_iso, creator_value, site_id),
)
cur.execute(
"""
INSERT OR REPLACE INTO enrollment_install_codes_persistent (
id,
code,
created_at,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
last_known_use_count,
last_used_at,
is_active,
archived_at,
consumed_at,
site_id
)
VALUES (?, ?, ?, ?, ?, NULL, NULL, 0, 0, NULL, 1, NULL, NULL, ?)
""",
(code_id, code_value, issued_iso, expires_iso, creator_value, site_id),
)
return {
"id": code_id,
"code": code_value,
"created_at": issued_iso,
"expires_at": expires_iso,
}
def list_sites(self) -> Tuple[Dict[str, Any], int]:
conn = self._db_conn()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT s.id,
s.name,
s.description,
s.created_at,
COALESCE(ds.cnt, 0) AS device_count
FROM sites AS s
LEFT JOIN (
SELECT site_id, COUNT(*) AS cnt
FROM device_sites
GROUP BY site_id
) AS ds ON ds.site_id = s.id
ORDER BY LOWER(s.name) ASC
"""
)
cur.execute(self._site_select_sql() + " ORDER BY LOWER(s.name) ASC")
rows = cur.fetchall()
sites = [_row_to_site(row) for row in rows]
return {"sites": sites}, 200
@@ -1136,6 +1205,8 @@ class DeviceManagementService:
if not name:
return {"error": "name is required"}, 400
now = int(time.time())
user = self._current_user() or {}
creator = user.get("username") or "system"
conn = self._db_conn()
try:
cur = conn.cursor()
@@ -1144,14 +1215,16 @@ class DeviceManagementService:
(name, description, now),
)
site_id = cur.lastrowid
code_info = self._issue_site_enrollment_code(cur, site_id, creator=creator)
cur.execute("UPDATE sites SET enrollment_code_id = ? WHERE id = ?", (code_info["id"], site_id))
conn.commit()
cur.execute(
"SELECT id, name, description, created_at, 0 FROM sites WHERE id = ?",
(site_id,),
)
row = cur.fetchone()
row = self._fetch_site_row(cur, site_id)
if not row:
return {"error": "creation_failed"}, 500
self.service_log(
"server",
f"site created id={site_id} code_id={code_info['id']} by={creator}",
)
return _row_to_site(row), 201
except sqlite3.IntegrityError:
conn.rollback()
@@ -1171,13 +1244,19 @@ class DeviceManagementService:
try:
norm_ids.append(int(value))
except Exception:
continue
return {"error": "invalid id"}, 400
if not norm_ids:
return {"status": "ok", "deleted": 0}, 200
conn = self._db_conn()
try:
cur = conn.cursor()
placeholders = ",".join("?" * len(norm_ids))
cur.execute(
f"SELECT id FROM enrollment_install_codes WHERE site_id IN ({placeholders})",
tuple(norm_ids),
)
code_ids = [row[0] for row in cur.fetchall() if row and row[0]]
now_iso = datetime.now(tz=timezone.utc).isoformat()
cur.execute(
f"DELETE FROM device_sites WHERE site_id IN ({placeholders})",
tuple(norm_ids),
@@ -1187,6 +1266,30 @@ class DeviceManagementService:
tuple(norm_ids),
)
deleted = cur.rowcount
cur.execute(
f"""
UPDATE enrollment_install_codes_persistent
SET is_active = 0,
archived_at = COALESCE(archived_at, ?)
WHERE site_id IN ({placeholders})
""",
(now_iso, *norm_ids),
)
if code_ids:
code_placeholders = ",".join("?" * len(code_ids))
cur.execute(
f"DELETE FROM enrollment_install_codes WHERE id IN ({code_placeholders})",
tuple(code_ids),
)
cur.execute(
f"""
UPDATE enrollment_install_codes_persistent
SET is_active = 0,
archived_at = COALESCE(archived_at, ?)
WHERE id IN ({code_placeholders})
""",
(now_iso, *code_ids),
)
conn.commit()
return {"status": "ok", "deleted": deleted}, 200
except Exception as exc:
@@ -1287,20 +1390,7 @@ class DeviceManagementService:
return {"error": "site not found"}, 404
conn.commit()
cur.execute(
"""
SELECT s.id,
s.name,
s.description,
s.created_at,
COALESCE(ds.cnt, 0) AS device_count
FROM sites AS s
LEFT JOIN (
SELECT site_id, COUNT(*) AS cnt
FROM device_sites
GROUP BY site_id
) ds ON ds.site_id = s.id
WHERE s.id = ?
""",
self._site_select_sql() + " WHERE s.id = ?",
(site_id_int,),
)
row = cur.fetchone()
@@ -1317,6 +1407,56 @@ class DeviceManagementService:
finally:
conn.close()
def rotate_site_enrollment_code(self, site_id: Any) -> Tuple[Dict[str, Any], int]:
try:
site_id_int = int(site_id)
except Exception:
return {"error": "invalid site_id"}, 400
user = self._current_user() or {}
creator = user.get("username") or "system"
now_iso = datetime.now(tz=timezone.utc).isoformat()
conn = self._db_conn()
try:
cur = conn.cursor()
cur.execute("SELECT enrollment_code_id FROM sites WHERE id = ?", (site_id_int,))
row = cur.fetchone()
if not row:
return {"error": "site not found"}, 404
existing_code_id = row[0]
if existing_code_id:
cur.execute("DELETE FROM enrollment_install_codes WHERE id = ?", (existing_code_id,))
cur.execute(
"""
UPDATE enrollment_install_codes_persistent
SET is_active = 0,
archived_at = COALESCE(archived_at, ?)
WHERE id = ?
""",
(now_iso, existing_code_id),
)
code_info = self._issue_site_enrollment_code(cur, site_id_int, creator=creator)
cur.execute(
"UPDATE sites SET enrollment_code_id = ? WHERE id = ?",
(code_info["id"], site_id_int),
)
conn.commit()
site_row = self._fetch_site_row(cur, site_id_int)
if not site_row:
return {"error": "site not found"}, 404
self.service_log(
"server",
f"site enrollment code rotated site_id={site_id_int} code_id={code_info['id']} by={creator}",
)
return _row_to_site(site_row), 200
except Exception as exc:
conn.rollback()
self.logger.debug("Failed to rotate site enrollment code", exc_info=True)
return {"error": str(exc)}, 500
finally:
conn.close()
def repo_current_hash(self) -> Tuple[Dict[str, Any], int]:
refresh_flag = (request.args.get("refresh") or "").strip().lower()
force_refresh = refresh_flag in {"1", "true", "yes", "force", "refresh"}
@@ -1786,6 +1926,16 @@ def register_management(app, adapters: "EngineServiceAdapters") -> None:
payload, status = service.rename_site(data.get("id"), (data.get("new_name") or "").strip())
return jsonify(payload), status
@blueprint.route("/api/sites/rotate_code", methods=["POST"])
def _sites_rotate_code():
requirement = service._require_admin()
if requirement:
payload, status = requirement
return jsonify(payload), status
data = request.get_json(silent=True) or {}
payload, status = service.rotate_site_enrollment_code(data.get("site_id"))
return jsonify(payload), status
@blueprint.route("/api/repo/current_hash", methods=["GET"])
def _repo_current_hash():
requirement = service._require_device_or_login()