mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-16 11:25:48 -07:00
Simplified & Reworked Enrollment Code System to be Site-Specific
This commit is contained in:
@@ -357,19 +357,23 @@ class AdminDeviceService:
|
||||
da.hostname_claimed,
|
||||
da.ssl_key_fingerprint_claimed,
|
||||
da.enrollment_code_id,
|
||||
da.site_id,
|
||||
da.status,
|
||||
da.client_nonce,
|
||||
da.server_nonce,
|
||||
da.created_at,
|
||||
da.updated_at,
|
||||
da.approved_by_user_id,
|
||||
u.username AS approved_by_username
|
||||
u.username AS approved_by_username,
|
||||
s.name AS site_name
|
||||
FROM device_approvals AS da
|
||||
LEFT JOIN users AS u
|
||||
ON (
|
||||
CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT)
|
||||
OR LOWER(da.approved_by_user_id) = LOWER(u.username)
|
||||
)
|
||||
LEFT JOIN sites AS s
|
||||
ON s.id = da.site_id
|
||||
"""
|
||||
status_norm = (status_filter or "").strip().lower()
|
||||
if status_norm and status_norm != "all":
|
||||
@@ -409,17 +413,19 @@ class AdminDeviceService:
|
||||
"hostname_claimed": hostname,
|
||||
"ssl_key_fingerprint_claimed": fingerprint_claimed,
|
||||
"enrollment_code_id": row[5],
|
||||
"status": row[6],
|
||||
"client_nonce": row[7],
|
||||
"server_nonce": row[8],
|
||||
"created_at": row[9],
|
||||
"updated_at": row[10],
|
||||
"approved_by_user_id": row[11],
|
||||
"site_id": row[6],
|
||||
"status": row[7],
|
||||
"client_nonce": row[8],
|
||||
"server_nonce": row[9],
|
||||
"created_at": row[10],
|
||||
"updated_at": row[11],
|
||||
"approved_by_user_id": row[12],
|
||||
"hostname_conflict": conflict,
|
||||
"alternate_hostname": alternate,
|
||||
"conflict_requires_prompt": requires_prompt,
|
||||
"fingerprint_match": fingerprint_match,
|
||||
"approved_by_username": row[12],
|
||||
"approved_by_username": row[13],
|
||||
"site_name": row[14],
|
||||
}
|
||||
)
|
||||
finally:
|
||||
@@ -578,4 +584,3 @@ def register_admin_endpoints(app, adapters: "EngineServiceAdapters") -> None:
|
||||
return jsonify(payload), status
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
# - GET /api/sites/device_map (Token Authenticated) - Provides hostname to site assignment mapping data.
|
||||
# - POST /api/sites/assign (Token Authenticated (Admin)) - Assigns a set of devices to a given site.
|
||||
# - POST /api/sites/rename (Token Authenticated (Admin)) - Renames an existing site record.
|
||||
# - POST /api/sites/rotate_code (Token Authenticated (Admin)) - Rotates the static enrollment code for a site.
|
||||
# - GET /api/repo/current_hash (Device or Token Authenticated) - Fetches the current agent repository hash (with caching).
|
||||
# - GET/POST /api/agent/hash (Device Authenticated) - Retrieves or updates an agent hash record bound to the authenticated device.
|
||||
# - GET /api/agent/hash_list (Token Authenticated (Admin + Loopback)) - Returns stored agent hash metadata for localhost diagnostics.
|
||||
@@ -31,10 +32,11 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -118,6 +120,11 @@ def _is_internal_request(remote_addr: Optional[str]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _generate_install_code() -> str:
|
||||
raw = secrets.token_hex(16).upper()
|
||||
return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))
|
||||
|
||||
|
||||
def _row_to_site(row: Tuple[Any, ...]) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": row[0],
|
||||
@@ -125,6 +132,11 @@ def _row_to_site(row: Tuple[Any, ...]) -> Dict[str, Any]:
|
||||
"description": row[2] or "",
|
||||
"created_at": row[3] or 0,
|
||||
"device_count": row[4] or 0,
|
||||
"enrollment_code_id": row[5],
|
||||
"enrollment_code": row[6] or "",
|
||||
"enrollment_code_expires_at": row[7] or "",
|
||||
"enrollment_code_last_used_at": row[8] or "",
|
||||
"enrollment_code_use_count": row[9] or 0,
|
||||
}
|
||||
|
||||
|
||||
@@ -1103,26 +1115,83 @@ class DeviceManagementService:
|
||||
# Site management helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _site_select_sql(self) -> str:
|
||||
return """
|
||||
SELECT s.id,
|
||||
s.name,
|
||||
s.description,
|
||||
s.created_at,
|
||||
COALESCE(ds.cnt, 0) AS device_count,
|
||||
s.enrollment_code_id,
|
||||
ic.code,
|
||||
ic.expires_at,
|
||||
ic.last_used_at,
|
||||
ic.use_count
|
||||
FROM sites AS s
|
||||
LEFT JOIN (
|
||||
SELECT site_id, COUNT(*) AS cnt
|
||||
FROM device_sites
|
||||
GROUP BY site_id
|
||||
) AS ds ON ds.site_id = s.id
|
||||
LEFT JOIN enrollment_install_codes AS ic
|
||||
ON ic.id = s.enrollment_code_id
|
||||
"""
|
||||
|
||||
def _fetch_site_row(self, cur: sqlite3.Cursor, site_id: int) -> Optional[Tuple[Any, ...]]:
|
||||
cur.execute(self._site_select_sql() + " WHERE s.id = ?", (site_id,))
|
||||
return cur.fetchone()
|
||||
|
||||
def _issue_site_enrollment_code(self, cur: sqlite3.Cursor, site_id: int, *, creator: str) -> Dict[str, Any]:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
issued_iso = now.isoformat()
|
||||
expires_iso = (now + timedelta(days=3650)).isoformat()
|
||||
code_id = str(uuid.uuid4())
|
||||
code_value = _generate_install_code()
|
||||
creator_value = creator or "system"
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO enrollment_install_codes (
|
||||
id, code, expires_at, created_by_user_id, used_at, used_by_guid,
|
||||
max_uses, use_count, last_used_at, site_id
|
||||
)
|
||||
VALUES (?, ?, ?, ?, NULL, NULL, 0, 0, NULL, ?)
|
||||
""",
|
||||
(code_id, code_value, expires_iso, creator_value, site_id),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO enrollment_install_codes_persistent (
|
||||
id,
|
||||
code,
|
||||
created_at,
|
||||
expires_at,
|
||||
created_by_user_id,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
last_known_use_count,
|
||||
last_used_at,
|
||||
is_active,
|
||||
archived_at,
|
||||
consumed_at,
|
||||
site_id
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, NULL, NULL, 0, 0, NULL, 1, NULL, NULL, ?)
|
||||
""",
|
||||
(code_id, code_value, issued_iso, expires_iso, creator_value, site_id),
|
||||
)
|
||||
return {
|
||||
"id": code_id,
|
||||
"code": code_value,
|
||||
"created_at": issued_iso,
|
||||
"expires_at": expires_iso,
|
||||
}
|
||||
|
||||
def list_sites(self) -> Tuple[Dict[str, Any], int]:
|
||||
conn = self._db_conn()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT s.id,
|
||||
s.name,
|
||||
s.description,
|
||||
s.created_at,
|
||||
COALESCE(ds.cnt, 0) AS device_count
|
||||
FROM sites AS s
|
||||
LEFT JOIN (
|
||||
SELECT site_id, COUNT(*) AS cnt
|
||||
FROM device_sites
|
||||
GROUP BY site_id
|
||||
) AS ds ON ds.site_id = s.id
|
||||
ORDER BY LOWER(s.name) ASC
|
||||
"""
|
||||
)
|
||||
cur.execute(self._site_select_sql() + " ORDER BY LOWER(s.name) ASC")
|
||||
rows = cur.fetchall()
|
||||
sites = [_row_to_site(row) for row in rows]
|
||||
return {"sites": sites}, 200
|
||||
@@ -1136,6 +1205,8 @@ class DeviceManagementService:
|
||||
if not name:
|
||||
return {"error": "name is required"}, 400
|
||||
now = int(time.time())
|
||||
user = self._current_user() or {}
|
||||
creator = user.get("username") or "system"
|
||||
conn = self._db_conn()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
@@ -1144,14 +1215,16 @@ class DeviceManagementService:
|
||||
(name, description, now),
|
||||
)
|
||||
site_id = cur.lastrowid
|
||||
code_info = self._issue_site_enrollment_code(cur, site_id, creator=creator)
|
||||
cur.execute("UPDATE sites SET enrollment_code_id = ? WHERE id = ?", (code_info["id"], site_id))
|
||||
conn.commit()
|
||||
cur.execute(
|
||||
"SELECT id, name, description, created_at, 0 FROM sites WHERE id = ?",
|
||||
(site_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
row = self._fetch_site_row(cur, site_id)
|
||||
if not row:
|
||||
return {"error": "creation_failed"}, 500
|
||||
self.service_log(
|
||||
"server",
|
||||
f"site created id={site_id} code_id={code_info['id']} by={creator}",
|
||||
)
|
||||
return _row_to_site(row), 201
|
||||
except sqlite3.IntegrityError:
|
||||
conn.rollback()
|
||||
@@ -1171,13 +1244,19 @@ class DeviceManagementService:
|
||||
try:
|
||||
norm_ids.append(int(value))
|
||||
except Exception:
|
||||
continue
|
||||
return {"error": "invalid id"}, 400
|
||||
if not norm_ids:
|
||||
return {"status": "ok", "deleted": 0}, 200
|
||||
conn = self._db_conn()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
placeholders = ",".join("?" * len(norm_ids))
|
||||
cur.execute(
|
||||
f"SELECT id FROM enrollment_install_codes WHERE site_id IN ({placeholders})",
|
||||
tuple(norm_ids),
|
||||
)
|
||||
code_ids = [row[0] for row in cur.fetchall() if row and row[0]]
|
||||
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||
cur.execute(
|
||||
f"DELETE FROM device_sites WHERE site_id IN ({placeholders})",
|
||||
tuple(norm_ids),
|
||||
@@ -1187,6 +1266,30 @@ class DeviceManagementService:
|
||||
tuple(norm_ids),
|
||||
)
|
||||
deleted = cur.rowcount
|
||||
cur.execute(
|
||||
f"""
|
||||
UPDATE enrollment_install_codes_persistent
|
||||
SET is_active = 0,
|
||||
archived_at = COALESCE(archived_at, ?)
|
||||
WHERE site_id IN ({placeholders})
|
||||
""",
|
||||
(now_iso, *norm_ids),
|
||||
)
|
||||
if code_ids:
|
||||
code_placeholders = ",".join("?" * len(code_ids))
|
||||
cur.execute(
|
||||
f"DELETE FROM enrollment_install_codes WHERE id IN ({code_placeholders})",
|
||||
tuple(code_ids),
|
||||
)
|
||||
cur.execute(
|
||||
f"""
|
||||
UPDATE enrollment_install_codes_persistent
|
||||
SET is_active = 0,
|
||||
archived_at = COALESCE(archived_at, ?)
|
||||
WHERE id IN ({code_placeholders})
|
||||
""",
|
||||
(now_iso, *code_ids),
|
||||
)
|
||||
conn.commit()
|
||||
return {"status": "ok", "deleted": deleted}, 200
|
||||
except Exception as exc:
|
||||
@@ -1287,20 +1390,7 @@ class DeviceManagementService:
|
||||
return {"error": "site not found"}, 404
|
||||
conn.commit()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT s.id,
|
||||
s.name,
|
||||
s.description,
|
||||
s.created_at,
|
||||
COALESCE(ds.cnt, 0) AS device_count
|
||||
FROM sites AS s
|
||||
LEFT JOIN (
|
||||
SELECT site_id, COUNT(*) AS cnt
|
||||
FROM device_sites
|
||||
GROUP BY site_id
|
||||
) ds ON ds.site_id = s.id
|
||||
WHERE s.id = ?
|
||||
""",
|
||||
self._site_select_sql() + " WHERE s.id = ?",
|
||||
(site_id_int,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
@@ -1317,6 +1407,56 @@ class DeviceManagementService:
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def rotate_site_enrollment_code(self, site_id: Any) -> Tuple[Dict[str, Any], int]:
|
||||
try:
|
||||
site_id_int = int(site_id)
|
||||
except Exception:
|
||||
return {"error": "invalid site_id"}, 400
|
||||
|
||||
user = self._current_user() or {}
|
||||
creator = user.get("username") or "system"
|
||||
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||
|
||||
conn = self._db_conn()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT enrollment_code_id FROM sites WHERE id = ?", (site_id_int,))
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return {"error": "site not found"}, 404
|
||||
existing_code_id = row[0]
|
||||
if existing_code_id:
|
||||
cur.execute("DELETE FROM enrollment_install_codes WHERE id = ?", (existing_code_id,))
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE enrollment_install_codes_persistent
|
||||
SET is_active = 0,
|
||||
archived_at = COALESCE(archived_at, ?)
|
||||
WHERE id = ?
|
||||
""",
|
||||
(now_iso, existing_code_id),
|
||||
)
|
||||
code_info = self._issue_site_enrollment_code(cur, site_id_int, creator=creator)
|
||||
cur.execute(
|
||||
"UPDATE sites SET enrollment_code_id = ? WHERE id = ?",
|
||||
(code_info["id"], site_id_int),
|
||||
)
|
||||
conn.commit()
|
||||
site_row = self._fetch_site_row(cur, site_id_int)
|
||||
if not site_row:
|
||||
return {"error": "site not found"}, 404
|
||||
self.service_log(
|
||||
"server",
|
||||
f"site enrollment code rotated site_id={site_id_int} code_id={code_info['id']} by={creator}",
|
||||
)
|
||||
return _row_to_site(site_row), 200
|
||||
except Exception as exc:
|
||||
conn.rollback()
|
||||
self.logger.debug("Failed to rotate site enrollment code", exc_info=True)
|
||||
return {"error": str(exc)}, 500
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def repo_current_hash(self) -> Tuple[Dict[str, Any], int]:
|
||||
refresh_flag = (request.args.get("refresh") or "").strip().lower()
|
||||
force_refresh = refresh_flag in {"1", "true", "yes", "force", "refresh"}
|
||||
@@ -1786,6 +1926,16 @@ def register_management(app, adapters: "EngineServiceAdapters") -> None:
|
||||
payload, status = service.rename_site(data.get("id"), (data.get("new_name") or "").strip())
|
||||
return jsonify(payload), status
|
||||
|
||||
@blueprint.route("/api/sites/rotate_code", methods=["POST"])
|
||||
def _sites_rotate_code():
|
||||
requirement = service._require_admin()
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
data = request.get_json(silent=True) or {}
|
||||
payload, status = service.rotate_site_enrollment_code(data.get("site_id"))
|
||||
return jsonify(payload), status
|
||||
|
||||
@blueprint.route("/api/repo/current_hash", methods=["GET"])
|
||||
def _repo_current_hash():
|
||||
requirement = service._require_device_or_login()
|
||||
|
||||
@@ -103,7 +103,8 @@ def register(
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
last_used_at,
|
||||
site_id
|
||||
FROM enrollment_install_codes
|
||||
WHERE code = ?
|
||||
""",
|
||||
@@ -121,6 +122,7 @@ def register(
|
||||
"max_uses",
|
||||
"use_count",
|
||||
"last_used_at",
|
||||
"site_id",
|
||||
]
|
||||
record = dict(zip(keys, row))
|
||||
return record
|
||||
@@ -140,16 +142,16 @@ def register(
|
||||
if expiry <= _now():
|
||||
return False, None
|
||||
try:
|
||||
max_uses = int(record.get("max_uses") or 1)
|
||||
max_uses_raw = record.get("max_uses")
|
||||
max_uses = int(max_uses_raw) if max_uses_raw is not None else 0
|
||||
except Exception:
|
||||
max_uses = 1
|
||||
if max_uses < 1:
|
||||
max_uses = 1
|
||||
max_uses = 0
|
||||
unlimited = max_uses <= 0
|
||||
try:
|
||||
use_count = int(record.get("use_count") or 0)
|
||||
except Exception:
|
||||
use_count = 0
|
||||
if use_count < max_uses:
|
||||
if unlimited or use_count < max_uses:
|
||||
return True, None
|
||||
|
||||
guid = normalize_guid(record.get("used_by_guid"))
|
||||
@@ -392,6 +394,24 @@ def register(
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
install_code = _load_install_code(cur, enrollment_code)
|
||||
site_id = install_code.get("site_id") if install_code else None
|
||||
if site_id is None:
|
||||
log(
|
||||
"server",
|
||||
"enrollment request rejected missing_site_binding "
|
||||
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"error": "invalid_enrollment_code"}), 400
|
||||
cur.execute("SELECT 1 FROM sites WHERE id = ?", (site_id,))
|
||||
if cur.fetchone() is None:
|
||||
log(
|
||||
"server",
|
||||
"enrollment request rejected missing_site_owner "
|
||||
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"error": "invalid_enrollment_code"}), 400
|
||||
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
|
||||
if not valid_code:
|
||||
log(
|
||||
@@ -427,6 +447,7 @@ def register(
|
||||
SET hostname_claimed = ?,
|
||||
guid = ?,
|
||||
enrollment_code_id = ?,
|
||||
site_id = ?,
|
||||
client_nonce = ?,
|
||||
server_nonce = ?,
|
||||
agent_pubkey_der = ?,
|
||||
@@ -437,6 +458,7 @@ def register(
|
||||
hostname,
|
||||
reuse_guid,
|
||||
install_code["id"],
|
||||
install_code.get("site_id"),
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
@@ -451,11 +473,11 @@ def register(
|
||||
"""
|
||||
INSERT INTO device_approvals (
|
||||
id, approval_reference, guid, hostname_claimed,
|
||||
ssl_key_fingerprint_claimed, enrollment_code_id,
|
||||
ssl_key_fingerprint_claimed, enrollment_code_id, site_id,
|
||||
status, client_nonce, server_nonce, agent_pubkey_der,
|
||||
created_at, updated_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
record_id,
|
||||
@@ -464,6 +486,7 @@ def register(
|
||||
hostname,
|
||||
fingerprint,
|
||||
install_code["id"],
|
||||
install_code.get("site_id"),
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
@@ -535,7 +558,7 @@ def register(
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed,
|
||||
enrollment_code_id, status, client_nonce, server_nonce,
|
||||
enrollment_code_id, site_id, status, client_nonce, server_nonce,
|
||||
agent_pubkey_der, created_at, updated_at, approved_by_user_id
|
||||
FROM device_approvals
|
||||
WHERE approval_reference = ?
|
||||
@@ -553,6 +576,7 @@ def register(
|
||||
hostname_claimed,
|
||||
fingerprint,
|
||||
enrollment_code_id,
|
||||
site_id,
|
||||
status,
|
||||
client_nonce_stored,
|
||||
server_nonce_b64,
|
||||
@@ -643,6 +667,17 @@ def register(
|
||||
|
||||
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
|
||||
_store_device_key(cur, effective_guid, fingerprint)
|
||||
if site_id:
|
||||
assigned_at = int(time.time())
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO device_sites(device_hostname, site_id, assigned_at)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(device_hostname)
|
||||
DO UPDATE SET site_id=excluded.site_id, assigned_at=excluded.assigned_at
|
||||
""",
|
||||
(device_record.get("hostname"), site_id, assigned_at),
|
||||
)
|
||||
|
||||
# Mark install code used
|
||||
if enrollment_code_id:
|
||||
@@ -656,13 +691,14 @@ def register(
|
||||
except Exception:
|
||||
prior_count = 0
|
||||
try:
|
||||
allowed_uses = int(usage_row[1]) if usage_row else 1
|
||||
allowed_uses = int(usage_row[1]) if usage_row else 0
|
||||
except Exception:
|
||||
allowed_uses = 1
|
||||
allowed_uses = 0
|
||||
unlimited = allowed_uses <= 0
|
||||
if allowed_uses < 1:
|
||||
allowed_uses = 1
|
||||
new_count = prior_count + 1
|
||||
consumed = new_count >= allowed_uses
|
||||
consumed = False if unlimited else new_count >= allowed_uses
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE enrollment_install_codes
|
||||
@@ -767,4 +803,3 @@ def _mask_code(code: str) -> str:
|
||||
if len(trimmed) <= 6:
|
||||
return "***"
|
||||
return f"{trimmed[:3]}***{trimmed[-3:]}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user