mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-27 05:21:57 -06:00
Support multi-use installer codes and reuse
This commit is contained in:
@@ -54,18 +54,27 @@ def register(
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
sql = """
|
||||
SELECT id, code, expires_at, created_by_user_id, used_at, used_by_guid
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
created_by_user_id,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
"""
|
||||
params: List[str] = []
|
||||
now_iso = _iso(_now())
|
||||
if status_filter == "active":
|
||||
sql += " WHERE used_at IS NULL AND expires_at > ?"
|
||||
params.append(_iso(_now()))
|
||||
sql += " WHERE use_count < max_uses AND expires_at > ?"
|
||||
params.append(now_iso)
|
||||
elif status_filter == "expired":
|
||||
sql += " WHERE used_at IS NULL AND expires_at <= ?"
|
||||
params.append(_iso(_now()))
|
||||
sql += " WHERE use_count < max_uses AND expires_at <= ?"
|
||||
params.append(now_iso)
|
||||
elif status_filter == "used":
|
||||
sql += " WHERE used_at IS NOT NULL"
|
||||
sql += " WHERE use_count >= max_uses"
|
||||
sql += " ORDER BY expires_at ASC"
|
||||
cur.execute(sql, params)
|
||||
rows = cur.fetchall()
|
||||
@@ -82,6 +91,9 @@ def register(
|
||||
"created_by_user_id": row[3],
|
||||
"used_at": row[4],
|
||||
"used_by_guid": row[5],
|
||||
"max_uses": row[6],
|
||||
"use_count": row[7],
|
||||
"last_used_at": row[8],
|
||||
}
|
||||
)
|
||||
return jsonify({"codes": records})
|
||||
@@ -93,6 +105,18 @@ def register(
|
||||
if ttl_hours not in VALID_TTL_HOURS:
|
||||
return jsonify({"error": "invalid_ttl"}), 400
|
||||
|
||||
max_uses_value = payload.get("max_uses")
|
||||
if max_uses_value is None:
|
||||
max_uses_value = payload.get("allowed_uses")
|
||||
try:
|
||||
max_uses = int(max_uses_value)
|
||||
except Exception:
|
||||
max_uses = 2
|
||||
if max_uses < 1:
|
||||
max_uses = 1
|
||||
if max_uses > 10:
|
||||
max_uses = 10
|
||||
|
||||
user = current_user() or {}
|
||||
username = user.get("username") or ""
|
||||
|
||||
@@ -106,22 +130,28 @@ def register(
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO enrollment_install_codes (
|
||||
id, code, expires_at, created_by_user_id
|
||||
id, code, expires_at, created_by_user_id, max_uses, use_count
|
||||
)
|
||||
VALUES (?, ?, ?, ?)
|
||||
VALUES (?, ?, ?, ?, ?, 0)
|
||||
""",
|
||||
(record_id, code_value, _iso(expires_at), created_by),
|
||||
(record_id, code_value, _iso(expires_at), created_by, max_uses),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
log("server", f"installer code created id={record_id} by={username} ttl={ttl_hours}h")
|
||||
log(
|
||||
"server",
|
||||
f"installer code created id={record_id} by={username} ttl={ttl_hours}h max_uses={max_uses}",
|
||||
)
|
||||
return jsonify(
|
||||
{
|
||||
"id": record_id,
|
||||
"code": code_value,
|
||||
"expires_at": _iso(expires_at),
|
||||
"max_uses": max_uses,
|
||||
"use_count": 0,
|
||||
"last_used_at": None,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -131,7 +161,7 @@ def register(
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"DELETE FROM enrollment_install_codes WHERE id = ? AND used_at IS NULL",
|
||||
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
|
||||
(code_id,),
|
||||
)
|
||||
deleted = cur.rowcount
|
||||
|
||||
@@ -152,7 +152,10 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
|
||||
expires_at TEXT NOT NULL,
|
||||
created_by_user_id TEXT,
|
||||
used_at TEXT,
|
||||
used_by_guid TEXT
|
||||
used_by_guid TEXT,
|
||||
max_uses INTEGER NOT NULL DEFAULT 1,
|
||||
use_count INTEGER NOT NULL DEFAULT 0,
|
||||
last_used_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -163,6 +166,29 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
)
|
||||
|
||||
columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")}
|
||||
if "max_uses" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes
|
||||
ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1
|
||||
"""
|
||||
)
|
||||
if "use_count" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes
|
||||
ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0
|
||||
"""
|
||||
)
|
||||
if "last_used_at" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes
|
||||
ADD COLUMN last_used_at TEXT
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
|
||||
@@ -6,7 +6,7 @@ import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
@@ -66,31 +66,79 @@ def register(
|
||||
|
||||
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
|
||||
cur.execute(
|
||||
"SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?",
|
||||
"""
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
WHERE code = ?
|
||||
""",
|
||||
(code_value,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
keys = ["id", "code", "expires_at", "used_at"]
|
||||
keys = [
|
||||
"id",
|
||||
"code",
|
||||
"expires_at",
|
||||
"used_at",
|
||||
"used_by_guid",
|
||||
"max_uses",
|
||||
"use_count",
|
||||
"last_used_at",
|
||||
]
|
||||
record = dict(zip(keys, row))
|
||||
return record
|
||||
|
||||
def _install_code_valid(record: Dict[str, Any]) -> bool:
|
||||
def _install_code_valid(
|
||||
record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
if not record:
|
||||
return False
|
||||
return False, None
|
||||
expires_at = record.get("expires_at")
|
||||
if not isinstance(expires_at, str):
|
||||
return False
|
||||
return False, None
|
||||
try:
|
||||
expiry = datetime.fromisoformat(expires_at)
|
||||
except Exception:
|
||||
return False
|
||||
return False, None
|
||||
if expiry <= _now():
|
||||
return False
|
||||
if record.get("used_at"):
|
||||
return False
|
||||
return True
|
||||
return False, None
|
||||
try:
|
||||
max_uses = int(record.get("max_uses") or 1)
|
||||
except Exception:
|
||||
max_uses = 1
|
||||
if max_uses < 1:
|
||||
max_uses = 1
|
||||
try:
|
||||
use_count = int(record.get("use_count") or 0)
|
||||
except Exception:
|
||||
use_count = 0
|
||||
if use_count < max_uses:
|
||||
return True, None
|
||||
|
||||
guid = str(record.get("used_by_guid") or "").strip()
|
||||
if not guid:
|
||||
return False, None
|
||||
cur.execute(
|
||||
"SELECT ssl_key_fingerprint FROM devices WHERE guid = ?",
|
||||
(guid,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return False, None
|
||||
stored_fp = (row[0] or "").strip().lower()
|
||||
if not stored_fp:
|
||||
return False, None
|
||||
if stored_fp == (fingerprint or "").strip().lower():
|
||||
return True, guid
|
||||
return False, None
|
||||
|
||||
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
|
||||
base = (hostname or "").strip() or guid
|
||||
@@ -305,7 +353,13 @@ def register(
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
install_code = _load_install_code(cur, enrollment_code)
|
||||
if not _install_code_valid(install_code):
|
||||
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
|
||||
if not valid_code:
|
||||
log(
|
||||
"server",
|
||||
"enrollment request invalid_code "
|
||||
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
|
||||
)
|
||||
return jsonify({"error": "invalid_enrollment_code"}), 400
|
||||
|
||||
approval_reference: str
|
||||
@@ -331,6 +385,7 @@ def register(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET hostname_claimed = ?,
|
||||
guid = ?,
|
||||
enrollment_code_id = ?,
|
||||
client_nonce = ?,
|
||||
server_nonce = ?,
|
||||
@@ -340,6 +395,7 @@ def register(
|
||||
""",
|
||||
(
|
||||
hostname,
|
||||
reuse_guid,
|
||||
install_code["id"],
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
@@ -359,11 +415,12 @@ def register(
|
||||
status, client_nonce, server_nonce, agent_pubkey_der,
|
||||
created_at, updated_at
|
||||
)
|
||||
VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
record_id,
|
||||
approval_reference,
|
||||
reuse_guid,
|
||||
hostname,
|
||||
fingerprint,
|
||||
install_code["id"],
|
||||
@@ -537,14 +594,40 @@ def register(
|
||||
|
||||
# Mark install code used
|
||||
if enrollment_code_id:
|
||||
cur.execute(
|
||||
"SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?",
|
||||
(enrollment_code_id,),
|
||||
)
|
||||
usage_row = cur.fetchone()
|
||||
try:
|
||||
prior_count = int(usage_row[0]) if usage_row else 0
|
||||
except Exception:
|
||||
prior_count = 0
|
||||
try:
|
||||
allowed_uses = int(usage_row[1]) if usage_row else 1
|
||||
except Exception:
|
||||
allowed_uses = 1
|
||||
if allowed_uses < 1:
|
||||
allowed_uses = 1
|
||||
new_count = prior_count + 1
|
||||
consumed = new_count >= allowed_uses
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE enrollment_install_codes
|
||||
SET used_at = ?, used_by_guid = ?
|
||||
SET use_count = ?,
|
||||
used_by_guid = ?,
|
||||
last_used_at = ?,
|
||||
used_at = CASE WHEN ? THEN ? ELSE used_at END
|
||||
WHERE id = ?
|
||||
AND used_at IS NULL
|
||||
""",
|
||||
(now_iso, effective_guid, enrollment_code_id),
|
||||
(
|
||||
new_count,
|
||||
effective_guid,
|
||||
now_iso,
|
||||
1 if consumed else 0,
|
||||
now_iso,
|
||||
enrollment_code_id,
|
||||
),
|
||||
)
|
||||
|
||||
# Update approval record with final state
|
||||
|
||||
@@ -34,7 +34,7 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
|
||||
cur.execute(
|
||||
"""
|
||||
DELETE FROM enrollment_install_codes
|
||||
WHERE used_at IS NULL
|
||||
WHERE use_count = 0
|
||||
AND expires_at < ?
|
||||
""",
|
||||
(now_iso,),
|
||||
@@ -52,7 +52,10 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
|
||||
SELECT 1
|
||||
FROM enrollment_install_codes c
|
||||
WHERE c.id = device_approvals.enrollment_code_id
|
||||
AND c.expires_at < ?
|
||||
AND (
|
||||
c.expires_at < ?
|
||||
OR c.use_count >= c.max_uses
|
||||
)
|
||||
)
|
||||
OR created_at < ?
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user