mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 23:41:58 -06:00
Support multi-use installer codes and reuse
This commit is contained in:
@@ -54,18 +54,27 @@ def register(
|
|||||||
try:
|
try:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
sql = """
|
sql = """
|
||||||
SELECT id, code, expires_at, created_by_user_id, used_at, used_by_guid
|
SELECT id,
|
||||||
|
code,
|
||||||
|
expires_at,
|
||||||
|
created_by_user_id,
|
||||||
|
used_at,
|
||||||
|
used_by_guid,
|
||||||
|
max_uses,
|
||||||
|
use_count,
|
||||||
|
last_used_at
|
||||||
FROM enrollment_install_codes
|
FROM enrollment_install_codes
|
||||||
"""
|
"""
|
||||||
params: List[str] = []
|
params: List[str] = []
|
||||||
|
now_iso = _iso(_now())
|
||||||
if status_filter == "active":
|
if status_filter == "active":
|
||||||
sql += " WHERE used_at IS NULL AND expires_at > ?"
|
sql += " WHERE use_count < max_uses AND expires_at > ?"
|
||||||
params.append(_iso(_now()))
|
params.append(now_iso)
|
||||||
elif status_filter == "expired":
|
elif status_filter == "expired":
|
||||||
sql += " WHERE used_at IS NULL AND expires_at <= ?"
|
sql += " WHERE use_count < max_uses AND expires_at <= ?"
|
||||||
params.append(_iso(_now()))
|
params.append(now_iso)
|
||||||
elif status_filter == "used":
|
elif status_filter == "used":
|
||||||
sql += " WHERE used_at IS NOT NULL"
|
sql += " WHERE use_count >= max_uses"
|
||||||
sql += " ORDER BY expires_at ASC"
|
sql += " ORDER BY expires_at ASC"
|
||||||
cur.execute(sql, params)
|
cur.execute(sql, params)
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
@@ -82,6 +91,9 @@ def register(
|
|||||||
"created_by_user_id": row[3],
|
"created_by_user_id": row[3],
|
||||||
"used_at": row[4],
|
"used_at": row[4],
|
||||||
"used_by_guid": row[5],
|
"used_by_guid": row[5],
|
||||||
|
"max_uses": row[6],
|
||||||
|
"use_count": row[7],
|
||||||
|
"last_used_at": row[8],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return jsonify({"codes": records})
|
return jsonify({"codes": records})
|
||||||
@@ -93,6 +105,18 @@ def register(
|
|||||||
if ttl_hours not in VALID_TTL_HOURS:
|
if ttl_hours not in VALID_TTL_HOURS:
|
||||||
return jsonify({"error": "invalid_ttl"}), 400
|
return jsonify({"error": "invalid_ttl"}), 400
|
||||||
|
|
||||||
|
max_uses_value = payload.get("max_uses")
|
||||||
|
if max_uses_value is None:
|
||||||
|
max_uses_value = payload.get("allowed_uses")
|
||||||
|
try:
|
||||||
|
max_uses = int(max_uses_value)
|
||||||
|
except Exception:
|
||||||
|
max_uses = 2
|
||||||
|
if max_uses < 1:
|
||||||
|
max_uses = 1
|
||||||
|
if max_uses > 10:
|
||||||
|
max_uses = 10
|
||||||
|
|
||||||
user = current_user() or {}
|
user = current_user() or {}
|
||||||
username = user.get("username") or ""
|
username = user.get("username") or ""
|
||||||
|
|
||||||
@@ -106,22 +130,28 @@ def register(
|
|||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO enrollment_install_codes (
|
INSERT INTO enrollment_install_codes (
|
||||||
id, code, expires_at, created_by_user_id
|
id, code, expires_at, created_by_user_id, max_uses, use_count
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, 0)
|
||||||
""",
|
""",
|
||||||
(record_id, code_value, _iso(expires_at), created_by),
|
(record_id, code_value, _iso(expires_at), created_by, max_uses),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
log("server", f"installer code created id={record_id} by={username} ttl={ttl_hours}h")
|
log(
|
||||||
|
"server",
|
||||||
|
f"installer code created id={record_id} by={username} ttl={ttl_hours}h max_uses={max_uses}",
|
||||||
|
)
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"id": record_id,
|
"id": record_id,
|
||||||
"code": code_value,
|
"code": code_value,
|
||||||
"expires_at": _iso(expires_at),
|
"expires_at": _iso(expires_at),
|
||||||
|
"max_uses": max_uses,
|
||||||
|
"use_count": 0,
|
||||||
|
"last_used_at": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -131,7 +161,7 @@ def register(
|
|||||||
try:
|
try:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"DELETE FROM enrollment_install_codes WHERE id = ? AND used_at IS NULL",
|
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
|
||||||
(code_id,),
|
(code_id,),
|
||||||
)
|
)
|
||||||
deleted = cur.rowcount
|
deleted = cur.rowcount
|
||||||
|
|||||||
@@ -152,7 +152,10 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
|
|||||||
expires_at TEXT NOT NULL,
|
expires_at TEXT NOT NULL,
|
||||||
created_by_user_id TEXT,
|
created_by_user_id TEXT,
|
||||||
used_at TEXT,
|
used_at TEXT,
|
||||||
used_by_guid TEXT
|
used_by_guid TEXT,
|
||||||
|
max_uses INTEGER NOT NULL DEFAULT 1,
|
||||||
|
use_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
last_used_at TEXT
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@@ -163,6 +166,29 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")}
|
||||||
|
if "max_uses" not in columns:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE enrollment_install_codes
|
||||||
|
ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
if "use_count" not in columns:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE enrollment_install_codes
|
||||||
|
ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
if "last_used_at" not in columns:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE enrollment_install_codes
|
||||||
|
ADD COLUMN last_used_at TEXT
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
|
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import sqlite3
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
from flask import Blueprint, jsonify, request
|
from flask import Blueprint, jsonify, request
|
||||||
|
|
||||||
@@ -66,31 +66,79 @@ def register(
|
|||||||
|
|
||||||
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
|
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?",
|
"""
|
||||||
|
SELECT id,
|
||||||
|
code,
|
||||||
|
expires_at,
|
||||||
|
used_at,
|
||||||
|
used_by_guid,
|
||||||
|
max_uses,
|
||||||
|
use_count,
|
||||||
|
last_used_at
|
||||||
|
FROM enrollment_install_codes
|
||||||
|
WHERE code = ?
|
||||||
|
""",
|
||||||
(code_value,),
|
(code_value,),
|
||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
keys = ["id", "code", "expires_at", "used_at"]
|
keys = [
|
||||||
|
"id",
|
||||||
|
"code",
|
||||||
|
"expires_at",
|
||||||
|
"used_at",
|
||||||
|
"used_by_guid",
|
||||||
|
"max_uses",
|
||||||
|
"use_count",
|
||||||
|
"last_used_at",
|
||||||
|
]
|
||||||
record = dict(zip(keys, row))
|
record = dict(zip(keys, row))
|
||||||
return record
|
return record
|
||||||
|
|
||||||
def _install_code_valid(record: Dict[str, Any]) -> bool:
|
def _install_code_valid(
|
||||||
|
record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor
|
||||||
|
) -> Tuple[bool, Optional[str]]:
|
||||||
if not record:
|
if not record:
|
||||||
return False
|
return False, None
|
||||||
expires_at = record.get("expires_at")
|
expires_at = record.get("expires_at")
|
||||||
if not isinstance(expires_at, str):
|
if not isinstance(expires_at, str):
|
||||||
return False
|
return False, None
|
||||||
try:
|
try:
|
||||||
expiry = datetime.fromisoformat(expires_at)
|
expiry = datetime.fromisoformat(expires_at)
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False, None
|
||||||
if expiry <= _now():
|
if expiry <= _now():
|
||||||
return False
|
return False, None
|
||||||
if record.get("used_at"):
|
try:
|
||||||
return False
|
max_uses = int(record.get("max_uses") or 1)
|
||||||
return True
|
except Exception:
|
||||||
|
max_uses = 1
|
||||||
|
if max_uses < 1:
|
||||||
|
max_uses = 1
|
||||||
|
try:
|
||||||
|
use_count = int(record.get("use_count") or 0)
|
||||||
|
except Exception:
|
||||||
|
use_count = 0
|
||||||
|
if use_count < max_uses:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
guid = str(record.get("used_by_guid") or "").strip()
|
||||||
|
if not guid:
|
||||||
|
return False, None
|
||||||
|
cur.execute(
|
||||||
|
"SELECT ssl_key_fingerprint FROM devices WHERE guid = ?",
|
||||||
|
(guid,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if not row:
|
||||||
|
return False, None
|
||||||
|
stored_fp = (row[0] or "").strip().lower()
|
||||||
|
if not stored_fp:
|
||||||
|
return False, None
|
||||||
|
if stored_fp == (fingerprint or "").strip().lower():
|
||||||
|
return True, guid
|
||||||
|
return False, None
|
||||||
|
|
||||||
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
|
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
|
||||||
base = (hostname or "").strip() or guid
|
base = (hostname or "").strip() or guid
|
||||||
@@ -305,7 +353,13 @@ def register(
|
|||||||
try:
|
try:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
install_code = _load_install_code(cur, enrollment_code)
|
install_code = _load_install_code(cur, enrollment_code)
|
||||||
if not _install_code_valid(install_code):
|
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
|
||||||
|
if not valid_code:
|
||||||
|
log(
|
||||||
|
"server",
|
||||||
|
"enrollment request invalid_code "
|
||||||
|
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
|
||||||
|
)
|
||||||
return jsonify({"error": "invalid_enrollment_code"}), 400
|
return jsonify({"error": "invalid_enrollment_code"}), 400
|
||||||
|
|
||||||
approval_reference: str
|
approval_reference: str
|
||||||
@@ -331,6 +385,7 @@ def register(
|
|||||||
"""
|
"""
|
||||||
UPDATE device_approvals
|
UPDATE device_approvals
|
||||||
SET hostname_claimed = ?,
|
SET hostname_claimed = ?,
|
||||||
|
guid = ?,
|
||||||
enrollment_code_id = ?,
|
enrollment_code_id = ?,
|
||||||
client_nonce = ?,
|
client_nonce = ?,
|
||||||
server_nonce = ?,
|
server_nonce = ?,
|
||||||
@@ -340,6 +395,7 @@ def register(
|
|||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
hostname,
|
hostname,
|
||||||
|
reuse_guid,
|
||||||
install_code["id"],
|
install_code["id"],
|
||||||
client_nonce_b64,
|
client_nonce_b64,
|
||||||
server_nonce_b64,
|
server_nonce_b64,
|
||||||
@@ -359,11 +415,12 @@ def register(
|
|||||||
status, client_nonce, server_nonce, agent_pubkey_der,
|
status, client_nonce, server_nonce, agent_pubkey_der,
|
||||||
created_at, updated_at
|
created_at, updated_at
|
||||||
)
|
)
|
||||||
VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
record_id,
|
record_id,
|
||||||
approval_reference,
|
approval_reference,
|
||||||
|
reuse_guid,
|
||||||
hostname,
|
hostname,
|
||||||
fingerprint,
|
fingerprint,
|
||||||
install_code["id"],
|
install_code["id"],
|
||||||
@@ -537,14 +594,40 @@ def register(
|
|||||||
|
|
||||||
# Mark install code used
|
# Mark install code used
|
||||||
if enrollment_code_id:
|
if enrollment_code_id:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?",
|
||||||
|
(enrollment_code_id,),
|
||||||
|
)
|
||||||
|
usage_row = cur.fetchone()
|
||||||
|
try:
|
||||||
|
prior_count = int(usage_row[0]) if usage_row else 0
|
||||||
|
except Exception:
|
||||||
|
prior_count = 0
|
||||||
|
try:
|
||||||
|
allowed_uses = int(usage_row[1]) if usage_row else 1
|
||||||
|
except Exception:
|
||||||
|
allowed_uses = 1
|
||||||
|
if allowed_uses < 1:
|
||||||
|
allowed_uses = 1
|
||||||
|
new_count = prior_count + 1
|
||||||
|
consumed = new_count >= allowed_uses
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE enrollment_install_codes
|
UPDATE enrollment_install_codes
|
||||||
SET used_at = ?, used_by_guid = ?
|
SET use_count = ?,
|
||||||
|
used_by_guid = ?,
|
||||||
|
last_used_at = ?,
|
||||||
|
used_at = CASE WHEN ? THEN ? ELSE used_at END
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
AND used_at IS NULL
|
|
||||||
""",
|
""",
|
||||||
(now_iso, effective_guid, enrollment_code_id),
|
(
|
||||||
|
new_count,
|
||||||
|
effective_guid,
|
||||||
|
now_iso,
|
||||||
|
1 if consumed else 0,
|
||||||
|
now_iso,
|
||||||
|
enrollment_code_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update approval record with final state
|
# Update approval record with final state
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
|
|||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
DELETE FROM enrollment_install_codes
|
DELETE FROM enrollment_install_codes
|
||||||
WHERE used_at IS NULL
|
WHERE use_count = 0
|
||||||
AND expires_at < ?
|
AND expires_at < ?
|
||||||
""",
|
""",
|
||||||
(now_iso,),
|
(now_iso,),
|
||||||
@@ -52,7 +52,10 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
|
|||||||
SELECT 1
|
SELECT 1
|
||||||
FROM enrollment_install_codes c
|
FROM enrollment_install_codes c
|
||||||
WHERE c.id = device_approvals.enrollment_code_id
|
WHERE c.id = device_approvals.enrollment_code_id
|
||||||
AND c.expires_at < ?
|
AND (
|
||||||
|
c.expires_at < ?
|
||||||
|
OR c.use_count >= c.max_uses
|
||||||
|
)
|
||||||
)
|
)
|
||||||
OR created_at < ?
|
OR created_at < ?
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -65,7 +65,9 @@ const formatDateTime = (value) => {
|
|||||||
|
|
||||||
const determineStatus = (record) => {
|
const determineStatus = (record) => {
|
||||||
if (!record) return "expired";
|
if (!record) return "expired";
|
||||||
if (record.used_at) return "used";
|
const maxUses = Number.isFinite(record?.max_uses) ? record.max_uses : 1;
|
||||||
|
const useCount = Number.isFinite(record?.use_count) ? record.use_count : 0;
|
||||||
|
if (useCount >= Math.max(1, maxUses || 1)) return "used";
|
||||||
if (!record.expires_at) return "expired";
|
if (!record.expires_at) return "expired";
|
||||||
const expires = new Date(record.expires_at);
|
const expires = new Date(record.expires_at);
|
||||||
if (Number.isNaN(expires.getTime())) return "expired";
|
if (Number.isNaN(expires.getTime())) return "expired";
|
||||||
@@ -80,6 +82,7 @@ function EnrollmentCodes() {
|
|||||||
const [statusFilter, setStatusFilter] = useState("all");
|
const [statusFilter, setStatusFilter] = useState("all");
|
||||||
const [ttlHours, setTtlHours] = useState(6);
|
const [ttlHours, setTtlHours] = useState(6);
|
||||||
const [generating, setGenerating] = useState(false);
|
const [generating, setGenerating] = useState(false);
|
||||||
|
const [maxUses, setMaxUses] = useState(2);
|
||||||
|
|
||||||
const filteredCodes = useMemo(() => {
|
const filteredCodes = useMemo(() => {
|
||||||
if (statusFilter === "all") return codes;
|
if (statusFilter === "all") return codes;
|
||||||
@@ -119,7 +122,7 @@ function EnrollmentCodes() {
|
|||||||
method: "POST",
|
method: "POST",
|
||||||
credentials: "include",
|
credentials: "include",
|
||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
body: JSON.stringify({ ttl_hours: ttlHours }),
|
body: JSON.stringify({ ttl_hours: ttlHours, max_uses: maxUses }),
|
||||||
});
|
});
|
||||||
if (!resp.ok) {
|
if (!resp.ok) {
|
||||||
const body = await resp.json().catch(() => ({}));
|
const body = await resp.json().catch(() => ({}));
|
||||||
@@ -133,7 +136,7 @@ function EnrollmentCodes() {
|
|||||||
} finally {
|
} finally {
|
||||||
setGenerating(false);
|
setGenerating(false);
|
||||||
}
|
}
|
||||||
}, [fetchCodes, ttlHours]);
|
}, [fetchCodes, ttlHours, maxUses]);
|
||||||
|
|
||||||
const handleDelete = useCallback(
|
const handleDelete = useCallback(
|
||||||
async (id) => {
|
async (id) => {
|
||||||
@@ -216,7 +219,7 @@ function EnrollmentCodes() {
|
|||||||
labelId="ttl-select-label"
|
labelId="ttl-select-label"
|
||||||
label="Duration"
|
label="Duration"
|
||||||
value={ttlHours}
|
value={ttlHours}
|
||||||
onChange={(event) => setTtlHours(event.target.value)}
|
onChange={(event) => setTtlHours(Number(event.target.value))}
|
||||||
>
|
>
|
||||||
{TTL_PRESETS.map((preset) => (
|
{TTL_PRESETS.map((preset) => (
|
||||||
<MenuItem key={preset.value} value={preset.value}>
|
<MenuItem key={preset.value} value={preset.value}>
|
||||||
@@ -226,6 +229,22 @@ function EnrollmentCodes() {
|
|||||||
</Select>
|
</Select>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
|
||||||
|
<FormControl size="small" sx={{ minWidth: 160 }}>
|
||||||
|
<InputLabel id="uses-select-label">Allowed Uses</InputLabel>
|
||||||
|
<Select
|
||||||
|
labelId="uses-select-label"
|
||||||
|
label="Allowed Uses"
|
||||||
|
value={maxUses}
|
||||||
|
onChange={(event) => setMaxUses(Number(event.target.value))}
|
||||||
|
>
|
||||||
|
{[1, 2, 3, 5].map((uses) => (
|
||||||
|
<MenuItem key={uses} value={uses}>
|
||||||
|
{uses === 1 ? "Single use" : `${uses} uses`}
|
||||||
|
</MenuItem>
|
||||||
|
))}
|
||||||
|
</Select>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
<Button
|
<Button
|
||||||
variant="contained"
|
variant="contained"
|
||||||
color="primary"
|
color="primary"
|
||||||
@@ -270,7 +289,9 @@ function EnrollmentCodes() {
|
|||||||
<TableCell>Installer Code</TableCell>
|
<TableCell>Installer Code</TableCell>
|
||||||
<TableCell>Expires At</TableCell>
|
<TableCell>Expires At</TableCell>
|
||||||
<TableCell>Created By</TableCell>
|
<TableCell>Created By</TableCell>
|
||||||
<TableCell>Used At</TableCell>
|
<TableCell>Usage</TableCell>
|
||||||
|
<TableCell>Last Used</TableCell>
|
||||||
|
<TableCell>Consumed At</TableCell>
|
||||||
<TableCell>Used By GUID</TableCell>
|
<TableCell>Used By GUID</TableCell>
|
||||||
<TableCell align="right">Actions</TableCell>
|
<TableCell align="right">Actions</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
@@ -296,13 +317,17 @@ function EnrollmentCodes() {
|
|||||||
) : (
|
) : (
|
||||||
filteredCodes.map((record) => {
|
filteredCodes.map((record) => {
|
||||||
const status = determineStatus(record);
|
const status = determineStatus(record);
|
||||||
const disableDelete = status !== "active";
|
const maxAllowed = Math.max(1, Number.isFinite(record?.max_uses) ? record.max_uses : 1);
|
||||||
|
const usageCount = Math.max(0, Number.isFinite(record?.use_count) ? record.use_count : 0);
|
||||||
|
const disableDelete = usageCount !== 0;
|
||||||
return (
|
return (
|
||||||
<TableRow hover key={record.id}>
|
<TableRow hover key={record.id}>
|
||||||
<TableCell>{renderStatusChip(record)}</TableCell>
|
<TableCell>{renderStatusChip(record)}</TableCell>
|
||||||
<TableCell sx={{ fontFamily: "monospace" }}>{maskCode(record.code)}</TableCell>
|
<TableCell sx={{ fontFamily: "monospace" }}>{maskCode(record.code)}</TableCell>
|
||||||
<TableCell>{formatDateTime(record.expires_at)}</TableCell>
|
<TableCell>{formatDateTime(record.expires_at)}</TableCell>
|
||||||
<TableCell>{record.created_by_user_id || "—"}</TableCell>
|
<TableCell>{record.created_by_user_id || "—"}</TableCell>
|
||||||
|
<TableCell sx={{ fontFamily: "monospace" }}>{`${usageCount} / ${maxAllowed}`}</TableCell>
|
||||||
|
<TableCell>{formatDateTime(record.last_used_at)}</TableCell>
|
||||||
<TableCell>{formatDateTime(record.used_at)}</TableCell>
|
<TableCell>{formatDateTime(record.used_at)}</TableCell>
|
||||||
<TableCell sx={{ fontFamily: "monospace" }}>
|
<TableCell sx={{ fontFamily: "monospace" }}>
|
||||||
{record.used_by_guid || "—"}
|
{record.used_by_guid || "—"}
|
||||||
|
|||||||
213
tests/test_enrollment_install_codes.py
Normal file
213
tests/test_enrollment_install_codes.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import sqlite3
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
try: # pragma: no cover - optional dependency
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||||
|
_CRYPTO_IMPORT_ERROR: Exception | None = None
|
||||||
|
except Exception as exc: # pragma: no cover - dependency unavailable
|
||||||
|
serialization = None # type: ignore
|
||||||
|
ed25519 = None # type: ignore
|
||||||
|
_CRYPTO_IMPORT_ERROR = exc
|
||||||
|
try: # pragma: no cover - optional dependency
|
||||||
|
from flask import Flask
|
||||||
|
_FLASK_IMPORT_ERROR: Exception | None = None
|
||||||
|
except Exception as exc: # pragma: no cover - dependency unavailable
|
||||||
|
Flask = None # type: ignore
|
||||||
|
_FLASK_IMPORT_ERROR = exc
|
||||||
|
|
||||||
|
ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
|
from Data.Server.Modules import db_migrations
|
||||||
|
from Data.Server.Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||||
|
from Data.Server.Modules.enrollment.nonce_store import NonceCache
|
||||||
|
|
||||||
|
if Flask is not None: # pragma: no cover - dependency unavailable
|
||||||
|
from Data.Server.Modules.enrollment import routes as enrollment_routes
|
||||||
|
else: # pragma: no cover - dependency unavailable
|
||||||
|
enrollment_routes = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyJWTService:
|
||||||
|
def issue_access_token(self, guid: str, fingerprint: str, token_version: int, expires_in: int = 900, extra_claims=None):
|
||||||
|
return f"token-{guid}"
|
||||||
|
|
||||||
|
|
||||||
|
class _DummySigner:
|
||||||
|
def public_base64_spki(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app(db_path: str, tls_path: str):
|
||||||
|
if Flask is None or enrollment_routes is None: # pragma: no cover - dependency unavailable
|
||||||
|
pytest.skip(f"flask unavailable: {_FLASK_IMPORT_ERROR}")
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
def _factory():
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
enrollment_routes.register(
|
||||||
|
app,
|
||||||
|
db_conn_factory=_factory,
|
||||||
|
log=lambda channel, message: None,
|
||||||
|
jwt_service=_DummyJWTService(),
|
||||||
|
tls_bundle_path=tls_path,
|
||||||
|
ip_rate_limiter=SlidingWindowRateLimiter(),
|
||||||
|
fp_rate_limiter=SlidingWindowRateLimiter(),
|
||||||
|
nonce_cache=NonceCache(ttl_seconds=30.0),
|
||||||
|
script_signer=_DummySigner(),
|
||||||
|
)
|
||||||
|
return app, _factory
|
||||||
|
|
||||||
|
|
||||||
|
def _create_install_code(conn: sqlite3.Connection, code: str, *, max_uses: int = 2):
|
||||||
|
cur = conn.cursor()
|
||||||
|
record_id = str(uuid.uuid4())
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO enrollment_install_codes (
|
||||||
|
id, code, expires_at, created_by_user_id, max_uses, use_count
|
||||||
|
)
|
||||||
|
VALUES (?, ?, datetime('now', '+6 hours'), 'test-user', ?, 0)
|
||||||
|
""",
|
||||||
|
(record_id, code, max_uses),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return record_id
|
||||||
|
|
||||||
|
|
||||||
|
def _perform_enrollment_cycle(app, factory, code: str, private_key):
|
||||||
|
client = app.test_client()
|
||||||
|
public_der = private_key.public_key().public_bytes(
|
||||||
|
serialization.Encoding.DER,
|
||||||
|
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
public_b64 = base64.b64encode(public_der).decode("ascii")
|
||||||
|
client_nonce = os.urandom(32)
|
||||||
|
payload = {
|
||||||
|
"hostname": "unit-test-host",
|
||||||
|
"enrollment_code": code,
|
||||||
|
"agent_pubkey": public_b64,
|
||||||
|
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
|
||||||
|
}
|
||||||
|
request_resp = client.post("/api/agent/enroll/request", json=payload)
|
||||||
|
assert request_resp.status_code == 200
|
||||||
|
request_data = request_resp.get_json()
|
||||||
|
approval_reference = request_data["approval_reference"]
|
||||||
|
|
||||||
|
with factory() as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
UPDATE device_approvals
|
||||||
|
SET status = 'approved',
|
||||||
|
approved_by_user_id = 'tester'
|
||||||
|
WHERE approval_reference = ?
|
||||||
|
""",
|
||||||
|
(approval_reference,),
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT server_nonce, client_nonce
|
||||||
|
FROM device_approvals
|
||||||
|
WHERE approval_reference = ?
|
||||||
|
""",
|
||||||
|
(approval_reference,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
server_nonce_b64 = row["server_nonce"]
|
||||||
|
|
||||||
|
server_nonce = base64.b64decode(server_nonce_b64)
|
||||||
|
proof_message = server_nonce + approval_reference.encode("utf-8") + client_nonce
|
||||||
|
proof_sig = private_key.sign(proof_message)
|
||||||
|
poll_payload = {
|
||||||
|
"approval_reference": approval_reference,
|
||||||
|
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
|
||||||
|
"proof_sig": base64.b64encode(proof_sig).decode("ascii"),
|
||||||
|
}
|
||||||
|
poll_resp = client.post("/api/agent/enroll/poll", json=poll_payload)
|
||||||
|
assert poll_resp.status_code == 200
|
||||||
|
return poll_resp.get_json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("max_uses", [2])
|
||||||
|
@pytest.mark.skipif(ed25519 is None, reason=f"cryptography unavailable: {_CRYPTO_IMPORT_ERROR}")
|
||||||
|
@pytest.mark.skipif(Flask is None, reason=f"flask unavailable: {_FLASK_IMPORT_ERROR}")
|
||||||
|
def test_install_code_allows_multiple_and_reuse(tmp_path, max_uses):
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
db_migrations.apply_all(conn)
|
||||||
|
_create_install_code(conn, "TEST-CODE-1234", max_uses=max_uses)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
tls_path = tmp_path / "tls.pem"
|
||||||
|
tls_path.write_text("TEST CERT")
|
||||||
|
|
||||||
|
app, factory = _make_app(str(db_path), str(tls_path))
|
||||||
|
|
||||||
|
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||||
|
|
||||||
|
# First enrollment consumes one use but keeps the code active.
|
||||||
|
first = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
|
||||||
|
assert first["status"] == "approved"
|
||||||
|
|
||||||
|
with factory() as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"SELECT use_count, max_uses, used_at, last_used_at FROM enrollment_install_codes WHERE code = ?",
|
||||||
|
("TEST-CODE-1234",),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["use_count"] == 1
|
||||||
|
assert row["max_uses"] == max_uses
|
||||||
|
assert row["used_at"] is None
|
||||||
|
assert row["last_used_at"] is not None
|
||||||
|
|
||||||
|
# Second enrollment hits the configured max uses.
|
||||||
|
second = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
|
||||||
|
assert second["status"] == "approved"
|
||||||
|
|
||||||
|
with factory() as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"SELECT use_count, used_at, last_used_at, used_by_guid FROM enrollment_install_codes WHERE code = ?",
|
||||||
|
("TEST-CODE-1234",),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["use_count"] == max_uses
|
||||||
|
assert row["used_at"] is not None
|
||||||
|
assert row["last_used_at"] is not None
|
||||||
|
consumed_guid = row["used_by_guid"]
|
||||||
|
assert consumed_guid
|
||||||
|
|
||||||
|
# Additional enrollments from the same identity reuse the stored GUID even after consumption.
|
||||||
|
third = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
|
||||||
|
assert third["status"] == "approved"
|
||||||
|
|
||||||
|
with factory() as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"SELECT use_count, used_at, last_used_at, used_by_guid FROM enrollment_install_codes WHERE code = ?",
|
||||||
|
("TEST-CODE-1234",),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["use_count"] == max_uses + 1
|
||||||
|
assert row["used_by_guid"] == consumed_guid
|
||||||
|
assert row["used_at"] is not None
|
||||||
|
assert row["last_used_at"] is not None
|
||||||
|
|
||||||
|
cur.execute("SELECT COUNT(*) FROM devices WHERE guid = ?", (consumed_guid,))
|
||||||
|
assert cur.fetchone()[0] == 1
|
||||||
Reference in New Issue
Block a user