Simplified & Reworked Enrollment Code System to be Site-Specific

This commit is contained in:
2025-11-16 17:40:24 -07:00
parent 65bee703e9
commit b2120d7385
13 changed files with 649 additions and 492 deletions

View File

@@ -94,7 +94,8 @@ CREATE TABLE IF NOT EXISTS enrollment_install_codes (
used_by_guid TEXT,
max_uses INTEGER,
use_count INTEGER,
last_used_at TEXT
last_used_at TEXT,
site_id INTEGER
);
CREATE TABLE IF NOT EXISTS enrollment_install_codes_persistent (
id TEXT PRIMARY KEY,
@@ -109,7 +110,8 @@ CREATE TABLE IF NOT EXISTS enrollment_install_codes_persistent (
last_used_at TEXT,
is_active INTEGER NOT NULL DEFAULT 1,
archived_at TEXT,
consumed_at TEXT
consumed_at TEXT,
site_id INTEGER
);
CREATE TABLE IF NOT EXISTS device_approvals (
id TEXT PRIMARY KEY,
@@ -118,6 +120,7 @@ CREATE TABLE IF NOT EXISTS device_approvals (
hostname_claimed TEXT,
ssl_key_fingerprint_claimed TEXT,
enrollment_code_id TEXT,
site_id INTEGER,
status TEXT,
client_nonce TEXT,
server_nonce TEXT,
@@ -145,7 +148,8 @@ CREATE TABLE IF NOT EXISTS sites (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
description TEXT,
created_at INTEGER
created_at INTEGER,
enrollment_code_id TEXT
);
CREATE TABLE IF NOT EXISTS device_sites (
device_hostname TEXT PRIMARY KEY,
@@ -270,9 +274,51 @@ def engine_harness(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Iterator[
"2025-10-01T00:00:00Z",
),
)
site_code_id = "SITE-CODE-0001"
site_code_value = "SITE-MAIN-CODE"
site_code_created = "2025-01-01T00:00:00Z"
site_code_expires = "2030-01-01T00:00:00Z"
cur.execute(
"INSERT INTO sites (id, name, description, created_at) VALUES (?, ?, ?, ?)",
(1, "Main Lab", "Primary integration site", 1_700_000_000),
"""
INSERT INTO enrollment_install_codes (
id,
code,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at,
site_id
) VALUES (?, ?, ?, ?, NULL, NULL, 0, 0, NULL, ?)
""",
(site_code_id, site_code_value, site_code_expires, "admin", 1),
)
cur.execute(
"""
INSERT INTO enrollment_install_codes_persistent (
id,
code,
created_at,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
last_known_use_count,
last_used_at,
is_active,
archived_at,
consumed_at,
site_id
) VALUES (?, ?, ?, ?, ?, NULL, NULL, 0, 0, NULL, 1, NULL, NULL, ?)
""",
(site_code_id, site_code_value, site_code_created, site_code_expires, "admin", 1),
)
cur.execute(
"INSERT INTO sites (id, name, description, created_at, enrollment_code_id) VALUES (?, ?, ?, ?, ?)",
(1, "Main Lab", "Primary integration site", 1_700_000_000, site_code_id),
)
cur.execute(
"INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)",
@@ -294,6 +340,7 @@ def engine_harness(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Iterator[
hostname_claimed,
ssl_key_fingerprint_claimed,
enrollment_code_id,
site_id,
status,
client_nonce,
server_nonce,
@@ -302,7 +349,7 @@ def engine_harness(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Iterator[
updated_at,
approved_by_user_id
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
"approval-1",
@@ -310,7 +357,8 @@ def engine_harness(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Iterator[
None,
"pending-device",
"aa:bb:cc:dd",
None,
site_code_id,
1,
"pending",
"client-nonce",
"server-nonce",

View File

@@ -57,7 +57,7 @@ def _patch_repo_call(monkeypatch: pytest.MonkeyPatch, calls: dict) -> None:
def test_list_devices(engine_harness: EngineTestHarness) -> None:
client = engine_harness.app.test_client()
client = _client_with_admin_session(engine_harness)
response = client.get("/api/devices")
assert response.status_code == 200
payload = response.get_json()
@@ -70,7 +70,7 @@ def test_list_devices(engine_harness: EngineTestHarness) -> None:
def test_list_agents(engine_harness: EngineTestHarness) -> None:
client = engine_harness.app.test_client()
client = _client_with_admin_session(engine_harness)
response = client.get("/api/agents")
assert response.status_code == 200
payload = response.get_json()
@@ -82,7 +82,7 @@ def test_list_agents(engine_harness: EngineTestHarness) -> None:
def test_device_details(engine_harness: EngineTestHarness) -> None:
client = engine_harness.app.test_client()
client = _client_with_admin_session(engine_harness)
response = client.get("/api/device/details/test-device")
assert response.status_code == 200
payload = response.get_json()
@@ -165,7 +165,7 @@ def test_repo_current_hash_allows_device_token(engine_harness: EngineTestHarness
def test_agent_hash_list_permissions(engine_harness: EngineTestHarness) -> None:
client = engine_harness.app.test_client()
client = _client_with_admin_session(engine_harness)
forbidden = client.get("/api/agent/hash_list", environ_base={"REMOTE_ADDR": "192.0.2.10"})
assert forbidden.status_code == 403
allowed = client.get("/api/agent/hash_list", environ_base={"REMOTE_ADDR": "127.0.0.1"})
@@ -208,21 +208,20 @@ def test_sites_lifecycle(engine_harness: EngineTestHarness) -> None:
assert delete_resp.status_code == 200
def test_admin_enrollment_code_flow(engine_harness: EngineTestHarness) -> None:
def test_site_enrollment_code_rotation(engine_harness: EngineTestHarness) -> None:
client = _client_with_admin_session(engine_harness)
create_resp = client.post(
"/api/admin/enrollment-codes",
json={"ttl_hours": 1, "max_uses": 2},
)
assert create_resp.status_code == 201
code_id = create_resp.get_json()["id"]
sites_resp = client.get("/api/sites")
assert sites_resp.status_code == 200
sites = sites_resp.get_json()["sites"]
assert sites and sites[0]["enrollment_code"]
site_id = sites[0]["id"]
original_code = sites[0]["enrollment_code"]
list_resp = client.get("/api/admin/enrollment-codes")
codes = list_resp.get_json()["codes"]
assert any(code["id"] == code_id for code in codes)
delete_resp = client.delete(f"/api/admin/enrollment-codes/{code_id}")
assert delete_resp.status_code == 200
rotate_resp = client.post("/api/sites/rotate_code", json={"site_id": site_id})
assert rotate_resp.status_code == 200
rotated = rotate_resp.get_json()
assert rotated["id"] == site_id
assert rotated["enrollment_code"] and rotated["enrollment_code"] != original_code
def test_admin_device_approvals(engine_harness: EngineTestHarness) -> None:

View File

@@ -31,19 +31,29 @@ def _iso(dt: datetime) -> str:
return dt.astimezone(timezone.utc).isoformat()
def _seed_install_code(db_path: os.PathLike[str], code: str) -> str:
def _seed_install_code(db_path: os.PathLike[str], code: str, site_id: int = 1) -> str:
record_id = str(uuid.uuid4())
baseline = _now()
issued_at = _iso(baseline)
expires_at = _iso(baseline + timedelta(days=1))
with sqlite3.connect(str(db_path)) as conn:
columns = {row[1] for row in conn.execute("PRAGMA table_info(sites)")}
if "enrollment_code_id" not in columns:
conn.execute("ALTER TABLE sites ADD COLUMN enrollment_code_id TEXT")
conn.execute(
"""
INSERT OR IGNORE INTO sites (id, name, description, created_at, enrollment_code_id)
VALUES (?, ?, ?, ?, ?)
""",
(site_id, f"Test Site {site_id}", "Seeded site", int(baseline.timestamp()), record_id),
)
conn.execute(
"""
INSERT INTO enrollment_install_codes (
id, code, expires_at, used_at, used_by_guid, max_uses, use_count, last_used_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
id, code, expires_at, used_at, used_by_guid, max_uses, use_count, last_used_at, site_id
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(record_id, code, expires_at, None, None, 1, 0, None),
(record_id, code, expires_at, None, None, 1, 0, None, site_id),
)
conn.execute(
"""
@@ -60,8 +70,9 @@ def _seed_install_code(db_path: os.PathLike[str], code: str) -> str:
last_used_at,
is_active,
archived_at,
consumed_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
consumed_at,
site_id
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
record_id,
@@ -77,8 +88,17 @@ def _seed_install_code(db_path: os.PathLike[str], code: str) -> str:
1,
None,
None,
site_id,
),
)
conn.execute(
"""
UPDATE sites
SET enrollment_code_id = ?
WHERE id = ?
""",
(record_id, site_id),
)
conn.commit()
return record_id
@@ -124,7 +144,7 @@ def test_enrollment_request_creates_pending_approval(engine_harness: EngineTestH
cur = conn.cursor()
cur.execute(
"""
SELECT hostname_claimed, ssl_key_fingerprint_claimed, client_nonce, status, enrollment_code_id
SELECT hostname_claimed, ssl_key_fingerprint_claimed, client_nonce, status, enrollment_code_id, site_id
FROM device_approvals
WHERE approval_reference = ?
""",
@@ -133,11 +153,12 @@ def test_enrollment_request_creates_pending_approval(engine_harness: EngineTestH
row = cur.fetchone()
assert row is not None
hostname_claimed, fingerprint, stored_client_nonce, status, stored_code_id = row
hostname_claimed, fingerprint, stored_client_nonce, status, stored_code_id, stored_site_id = row
assert hostname_claimed == "agent-node-01"
assert stored_client_nonce == client_nonce_b64
assert status == "pending"
assert stored_code_id == install_code_id
assert stored_site_id == 1
expected_fingerprint = crypto_keys.fingerprint_from_spki_der(public_der)
assert fingerprint == expected_fingerprint
@@ -206,7 +227,7 @@ def test_enrollment_poll_finalizes_when_approved(engine_harness: EngineTestHarne
with sqlite3.connect(str(harness.db_path)) as conn:
cur = conn.cursor()
cur.execute(
"SELECT guid, status FROM device_approvals WHERE approval_reference = ?",
"SELECT guid, status, site_id FROM device_approvals WHERE approval_reference = ?",
(approval_reference,),
)
approval_row = cur.fetchone()
@@ -215,6 +236,11 @@ def test_enrollment_poll_finalizes_when_approved(engine_harness: EngineTestHarne
(final_guid,),
)
device_row = cur.fetchone()
cur.execute(
"SELECT site_id FROM device_sites WHERE device_hostname = ?",
(device_row[0] if device_row else None,),
)
site_row = cur.fetchone()
cur.execute(
"SELECT COUNT(*) FROM refresh_tokens WHERE guid = ?",
(final_guid,),
@@ -241,15 +267,18 @@ def test_enrollment_poll_finalizes_when_approved(engine_harness: EngineTestHarne
persistent_row = cur.fetchone()
assert approval_row is not None
approval_guid, approval_status = approval_row
approval_guid, approval_status, approval_site_id = approval_row
assert approval_status == "completed"
assert approval_guid == final_guid
assert approval_site_id == 1
assert device_row is not None
hostname, fingerprint, token_version = device_row
assert hostname == "agent-node-02"
assert fingerprint == crypto_keys.fingerprint_from_spki_der(public_der)
assert token_version >= 1
assert site_row is not None
assert site_row[0] == 1
assert refresh_count == 1
assert install_row is not None

View File

@@ -10,8 +10,11 @@
from __future__ import annotations
import logging
import secrets
import sqlite3
import time
import uuid
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional, Sequence
@@ -24,6 +27,15 @@ _DEFAULT_ADMIN_HASH = (
)
def _iso(dt: datetime) -> str:
return dt.astimezone(timezone.utc).isoformat()
def _generate_install_code() -> str:
raw = secrets.token_hex(16).upper()
return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))
def initialise_engine_database(database_path: str, *, logger: Optional[logging.Logger] = None) -> None:
"""Ensure the Engine database has the required schema and default admin account."""
@@ -46,6 +58,7 @@ def initialise_engine_database(database_path: str, *, logger: Optional[logging.L
_ensure_activity_history(conn, logger=logger)
_ensure_device_list_views(conn, logger=logger)
_ensure_sites(conn, logger=logger)
_ensure_site_enrollment_codes(conn, logger=logger)
_ensure_users_table(conn, logger=logger)
_ensure_default_admin(conn, logger=logger)
_ensure_ansible_recaps(conn, logger=logger)
@@ -92,7 +105,8 @@ def _restore_persisted_enrollment_codes(conn: sqlite3.Connection, *, logger: Opt
used_by_guid,
max_uses,
use_count,
last_used_at
last_used_at,
site_id
)
SELECT
p.id,
@@ -103,7 +117,8 @@ def _restore_persisted_enrollment_codes(conn: sqlite3.Connection, *, logger: Opt
p.used_by_guid,
p.max_uses,
p.last_known_use_count,
p.last_used_at
p.last_used_at,
p.site_id
FROM enrollment_install_codes_persistent AS p
WHERE p.is_active = 1
ON CONFLICT(id) DO UPDATE
@@ -114,7 +129,8 @@ def _restore_persisted_enrollment_codes(conn: sqlite3.Connection, *, logger: Opt
used_by_guid = excluded.used_by_guid,
max_uses = excluded.max_uses,
use_count = excluded.use_count,
last_used_at = excluded.last_used_at
last_used_at = excluded.last_used_at,
site_id = excluded.site_id
"""
)
conn.commit()
@@ -185,7 +201,8 @@ def _ensure_sites(conn: sqlite3.Connection, *, logger: Optional[logging.Logger])
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
description TEXT,
created_at INTEGER
created_at INTEGER,
enrollment_code_id TEXT
)
"""
)
@@ -199,6 +216,10 @@ def _ensure_sites(conn: sqlite3.Connection, *, logger: Optional[logging.Logger])
)
"""
)
cur.execute("PRAGMA table_info(sites)")
columns = {row[1] for row in cur.fetchall()}
if "enrollment_code_id" not in columns:
cur.execute("ALTER TABLE sites ADD COLUMN enrollment_code_id TEXT")
except Exception as exc:
if logger:
logger.error("Failed to ensure site tables: %s", exc, exc_info=True)
@@ -208,6 +229,147 @@ def _ensure_sites(conn: sqlite3.Connection, *, logger: Optional[logging.Logger])
cur.close()
def _ensure_site_enrollment_codes(conn: sqlite3.Connection, *, logger: Optional[logging.Logger]) -> None:
cur = conn.cursor()
try:
cur.execute("SELECT id, enrollment_code_id FROM sites")
sites = cur.fetchall()
if not sites:
return
now = datetime.now(tz=timezone.utc)
long_expiry = _iso(now + timedelta(days=3650))
for site_id, current_code_id in sites:
active_code_id: Optional[str] = None
if current_code_id:
cur.execute(
"SELECT id, site_id FROM enrollment_install_codes WHERE id = ?",
(current_code_id,),
)
existing = cur.fetchone()
if existing:
active_code_id = current_code_id
if existing[1] is None:
cur.execute(
"UPDATE enrollment_install_codes SET site_id = ? WHERE id = ?",
(site_id, current_code_id),
)
cur.execute(
"UPDATE enrollment_install_codes_persistent SET site_id = COALESCE(site_id, ?) WHERE id = ?",
(site_id, current_code_id),
)
if not active_code_id:
cur.execute(
"""
SELECT id, code, created_at, expires_at, max_uses, last_known_use_count, last_used_at, site_id
FROM enrollment_install_codes_persistent
WHERE site_id = ? AND is_active = 1
ORDER BY datetime(created_at) DESC
LIMIT 1
""",
(site_id,),
)
row = cur.fetchone()
if row:
active_code_id = row[0]
if row[7] is None:
cur.execute(
"UPDATE enrollment_install_codes_persistent SET site_id = ? WHERE id = ?",
(site_id, active_code_id),
)
cur.execute(
"""
INSERT OR REPLACE INTO enrollment_install_codes (
id,
code,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at,
site_id
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
row[0],
row[1],
row[3] or long_expiry,
"system",
None,
None,
row[4] or 0,
row[5] or 0,
row[6],
site_id,
),
)
if not active_code_id:
new_id = str(uuid.uuid4())
code_value = _generate_install_code()
issued_at = _iso(now)
cur.execute(
"""
INSERT OR REPLACE INTO enrollment_install_codes (
id,
code,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at,
site_id
)
VALUES (?, ?, ?, 'system', NULL, NULL, 0, 0, NULL, ?)
""",
(new_id, code_value, long_expiry, site_id),
)
cur.execute(
"""
INSERT OR REPLACE INTO enrollment_install_codes_persistent (
id,
code,
created_at,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
last_known_use_count,
last_used_at,
is_active,
archived_at,
consumed_at,
site_id
)
VALUES (?, ?, ?, ?, 'system', NULL, NULL, 0, 0, NULL, 1, NULL, NULL, ?)
""",
(new_id, code_value, issued_at, long_expiry, site_id),
)
active_code_id = new_id
if active_code_id and active_code_id != current_code_id:
cur.execute(
"UPDATE sites SET enrollment_code_id = ? WHERE id = ?",
(active_code_id, site_id),
)
conn.commit()
except Exception as exc:
conn.rollback()
if logger:
logger.error("Failed to ensure site enrollment codes: %s", exc, exc_info=True)
else:
raise
finally:
cur.close()
def _ensure_users_table(conn: sqlite3.Connection, *, logger: Optional[logging.Logger]) -> None:
cur = conn.cursor()
try:

View File

@@ -155,7 +155,8 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
used_by_guid TEXT,
max_uses INTEGER NOT NULL DEFAULT 1,
use_count INTEGER NOT NULL DEFAULT 0,
last_used_at TEXT
last_used_at TEXT,
site_id INTEGER
)
"""
)
@@ -188,6 +189,13 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
ADD COLUMN last_used_at TEXT
"""
)
if "site_id" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN site_id INTEGER
"""
)
def _ensure_install_code_persistence_table(conn: sqlite3.Connection) -> None:
@@ -207,7 +215,8 @@ def _ensure_install_code_persistence_table(conn: sqlite3.Connection) -> None:
last_used_at TEXT,
is_active INTEGER NOT NULL DEFAULT 1,
archived_at TEXT,
consumed_at TEXT
consumed_at TEXT,
site_id INTEGER
)
"""
)
@@ -274,6 +283,13 @@ def _ensure_install_code_persistence_table(conn: sqlite3.Connection) -> None:
ADD COLUMN last_used_at TEXT
"""
)
if "site_id" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes_persistent
ADD COLUMN site_id INTEGER
"""
)
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
@@ -287,6 +303,7 @@ def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
hostname_claimed TEXT NOT NULL,
ssl_key_fingerprint_claimed TEXT NOT NULL,
enrollment_code_id TEXT NOT NULL,
site_id INTEGER,
status TEXT NOT NULL,
client_nonce TEXT NOT NULL,
server_nonce TEXT NOT NULL,
@@ -297,6 +314,16 @@ def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
)
"""
)
cur.execute("PRAGMA table_info(device_approvals)")
columns = {row[1] for row in cur.fetchall()}
if "site_id" not in columns:
cur.execute(
"""
ALTER TABLE device_approvals
ADD COLUMN site_id INTEGER
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_da_status
@@ -309,6 +336,12 @@ def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
ON device_approvals(ssl_key_fingerprint_claimed, status)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_da_site
ON device_approvals(site_id)
"""
)
def _create_devices_table(cur: sqlite3.Cursor) -> None:

View File

@@ -357,19 +357,23 @@ class AdminDeviceService:
da.hostname_claimed,
da.ssl_key_fingerprint_claimed,
da.enrollment_code_id,
da.site_id,
da.status,
da.client_nonce,
da.server_nonce,
da.created_at,
da.updated_at,
da.approved_by_user_id,
u.username AS approved_by_username
u.username AS approved_by_username,
s.name AS site_name
FROM device_approvals AS da
LEFT JOIN users AS u
ON (
CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT)
OR LOWER(da.approved_by_user_id) = LOWER(u.username)
)
LEFT JOIN sites AS s
ON s.id = da.site_id
"""
status_norm = (status_filter or "").strip().lower()
if status_norm and status_norm != "all":
@@ -409,17 +413,19 @@ class AdminDeviceService:
"hostname_claimed": hostname,
"ssl_key_fingerprint_claimed": fingerprint_claimed,
"enrollment_code_id": row[5],
"status": row[6],
"client_nonce": row[7],
"server_nonce": row[8],
"created_at": row[9],
"updated_at": row[10],
"approved_by_user_id": row[11],
"site_id": row[6],
"status": row[7],
"client_nonce": row[8],
"server_nonce": row[9],
"created_at": row[10],
"updated_at": row[11],
"approved_by_user_id": row[12],
"hostname_conflict": conflict,
"alternate_hostname": alternate,
"conflict_requires_prompt": requires_prompt,
"fingerprint_match": fingerprint_match,
"approved_by_username": row[12],
"approved_by_username": row[13],
"site_name": row[14],
}
)
finally:
@@ -578,4 +584,3 @@ def register_admin_endpoints(app, adapters: "EngineServiceAdapters") -> None:
return jsonify(payload), status
app.register_blueprint(blueprint)

View File

@@ -20,6 +20,7 @@
# - GET /api/sites/device_map (Token Authenticated) - Provides hostname to site assignment mapping data.
# - POST /api/sites/assign (Token Authenticated (Admin)) - Assigns a set of devices to a given site.
# - POST /api/sites/rename (Token Authenticated (Admin)) - Renames an existing site record.
# - POST /api/sites/rotate_code (Token Authenticated (Admin)) - Rotates the static enrollment code for a site.
# - GET /api/repo/current_hash (Device or Token Authenticated) - Fetches the current agent repository hash (with caching).
# - GET/POST /api/agent/hash (Device Authenticated) - Retrieves or updates an agent hash record bound to the authenticated device.
# - GET /api/agent/hash_list (Token Authenticated (Admin + Loopback)) - Returns stored agent hash metadata for localhost diagnostics.
@@ -31,10 +32,11 @@ from __future__ import annotations
import json
import logging
import os
import secrets
import sqlite3
import time
import uuid
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -118,6 +120,11 @@ def _is_internal_request(remote_addr: Optional[str]) -> bool:
return False
def _generate_install_code() -> str:
raw = secrets.token_hex(16).upper()
return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))
def _row_to_site(row: Tuple[Any, ...]) -> Dict[str, Any]:
return {
"id": row[0],
@@ -125,6 +132,11 @@ def _row_to_site(row: Tuple[Any, ...]) -> Dict[str, Any]:
"description": row[2] or "",
"created_at": row[3] or 0,
"device_count": row[4] or 0,
"enrollment_code_id": row[5],
"enrollment_code": row[6] or "",
"enrollment_code_expires_at": row[7] or "",
"enrollment_code_last_used_at": row[8] or "",
"enrollment_code_use_count": row[9] or 0,
}
@@ -1103,26 +1115,83 @@ class DeviceManagementService:
# Site management helpers
# ------------------------------------------------------------------
def _site_select_sql(self) -> str:
return """
SELECT s.id,
s.name,
s.description,
s.created_at,
COALESCE(ds.cnt, 0) AS device_count,
s.enrollment_code_id,
ic.code,
ic.expires_at,
ic.last_used_at,
ic.use_count
FROM sites AS s
LEFT JOIN (
SELECT site_id, COUNT(*) AS cnt
FROM device_sites
GROUP BY site_id
) AS ds ON ds.site_id = s.id
LEFT JOIN enrollment_install_codes AS ic
ON ic.id = s.enrollment_code_id
"""
def _fetch_site_row(self, cur: sqlite3.Cursor, site_id: int) -> Optional[Tuple[Any, ...]]:
cur.execute(self._site_select_sql() + " WHERE s.id = ?", (site_id,))
return cur.fetchone()
def _issue_site_enrollment_code(self, cur: sqlite3.Cursor, site_id: int, *, creator: str) -> Dict[str, Any]:
now = datetime.now(tz=timezone.utc)
issued_iso = now.isoformat()
expires_iso = (now + timedelta(days=3650)).isoformat()
code_id = str(uuid.uuid4())
code_value = _generate_install_code()
creator_value = creator or "system"
cur.execute(
"""
INSERT INTO enrollment_install_codes (
id, code, expires_at, created_by_user_id, used_at, used_by_guid,
max_uses, use_count, last_used_at, site_id
)
VALUES (?, ?, ?, ?, NULL, NULL, 0, 0, NULL, ?)
""",
(code_id, code_value, expires_iso, creator_value, site_id),
)
cur.execute(
"""
INSERT OR REPLACE INTO enrollment_install_codes_persistent (
id,
code,
created_at,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
last_known_use_count,
last_used_at,
is_active,
archived_at,
consumed_at,
site_id
)
VALUES (?, ?, ?, ?, ?, NULL, NULL, 0, 0, NULL, 1, NULL, NULL, ?)
""",
(code_id, code_value, issued_iso, expires_iso, creator_value, site_id),
)
return {
"id": code_id,
"code": code_value,
"created_at": issued_iso,
"expires_at": expires_iso,
}
def list_sites(self) -> Tuple[Dict[str, Any], int]:
conn = self._db_conn()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT s.id,
s.name,
s.description,
s.created_at,
COALESCE(ds.cnt, 0) AS device_count
FROM sites AS s
LEFT JOIN (
SELECT site_id, COUNT(*) AS cnt
FROM device_sites
GROUP BY site_id
) AS ds ON ds.site_id = s.id
ORDER BY LOWER(s.name) ASC
"""
)
cur.execute(self._site_select_sql() + " ORDER BY LOWER(s.name) ASC")
rows = cur.fetchall()
sites = [_row_to_site(row) for row in rows]
return {"sites": sites}, 200
@@ -1136,6 +1205,8 @@ class DeviceManagementService:
if not name:
return {"error": "name is required"}, 400
now = int(time.time())
user = self._current_user() or {}
creator = user.get("username") or "system"
conn = self._db_conn()
try:
cur = conn.cursor()
@@ -1144,14 +1215,16 @@ class DeviceManagementService:
(name, description, now),
)
site_id = cur.lastrowid
code_info = self._issue_site_enrollment_code(cur, site_id, creator=creator)
cur.execute("UPDATE sites SET enrollment_code_id = ? WHERE id = ?", (code_info["id"], site_id))
conn.commit()
cur.execute(
"SELECT id, name, description, created_at, 0 FROM sites WHERE id = ?",
(site_id,),
)
row = cur.fetchone()
row = self._fetch_site_row(cur, site_id)
if not row:
return {"error": "creation_failed"}, 500
self.service_log(
"server",
f"site created id={site_id} code_id={code_info['id']} by={creator}",
)
return _row_to_site(row), 201
except sqlite3.IntegrityError:
conn.rollback()
@@ -1171,13 +1244,19 @@ class DeviceManagementService:
try:
norm_ids.append(int(value))
except Exception:
continue
return {"error": "invalid id"}, 400
if not norm_ids:
return {"status": "ok", "deleted": 0}, 200
conn = self._db_conn()
try:
cur = conn.cursor()
placeholders = ",".join("?" * len(norm_ids))
cur.execute(
f"SELECT id FROM enrollment_install_codes WHERE site_id IN ({placeholders})",
tuple(norm_ids),
)
code_ids = [row[0] for row in cur.fetchall() if row and row[0]]
now_iso = datetime.now(tz=timezone.utc).isoformat()
cur.execute(
f"DELETE FROM device_sites WHERE site_id IN ({placeholders})",
tuple(norm_ids),
@@ -1187,6 +1266,30 @@ class DeviceManagementService:
tuple(norm_ids),
)
deleted = cur.rowcount
cur.execute(
f"""
UPDATE enrollment_install_codes_persistent
SET is_active = 0,
archived_at = COALESCE(archived_at, ?)
WHERE site_id IN ({placeholders})
""",
(now_iso, *norm_ids),
)
if code_ids:
code_placeholders = ",".join("?" * len(code_ids))
cur.execute(
f"DELETE FROM enrollment_install_codes WHERE id IN ({code_placeholders})",
tuple(code_ids),
)
cur.execute(
f"""
UPDATE enrollment_install_codes_persistent
SET is_active = 0,
archived_at = COALESCE(archived_at, ?)
WHERE id IN ({code_placeholders})
""",
(now_iso, *code_ids),
)
conn.commit()
return {"status": "ok", "deleted": deleted}, 200
except Exception as exc:
@@ -1287,20 +1390,7 @@ class DeviceManagementService:
return {"error": "site not found"}, 404
conn.commit()
cur.execute(
"""
SELECT s.id,
s.name,
s.description,
s.created_at,
COALESCE(ds.cnt, 0) AS device_count
FROM sites AS s
LEFT JOIN (
SELECT site_id, COUNT(*) AS cnt
FROM device_sites
GROUP BY site_id
) ds ON ds.site_id = s.id
WHERE s.id = ?
""",
self._site_select_sql() + " WHERE s.id = ?",
(site_id_int,),
)
row = cur.fetchone()
@@ -1317,6 +1407,56 @@ class DeviceManagementService:
finally:
conn.close()
def rotate_site_enrollment_code(self, site_id: Any) -> Tuple[Dict[str, Any], int]:
try:
site_id_int = int(site_id)
except Exception:
return {"error": "invalid site_id"}, 400
user = self._current_user() or {}
creator = user.get("username") or "system"
now_iso = datetime.now(tz=timezone.utc).isoformat()
conn = self._db_conn()
try:
cur = conn.cursor()
cur.execute("SELECT enrollment_code_id FROM sites WHERE id = ?", (site_id_int,))
row = cur.fetchone()
if not row:
return {"error": "site not found"}, 404
existing_code_id = row[0]
if existing_code_id:
cur.execute("DELETE FROM enrollment_install_codes WHERE id = ?", (existing_code_id,))
cur.execute(
"""
UPDATE enrollment_install_codes_persistent
SET is_active = 0,
archived_at = COALESCE(archived_at, ?)
WHERE id = ?
""",
(now_iso, existing_code_id),
)
code_info = self._issue_site_enrollment_code(cur, site_id_int, creator=creator)
cur.execute(
"UPDATE sites SET enrollment_code_id = ? WHERE id = ?",
(code_info["id"], site_id_int),
)
conn.commit()
site_row = self._fetch_site_row(cur, site_id_int)
if not site_row:
return {"error": "site not found"}, 404
self.service_log(
"server",
f"site enrollment code rotated site_id={site_id_int} code_id={code_info['id']} by={creator}",
)
return _row_to_site(site_row), 200
except Exception as exc:
conn.rollback()
self.logger.debug("Failed to rotate site enrollment code", exc_info=True)
return {"error": str(exc)}, 500
finally:
conn.close()
def repo_current_hash(self) -> Tuple[Dict[str, Any], int]:
refresh_flag = (request.args.get("refresh") or "").strip().lower()
force_refresh = refresh_flag in {"1", "true", "yes", "force", "refresh"}
@@ -1786,6 +1926,16 @@ def register_management(app, adapters: "EngineServiceAdapters") -> None:
payload, status = service.rename_site(data.get("id"), (data.get("new_name") or "").strip())
return jsonify(payload), status
@blueprint.route("/api/sites/rotate_code", methods=["POST"])
def _sites_rotate_code():
requirement = service._require_admin()
if requirement:
payload, status = requirement
return jsonify(payload), status
data = request.get_json(silent=True) or {}
payload, status = service.rotate_site_enrollment_code(data.get("site_id"))
return jsonify(payload), status
@blueprint.route("/api/repo/current_hash", methods=["GET"])
def _repo_current_hash():
requirement = service._require_device_or_login()

View File

@@ -103,7 +103,8 @@ def register(
used_by_guid,
max_uses,
use_count,
last_used_at
last_used_at,
site_id
FROM enrollment_install_codes
WHERE code = ?
""",
@@ -121,6 +122,7 @@ def register(
"max_uses",
"use_count",
"last_used_at",
"site_id",
]
record = dict(zip(keys, row))
return record
@@ -140,16 +142,16 @@ def register(
if expiry <= _now():
return False, None
try:
max_uses = int(record.get("max_uses") or 1)
max_uses_raw = record.get("max_uses")
max_uses = int(max_uses_raw) if max_uses_raw is not None else 0
except Exception:
max_uses = 1
if max_uses < 1:
max_uses = 1
max_uses = 0
unlimited = max_uses <= 0
try:
use_count = int(record.get("use_count") or 0)
except Exception:
use_count = 0
if use_count < max_uses:
if unlimited or use_count < max_uses:
return True, None
guid = normalize_guid(record.get("used_by_guid"))
@@ -392,6 +394,24 @@ def register(
try:
cur = conn.cursor()
install_code = _load_install_code(cur, enrollment_code)
site_id = install_code.get("site_id") if install_code else None
if site_id is None:
log(
"server",
"enrollment request rejected missing_site_binding "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
)
return jsonify({"error": "invalid_enrollment_code"}), 400
cur.execute("SELECT 1 FROM sites WHERE id = ?", (site_id,))
if cur.fetchone() is None:
log(
"server",
"enrollment request rejected missing_site_owner "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
)
return jsonify({"error": "invalid_enrollment_code"}), 400
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
if not valid_code:
log(
@@ -427,6 +447,7 @@ def register(
SET hostname_claimed = ?,
guid = ?,
enrollment_code_id = ?,
site_id = ?,
client_nonce = ?,
server_nonce = ?,
agent_pubkey_der = ?,
@@ -437,6 +458,7 @@ def register(
hostname,
reuse_guid,
install_code["id"],
install_code.get("site_id"),
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
@@ -451,11 +473,11 @@ def register(
"""
INSERT INTO device_approvals (
id, approval_reference, guid, hostname_claimed,
ssl_key_fingerprint_claimed, enrollment_code_id,
ssl_key_fingerprint_claimed, enrollment_code_id, site_id,
status, client_nonce, server_nonce, agent_pubkey_der,
created_at, updated_at
)
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
""",
(
record_id,
@@ -464,6 +486,7 @@ def register(
hostname,
fingerprint,
install_code["id"],
install_code.get("site_id"),
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
@@ -535,7 +558,7 @@ def register(
cur.execute(
"""
SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed,
enrollment_code_id, status, client_nonce, server_nonce,
enrollment_code_id, site_id, status, client_nonce, server_nonce,
agent_pubkey_der, created_at, updated_at, approved_by_user_id
FROM device_approvals
WHERE approval_reference = ?
@@ -553,6 +576,7 @@ def register(
hostname_claimed,
fingerprint,
enrollment_code_id,
site_id,
status,
client_nonce_stored,
server_nonce_b64,
@@ -643,6 +667,17 @@ def register(
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
_store_device_key(cur, effective_guid, fingerprint)
if site_id:
assigned_at = int(time.time())
cur.execute(
"""
INSERT INTO device_sites(device_hostname, site_id, assigned_at)
VALUES (?, ?, ?)
ON CONFLICT(device_hostname)
DO UPDATE SET site_id=excluded.site_id, assigned_at=excluded.assigned_at
""",
(device_record.get("hostname"), site_id, assigned_at),
)
# Mark install code used
if enrollment_code_id:
@@ -656,13 +691,14 @@ def register(
except Exception:
prior_count = 0
try:
allowed_uses = int(usage_row[1]) if usage_row else 1
allowed_uses = int(usage_row[1]) if usage_row else 0
except Exception:
allowed_uses = 1
allowed_uses = 0
unlimited = allowed_uses <= 0
if allowed_uses < 1:
allowed_uses = 1
new_count = prior_count + 1
consumed = new_count >= allowed_uses
consumed = False if unlimited else new_count >= allowed_uses
cur.execute(
"""
UPDATE enrollment_install_codes
@@ -767,4 +803,3 @@ def _mask_code(code: str) -> str:
if len(trimmed) <= 6:
return "***"
return f"{trimmed[:3]}***{trimmed[-3:]}"

View File

@@ -48,7 +48,6 @@ import GithubAPIToken from "./Access_Management/Github_API_Token.jsx";
import ServerInfo from "./Admin/Server_Info.jsx";
import PageTemplate from "./Admin/Page_Template.jsx";
import LogManagement from "./Admin/Log_Management.jsx";
import EnrollmentCodes from "./Devices/Enrollment_Codes.jsx";
import DeviceApprovals from "./Devices/Device_Approvals.jsx";
// Networking Imports
@@ -230,8 +229,6 @@ const LOCAL_STORAGE_KEY = "borealis_persistent_state";
return "/admin/server_info";
case "page_template":
return "/admin/page_template";
case "admin_enrollment_codes":
return "/admin/enrollment-codes";
case "admin_device_approvals":
return "/admin/device-approvals";
default:
@@ -286,7 +283,6 @@ const LOCAL_STORAGE_KEY = "borealis_persistent_state";
if (path === "/access_management/credentials") return { page: "access_credentials", options: {} };
if (path === "/admin/server_info") return { page: "server_info", options: {} };
if (path === "/admin/page_template") return { page: "page_template", options: {} };
if (path === "/admin/enrollment-codes") return { page: "admin_enrollment_codes", options: {} };
if (path === "/admin/device-approvals") return { page: "admin_device_approvals", options: {} };
return { page: "devices", options: {} };
} catch {
@@ -512,10 +508,6 @@ const LOCAL_STORAGE_KEY = "borealis_persistent_state";
items.push({ label: "Developer Tools" });
items.push({ label: "Page Template", page: "page_template" });
break;
case "admin_enrollment_codes":
items.push({ label: "Admin Settings", page: "server_info" });
items.push({ label: "Installer Codes", page: "admin_enrollment_codes" });
break;
case "admin_device_approvals":
items.push({ label: "Admin Settings", page: "server_info" });
items.push({ label: "Device Approvals", page: "admin_device_approvals" });
@@ -1045,7 +1037,6 @@ const LOCAL_STORAGE_KEY = "borealis_persistent_state";
useEffect(() => {
const requiresAdmin = currentPage === 'server_info'
|| currentPage === 'admin_enrollment_codes'
|| currentPage === 'admin_device_approvals'
|| currentPage === 'access_credentials'
|| currentPage === 'access_github_token'
@@ -1199,9 +1190,6 @@ const LOCAL_STORAGE_KEY = "borealis_persistent_state";
case "page_template":
return <PageTemplate isAdmin={isAdmin} />;
case "admin_enrollment_codes":
return <EnrollmentCodes />;
case "admin_device_approvals":
return <DeviceApprovals />;

View File

@@ -294,6 +294,12 @@ export default function DeviceApprovals() {
minWidth: 100,
Width: 100,
},
{
headerName: "Site",
field: "site_name",
valueGetter: (p) => p.data?.site_name || (p.data?.site_id ? `Site ${p.data.site_id}` : "—"),
minWidth: 160,
},
{ headerName: "Created", field: "created_at", valueFormatter: (p) => formatDateTime(p.value), minWidth: 160 },
{ headerName: "Updated", field: "updated_at", valueFormatter: (p) => formatDateTime(p.value), minWidth: 160 },
{

View File

@@ -1,372 +0,0 @@
import React, { useCallback, useEffect, useMemo, useState, useRef } from "react";
import {
Box,
Paper,
Typography,
Button,
Stack,
Alert,
FormControl,
InputLabel,
MenuItem,
Select,
CircularProgress,
Tooltip
} from "@mui/material";
import {
ContentCopy as CopyIcon,
DeleteOutline as DeleteIcon,
Refresh as RefreshIcon,
Key as KeyIcon,
} from "@mui/icons-material";
import { AgGridReact } from "ag-grid-react";
import { ModuleRegistry, AllCommunityModule, themeQuartz } from "ag-grid-community";
// IMPORTANT: Do NOT import global AG Grid CSS here to avoid overriding other pages.
// We rely on the project's existing CSS and themeQuartz class name like other MagicUI pages.
ModuleRegistry.registerModules([AllCommunityModule]);
// Match the palette used on other pages (see Site_List / Device_List)
const MAGIC_UI = {
shellBg:
"radial-gradient(120% 120% at 0% 0%, rgba(76, 186, 255, 0.16), transparent 55%), " +
"radial-gradient(120% 120% at 100% 0%, rgba(214, 130, 255, 0.18), transparent 60%), #040711",
panelBg:
"linear-gradient(135deg, rgba(10, 16, 31, 0.98) 0%, rgba(6, 10, 24, 0.94) 60%, rgba(15, 6, 26, 0.96) 100%)",
panelBorder: "rgba(148, 163, 184, 0.35)",
textBright: "#e2e8f0",
textMuted: "#94a3b8",
accentA: "#7dd3fc",
accentB: "#c084fc",
};
// Generate a scoped Quartz theme class (same pattern as other pages)
const gridTheme = themeQuartz.withParams({
accentColor: "#8b5cf6",
backgroundColor: "#070b1a",
browserColorScheme: "dark",
fontFamily: { googleFont: "IBM Plex Sans" },
foregroundColor: "#f4f7ff",
headerFontSize: 13,
});
const themeClassName = gridTheme.themeName || "ag-theme-quartz";
const TTL_PRESETS = [
{ value: 1, label: "1 hour" },
{ value: 3, label: "3 hours" },
{ value: 6, label: "6 hours" },
{ value: 12, label: "12 hours" },
{ value: 24, label: "24 hours" },
];
const determineStatus = (record) => {
if (!record) return "expired";
const maxUses = Number.isFinite(record?.max_uses) ? record.max_uses : 1;
const useCount = Number.isFinite(record?.use_count) ? record.use_count : 0;
if (useCount >= Math.max(1, maxUses || 1)) return "used";
if (!record.expires_at) return "expired";
const expires = new Date(record.expires_at);
if (Number.isNaN(expires.getTime())) return "expired";
return expires.getTime() > Date.now() ? "active" : "expired";
};
const formatDateTime = (value) => {
if (!value) return "—";
const date = new Date(value);
if (Number.isNaN(date.getTime())) return value;
return date.toLocaleString();
};
const maskCode = (code) => {
if (!code) return "—";
const parts = code.split("-");
if (parts.length <= 1) {
const prefix = code.slice(0, 4);
return `${prefix}${"•".repeat(Math.max(0, code.length - prefix.length))}`;
}
return parts
.map((part, idx) => (idx === 0 || idx === parts.length - 1 ? part : "•".repeat(part.length)))
.join("-");
};
export default function EnrollmentCodes() {
const [codes, setCodes] = useState([]);
const [loading, setLoading] = useState(false);
const [error, setError] = useState("");
const [feedback, setFeedback] = useState(null);
const [statusFilter, setStatusFilter] = useState("all");
const [ttlHours, setTtlHours] = useState(6);
const [generating, setGenerating] = useState(false);
const [maxUses, setMaxUses] = useState(2);
const gridRef = useRef(null);
const fetchCodes = useCallback(async () => {
setLoading(true);
setError("");
try {
const query = statusFilter === "all" ? "" : `?status=${encodeURIComponent(statusFilter)}`;
const resp = await fetch(`/api/admin/enrollment-codes${query}`, { credentials: "include" });
if (!resp.ok) {
const body = await resp.json().catch(() => ({}));
throw new Error(body.error || `Request failed (${resp.status})`);
}
const data = await resp.json();
setCodes(Array.isArray(data.codes) ? data.codes : []);
} catch (err) {
setError(err.message || "Unable to load codes");
} finally {
setLoading(false);
}
}, [statusFilter]);
useEffect(() => { fetchCodes(); }, [fetchCodes]);
const handleGenerate = useCallback(async () => {
setGenerating(true);
try {
const resp = await fetch("/api/admin/enrollment-codes", {
method: "POST",
credentials: "include",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ ttl_hours: ttlHours, max_uses: maxUses }),
});
if (!resp.ok) {
const body = await resp.json().catch(() => ({}));
throw new Error(body.error || `Request failed (${resp.status})`);
}
await fetchCodes();
setFeedback({ type: "success", message: "New installer code created" });
} catch (err) {
setFeedback({ type: "error", message: err.message });
} finally {
setGenerating(false);
}
}, [ttlHours, maxUses, fetchCodes]);
const handleCopy = (code) => {
if (!code) return;
try {
if (navigator.clipboard?.writeText) {
navigator.clipboard.writeText(code);
setFeedback({ type: "success", message: "Code copied to clipboard" });
}
} catch (_) {}
};
const handleDelete = async (id) => {
if (!id) return;
if (!window.confirm("Delete this installer code?")) return;
try {
const resp = await fetch(`/api/admin/enrollment-codes/${id}`, {
method: "DELETE",
credentials: "include",
});
if (!resp.ok) {
const body = await resp.json().catch(() => ({}));
throw new Error(body.error || `Request failed (${resp.status})`);
}
await fetchCodes();
setFeedback({ type: "success", message: "Code deleted" });
} catch (err) {
setFeedback({ type: "error", message: err.message });
}
};
const columns = useMemo(() => [
{
headerName: "Status",
field: "status",
cellRenderer: (params) => {
const status = determineStatus(params.data);
const color =
status === "active" ? "#34d399" :
status === "used" ? "#7dd3fc" :
"#fbbf24";
return <span style={{ color, fontWeight: 600 }}>{status}</span>;
},
minWidth: 100
},
{
headerName: "Installer Code",
field: "code",
cellRenderer: (params) => (
<span style={{ fontFamily: "monospace", color: "#7dd3fc" }}>{maskCode(params.value)}</span>
),
minWidth: 340
},
{ headerName: "Expires At",
field: "expires_at",
valueFormatter: p => formatDateTime(p.value)
},
{ headerName: "Created By", field: "created_by_user_id" },
{
headerName: "Usage",
valueGetter: (p) => `${p.data.use_count || 0} / ${p.data.max_uses || 1}`,
cellStyle: { fontFamily: "monospace" },
width: 120
},
{ headerName: "Last Used", field: "last_used_at", valueFormatter: p => formatDateTime(p.value) },
{ headerName: "Used By GUID", field: "used_by_guid" },
{
headerName: "Actions",
cellRenderer: (params) => {
const record = params.data;
const disableDelete = (record.use_count || 0) !== 0;
return (
<Stack direction="row" spacing={1} justifyContent="flex-end">
<Tooltip title="Copy code">
<span>
<Button size="small" onClick={() => handleCopy(record.code)}>
<CopyIcon fontSize="small" />
</Button>
</span>
</Tooltip>
<Tooltip title={disableDelete ? "Only unused codes can be deleted" : "Delete code"}>
<span>
<Button size="small" disabled={disableDelete} onClick={() => handleDelete(record.id)}>
<DeleteIcon fontSize="small" />
</Button>
</span>
</Tooltip>
</Stack>
);
},
width: 160
}
], []);
const defaultColDef = useMemo(() => ({
sortable: true,
filter: true,
resizable: true,
flex: 1,
minWidth: 140,
}), []);
return (
<Paper
sx={{
m: 0,
p: 0,
display: "flex",
flexDirection: "column",
flexGrow: 1,
minWidth: 0,
height: "100%",
borderRadius: 0,
border: `1px solid ${MAGIC_UI.panelBorder}`,
background: MAGIC_UI.shellBg,
boxShadow: "0 25px 80px rgba(6, 12, 30, 0.8)",
overflow: "hidden",
}}
elevation={0}
>
{/* Hero header */}
<Box sx={{ p: 3 }}>
<Stack direction="row" spacing={2} alignItems="center" justifyContent="space-between">
<Stack direction="row" spacing={1} alignItems="center">
<KeyIcon sx={{ color: MAGIC_UI.accentA }} />
<Typography variant="h6" sx={{ color: MAGIC_UI.textBright, fontWeight: 700 }}>
Enrollment Installer Codes
</Typography>
</Stack>
<Stack direction="row" spacing={1}>
<Button
variant="contained"
disabled={generating}
startIcon={generating ? <CircularProgress size={16} color="inherit" /> : null}
onClick={handleGenerate}
sx={{ background: "linear-gradient(135deg,#7dd3fc,#c084fc)", borderRadius: 999 }}
>
{generating ? "Generating…" : "Generate Code"}
</Button>
<Button variant="outlined" startIcon={<RefreshIcon />} onClick={fetchCodes} disabled={loading}>
Refresh
</Button>
</Stack>
</Stack>
</Box>
{/* Controls */}
<Box sx={{ p: 2, display: "flex", gap: 2, alignItems: "center", flexWrap: "wrap" }}>
<FormControl size="small" sx={{ minWidth: 140 }}>
<InputLabel>Status</InputLabel>
<Select value={statusFilter} label="Status" onChange={(e) => setStatusFilter(e.target.value)}>
<MenuItem value="all">All</MenuItem>
<MenuItem value="active">Active</MenuItem>
<MenuItem value="used">Used</MenuItem>
<MenuItem value="expired">Expired</MenuItem>
</Select>
</FormControl>
<FormControl size="small" sx={{ minWidth: 160 }}>
<InputLabel>Duration</InputLabel>
<Select value={ttlHours} label="Duration" onChange={(e) => setTtlHours(Number(e.target.value))}>
{TTL_PRESETS.map((p) => (
<MenuItem key={p.value} value={p.value}>
{p.label}
</MenuItem>
))}
</Select>
</FormControl>
<FormControl size="small" sx={{ minWidth: 160 }}>
<InputLabel>Allowed Uses</InputLabel>
<Select value={maxUses} label="Allowed Uses" onChange={(e) => setMaxUses(Number(e.target.value))}>
{[1, 2, 3, 5].map((n) => (
<MenuItem key={n} value={n}>
{n === 1 ? "Single use" : `${n} uses`}
</MenuItem>
))}
</Select>
</FormControl>
</Box>
{feedback && (
<Box sx={{ px: 3 }}>
<Alert severity={feedback.type} onClose={() => setFeedback(null)}>
{feedback.message}
</Alert>
</Box>
)}
{error && (
<Box sx={{ px: 3 }}>
<Alert severity="error">{error}</Alert>
</Box>
)}
{/* Grid wrapper — all overrides are SCOPED to this instance via inline CSS vars */}
<Box
className={themeClassName}
sx={{
flex: 1,
p: 2,
overflow: "hidden",
}}
// Inline style ensures the CSS variables only affect THIS grid instance
style={{
"--ag-background-color": "#070b1a",
"--ag-foreground-color": "#f4f7ff",
"--ag-header-background-color": "#0f172a",
"--ag-header-foreground-color": "#cfe0ff",
"--ag-odd-row-background-color": "rgba(255,255,255,0.02)",
"--ag-row-hover-color": "rgba(125,183,255,0.08)",
"--ag-selected-row-background-color": "rgba(64,164,255,0.18)",
"--ag-font-family": "'IBM Plex Sans', 'Helvetica Neue', Arial, sans-serif",
"--ag-border-color": "rgba(125,183,255,0.18)",
"--ag-row-border-color": "rgba(125,183,255,0.14)",
"--ag-border-radius": "8px",
}}
>
<AgGridReact
ref={gridRef}
rowData={codes}
columnDefs={columns}
defaultColDef={defaultColDef}
animateRows
pagination
paginationPageSize={20}
/>
</Box>
</Paper>
);
}

View File

@@ -24,7 +24,6 @@ import {
VpnKey as CredentialIcon,
PersonOutline as UserIcon,
GitHub as GitHubIcon,
Key as KeyIcon,
Dashboard as PageTemplateIcon,
AdminPanelSettings as AdminPanelSettingsIcon,
ReceiptLong as LogsIcon,
@@ -61,7 +60,6 @@ function NavigationSidebar({ currentPage, onNavigate, isAdmin = false }) {
"winrm_devices",
"agent_devices",
"admin_device_approvals",
"admin_enrollment_codes",
].includes(currentPage),
automation: ["jobs", "assemblies", "community"].includes(currentPage),
filters: ["filters", "groups"].includes(currentPage),
@@ -194,12 +192,6 @@ function NavigationSidebar({ currentPage, onNavigate, isAdmin = false }) {
label="Device Approvals"
pageKey="admin_device_approvals"
/>
<NavItem
icon={<KeyIcon fontSize="small" />}
label="Enrollment Codes"
pageKey="admin_enrollment_codes"
indent
/>
<NavItem
icon={<DevicesIcon fontSize="small" />}
label="Devices"

View File

@@ -6,12 +6,15 @@ import {
Button,
IconButton,
Tooltip,
CircularProgress,
} from "@mui/material";
import AddIcon from "@mui/icons-material/Add";
import LocationCityIcon from "@mui/icons-material/LocationCity";
import DeleteIcon from "@mui/icons-material/DeleteOutline";
import EditIcon from "@mui/icons-material/Edit";
import RefreshIcon from "@mui/icons-material/Refresh";
import ContentCopyIcon from "@mui/icons-material/ContentCopy";
import { AgGridReact } from "ag-grid-react";
import { ModuleRegistry, AllCommunityModule, themeQuartz } from "ag-grid-community";
import { CreateSiteDialog, ConfirmDeleteDialog, RenameSiteDialog } from "../Dialogs.jsx";
@@ -69,6 +72,7 @@ export default function SiteList({ onOpenDevicesForSite }) {
const [deleteOpen, setDeleteOpen] = useState(false);
const [renameOpen, setRenameOpen] = useState(false);
const [renameValue, setRenameValue] = useState("");
const [rotatingId, setRotatingId] = useState(null);
const gridRef = useRef(null);
const fetchSites = useCallback(async () => {
@@ -83,6 +87,42 @@ export default function SiteList({ onOpenDevicesForSite }) {
useEffect(() => { fetchSites(); }, [fetchSites]);
const handleCopy = useCallback(async (code) => {
const value = (code || "").trim();
if (!value) return;
try {
await navigator.clipboard.writeText(value);
} catch {
window.prompt("Copy enrollment code", value);
}
}, []);
const handleRotate = useCallback(async (site) => {
if (!site?.id) return;
const confirmRotate = window.confirm(
"Are you sure you want to rotate the enrollment code associated with this site? "
+ "If there are automations that deploy agents to endpoints, the enrollment code associated with them will need to also be updated."
);
if (!confirmRotate) return;
setRotatingId(site.id);
try {
const resp = await fetch("/api/sites/rotate_code", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ site_id: site.id }),
});
if (resp.ok) {
const updated = await resp.json();
setRows((prev) => prev.map((row) => (row.id === site.id ? { ...row, ...updated } : row)));
}
} catch {
// Silently fail the rotate if the request errors; grid will refresh on next fetch.
} finally {
setRotatingId(null);
fetchSites();
}
}, [fetchSites]);
const columnDefs = useMemo(() => [
{
headerName: "",
@@ -105,9 +145,51 @@ export default function SiteList({ onOpenDevicesForSite }) {
</span>
),
},
{
headerName: "Agent Enrollment Code",
field: "enrollment_code",
minWidth: 320,
flex: 1.2,
cellRenderer: (params) => {
const code = params.value || "—";
const site = params.data || {};
const busy = rotatingId === site.id;
return (
<Box sx={{ display: "flex", alignItems: "center", gap: 1 }}>
<Tooltip title="Rotate Code">
<span>
<IconButton
size="small"
onClick={() => handleRotate(site)}
disabled={busy}
sx={{ color: MAGIC_UI.accentA, border: "1px solid rgba(148,163,184,0.35)" }}
>
{busy ? <CircularProgress size={16} color="inherit" /> : <RefreshIcon fontSize="small" />}
</IconButton>
</span>
</Tooltip>
<Typography variant="body2" sx={{ fontFamily: "monospace", color: MAGIC_UI.textBright }}>
{code}
</Typography>
<Tooltip title="Copy">
<span>
<IconButton
size="small"
onClick={() => handleCopy(code)}
disabled={!code || code === "—"}
sx={{ color: MAGIC_UI.textMuted }}
>
<ContentCopyIcon fontSize="small" />
</IconButton>
</span>
</Tooltip>
</Box>
);
},
},
{ headerName: "Description", field: "description", minWidth: 220 },
{ headerName: "Devices", field: "device_count", minWidth: 120 },
], [onOpenDevicesForSite]);
], [onOpenDevicesForSite, handleRotate, handleCopy, rotatingId]);
const defaultColDef = useMemo(() => ({
sortable: true,