mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-16 11:25:48 -07:00
Removed Legacy Server Codebase
This commit is contained in:
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,496 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
from Modules.guid_utils import normalize_guid
|
||||
|
||||
|
||||
VALID_TTL_HOURS = {1, 3, 6, 12, 24}
|
||||
|
||||
|
||||
def register(
|
||||
app,
|
||||
*,
|
||||
db_conn_factory: Callable[[], sqlite3.Connection],
|
||||
require_admin: Callable[[], Optional[Any]],
|
||||
current_user: Callable[[], Optional[Dict[str, str]]],
|
||||
log: Callable[[str, str, Optional[str]], None],
|
||||
) -> None:
|
||||
blueprint = Blueprint("admin", __name__)
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
def _iso(dt: datetime) -> str:
|
||||
return dt.isoformat()
|
||||
|
||||
def _lookup_user_id(cur: sqlite3.Cursor, username: str) -> Optional[str]:
|
||||
if not username:
|
||||
return None
|
||||
cur.execute(
|
||||
"SELECT id FROM users WHERE LOWER(username) = LOWER(?)",
|
||||
(username,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
return str(row[0])
|
||||
return None
|
||||
|
||||
def _hostname_conflict(
|
||||
cur: sqlite3.Cursor,
|
||||
hostname: Optional[str],
|
||||
pending_guid: Optional[str],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if not hostname:
|
||||
return None
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT d.guid, d.ssl_key_fingerprint, ds.site_id, s.name
|
||||
FROM devices d
|
||||
LEFT JOIN device_sites ds ON ds.device_hostname = d.hostname
|
||||
LEFT JOIN sites s ON s.id = ds.site_id
|
||||
WHERE d.hostname = ?
|
||||
""",
|
||||
(hostname,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
existing_guid = normalize_guid(row[0])
|
||||
existing_fingerprint = (row[1] or "").strip().lower()
|
||||
pending_norm = normalize_guid(pending_guid)
|
||||
if existing_guid and pending_norm and existing_guid == pending_norm:
|
||||
return None
|
||||
site_id_raw = row[2]
|
||||
site_id = None
|
||||
if site_id_raw is not None:
|
||||
try:
|
||||
site_id = int(site_id_raw)
|
||||
except (TypeError, ValueError):
|
||||
site_id = None
|
||||
site_name = row[3] or ""
|
||||
return {
|
||||
"guid": existing_guid or None,
|
||||
"ssl_key_fingerprint": existing_fingerprint or None,
|
||||
"site_id": site_id,
|
||||
"site_name": site_name,
|
||||
}
|
||||
|
||||
def _suggest_alternate_hostname(
|
||||
cur: sqlite3.Cursor,
|
||||
hostname: Optional[str],
|
||||
pending_guid: Optional[str],
|
||||
) -> Optional[str]:
|
||||
base = (hostname or "").strip()
|
||||
if not base:
|
||||
return None
|
||||
base = base[:253]
|
||||
candidate = base
|
||||
pending_norm = normalize_guid(pending_guid)
|
||||
suffix = 1
|
||||
while True:
|
||||
cur.execute(
|
||||
"SELECT guid FROM devices WHERE hostname = ?",
|
||||
(candidate,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return candidate
|
||||
existing_guid = normalize_guid(row[0])
|
||||
if pending_norm and existing_guid == pending_norm:
|
||||
return candidate
|
||||
candidate = f"{base}-{suffix}"
|
||||
suffix += 1
|
||||
if suffix > 50:
|
||||
return pending_norm or candidate
|
||||
|
||||
@blueprint.before_request
|
||||
def _check_admin():
|
||||
result = require_admin()
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
@blueprint.route("/api/admin/enrollment-codes", methods=["GET"])
|
||||
def list_enrollment_codes():
|
||||
status_filter = request.args.get("status")
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
sql = """
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
created_by_user_id,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
"""
|
||||
params: List[str] = []
|
||||
now_iso = _iso(_now())
|
||||
if status_filter == "active":
|
||||
sql += " WHERE use_count < max_uses AND expires_at > ?"
|
||||
params.append(now_iso)
|
||||
elif status_filter == "expired":
|
||||
sql += " WHERE use_count < max_uses AND expires_at <= ?"
|
||||
params.append(now_iso)
|
||||
elif status_filter == "used":
|
||||
sql += " WHERE use_count >= max_uses"
|
||||
sql += " ORDER BY expires_at ASC"
|
||||
cur.execute(sql, params)
|
||||
rows = cur.fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
records = []
|
||||
for row in rows:
|
||||
records.append(
|
||||
{
|
||||
"id": row[0],
|
||||
"code": row[1],
|
||||
"expires_at": row[2],
|
||||
"created_by_user_id": row[3],
|
||||
"used_at": row[4],
|
||||
"used_by_guid": row[5],
|
||||
"max_uses": row[6],
|
||||
"use_count": row[7],
|
||||
"last_used_at": row[8],
|
||||
}
|
||||
)
|
||||
return jsonify({"codes": records})
|
||||
|
||||
@blueprint.route("/api/admin/enrollment-codes", methods=["POST"])
|
||||
def create_enrollment_code():
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
ttl_hours = int(payload.get("ttl_hours") or 1)
|
||||
if ttl_hours not in VALID_TTL_HOURS:
|
||||
return jsonify({"error": "invalid_ttl"}), 400
|
||||
|
||||
max_uses_value = payload.get("max_uses")
|
||||
if max_uses_value is None:
|
||||
max_uses_value = payload.get("allowed_uses")
|
||||
try:
|
||||
max_uses = int(max_uses_value)
|
||||
except Exception:
|
||||
max_uses = 2
|
||||
if max_uses < 1:
|
||||
max_uses = 1
|
||||
if max_uses > 10:
|
||||
max_uses = 10
|
||||
|
||||
user = current_user() or {}
|
||||
username = user.get("username") or ""
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
created_by = _lookup_user_id(cur, username) or username or "system"
|
||||
code_value = _generate_install_code()
|
||||
issued_at = _now()
|
||||
expires_at = issued_at + timedelta(hours=ttl_hours)
|
||||
record_id = str(uuid.uuid4())
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO enrollment_install_codes (
|
||||
id, code, expires_at, created_by_user_id, max_uses, use_count
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, 0)
|
||||
""",
|
||||
(record_id, code_value, _iso(expires_at), created_by, max_uses),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO enrollment_install_codes_persistent (
|
||||
id,
|
||||
code,
|
||||
created_at,
|
||||
expires_at,
|
||||
created_by_user_id,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
last_known_use_count,
|
||||
last_used_at,
|
||||
is_active,
|
||||
archived_at,
|
||||
consumed_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, NULL, NULL, ?, 0, NULL, 1, NULL, NULL)
|
||||
ON CONFLICT(id) DO UPDATE
|
||||
SET code = excluded.code,
|
||||
created_at = excluded.created_at,
|
||||
expires_at = excluded.expires_at,
|
||||
created_by_user_id = excluded.created_by_user_id,
|
||||
max_uses = excluded.max_uses,
|
||||
last_known_use_count = 0,
|
||||
used_at = NULL,
|
||||
used_by_guid = NULL,
|
||||
last_used_at = NULL,
|
||||
is_active = 1,
|
||||
archived_at = NULL,
|
||||
consumed_at = NULL
|
||||
""",
|
||||
(record_id, code_value, _iso(issued_at), _iso(expires_at), created_by, max_uses),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
log(
|
||||
"server",
|
||||
f"installer code created id={record_id} by={username} ttl={ttl_hours}h max_uses={max_uses}",
|
||||
)
|
||||
return jsonify(
|
||||
{
|
||||
"id": record_id,
|
||||
"code": code_value,
|
||||
"expires_at": _iso(expires_at),
|
||||
"max_uses": max_uses,
|
||||
"use_count": 0,
|
||||
"last_used_at": None,
|
||||
}
|
||||
)
|
||||
|
||||
@blueprint.route("/api/admin/enrollment-codes/<code_id>", methods=["DELETE"])
|
||||
def delete_enrollment_code(code_id: str):
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
|
||||
(code_id,),
|
||||
)
|
||||
deleted = cur.rowcount
|
||||
if deleted:
|
||||
archive_ts = _iso(_now())
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE enrollment_install_codes_persistent
|
||||
SET is_active = 0,
|
||||
archived_at = COALESCE(archived_at, ?)
|
||||
WHERE id = ?
|
||||
""",
|
||||
(archive_ts, code_id),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if not deleted:
|
||||
return jsonify({"error": "not_found"}), 404
|
||||
log("server", f"installer code deleted id={code_id}")
|
||||
return jsonify({"status": "deleted"})
|
||||
|
||||
@blueprint.route("/api/admin/device-approvals", methods=["GET"])
|
||||
def list_device_approvals():
|
||||
status_raw = request.args.get("status")
|
||||
status = (status_raw or "").strip().lower()
|
||||
approvals: List[Dict[str, Any]] = []
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
params: List[str] = []
|
||||
sql = """
|
||||
SELECT
|
||||
da.id,
|
||||
da.approval_reference,
|
||||
da.guid,
|
||||
da.hostname_claimed,
|
||||
da.ssl_key_fingerprint_claimed,
|
||||
da.enrollment_code_id,
|
||||
da.status,
|
||||
da.client_nonce,
|
||||
da.server_nonce,
|
||||
da.created_at,
|
||||
da.updated_at,
|
||||
da.approved_by_user_id,
|
||||
u.username AS approved_by_username
|
||||
FROM device_approvals AS da
|
||||
LEFT JOIN users AS u
|
||||
ON (
|
||||
CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT)
|
||||
OR LOWER(da.approved_by_user_id) = LOWER(u.username)
|
||||
)
|
||||
"""
|
||||
if status and status != "all":
|
||||
sql += " WHERE LOWER(da.status) = ?"
|
||||
params.append(status)
|
||||
sql += " ORDER BY da.created_at ASC"
|
||||
cur.execute(sql, params)
|
||||
rows = cur.fetchall()
|
||||
for row in rows:
|
||||
record_guid = row[2]
|
||||
hostname = row[3]
|
||||
fingerprint_claimed = row[4]
|
||||
claimed_fp_norm = (fingerprint_claimed or "").strip().lower()
|
||||
conflict_raw = _hostname_conflict(cur, hostname, record_guid)
|
||||
fingerprint_match = False
|
||||
requires_prompt = False
|
||||
conflict = None
|
||||
if conflict_raw:
|
||||
conflict_fp = (conflict_raw.get("ssl_key_fingerprint") or "").strip().lower()
|
||||
fingerprint_match = bool(conflict_fp and claimed_fp_norm) and conflict_fp == claimed_fp_norm
|
||||
requires_prompt = not fingerprint_match
|
||||
conflict = {
|
||||
**conflict_raw,
|
||||
"fingerprint_match": fingerprint_match,
|
||||
"requires_prompt": requires_prompt,
|
||||
}
|
||||
alternate_hostname = (
|
||||
_suggest_alternate_hostname(cur, hostname, record_guid)
|
||||
if conflict_raw and requires_prompt
|
||||
else None
|
||||
)
|
||||
approvals.append(
|
||||
{
|
||||
"id": row[0],
|
||||
"approval_reference": row[1],
|
||||
"guid": record_guid,
|
||||
"hostname_claimed": hostname,
|
||||
"ssl_key_fingerprint_claimed": fingerprint_claimed,
|
||||
"enrollment_code_id": row[5],
|
||||
"status": row[6],
|
||||
"client_nonce": row[7],
|
||||
"server_nonce": row[8],
|
||||
"created_at": row[9],
|
||||
"updated_at": row[10],
|
||||
"approved_by_user_id": row[11],
|
||||
"hostname_conflict": conflict,
|
||||
"alternate_hostname": alternate_hostname,
|
||||
"conflict_requires_prompt": requires_prompt,
|
||||
"fingerprint_match": fingerprint_match,
|
||||
"approved_by_username": row[12],
|
||||
}
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return jsonify({"approvals": approvals})
|
||||
|
||||
def _set_approval_status(
|
||||
approval_id: str,
|
||||
status: str,
|
||||
*,
|
||||
guid: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
):
|
||||
user = current_user() or {}
|
||||
username = user.get("username") or ""
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT status,
|
||||
guid,
|
||||
hostname_claimed,
|
||||
ssl_key_fingerprint_claimed
|
||||
FROM device_approvals
|
||||
WHERE id = ?
|
||||
""",
|
||||
(approval_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return {"error": "not_found"}, 404
|
||||
existing_status = (row[0] or "").strip().lower()
|
||||
if existing_status != "pending":
|
||||
return {"error": "approval_not_pending"}, 409
|
||||
stored_guid = row[1]
|
||||
hostname_claimed = row[2]
|
||||
fingerprint_claimed = (row[3] or "").strip().lower()
|
||||
|
||||
guid_effective = normalize_guid(guid) if guid else normalize_guid(stored_guid)
|
||||
resolution_effective = (resolution.strip().lower() if isinstance(resolution, str) else None)
|
||||
|
||||
conflict = None
|
||||
if status == "approved":
|
||||
conflict = _hostname_conflict(cur, hostname_claimed, guid_effective)
|
||||
if conflict:
|
||||
conflict_fp = (conflict.get("ssl_key_fingerprint") or "").strip().lower()
|
||||
fingerprint_match = bool(conflict_fp and fingerprint_claimed) and conflict_fp == fingerprint_claimed
|
||||
if fingerprint_match:
|
||||
guid_effective = conflict.get("guid") or guid_effective
|
||||
if not resolution_effective:
|
||||
resolution_effective = "auto_merge_fingerprint"
|
||||
elif resolution_effective == "overwrite":
|
||||
guid_effective = conflict.get("guid") or guid_effective
|
||||
elif resolution_effective == "coexist":
|
||||
pass
|
||||
else:
|
||||
return {
|
||||
"error": "conflict_resolution_required",
|
||||
"hostname": hostname_claimed,
|
||||
}, 409
|
||||
|
||||
guid_to_store = guid_effective or normalize_guid(stored_guid) or None
|
||||
|
||||
approved_by = _lookup_user_id(cur, username) or username or "system"
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET status = ?,
|
||||
guid = ?,
|
||||
approved_by_user_id = ?,
|
||||
updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
status,
|
||||
guid_to_store,
|
||||
approved_by,
|
||||
_iso(_now()),
|
||||
approval_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
resolution_note = f" ({resolution_effective})" if resolution_effective else ""
|
||||
log("server", f"device approval {approval_id} -> {status}{resolution_note} by {username}")
|
||||
payload: Dict[str, Any] = {"status": status}
|
||||
if resolution_effective:
|
||||
payload["conflict_resolution"] = resolution_effective
|
||||
return payload, 200
|
||||
|
||||
@blueprint.route("/api/admin/device-approvals/<approval_id>/approve", methods=["POST"])
|
||||
def approve_device(approval_id: str):
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
guid = payload.get("guid")
|
||||
if guid:
|
||||
guid = str(guid).strip()
|
||||
resolution_val = payload.get("conflict_resolution")
|
||||
resolution = None
|
||||
if isinstance(resolution_val, str):
|
||||
cleaned = resolution_val.strip().lower()
|
||||
if cleaned:
|
||||
resolution = cleaned
|
||||
result, status_code = _set_approval_status(
|
||||
approval_id,
|
||||
"approved",
|
||||
guid=guid,
|
||||
resolution=resolution,
|
||||
)
|
||||
return jsonify(result), status_code
|
||||
|
||||
@blueprint.route("/api/admin/device-approvals/<approval_id>/deny", methods=["POST"])
|
||||
def deny_device(approval_id: str):
|
||||
result, status_code = _set_approval_status(approval_id, "denied")
|
||||
return jsonify(result), status_code
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
def _generate_install_code() -> str:
|
||||
raw = secrets.token_hex(16).upper()
|
||||
return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import sqlite3
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from flask import Blueprint, jsonify, request, g
|
||||
|
||||
from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
|
||||
from Modules.crypto.signing import ScriptSigner
|
||||
from Modules.guid_utils import normalize_guid
|
||||
|
||||
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||
|
||||
|
||||
def _canonical_context(value: Optional[str]) -> Optional[str]:
|
||||
if not value:
|
||||
return None
|
||||
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
|
||||
if not cleaned:
|
||||
return None
|
||||
return cleaned.upper()
|
||||
|
||||
|
||||
def register(
|
||||
app,
|
||||
*,
|
||||
db_conn_factory: Callable[[], Any],
|
||||
auth_manager: DeviceAuthManager,
|
||||
log: Callable[[str, str, Optional[str]], None],
|
||||
script_signer: ScriptSigner,
|
||||
) -> None:
|
||||
blueprint = Blueprint("agents", __name__)
|
||||
|
||||
def _json_or_none(value) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return json.dumps(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _context_hint(ctx=None) -> Optional[str]:
|
||||
if ctx is not None and getattr(ctx, "service_mode", None):
|
||||
return _canonical_context(getattr(ctx, "service_mode", None))
|
||||
return _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||
|
||||
def _auth_context():
|
||||
ctx = getattr(g, "device_auth", None)
|
||||
if ctx is None:
|
||||
log("server", f"device auth context missing for {request.path}", _context_hint())
|
||||
return ctx
|
||||
|
||||
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
|
||||
@require_device_auth(auth_manager)
|
||||
def heartbeat():
|
||||
ctx = _auth_context()
|
||||
if ctx is None:
|
||||
return jsonify({"error": "auth_context_missing"}), 500
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
context_label = _context_hint(ctx)
|
||||
|
||||
now_ts = int(time.time())
|
||||
updates: Dict[str, Optional[str]] = {"last_seen": now_ts}
|
||||
|
||||
hostname = payload.get("hostname")
|
||||
if isinstance(hostname, str) and hostname.strip():
|
||||
updates["hostname"] = hostname.strip()
|
||||
|
||||
inventory = payload.get("inventory") if isinstance(payload.get("inventory"), dict) else {}
|
||||
for key in ("memory", "network", "software", "storage", "cpu"):
|
||||
if key in inventory and inventory[key] is not None:
|
||||
encoded = _json_or_none(inventory[key])
|
||||
if encoded is not None:
|
||||
updates[key] = encoded
|
||||
|
||||
metrics = payload.get("metrics") if isinstance(payload.get("metrics"), dict) else {}
|
||||
def _maybe_str(field: str) -> Optional[str]:
|
||||
val = metrics.get(field)
|
||||
if isinstance(val, str):
|
||||
return val.strip()
|
||||
return None
|
||||
|
||||
if "last_user" in metrics and metrics["last_user"]:
|
||||
updates["last_user"] = str(metrics["last_user"])
|
||||
if "operating_system" in metrics and metrics["operating_system"]:
|
||||
updates["operating_system"] = str(metrics["operating_system"])
|
||||
if "uptime" in metrics and metrics["uptime"] is not None:
|
||||
try:
|
||||
updates["uptime"] = int(metrics["uptime"])
|
||||
except Exception:
|
||||
pass
|
||||
for field in ("external_ip", "internal_ip", "device_type"):
|
||||
if field in payload and payload[field]:
|
||||
updates[field] = str(payload[field])
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
|
||||
def _apply_updates() -> int:
|
||||
if not updates:
|
||||
return 0
|
||||
columns = ", ".join(f"{col} = ?" for col in updates.keys())
|
||||
values = list(updates.values())
|
||||
normalized_guid = normalize_guid(ctx.guid)
|
||||
selected_guid: Optional[str] = None
|
||||
if normalized_guid:
|
||||
cur.execute(
|
||||
"SELECT guid FROM devices WHERE UPPER(guid) = ?",
|
||||
(normalized_guid,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
for (stored_guid,) in rows or []:
|
||||
if stored_guid == ctx.guid:
|
||||
selected_guid = stored_guid
|
||||
break
|
||||
if not selected_guid and rows:
|
||||
selected_guid = rows[0][0]
|
||||
target_guid = selected_guid or ctx.guid
|
||||
cur.execute(
|
||||
f"UPDATE devices SET {columns} WHERE guid = ?",
|
||||
values + [target_guid],
|
||||
)
|
||||
updated = cur.rowcount
|
||||
if updated > 0 and normalized_guid and target_guid != normalized_guid:
|
||||
try:
|
||||
cur.execute(
|
||||
"UPDATE devices SET guid = ? WHERE guid = ?",
|
||||
(normalized_guid, target_guid),
|
||||
)
|
||||
except sqlite3.IntegrityError:
|
||||
pass
|
||||
return updated
|
||||
|
||||
try:
|
||||
rowcount = _apply_updates()
|
||||
except sqlite3.IntegrityError as exc:
|
||||
if "devices.hostname" in str(exc) and "UNIQUE" in str(exc).upper():
|
||||
# Another device already claims this hostname; keep the existing
|
||||
# canonical hostname assigned during enrollment to avoid breaking
|
||||
# the unique constraint and continue updating the remaining fields.
|
||||
existing_guid_for_hostname: Optional[str] = None
|
||||
if "hostname" in updates:
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT guid FROM devices WHERE hostname = ?",
|
||||
(updates["hostname"],),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row[0]:
|
||||
existing_guid_for_hostname = normalize_guid(row[0])
|
||||
except Exception:
|
||||
existing_guid_for_hostname = None
|
||||
if "hostname" in updates:
|
||||
updates.pop("hostname", None)
|
||||
try:
|
||||
rowcount = _apply_updates()
|
||||
except sqlite3.IntegrityError:
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
current_guid = normalize_guid(ctx.guid)
|
||||
except Exception:
|
||||
current_guid = ctx.guid
|
||||
if (
|
||||
existing_guid_for_hostname
|
||||
and current_guid
|
||||
and existing_guid_for_hostname == current_guid
|
||||
):
|
||||
pass # Same device contexts; no log needed.
|
||||
else:
|
||||
log(
|
||||
"server",
|
||||
"heartbeat hostname collision ignored for guid="
|
||||
f"{ctx.guid}",
|
||||
context_label,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
if rowcount == 0:
|
||||
log("server", f"heartbeat missing device record guid={ctx.guid}", context_label)
|
||||
return jsonify({"error": "device_not_registered"}), 404
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return jsonify({"status": "ok", "poll_after_ms": 15000})
|
||||
|
||||
@blueprint.route("/api/agent/script/request", methods=["POST"])
|
||||
@require_device_auth(auth_manager)
|
||||
def script_request():
|
||||
ctx = _auth_context()
|
||||
if ctx is None:
|
||||
return jsonify({"error": "auth_context_missing"}), 500
|
||||
if ctx.status != "active":
|
||||
return jsonify(
|
||||
{
|
||||
"status": "quarantined",
|
||||
"poll_after_ms": 60000,
|
||||
"sig_alg": "ed25519",
|
||||
"signing_key": script_signer.public_base64_spki(),
|
||||
}
|
||||
)
|
||||
|
||||
# Placeholder: actual dispatch logic will integrate with job scheduler.
|
||||
return jsonify(
|
||||
{
|
||||
"status": "idle",
|
||||
"poll_after_ms": 30000,
|
||||
"sig_alg": "ed25519",
|
||||
"signing_key": script_signer.public_base64_spki(),
|
||||
}
|
||||
)
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,310 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import sqlite3
|
||||
import time
|
||||
from contextlib import closing
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from flask import g, jsonify, request
|
||||
|
||||
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
||||
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||
from Modules.guid_utils import normalize_guid
|
||||
|
||||
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||
|
||||
|
||||
def _canonical_context(value: Optional[str]) -> Optional[str]:
|
||||
if not value:
|
||||
return None
|
||||
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
|
||||
if not cleaned:
|
||||
return None
|
||||
return cleaned.upper()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceAuthContext:
|
||||
guid: str
|
||||
ssl_key_fingerprint: str
|
||||
token_version: int
|
||||
access_token: str
|
||||
claims: Dict[str, Any]
|
||||
dpop_jkt: Optional[str]
|
||||
status: str
|
||||
service_mode: Optional[str]
|
||||
|
||||
|
||||
class DeviceAuthError(Exception):
|
||||
status_code = 401
|
||||
error_code = "unauthorized"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "unauthorized",
|
||||
*,
|
||||
status_code: Optional[int] = None,
|
||||
retry_after: Optional[float] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
if status_code is not None:
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
class DeviceAuthManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
db_conn_factory: Callable[[], Any],
|
||||
jwt_service,
|
||||
dpop_validator: Optional[DPoPValidator],
|
||||
log: Callable[[str, str, Optional[str]], None],
|
||||
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
|
||||
) -> None:
|
||||
self._db_conn_factory = db_conn_factory
|
||||
self._jwt_service = jwt_service
|
||||
self._dpop_validator = dpop_validator
|
||||
self._log = log
|
||||
self._rate_limiter = rate_limiter
|
||||
|
||||
def authenticate(self) -> DeviceAuthContext:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise DeviceAuthError("missing_authorization")
|
||||
token = auth_header[len("Bearer ") :].strip()
|
||||
if not token:
|
||||
raise DeviceAuthError("missing_authorization")
|
||||
|
||||
try:
|
||||
claims = self._jwt_service.decode(token)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise DeviceAuthError("token_expired")
|
||||
except Exception:
|
||||
raise DeviceAuthError("invalid_token")
|
||||
|
||||
raw_guid = str(claims.get("guid") or "").strip()
|
||||
guid = normalize_guid(raw_guid)
|
||||
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
|
||||
token_version = int(claims.get("token_version") or 0)
|
||||
if not guid or not fingerprint or token_version <= 0:
|
||||
raise DeviceAuthError("invalid_claims")
|
||||
|
||||
if self._rate_limiter:
|
||||
decision = self._rate_limiter.check(f"fp:{fingerprint}", 60, 60.0)
|
||||
if not decision.allowed:
|
||||
raise DeviceAuthError(
|
||||
"rate_limited",
|
||||
status_code=429,
|
||||
retry_after=decision.retry_after,
|
||||
)
|
||||
|
||||
context_label = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||
|
||||
with closing(self._db_conn_factory()) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
row = None
|
||||
for candidate in rows or []:
|
||||
candidate_guid = normalize_guid(candidate[0])
|
||||
if candidate_guid == guid:
|
||||
row = candidate
|
||||
break
|
||||
if row is None and rows:
|
||||
row = rows[0]
|
||||
|
||||
if not row:
|
||||
row = self._recover_device_record(
|
||||
conn, guid, fingerprint, token_version, context_label
|
||||
)
|
||||
|
||||
if not row:
|
||||
raise DeviceAuthError("device_not_found", status_code=403)
|
||||
|
||||
db_guid, db_fp, db_token_version, status = row
|
||||
db_guid_normalized = normalize_guid(db_guid)
|
||||
|
||||
if not db_guid_normalized or db_guid_normalized != guid:
|
||||
raise DeviceAuthError("device_guid_mismatch", status_code=403)
|
||||
|
||||
db_fp = (db_fp or "").lower().strip()
|
||||
if db_fp and db_fp != fingerprint:
|
||||
raise DeviceAuthError("fingerprint_mismatch", status_code=403)
|
||||
|
||||
if db_token_version and db_token_version > token_version:
|
||||
raise DeviceAuthError("token_version_revoked", status_code=401)
|
||||
|
||||
status_normalized = (status or "active").strip().lower()
|
||||
allowed_statuses = {"active", "quarantined"}
|
||||
if status_normalized not in allowed_statuses:
|
||||
raise DeviceAuthError("device_revoked", status_code=403)
|
||||
if status_normalized == "quarantined":
|
||||
self._log(
|
||||
"server",
|
||||
f"device {guid} is quarantined; limited access for {request.path}",
|
||||
context_label,
|
||||
)
|
||||
|
||||
dpop_jkt: Optional[str] = None
|
||||
dpop_proof = request.headers.get("DPoP")
|
||||
if dpop_proof:
|
||||
if not self._dpop_validator:
|
||||
raise DeviceAuthError("dpop_not_supported", status_code=400)
|
||||
try:
|
||||
htu = request.url
|
||||
dpop_jkt = self._dpop_validator.verify(request.method, htu, dpop_proof, token)
|
||||
except DPoPReplayError:
|
||||
raise DeviceAuthError("dpop_replayed", status_code=400)
|
||||
except DPoPVerificationError:
|
||||
raise DeviceAuthError("dpop_invalid", status_code=400)
|
||||
|
||||
ctx = DeviceAuthContext(
|
||||
guid=guid,
|
||||
ssl_key_fingerprint=fingerprint,
|
||||
token_version=token_version,
|
||||
access_token=token,
|
||||
claims=claims,
|
||||
dpop_jkt=dpop_jkt,
|
||||
status=status_normalized,
|
||||
service_mode=context_label,
|
||||
)
|
||||
return ctx
|
||||
|
||||
def _recover_device_record(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
guid: str,
|
||||
fingerprint: str,
|
||||
token_version: int,
|
||||
context_label: Optional[str],
|
||||
) -> Optional[tuple]:
|
||||
"""Attempt to recreate a missing device row for an authenticated token."""
|
||||
|
||||
guid = normalize_guid(guid)
|
||||
fingerprint = (fingerprint or "").strip()
|
||||
if not guid or not fingerprint:
|
||||
return None
|
||||
|
||||
cur = conn.cursor()
|
||||
now_ts = int(time.time())
|
||||
try:
|
||||
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||
except Exception:
|
||||
now_iso = datetime.utcnow().isoformat() # pragma: no cover
|
||||
|
||||
base_hostname = f"RECOVERED-{guid[:12].upper()}" if guid else "RECOVERED"
|
||||
|
||||
for attempt in range(6):
|
||||
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
|
||||
try:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO devices (
|
||||
guid,
|
||||
hostname,
|
||||
created_at,
|
||||
last_seen,
|
||||
ssl_key_fingerprint,
|
||||
token_version,
|
||||
status,
|
||||
key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
|
||||
""",
|
||||
(
|
||||
guid,
|
||||
hostname,
|
||||
now_ts,
|
||||
now_ts,
|
||||
fingerprint,
|
||||
max(token_version or 1, 1),
|
||||
now_iso,
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
# Hostname collision – try again with a suffixed placeholder.
|
||||
message = str(exc).lower()
|
||||
if "hostname" in message and "unique" in message:
|
||||
continue
|
||||
self._log(
|
||||
"server",
|
||||
f"device auth failed to recover guid={guid} due to integrity error: {exc}",
|
||||
context_label,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log(
|
||||
"server",
|
||||
f"device auth unexpected error recovering guid={guid}: {exc}",
|
||||
context_label,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
else:
|
||||
conn.commit()
|
||||
break
|
||||
else:
|
||||
# Exhausted attempts because of hostname collisions.
|
||||
self._log(
|
||||
"server",
|
||||
f"device auth could not recover guid={guid}; hostname collisions persisted",
|
||||
context_label,
|
||||
)
|
||||
conn.rollback()
|
||||
return None
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(guid,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
self._log(
|
||||
"server",
|
||||
f"device auth recovery for guid={guid} committed but row still missing",
|
||||
context_label,
|
||||
)
|
||||
return row
|
||||
|
||||
|
||||
def require_device_auth(manager: DeviceAuthManager):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
ctx = manager.authenticate()
|
||||
except DeviceAuthError as exc:
|
||||
response = jsonify({"error": exc.message})
|
||||
response.status_code = exc.status_code
|
||||
retry_after = getattr(exc, "retry_after", None)
|
||||
if retry_after:
|
||||
try:
|
||||
response.headers["Retry-After"] = str(max(1, int(retry_after)))
|
||||
except Exception:
|
||||
response.headers["Retry-After"] = "1"
|
||||
return response
|
||||
|
||||
g.device_auth = ctx
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -1,109 +0,0 @@
|
||||
"""
|
||||
DPoP proof verification helpers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict, Optional
|
||||
|
||||
import jwt
|
||||
|
||||
_DP0P_MAX_SKEW = 300.0 # seconds
|
||||
|
||||
|
||||
class DPoPVerificationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DPoPReplayError(DPoPVerificationError):
|
||||
pass
|
||||
|
||||
|
||||
class DPoPValidator:
|
||||
def __init__(self) -> None:
|
||||
self._observed_jti: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def verify(
|
||||
self,
|
||||
method: str,
|
||||
htu: str,
|
||||
proof: str,
|
||||
access_token: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Verify the presented DPoP proof. Returns the JWK thumbprint on success.
|
||||
"""
|
||||
|
||||
if not proof:
|
||||
raise DPoPVerificationError("DPoP proof missing")
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(proof)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP header") from exc
|
||||
|
||||
jwk = header.get("jwk")
|
||||
alg = header.get("alg")
|
||||
if not jwk or not isinstance(jwk, dict):
|
||||
raise DPoPVerificationError("missing jwk in DPoP header")
|
||||
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
|
||||
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
|
||||
|
||||
try:
|
||||
key = jwt.PyJWK(jwk)
|
||||
public_key = key.key
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
|
||||
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
proof,
|
||||
public_key,
|
||||
algorithms=[alg],
|
||||
options={"require": ["htm", "htu", "jti", "iat"]},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP signature") from exc
|
||||
|
||||
htm = claims.get("htm")
|
||||
proof_htu = claims.get("htu")
|
||||
jti = claims.get("jti")
|
||||
iat = claims.get("iat")
|
||||
ath = claims.get("ath")
|
||||
|
||||
if not isinstance(htm, str) or htm.lower() != method.lower():
|
||||
raise DPoPVerificationError("DPoP htm mismatch")
|
||||
if not isinstance(proof_htu, str) or proof_htu != htu:
|
||||
raise DPoPVerificationError("DPoP htu mismatch")
|
||||
if not isinstance(jti, str):
|
||||
raise DPoPVerificationError("DPoP jti missing")
|
||||
if not isinstance(iat, (int, float)):
|
||||
raise DPoPVerificationError("DPoP iat missing")
|
||||
|
||||
now = time.time()
|
||||
if abs(now - float(iat)) > _DP0P_MAX_SKEW:
|
||||
raise DPoPVerificationError("DPoP proof outside allowed skew")
|
||||
|
||||
if ath and access_token:
|
||||
expected_ath = jwt.utils.base64url_encode(
|
||||
hashlib.sha256(access_token.encode("utf-8")).digest()
|
||||
).decode("ascii")
|
||||
if expected_ath != ath:
|
||||
raise DPoPVerificationError("DPoP ath mismatch")
|
||||
|
||||
with self._lock:
|
||||
expiry = self._observed_jti.get(jti)
|
||||
if expiry and expiry > now:
|
||||
raise DPoPReplayError("DPoP proof replay detected")
|
||||
self._observed_jti[jti] = now + _DP0P_MAX_SKEW
|
||||
# Opportunistic cleanup
|
||||
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
|
||||
for key in stale:
|
||||
self._observed_jti.pop(key, None)
|
||||
|
||||
thumbprint = jwt.PyJWK(jwk).thumbprint()
|
||||
return thumbprint.decode("ascii")
|
||||
@@ -1,140 +0,0 @@
|
||||
"""
|
||||
JWT access-token helpers backed by an Ed25519 signing key.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from Modules.runtime import ensure_runtime_dir, runtime_path
|
||||
|
||||
_KEY_DIR = runtime_path("auth_keys")
|
||||
_KEY_FILE = _KEY_DIR / "borealis-jwt-ed25519.key"
|
||||
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-jwt-ed25519.key"
|
||||
|
||||
|
||||
class JWTService:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str):
|
||||
self._private_key = private_key
|
||||
self._public_key = private_key.public_key()
|
||||
self._key_id = key_id
|
||||
|
||||
@property
|
||||
def key_id(self) -> str:
|
||||
return self._key_id
|
||||
|
||||
def issue_access_token(
|
||||
self,
|
||||
guid: str,
|
||||
ssl_key_fingerprint: str,
|
||||
token_version: int,
|
||||
expires_in: int = 900,
|
||||
extra_claims: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
now = int(time.time())
|
||||
payload: Dict[str, Any] = {
|
||||
"sub": f"device:{guid}",
|
||||
"guid": guid,
|
||||
"ssl_key_fingerprint": ssl_key_fingerprint,
|
||||
"token_version": int(token_version),
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": now + int(expires_in),
|
||||
}
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
self._private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
),
|
||||
algorithm="EdDSA",
|
||||
headers={"kid": self._key_id},
|
||||
)
|
||||
return token
|
||||
|
||||
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
|
||||
options = {"require": ["exp", "iat", "sub"]}
|
||||
public_pem = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=["EdDSA"],
|
||||
audience=audience,
|
||||
options=options,
|
||||
)
|
||||
|
||||
def public_jwk(self) -> Dict[str, Any]:
|
||||
public_bytes = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
# PyJWT expects base64url without padding.
|
||||
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
|
||||
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
|
||||
|
||||
|
||||
def load_service() -> JWTService:
|
||||
private_key = _load_or_create_private_key()
|
||||
public_bytes = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
|
||||
return JWTService(private_key, key_id)
|
||||
|
||||
|
||||
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
|
||||
ensure_runtime_dir("auth_keys")
|
||||
_migrate_legacy_key_if_present()
|
||||
|
||||
if _KEY_FILE.exists():
|
||||
with _KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
if _LEGACY_KEY_FILE.exists():
|
||||
with _LEGACY_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
with _KEY_FILE.open("wb") as fh:
|
||||
fh.write(pem)
|
||||
try:
|
||||
if _KEY_FILE.exists() and hasattr(_KEY_FILE, "chmod"):
|
||||
_KEY_FILE.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
return private_key
|
||||
|
||||
|
||||
def _migrate_legacy_key_if_present() -> None:
|
||||
if not _LEGACY_KEY_FILE.exists() or _KEY_FILE.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
ensure_runtime_dir("auth_keys")
|
||||
try:
|
||||
_LEGACY_KEY_FILE.replace(_KEY_FILE)
|
||||
except Exception:
|
||||
_KEY_FILE.write_bytes(_LEGACY_KEY_FILE.read_bytes())
|
||||
except Exception:
|
||||
return
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"""
|
||||
Tiny in-memory rate limiter suitable for single-process development servers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Deque, Dict, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitDecision:
|
||||
allowed: bool
|
||||
retry_after: float
|
||||
|
||||
|
||||
class SlidingWindowRateLimiter:
|
||||
def __init__(self) -> None:
|
||||
self._buckets: Dict[str, Deque[float]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
bucket = self._buckets.get(key)
|
||||
if bucket is None:
|
||||
bucket = deque()
|
||||
self._buckets[key] = bucket
|
||||
|
||||
while bucket and now - bucket[0] > window_seconds:
|
||||
bucket.popleft()
|
||||
|
||||
if len(bucket) >= limit:
|
||||
retry_after = max(0.0, window_seconds - (now - bucket[0]))
|
||||
return RateLimitDecision(False, retry_after)
|
||||
|
||||
bucket.append(now)
|
||||
return RateLimitDecision(True, 0.0)
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,372 +0,0 @@
|
||||
"""
|
||||
Server TLS certificate management.
|
||||
|
||||
Borealis now issues a dedicated root CA and a leaf server certificate so that
|
||||
agents can pin the CA without requiring a re-enrollment every time the server
|
||||
certificate is refreshed. The CA is persisted alongside the server key so that
|
||||
existing deployments can be upgraded in-place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import ssl
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
|
||||
|
||||
from Modules.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path
|
||||
|
||||
_CERT_DIR = server_certificates_path()
|
||||
_CERT_FILE = _CERT_DIR / "borealis-server-cert.pem"
|
||||
_KEY_FILE = _CERT_DIR / "borealis-server-key.pem"
|
||||
_BUNDLE_FILE = _CERT_DIR / "borealis-server-bundle.pem"
|
||||
_CA_KEY_FILE = _CERT_DIR / "borealis-root-ca-key.pem"
|
||||
_CA_CERT_FILE = _CERT_DIR / "borealis-root-ca.pem"
|
||||
|
||||
_LEGACY_CERT_DIR = runtime_path("certs")
|
||||
_LEGACY_CERT_FILE = _LEGACY_CERT_DIR / "borealis-server-cert.pem"
|
||||
_LEGACY_KEY_FILE = _LEGACY_CERT_DIR / "borealis-server-key.pem"
|
||||
_LEGACY_BUNDLE_FILE = _LEGACY_CERT_DIR / "borealis-server-bundle.pem"
|
||||
|
||||
_ROOT_COMMON_NAME = "Borealis Root CA"
|
||||
_ORG_NAME = "Borealis"
|
||||
_ROOT_VALIDITY = timedelta(days=365 * 100)
|
||||
_SERVER_VALIDITY = timedelta(days=365 * 5)
|
||||
|
||||
|
||||
def ensure_certificate(common_name: str = "Borealis Server") -> Tuple[Path, Path, Path]:
|
||||
"""
|
||||
Ensure the root CA, server certificate, and bundle exist on disk.
|
||||
|
||||
Returns (cert_path, key_path, bundle_path).
|
||||
"""
|
||||
|
||||
ensure_server_certificates_dir()
|
||||
_migrate_legacy_material_if_present()
|
||||
|
||||
ca_key, ca_cert, ca_regenerated = _ensure_root_ca()
|
||||
|
||||
server_cert = _load_certificate(_CERT_FILE)
|
||||
needs_regen = ca_regenerated or _server_certificate_needs_regeneration(server_cert, ca_cert)
|
||||
if needs_regen:
|
||||
server_cert = _generate_server_certificate(common_name, ca_key, ca_cert)
|
||||
|
||||
if server_cert is None:
|
||||
server_cert = _generate_server_certificate(common_name, ca_key, ca_cert)
|
||||
|
||||
_write_bundle(server_cert, ca_cert)
|
||||
|
||||
return _CERT_FILE, _KEY_FILE, _BUNDLE_FILE
|
||||
|
||||
|
||||
def _migrate_legacy_material_if_present() -> None:
|
||||
# Promote legacy runtime certificates (Server/Borealis/certs) into the new location.
|
||||
if not _CERT_FILE.exists() or not _KEY_FILE.exists():
|
||||
legacy_cert = _LEGACY_CERT_FILE
|
||||
legacy_key = _LEGACY_KEY_FILE
|
||||
if legacy_cert.exists() and legacy_key.exists():
|
||||
try:
|
||||
ensure_server_certificates_dir()
|
||||
if not _CERT_FILE.exists():
|
||||
_safe_copy(legacy_cert, _CERT_FILE)
|
||||
if not _KEY_FILE.exists():
|
||||
_safe_copy(legacy_key, _KEY_FILE)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_root_ca() -> Tuple[ec.EllipticCurvePrivateKey, x509.Certificate, bool]:
|
||||
regenerated = False
|
||||
|
||||
ca_key: Optional[ec.EllipticCurvePrivateKey] = None
|
||||
ca_cert: Optional[x509.Certificate] = None
|
||||
|
||||
if _CA_KEY_FILE.exists() and _CA_CERT_FILE.exists():
|
||||
try:
|
||||
ca_key = _load_private_key(_CA_KEY_FILE)
|
||||
ca_cert = _load_certificate(_CA_CERT_FILE)
|
||||
if ca_cert is not None and ca_key is not None:
|
||||
expiry = _cert_not_after(ca_cert)
|
||||
subject = ca_cert.subject
|
||||
subject_cn = ""
|
||||
try:
|
||||
subject_cn = subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value # type: ignore[index]
|
||||
except Exception:
|
||||
subject_cn = ""
|
||||
try:
|
||||
basic = ca_cert.extensions.get_extension_for_class(x509.BasicConstraints).value # type: ignore[attr-defined]
|
||||
is_ca = bool(basic.ca)
|
||||
except Exception:
|
||||
is_ca = False
|
||||
if (
|
||||
expiry <= datetime.now(tz=timezone.utc)
|
||||
or not is_ca
|
||||
or subject_cn != _ROOT_COMMON_NAME
|
||||
):
|
||||
regenerated = True
|
||||
else:
|
||||
regenerated = True
|
||||
except Exception:
|
||||
regenerated = True
|
||||
else:
|
||||
regenerated = True
|
||||
|
||||
if regenerated or ca_key is None or ca_cert is None:
|
||||
ca_key = ec.generate_private_key(ec.SECP384R1())
|
||||
public_key = ca_key.public_key()
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
builder = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, _ROOT_COMMON_NAME),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
|
||||
]
|
||||
)
|
||||
)
|
||||
.issuer_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, _ROOT_COMMON_NAME),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
|
||||
]
|
||||
)
|
||||
)
|
||||
.public_key(public_key)
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now - timedelta(minutes=5))
|
||||
.not_valid_after(now + _ROOT_VALIDITY)
|
||||
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=True,
|
||||
crl_sign=True,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
)
|
||||
|
||||
builder = builder.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
|
||||
ca_cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA384())
|
||||
|
||||
_CA_KEY_FILE.write_bytes(
|
||||
ca_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
_CA_CERT_FILE.write_bytes(ca_cert.public_bytes(serialization.Encoding.PEM))
|
||||
|
||||
_tighten_permissions(_CA_KEY_FILE)
|
||||
_tighten_permissions(_CA_CERT_FILE)
|
||||
else:
|
||||
regenerated = False
|
||||
|
||||
return ca_key, ca_cert, regenerated
|
||||
|
||||
|
||||
def _server_certificate_needs_regeneration(
|
||||
server_cert: Optional[x509.Certificate],
|
||||
ca_cert: x509.Certificate,
|
||||
) -> bool:
|
||||
if server_cert is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
if server_cert.issuer != ca_cert.subject:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
try:
|
||||
expiry = _cert_not_after(server_cert)
|
||||
if expiry <= datetime.now(tz=timezone.utc):
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
try:
|
||||
basic = server_cert.extensions.get_extension_for_class(x509.BasicConstraints).value # type: ignore[attr-defined]
|
||||
if basic.ca:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
try:
|
||||
eku = server_cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value # type: ignore[attr-defined]
|
||||
if ExtendedKeyUsageOID.SERVER_AUTH not in eku:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _generate_server_certificate(
|
||||
common_name: str,
|
||||
ca_key: ec.EllipticCurvePrivateKey,
|
||||
ca_cert: x509.Certificate,
|
||||
) -> x509.Certificate:
|
||||
private_key = ec.generate_private_key(ec.SECP384R1())
|
||||
public_key = private_key.public_key()
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
ca_expiry = _cert_not_after(ca_cert)
|
||||
candidate_expiry = now + _SERVER_VALIDITY
|
||||
not_after = min(ca_expiry - timedelta(days=1), candidate_expiry)
|
||||
|
||||
builder = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
|
||||
]
|
||||
)
|
||||
)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(public_key)
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now - timedelta(minutes=5))
|
||||
.not_valid_after(not_after)
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName(
|
||||
[
|
||||
x509.DNSName("localhost"),
|
||||
x509.DNSName("127.0.0.1"),
|
||||
x509.DNSName("::1"),
|
||||
]
|
||||
),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=False,
|
||||
crl_sign=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_key.public_key()),
|
||||
critical=False,
|
||||
)
|
||||
)
|
||||
|
||||
certificate = builder.sign(private_key=ca_key, algorithm=hashes.SHA384())
|
||||
|
||||
_KEY_FILE.write_bytes(
|
||||
private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
_CERT_FILE.write_bytes(certificate.public_bytes(serialization.Encoding.PEM))
|
||||
|
||||
_tighten_permissions(_KEY_FILE)
|
||||
_tighten_permissions(_CERT_FILE)
|
||||
|
||||
return certificate
|
||||
|
||||
|
||||
def _write_bundle(server_cert: x509.Certificate, ca_cert: x509.Certificate) -> None:
|
||||
try:
|
||||
server_pem = server_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip()
|
||||
ca_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
bundle = f"{server_pem}\n{ca_pem}\n"
|
||||
_BUNDLE_FILE.write_text(bundle, encoding="utf-8")
|
||||
_tighten_permissions(_BUNDLE_FILE)
|
||||
|
||||
|
||||
def _safe_copy(src: Path, dst: Path) -> None:
|
||||
try:
|
||||
dst.write_bytes(src.read_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _tighten_permissions(path: Path) -> None:
|
||||
try:
|
||||
if os.name == "posix":
|
||||
path.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _load_private_key(path: Path) -> ec.EllipticCurvePrivateKey:
|
||||
with path.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
|
||||
def _load_certificate(path: Path) -> Optional[x509.Certificate]:
|
||||
try:
|
||||
return x509.load_pem_x509_certificate(path.read_bytes())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _cert_not_after(cert: x509.Certificate) -> datetime:
|
||||
try:
|
||||
return cert.not_valid_after_utc # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
value = cert.not_valid_after
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value
|
||||
|
||||
|
||||
def build_ssl_context() -> ssl.SSLContext:
|
||||
cert_path, key_path, bundle_path = ensure_certificate()
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||
context.load_cert_chain(certfile=str(bundle_path), keyfile=str(key_path))
|
||||
return context
|
||||
|
||||
|
||||
def certificate_paths() -> Tuple[str, str, str]:
|
||||
cert_path, key_path, bundle_path = ensure_certificate()
|
||||
return str(cert_path), str(key_path), str(bundle_path)
|
||||
@@ -1,71 +0,0 @@
|
||||
"""
|
||||
Utility helpers for working with Ed25519 keys and fingerprints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
|
||||
def generate_ed25519_keypair() -> Tuple[ed25519.Ed25519PrivateKey, bytes]:
|
||||
"""
|
||||
Generate a new Ed25519 keypair.
|
||||
|
||||
Returns the private key object and the public key encoded as SubjectPublicKeyInfo DER bytes.
|
||||
"""
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
public_key = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return private_key, public_key
|
||||
|
||||
|
||||
def normalize_base64(data: str) -> str:
|
||||
"""
|
||||
Collapse whitespace and normalise URL-safe encodings so we can reliably decode.
|
||||
"""
|
||||
|
||||
cleaned = re.sub(r"\\s+", "", data or "")
|
||||
return cleaned.replace("-", "+").replace("_", "/")
|
||||
|
||||
|
||||
def spki_der_from_base64(spki_b64: str) -> bytes:
|
||||
return base64.b64decode(normalize_base64(spki_b64), validate=True)
|
||||
|
||||
|
||||
def base64_from_spki_der(spki_der: bytes) -> str:
|
||||
return base64.b64encode(spki_der).decode("ascii")
|
||||
|
||||
|
||||
def fingerprint_from_spki_der(spki_der: bytes) -> str:
|
||||
digest = hashlib.sha256(spki_der).hexdigest()
|
||||
return digest.lower()
|
||||
|
||||
|
||||
def fingerprint_from_base64_spki(spki_b64: str) -> str:
|
||||
return fingerprint_from_spki_der(spki_der_from_base64(spki_b64))
|
||||
|
||||
|
||||
def private_key_to_pem(private_key: ed25519.Ed25519PrivateKey) -> bytes:
|
||||
return private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
|
||||
def public_key_to_pem(public_spki_der: bytes) -> bytes:
|
||||
public_key = load_der_public_key(public_spki_der)
|
||||
return public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
@@ -1,125 +0,0 @@
|
||||
"""
|
||||
Code-signing helpers for delivering scripts to agents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from Modules.runtime import (
|
||||
ensure_server_certificates_dir,
|
||||
server_certificates_path,
|
||||
runtime_path,
|
||||
)
|
||||
|
||||
from .keys import base64_from_spki_der
|
||||
|
||||
_KEY_DIR = server_certificates_path("Code-Signing")
|
||||
_SIGNING_KEY_FILE = _KEY_DIR / "borealis-script-ed25519.key"
|
||||
_SIGNING_PUB_FILE = _KEY_DIR / "borealis-script-ed25519.pub"
|
||||
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-script-ed25519.key"
|
||||
_LEGACY_PUB_FILE = runtime_path("keys") / "borealis-script-ed25519.pub"
|
||||
_OLD_RUNTIME_KEY_DIR = runtime_path("script_signing_keys")
|
||||
_OLD_RUNTIME_KEY_FILE = _OLD_RUNTIME_KEY_DIR / "borealis-script-ed25519.key"
|
||||
_OLD_RUNTIME_PUB_FILE = _OLD_RUNTIME_KEY_DIR / "borealis-script-ed25519.pub"
|
||||
|
||||
|
||||
class ScriptSigner:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey):
|
||||
self._private = private_key
|
||||
self._public = private_key.public_key()
|
||||
|
||||
def sign(self, payload: bytes) -> bytes:
|
||||
return self._private.sign(payload)
|
||||
|
||||
def public_spki_der(self) -> bytes:
|
||||
return self._public.public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def public_base64_spki(self) -> str:
|
||||
return base64_from_spki_der(self.public_spki_der())
|
||||
|
||||
|
||||
def load_signer() -> ScriptSigner:
|
||||
private_key = _load_or_create()
|
||||
return ScriptSigner(private_key)
|
||||
|
||||
|
||||
def _load_or_create() -> ed25519.Ed25519PrivateKey:
|
||||
ensure_server_certificates_dir("Code-Signing")
|
||||
_migrate_legacy_material_if_present()
|
||||
|
||||
if _SIGNING_KEY_FILE.exists():
|
||||
with _SIGNING_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
if _LEGACY_KEY_FILE.exists():
|
||||
with _LEGACY_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
with _SIGNING_KEY_FILE.open("wb") as fh:
|
||||
fh.write(pem)
|
||||
try:
|
||||
if hasattr(_SIGNING_KEY_FILE, "chmod"):
|
||||
_SIGNING_KEY_FILE.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
pub_der = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
_SIGNING_PUB_FILE.write_bytes(pub_der)
|
||||
|
||||
return private_key
|
||||
|
||||
|
||||
def _migrate_legacy_material_if_present() -> None:
|
||||
if _SIGNING_KEY_FILE.exists():
|
||||
return
|
||||
|
||||
# First migrate from legacy runtime path embedded in Server runtime.
|
||||
try:
|
||||
if _OLD_RUNTIME_KEY_FILE.exists() and not _SIGNING_KEY_FILE.exists():
|
||||
ensure_server_certificates_dir("Code-Signing")
|
||||
try:
|
||||
_OLD_RUNTIME_KEY_FILE.replace(_SIGNING_KEY_FILE)
|
||||
except Exception:
|
||||
_SIGNING_KEY_FILE.write_bytes(_OLD_RUNTIME_KEY_FILE.read_bytes())
|
||||
if _OLD_RUNTIME_PUB_FILE.exists() and not _SIGNING_PUB_FILE.exists():
|
||||
try:
|
||||
_OLD_RUNTIME_PUB_FILE.replace(_SIGNING_PUB_FILE)
|
||||
except Exception:
|
||||
_SIGNING_PUB_FILE.write_bytes(_OLD_RUNTIME_PUB_FILE.read_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not _LEGACY_KEY_FILE.exists() or _SIGNING_KEY_FILE.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
ensure_server_certificates_dir("Code-Signing")
|
||||
try:
|
||||
_LEGACY_KEY_FILE.replace(_SIGNING_KEY_FILE)
|
||||
except Exception:
|
||||
_SIGNING_KEY_FILE.write_bytes(_LEGACY_KEY_FILE.read_bytes())
|
||||
|
||||
if _LEGACY_PUB_FILE.exists() and not _SIGNING_PUB_FILE.exists():
|
||||
try:
|
||||
_LEGACY_PUB_FILE.replace(_SIGNING_PUB_FILE)
|
||||
except Exception:
|
||||
_SIGNING_PUB_FILE.write_bytes(_LEGACY_PUB_FILE.read_bytes())
|
||||
except Exception:
|
||||
return
|
||||
@@ -1,488 +0,0 @@
|
||||
"""
|
||||
Database migration helpers for Borealis.
|
||||
|
||||
This module centralises schema evolution so the main server module can stay
|
||||
focused on request handling. The migration functions are intentionally
|
||||
idempotent — they can run repeatedly without changing state once the schema
|
||||
matches the desired shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
|
||||
DEVICE_TABLE = "devices"
|
||||
|
||||
|
||||
def apply_all(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Run all known schema migrations against the provided sqlite3 connection.
|
||||
"""
|
||||
|
||||
_ensure_devices_table(conn)
|
||||
_ensure_device_aux_tables(conn)
|
||||
_ensure_refresh_token_table(conn)
|
||||
_ensure_install_code_table(conn)
|
||||
_ensure_install_code_persistence_table(conn)
|
||||
_ensure_device_approval_table(conn)
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _ensure_devices_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
if not _table_exists(cur, DEVICE_TABLE):
|
||||
_create_devices_table(cur)
|
||||
return
|
||||
|
||||
column_info = _table_info(cur, DEVICE_TABLE)
|
||||
col_names = [c[1] for c in column_info]
|
||||
pk_cols = [c[1] for c in column_info if c[5]]
|
||||
|
||||
needs_rebuild = pk_cols != ["guid"]
|
||||
required_columns = {
|
||||
"guid": "TEXT",
|
||||
"hostname": "TEXT",
|
||||
"description": "TEXT",
|
||||
"created_at": "INTEGER",
|
||||
"agent_hash": "TEXT",
|
||||
"memory": "TEXT",
|
||||
"network": "TEXT",
|
||||
"software": "TEXT",
|
||||
"storage": "TEXT",
|
||||
"cpu": "TEXT",
|
||||
"device_type": "TEXT",
|
||||
"domain": "TEXT",
|
||||
"external_ip": "TEXT",
|
||||
"internal_ip": "TEXT",
|
||||
"last_reboot": "TEXT",
|
||||
"last_seen": "INTEGER",
|
||||
"last_user": "TEXT",
|
||||
"operating_system": "TEXT",
|
||||
"uptime": "INTEGER",
|
||||
"agent_id": "TEXT",
|
||||
"ansible_ee_ver": "TEXT",
|
||||
"connection_type": "TEXT",
|
||||
"connection_endpoint": "TEXT",
|
||||
"ssl_key_fingerprint": "TEXT",
|
||||
"token_version": "INTEGER",
|
||||
"status": "TEXT",
|
||||
"key_added_at": "TEXT",
|
||||
}
|
||||
|
||||
missing_columns = [col for col in required_columns if col not in col_names]
|
||||
if missing_columns:
|
||||
needs_rebuild = True
|
||||
|
||||
if needs_rebuild:
|
||||
_rebuild_devices_table(conn, column_info)
|
||||
else:
|
||||
_ensure_column_defaults(cur)
|
||||
|
||||
_ensure_device_indexes(cur)
|
||||
|
||||
|
||||
def _ensure_device_aux_tables(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS device_keys (
|
||||
id TEXT PRIMARY KEY,
|
||||
guid TEXT NOT NULL,
|
||||
ssl_key_fingerprint TEXT NOT NULL,
|
||||
added_at TEXT NOT NULL,
|
||||
retired_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_device_keys_guid_fingerprint
|
||||
ON device_keys(guid, ssl_key_fingerprint)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_device_keys_guid
|
||||
ON device_keys(guid)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_refresh_token_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
id TEXT PRIMARY KEY,
|
||||
guid TEXT NOT NULL,
|
||||
token_hash TEXT NOT NULL,
|
||||
dpop_jkt TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
revoked_at TEXT,
|
||||
last_used_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_guid
|
||||
ON refresh_tokens(guid)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at
|
||||
ON refresh_tokens(expires_at)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS enrollment_install_codes (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL UNIQUE,
|
||||
expires_at TEXT NOT NULL,
|
||||
created_by_user_id TEXT,
|
||||
used_at TEXT,
|
||||
used_by_guid TEXT,
|
||||
max_uses INTEGER NOT NULL DEFAULT 1,
|
||||
use_count INTEGER NOT NULL DEFAULT 0,
|
||||
last_used_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_eic_expires_at
|
||||
ON enrollment_install_codes(expires_at)
|
||||
"""
|
||||
)
|
||||
|
||||
columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")}
|
||||
if "max_uses" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes
|
||||
ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1
|
||||
"""
|
||||
)
|
||||
if "use_count" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes
|
||||
ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0
|
||||
"""
|
||||
)
|
||||
if "last_used_at" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes
|
||||
ADD COLUMN last_used_at TEXT
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_install_code_persistence_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS enrollment_install_codes_persistent (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL UNIQUE,
|
||||
created_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
created_by_user_id TEXT,
|
||||
used_at TEXT,
|
||||
used_by_guid TEXT,
|
||||
max_uses INTEGER NOT NULL DEFAULT 1,
|
||||
last_known_use_count INTEGER NOT NULL DEFAULT 0,
|
||||
last_used_at TEXT,
|
||||
is_active INTEGER NOT NULL DEFAULT 1,
|
||||
archived_at TEXT,
|
||||
consumed_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_eicp_active
|
||||
ON enrollment_install_codes_persistent(is_active, expires_at)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_eicp_code
|
||||
ON enrollment_install_codes_persistent(code)
|
||||
"""
|
||||
)
|
||||
|
||||
columns = {row[1] for row in _table_info(cur, "enrollment_install_codes_persistent")}
|
||||
if "last_known_use_count" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes_persistent
|
||||
ADD COLUMN last_known_use_count INTEGER NOT NULL DEFAULT 0
|
||||
"""
|
||||
)
|
||||
if "archived_at" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes_persistent
|
||||
ADD COLUMN archived_at TEXT
|
||||
"""
|
||||
)
|
||||
if "consumed_at" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes_persistent
|
||||
ADD COLUMN consumed_at TEXT
|
||||
"""
|
||||
)
|
||||
if "is_active" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes_persistent
|
||||
ADD COLUMN is_active INTEGER NOT NULL DEFAULT 1
|
||||
"""
|
||||
)
|
||||
if "used_at" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes_persistent
|
||||
ADD COLUMN used_at TEXT
|
||||
"""
|
||||
)
|
||||
if "used_by_guid" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes_persistent
|
||||
ADD COLUMN used_by_guid TEXT
|
||||
"""
|
||||
)
|
||||
if "last_used_at" not in columns:
|
||||
cur.execute(
|
||||
"""
|
||||
ALTER TABLE enrollment_install_codes_persistent
|
||||
ADD COLUMN last_used_at TEXT
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS device_approvals (
|
||||
id TEXT PRIMARY KEY,
|
||||
approval_reference TEXT NOT NULL UNIQUE,
|
||||
guid TEXT,
|
||||
hostname_claimed TEXT NOT NULL,
|
||||
ssl_key_fingerprint_claimed TEXT NOT NULL,
|
||||
enrollment_code_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
client_nonce TEXT NOT NULL,
|
||||
server_nonce TEXT NOT NULL,
|
||||
agent_pubkey_der BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
approved_by_user_id TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_da_status
|
||||
ON device_approvals(status)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_da_fp_status
|
||||
ON device_approvals(ssl_key_fingerprint_claimed, status)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _create_devices_table(cur: sqlite3.Cursor) -> None:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE devices (
|
||||
guid TEXT PRIMARY KEY,
|
||||
hostname TEXT,
|
||||
description TEXT,
|
||||
created_at INTEGER,
|
||||
agent_hash TEXT,
|
||||
memory TEXT,
|
||||
network TEXT,
|
||||
software TEXT,
|
||||
storage TEXT,
|
||||
cpu TEXT,
|
||||
device_type TEXT,
|
||||
domain TEXT,
|
||||
external_ip TEXT,
|
||||
internal_ip TEXT,
|
||||
last_reboot TEXT,
|
||||
last_seen INTEGER,
|
||||
last_user TEXT,
|
||||
operating_system TEXT,
|
||||
uptime INTEGER,
|
||||
agent_id TEXT,
|
||||
ansible_ee_ver TEXT,
|
||||
connection_type TEXT,
|
||||
connection_endpoint TEXT,
|
||||
ssl_key_fingerprint TEXT,
|
||||
token_version INTEGER DEFAULT 1,
|
||||
status TEXT DEFAULT 'active',
|
||||
key_added_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
_ensure_device_indexes(cur)
|
||||
|
||||
|
||||
def _ensure_device_indexes(cur: sqlite3.Cursor) -> None:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_devices_hostname
|
||||
ON devices(hostname)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_devices_ssl_key
|
||||
ON devices(ssl_key_fingerprint)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_devices_status
|
||||
ON devices(status)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_column_defaults(cur: sqlite3.Cursor) -> None:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET token_version = COALESCE(token_version, 1)
|
||||
WHERE token_version IS NULL
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET status = COALESCE(status, 'active')
|
||||
WHERE status IS NULL OR status = ''
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _rebuild_devices_table(conn: sqlite3.Connection, column_info: Sequence[Tuple]) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute("PRAGMA foreign_keys=OFF")
|
||||
cur.execute("BEGIN IMMEDIATE")
|
||||
|
||||
cur.execute("ALTER TABLE devices RENAME TO devices_legacy")
|
||||
_create_devices_table(cur)
|
||||
|
||||
legacy_columns = [c[1] for c in column_info]
|
||||
cur.execute(f"SELECT {', '.join(legacy_columns)} FROM devices_legacy")
|
||||
rows = cur.fetchall()
|
||||
|
||||
insert_sql = (
|
||||
"""
|
||||
INSERT OR REPLACE INTO devices (
|
||||
guid, hostname, description, created_at, agent_hash, memory,
|
||||
network, software, storage, cpu, device_type, domain, external_ip,
|
||||
internal_ip, last_reboot, last_seen, last_user, operating_system,
|
||||
uptime, agent_id, ansible_ee_ver, connection_type, connection_endpoint,
|
||||
ssl_key_fingerprint, token_version, status, key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
record = dict(zip(legacy_columns, row))
|
||||
guid = _normalized_guid(record.get("guid"))
|
||||
if not guid:
|
||||
guid = str(uuid.uuid4())
|
||||
hostname = record.get("hostname")
|
||||
created_at = record.get("created_at")
|
||||
key_added_at = record.get("key_added_at")
|
||||
if key_added_at is None:
|
||||
key_added_at = _default_key_added_at(created_at)
|
||||
|
||||
params: Tuple = (
|
||||
guid,
|
||||
hostname,
|
||||
record.get("description"),
|
||||
created_at,
|
||||
record.get("agent_hash"),
|
||||
record.get("memory"),
|
||||
record.get("network"),
|
||||
record.get("software"),
|
||||
record.get("storage"),
|
||||
record.get("cpu"),
|
||||
record.get("device_type"),
|
||||
record.get("domain"),
|
||||
record.get("external_ip"),
|
||||
record.get("internal_ip"),
|
||||
record.get("last_reboot"),
|
||||
record.get("last_seen"),
|
||||
record.get("last_user"),
|
||||
record.get("operating_system"),
|
||||
record.get("uptime"),
|
||||
record.get("agent_id"),
|
||||
record.get("ansible_ee_ver"),
|
||||
record.get("connection_type"),
|
||||
record.get("connection_endpoint"),
|
||||
record.get("ssl_key_fingerprint"),
|
||||
record.get("token_version") or 1,
|
||||
record.get("status") or "active",
|
||||
key_added_at,
|
||||
)
|
||||
cur.execute(insert_sql, params)
|
||||
|
||||
cur.execute("DROP TABLE devices_legacy")
|
||||
cur.execute("COMMIT")
|
||||
cur.execute("PRAGMA foreign_keys=ON")
|
||||
|
||||
|
||||
def _default_key_added_at(created_at: Optional[int]) -> Optional[str]:
|
||||
if created_at:
|
||||
try:
|
||||
dt = datetime.fromtimestamp(int(created_at), tz=timezone.utc)
|
||||
return dt.isoformat()
|
||||
except Exception:
|
||||
pass
|
||||
return datetime.now(tz=timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _table_exists(cur: sqlite3.Cursor, name: str) -> bool:
|
||||
cur.execute(
|
||||
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(name,),
|
||||
)
|
||||
return cur.fetchone() is not None
|
||||
|
||||
|
||||
def _table_info(cur: sqlite3.Cursor, name: str) -> List[Tuple]:
|
||||
cur.execute(f"PRAGMA table_info({name})")
|
||||
return cur.fetchall()
|
||||
|
||||
|
||||
def _normalized_guid(value: Optional[str]) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
return str(value).strip()
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
"""
|
||||
Short-lived nonce cache to defend against replay attacks during enrollment.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class NonceCache:
|
||||
def __init__(self, ttl_seconds: float = 300.0) -> None:
|
||||
self._ttl = ttl_seconds
|
||||
self._entries: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def consume(self, key: str) -> bool:
|
||||
"""
|
||||
Attempt to consume the nonce identified by `key`.
|
||||
|
||||
Returns True on first use within TTL, False if already consumed.
|
||||
"""
|
||||
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
expire_at = self._entries.get(key)
|
||||
if expire_at and expire_at > now:
|
||||
return False
|
||||
self._entries[key] = now + self._ttl
|
||||
# Opportunistic cleanup to keep the dict small
|
||||
stale = [nonce for nonce, expiry in self._entries.items() if expiry <= now]
|
||||
for nonce in stale:
|
||||
self._entries.pop(nonce, None)
|
||||
return True
|
||||
@@ -1,759 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import secrets
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||
|
||||
|
||||
def _canonical_context(value: Optional[str]) -> Optional[str]:
|
||||
if not value:
|
||||
return None
|
||||
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
|
||||
if not cleaned:
|
||||
return None
|
||||
return cleaned.upper()
|
||||
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||
from Modules.crypto import keys as crypto_keys
|
||||
from Modules.enrollment.nonce_store import NonceCache
|
||||
from Modules.guid_utils import normalize_guid
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
|
||||
def register(
|
||||
app,
|
||||
*,
|
||||
db_conn_factory: Callable[[], sqlite3.Connection],
|
||||
log: Callable[[str, str, Optional[str]], None],
|
||||
jwt_service,
|
||||
tls_bundle_path: str,
|
||||
ip_rate_limiter: SlidingWindowRateLimiter,
|
||||
fp_rate_limiter: SlidingWindowRateLimiter,
|
||||
nonce_cache: NonceCache,
|
||||
script_signer,
|
||||
) -> None:
|
||||
blueprint = Blueprint("enrollment", __name__)
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
def _iso(dt: datetime) -> str:
|
||||
return dt.isoformat()
|
||||
|
||||
def _remote_addr() -> str:
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
addr = request.remote_addr or "unknown"
|
||||
return addr.strip()
|
||||
|
||||
def _signing_key_b64() -> str:
|
||||
if not script_signer:
|
||||
return ""
|
||||
try:
|
||||
return script_signer.public_base64_spki()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _rate_limited(
|
||||
key: str,
|
||||
limiter: SlidingWindowRateLimiter,
|
||||
limit: int,
|
||||
window_s: float,
|
||||
context_hint: Optional[str],
|
||||
):
|
||||
decision = limiter.check(key, limit, window_s)
|
||||
if not decision.allowed:
|
||||
log(
|
||||
"server",
|
||||
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
|
||||
context_hint,
|
||||
)
|
||||
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
|
||||
response.status_code = 429
|
||||
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
|
||||
return response
|
||||
return None
|
||||
|
||||
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id,
|
||||
code,
|
||||
expires_at,
|
||||
used_at,
|
||||
used_by_guid,
|
||||
max_uses,
|
||||
use_count,
|
||||
last_used_at
|
||||
FROM enrollment_install_codes
|
||||
WHERE code = ?
|
||||
""",
|
||||
(code_value,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
keys = [
|
||||
"id",
|
||||
"code",
|
||||
"expires_at",
|
||||
"used_at",
|
||||
"used_by_guid",
|
||||
"max_uses",
|
||||
"use_count",
|
||||
"last_used_at",
|
||||
]
|
||||
record = dict(zip(keys, row))
|
||||
return record
|
||||
|
||||
def _install_code_valid(
|
||||
record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
if not record:
|
||||
return False, None
|
||||
expires_at = record.get("expires_at")
|
||||
if not isinstance(expires_at, str):
|
||||
return False, None
|
||||
try:
|
||||
expiry = datetime.fromisoformat(expires_at)
|
||||
except Exception:
|
||||
return False, None
|
||||
if expiry <= _now():
|
||||
return False, None
|
||||
try:
|
||||
max_uses = int(record.get("max_uses") or 1)
|
||||
except Exception:
|
||||
max_uses = 1
|
||||
if max_uses < 1:
|
||||
max_uses = 1
|
||||
try:
|
||||
use_count = int(record.get("use_count") or 0)
|
||||
except Exception:
|
||||
use_count = 0
|
||||
if use_count < max_uses:
|
||||
return True, None
|
||||
|
||||
guid = normalize_guid(record.get("used_by_guid"))
|
||||
if not guid:
|
||||
return False, None
|
||||
cur.execute(
|
||||
"SELECT ssl_key_fingerprint FROM devices WHERE UPPER(guid) = ?",
|
||||
(guid,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return False, None
|
||||
stored_fp = (row[0] or "").strip().lower()
|
||||
if not stored_fp:
|
||||
return False, None
|
||||
if stored_fp == (fingerprint or "").strip().lower():
|
||||
return True, guid
|
||||
return False, None
|
||||
|
||||
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
|
||||
guid_norm = normalize_guid(guid)
|
||||
base = (hostname or "").strip() or guid_norm
|
||||
base = base[:253]
|
||||
candidate = base
|
||||
suffix = 1
|
||||
while True:
|
||||
cur.execute(
|
||||
"SELECT guid FROM devices WHERE hostname = ?",
|
||||
(candidate,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return candidate
|
||||
existing_guid = normalize_guid(row[0])
|
||||
if existing_guid == guid_norm:
|
||||
return candidate
|
||||
candidate = f"{base}-{suffix}"
|
||||
suffix += 1
|
||||
if suffix > 50:
|
||||
return guid_norm
|
||||
|
||||
def _store_device_key(cur: sqlite3.Cursor, guid: str, fingerprint: str) -> None:
|
||||
guid_norm = normalize_guid(guid)
|
||||
added_at = _iso(_now())
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(str(uuid.uuid4()), guid_norm, fingerprint, added_at),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_keys
|
||||
SET retired_at = ?
|
||||
WHERE guid = ?
|
||||
AND ssl_key_fingerprint != ?
|
||||
AND retired_at IS NULL
|
||||
""",
|
||||
(_iso(_now()), guid_norm, fingerprint),
|
||||
)
|
||||
|
||||
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
|
||||
guid_norm = normalize_guid(guid)
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
|
||||
FROM devices
|
||||
WHERE UPPER(guid) = ?
|
||||
""",
|
||||
(guid_norm,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
keys = [
|
||||
"guid",
|
||||
"hostname",
|
||||
"token_version",
|
||||
"status",
|
||||
"ssl_key_fingerprint",
|
||||
"key_added_at",
|
||||
]
|
||||
record = dict(zip(keys, row))
|
||||
record["guid"] = normalize_guid(record.get("guid"))
|
||||
stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower()
|
||||
new_fp = (fingerprint or "").strip().lower()
|
||||
if not stored_fp and new_fp:
|
||||
cur.execute(
|
||||
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
||||
(fingerprint, _iso(_now()), record["guid"]),
|
||||
)
|
||||
record["ssl_key_fingerprint"] = fingerprint
|
||||
elif new_fp and stored_fp != new_fp:
|
||||
now_iso = _iso(_now())
|
||||
try:
|
||||
current_version = int(record.get("token_version") or 1)
|
||||
except Exception:
|
||||
current_version = 1
|
||||
new_version = max(current_version + 1, 1)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET ssl_key_fingerprint = ?,
|
||||
key_added_at = ?,
|
||||
token_version = ?,
|
||||
status = 'active'
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(fingerprint, now_iso, new_version, record["guid"]),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = ?
|
||||
WHERE guid = ?
|
||||
AND revoked_at IS NULL
|
||||
""",
|
||||
(now_iso, record["guid"]),
|
||||
)
|
||||
record["ssl_key_fingerprint"] = fingerprint
|
||||
record["token_version"] = new_version
|
||||
record["status"] = "active"
|
||||
record["key_added_at"] = now_iso
|
||||
return record
|
||||
|
||||
resolved_hostname = _normalize_host(hostname, guid_norm, cur)
|
||||
created_at = int(time.time())
|
||||
key_added_at = _iso(_now())
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO devices (
|
||||
guid, hostname, created_at, last_seen, ssl_key_fingerprint,
|
||||
token_version, status, key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
|
||||
""",
|
||||
(
|
||||
guid_norm,
|
||||
resolved_hostname,
|
||||
created_at,
|
||||
created_at,
|
||||
fingerprint,
|
||||
key_added_at,
|
||||
),
|
||||
)
|
||||
return {
|
||||
"guid": guid_norm,
|
||||
"hostname": resolved_hostname,
|
||||
"token_version": 1,
|
||||
"status": "active",
|
||||
"ssl_key_fingerprint": fingerprint,
|
||||
"key_added_at": key_added_at,
|
||||
}
|
||||
|
||||
def _hash_refresh_token(token: str) -> str:
|
||||
import hashlib
|
||||
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
def _issue_refresh_token(cur: sqlite3.Cursor, guid: str) -> Dict[str, Any]:
|
||||
token = secrets.token_urlsafe(48)
|
||||
now = _now()
|
||||
expires_at = now.replace(microsecond=0) + timedelta(days=30)
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO refresh_tokens (id, guid, token_hash, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
guid,
|
||||
_hash_refresh_token(token),
|
||||
_iso(now),
|
||||
_iso(expires_at),
|
||||
),
|
||||
)
|
||||
return {"token": token, "expires_at": expires_at}
|
||||
|
||||
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
|
||||
def enrollment_request():
|
||||
remote = _remote_addr()
|
||||
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||
|
||||
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0, context_hint)
|
||||
if rate_error:
|
||||
return rate_error
|
||||
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
hostname = str(payload.get("hostname") or "").strip()
|
||||
enrollment_code = str(payload.get("enrollment_code") or "").strip()
|
||||
agent_pubkey_b64 = payload.get("agent_pubkey")
|
||||
client_nonce_b64 = payload.get("client_nonce")
|
||||
|
||||
log(
|
||||
"server",
|
||||
"enrollment request received "
|
||||
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
|
||||
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
|
||||
context_hint,
|
||||
)
|
||||
|
||||
if not hostname:
|
||||
log("server", f"enrollment rejected missing_hostname ip={remote}", context_hint)
|
||||
return jsonify({"error": "hostname_required"}), 400
|
||||
if not enrollment_code:
|
||||
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "enrollment_code_required"}), 400
|
||||
if not isinstance(agent_pubkey_b64, str):
|
||||
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "agent_pubkey_required"}), 400
|
||||
if not isinstance(client_nonce_b64, str):
|
||||
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "client_nonce_required"}), 400
|
||||
|
||||
try:
|
||||
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
|
||||
except Exception:
|
||||
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||
|
||||
if len(agent_pubkey_der) < 10:
|
||||
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
if len(client_nonce_bytes) < 16:
|
||||
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
|
||||
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
||||
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0, context_hint)
|
||||
if rate_error:
|
||||
return rate_error
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
install_code = _load_install_code(cur, enrollment_code)
|
||||
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
|
||||
if not valid_code:
|
||||
log(
|
||||
"server",
|
||||
"enrollment request invalid_code "
|
||||
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"error": "invalid_enrollment_code"}), 400
|
||||
|
||||
approval_reference: str
|
||||
record_id: str
|
||||
server_nonce_bytes = secrets.token_bytes(32)
|
||||
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
|
||||
now = _iso(_now())
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, approval_reference
|
||||
FROM device_approvals
|
||||
WHERE ssl_key_fingerprint_claimed = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(fingerprint,),
|
||||
)
|
||||
existing = cur.fetchone()
|
||||
if existing:
|
||||
record_id = existing[0]
|
||||
approval_reference = existing[1]
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET hostname_claimed = ?,
|
||||
guid = ?,
|
||||
enrollment_code_id = ?,
|
||||
client_nonce = ?,
|
||||
server_nonce = ?,
|
||||
agent_pubkey_der = ?,
|
||||
updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
hostname,
|
||||
reuse_guid,
|
||||
install_code["id"],
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
now,
|
||||
record_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
record_id = str(uuid.uuid4())
|
||||
approval_reference = str(uuid.uuid4())
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO device_approvals (
|
||||
id, approval_reference, guid, hostname_claimed,
|
||||
ssl_key_fingerprint_claimed, enrollment_code_id,
|
||||
status, client_nonce, server_nonce, agent_pubkey_der,
|
||||
created_at, updated_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
record_id,
|
||||
approval_reference,
|
||||
reuse_guid,
|
||||
hostname,
|
||||
fingerprint,
|
||||
install_code["id"],
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
now,
|
||||
now,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
response = {
|
||||
"status": "pending",
|
||||
"approval_reference": approval_reference,
|
||||
"server_nonce": server_nonce_b64,
|
||||
"poll_after_ms": 3000,
|
||||
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
||||
"signing_key": _signing_key_b64(),
|
||||
}
|
||||
log(
|
||||
"server",
|
||||
f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify(response)
|
||||
|
||||
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
|
||||
def enrollment_poll():
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
approval_reference = payload.get("approval_reference")
|
||||
client_nonce_b64 = payload.get("client_nonce")
|
||||
proof_sig_b64 = payload.get("proof_sig")
|
||||
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||
|
||||
log(
|
||||
"server",
|
||||
"enrollment poll received "
|
||||
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
|
||||
f" proof_sig_len={len(proof_sig_b64 or '')}",
|
||||
context_hint,
|
||||
)
|
||||
|
||||
if not isinstance(approval_reference, str) or not approval_reference:
|
||||
log("server", "enrollment poll rejected missing_reference", context_hint)
|
||||
return jsonify({"error": "approval_reference_required"}), 400
|
||||
if not isinstance(client_nonce_b64, str):
|
||||
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "client_nonce_required"}), 400
|
||||
if not isinstance(proof_sig_b64, str):
|
||||
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "proof_sig_required"}), 400
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
|
||||
try:
|
||||
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "invalid_proof_sig"}), 400
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed,
|
||||
enrollment_code_id, status, client_nonce, server_nonce,
|
||||
agent_pubkey_der, created_at, updated_at, approved_by_user_id
|
||||
FROM device_approvals
|
||||
WHERE approval_reference = ?
|
||||
""",
|
||||
(approval_reference,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
log("server", f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
|
||||
return jsonify({"status": "unknown"}), 404
|
||||
|
||||
(
|
||||
record_id,
|
||||
guid,
|
||||
hostname_claimed,
|
||||
fingerprint,
|
||||
enrollment_code_id,
|
||||
status,
|
||||
client_nonce_stored,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
created_at,
|
||||
updated_at,
|
||||
approved_by,
|
||||
) = row
|
||||
|
||||
if client_nonce_stored != client_nonce_b64:
|
||||
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "nonce_mismatch"}), 400
|
||||
|
||||
try:
|
||||
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "server_nonce_invalid"}), 400
|
||||
|
||||
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
|
||||
|
||||
try:
|
||||
public_key = serialization.load_der_public_key(agent_pubkey_der)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
|
||||
public_key = None
|
||||
|
||||
if public_key is None:
|
||||
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "agent_pubkey_invalid"}), 400
|
||||
|
||||
try:
|
||||
public_key.verify(proof_sig, message)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
|
||||
return jsonify({"error": "invalid_proof"}), 400
|
||||
|
||||
if status == "pending":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
|
||||
f" fingerprint={fingerprint[:12]}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": "pending", "poll_after_ms": 5000})
|
||||
if status == "denied":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": "denied", "reason": "operator_denied"})
|
||||
if status == "expired":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": "expired"})
|
||||
if status == "completed":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": "approved", "detail": "finalized"})
|
||||
|
||||
if status != "approved":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll unexpected_status={status} ref={approval_reference}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"status": status or "unknown"}), 400
|
||||
|
||||
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
|
||||
if not nonce_cache.consume(nonce_key):
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify({"error": "proof_replayed"}), 409
|
||||
|
||||
# Finalize enrollment
|
||||
effective_guid = normalize_guid(guid) if guid else normalize_guid(str(uuid.uuid4()))
|
||||
now_iso = _iso(_now())
|
||||
|
||||
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
|
||||
_store_device_key(cur, effective_guid, fingerprint)
|
||||
|
||||
# Mark install code used
|
||||
if enrollment_code_id:
|
||||
cur.execute(
|
||||
"SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?",
|
||||
(enrollment_code_id,),
|
||||
)
|
||||
usage_row = cur.fetchone()
|
||||
try:
|
||||
prior_count = int(usage_row[0]) if usage_row else 0
|
||||
except Exception:
|
||||
prior_count = 0
|
||||
try:
|
||||
allowed_uses = int(usage_row[1]) if usage_row else 1
|
||||
except Exception:
|
||||
allowed_uses = 1
|
||||
if allowed_uses < 1:
|
||||
allowed_uses = 1
|
||||
new_count = prior_count + 1
|
||||
consumed = new_count >= allowed_uses
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE enrollment_install_codes
|
||||
SET use_count = ?,
|
||||
used_by_guid = ?,
|
||||
last_used_at = ?,
|
||||
used_at = CASE WHEN ? THEN ? ELSE used_at END
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
new_count,
|
||||
effective_guid,
|
||||
now_iso,
|
||||
1 if consumed else 0,
|
||||
now_iso,
|
||||
enrollment_code_id,
|
||||
),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE enrollment_install_codes_persistent
|
||||
SET last_known_use_count = ?,
|
||||
used_by_guid = ?,
|
||||
last_used_at = ?,
|
||||
used_at = CASE WHEN ? THEN ? ELSE used_at END,
|
||||
is_active = CASE WHEN ? THEN 0 ELSE is_active END,
|
||||
consumed_at = CASE WHEN ? THEN COALESCE(consumed_at, ?) ELSE consumed_at END,
|
||||
archived_at = CASE WHEN ? THEN COALESCE(archived_at, ?) ELSE archived_at END
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
new_count,
|
||||
effective_guid,
|
||||
now_iso,
|
||||
1 if consumed else 0,
|
||||
now_iso,
|
||||
1 if consumed else 0,
|
||||
1 if consumed else 0,
|
||||
now_iso,
|
||||
1 if consumed else 0,
|
||||
now_iso,
|
||||
enrollment_code_id,
|
||||
),
|
||||
)
|
||||
|
||||
# Update approval record with final state
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET guid = ?,
|
||||
status = 'completed',
|
||||
updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(effective_guid, now_iso, record_id),
|
||||
)
|
||||
|
||||
refresh_info = _issue_refresh_token(cur, effective_guid)
|
||||
access_token = jwt_service.issue_access_token(
|
||||
effective_guid,
|
||||
fingerprint,
|
||||
device_record.get("token_version") or 1,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
log(
|
||||
"server",
|
||||
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
|
||||
context_hint,
|
||||
)
|
||||
return jsonify(
|
||||
{
|
||||
"status": "approved",
|
||||
"guid": effective_guid,
|
||||
"access_token": access_token,
|
||||
"expires_in": 900,
|
||||
"refresh_token": refresh_info["token"],
|
||||
"token_type": "Bearer",
|
||||
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
||||
"signing_key": _signing_key_b64(),
|
||||
}
|
||||
)
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
def _load_tls_bundle(path: str) -> str:
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as fh:
|
||||
return fh.read()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _mask_code(code: str) -> str:
|
||||
if not code:
|
||||
return "<missing>"
|
||||
trimmed = str(code).strip()
|
||||
if len(trimmed) <= 6:
|
||||
return "***"
|
||||
return f"{trimmed[:3]}***{trimmed[-3:]}"
|
||||
@@ -1,26 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def normalize_guid(value: Optional[str]) -> str:
|
||||
"""
|
||||
Canonicalize GUID strings so the server treats different casings/formats uniformly.
|
||||
"""
|
||||
candidate = (value or "").strip()
|
||||
if not candidate:
|
||||
return ""
|
||||
candidate = candidate.strip("{}")
|
||||
try:
|
||||
return str(uuid.UUID(candidate)).upper()
|
||||
except Exception:
|
||||
cleaned = "".join(ch for ch in candidate if ch in string.hexdigits or ch == "-")
|
||||
cleaned = cleaned.strip("-")
|
||||
if cleaned:
|
||||
try:
|
||||
return str(uuid.UUID(cleaned)).upper()
|
||||
except Exception:
|
||||
pass
|
||||
return candidate.upper()
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import eventlet
|
||||
from flask_socketio import SocketIO
|
||||
|
||||
|
||||
def start_prune_job(
|
||||
socketio: SocketIO,
|
||||
*,
|
||||
db_conn_factory: Callable[[], any],
|
||||
log: Callable[[str, str, Optional[str]], None],
|
||||
) -> None:
|
||||
def _job_loop():
|
||||
while True:
|
||||
try:
|
||||
_run_once(db_conn_factory, log)
|
||||
except Exception as exc:
|
||||
log("server", f"prune job failure: {exc}")
|
||||
eventlet.sleep(24 * 60 * 60)
|
||||
|
||||
socketio.start_background_task(_job_loop)
|
||||
|
||||
|
||||
def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str, Optional[str]], None]) -> None:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
now_iso = now.isoformat()
|
||||
stale_before = (now - timedelta(hours=24)).isoformat()
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
persistent_table_exists = False
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT 1 FROM sqlite_master WHERE type='table' AND name='enrollment_install_codes_persistent'"
|
||||
)
|
||||
persistent_table_exists = cur.fetchone() is not None
|
||||
except Exception:
|
||||
persistent_table_exists = False
|
||||
|
||||
expired_ids: List[str] = []
|
||||
if persistent_table_exists:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id
|
||||
FROM enrollment_install_codes
|
||||
WHERE use_count = 0
|
||||
AND expires_at < ?
|
||||
""",
|
||||
(now_iso,),
|
||||
)
|
||||
expired_ids = [str(row[0]) for row in cur.fetchall() if row and row[0]]
|
||||
cur.execute(
|
||||
"""
|
||||
DELETE FROM enrollment_install_codes
|
||||
WHERE use_count = 0
|
||||
AND expires_at < ?
|
||||
""",
|
||||
(now_iso,),
|
||||
)
|
||||
codes_pruned = cur.rowcount or 0
|
||||
if expired_ids:
|
||||
placeholders = ",".join("?" for _ in expired_ids)
|
||||
try:
|
||||
cur.execute(
|
||||
f"""
|
||||
UPDATE enrollment_install_codes_persistent
|
||||
SET is_active = 0,
|
||||
archived_at = COALESCE(archived_at, ?)
|
||||
WHERE id IN ({placeholders})
|
||||
""",
|
||||
(now_iso, *expired_ids),
|
||||
)
|
||||
except Exception:
|
||||
# Best-effort archival; continue if the persistence table is absent.
|
||||
pass
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET status = 'expired',
|
||||
updated_at = ?
|
||||
WHERE status = 'pending'
|
||||
AND (
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM enrollment_install_codes c
|
||||
WHERE c.id = device_approvals.enrollment_code_id
|
||||
AND (
|
||||
c.expires_at < ?
|
||||
OR c.use_count >= c.max_uses
|
||||
)
|
||||
)
|
||||
OR created_at < ?
|
||||
)
|
||||
""",
|
||||
(now_iso, now_iso, stale_before),
|
||||
)
|
||||
approvals_marked = cur.rowcount or 0
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if codes_pruned:
|
||||
log("server", f"prune job removed {codes_pruned} expired enrollment codes")
|
||||
if approvals_marked:
|
||||
log("server", f"prune job expired {approvals_marked} device approvals")
|
||||
@@ -1,168 +0,0 @@
|
||||
"""Utility helpers for locating runtime storage paths.
|
||||
|
||||
The Borealis repository keeps the authoritative source code under ``Data/``
|
||||
so that the bootstrap scripts can copy those assets into sibling ``Server/``
|
||||
and ``Agent/`` directories for execution. Runtime artefacts such as TLS
|
||||
certificates or signing keys must therefore live outside ``Data`` to avoid
|
||||
polluting the template tree. This module centralises the path selection so
|
||||
other modules can rely on a consistent location regardless of whether they
|
||||
are executed from the copied runtime directory or directly from ``Data``
|
||||
during development.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _env_path(name: str) -> Optional[Path]:
|
||||
"""Return a resolved ``Path`` for the given environment variable."""
|
||||
|
||||
value = os.environ.get(name)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return Path(value).expanduser().resolve()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def project_root() -> Path:
|
||||
"""Best-effort detection of the repository root."""
|
||||
|
||||
env = _env_path("BOREALIS_PROJECT_ROOT")
|
||||
if env:
|
||||
return env
|
||||
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "Borealis.ps1").exists() or (parent / ".git").is_dir():
|
||||
return parent
|
||||
|
||||
# Fallback to the ancestor that corresponds to ``<repo>/`` when the module
|
||||
# lives under ``Data/Server/Modules``.
|
||||
try:
|
||||
return current.parents[4]
|
||||
except IndexError:
|
||||
return current.parent
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def server_runtime_root() -> Path:
|
||||
"""Location where the running server stores mutable artefacts."""
|
||||
|
||||
env = _env_path("BOREALIS_SERVER_ROOT")
|
||||
if env:
|
||||
return env
|
||||
|
||||
root = project_root()
|
||||
runtime = root / "Server" / "Borealis"
|
||||
return runtime
|
||||
|
||||
|
||||
def runtime_path(*parts: str) -> Path:
|
||||
"""Return a path relative to the server runtime root."""
|
||||
|
||||
return server_runtime_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_runtime_dir(*parts: str) -> Path:
|
||||
"""Create (if required) and return a runtime directory."""
|
||||
|
||||
path = runtime_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def certificates_root() -> Path:
|
||||
"""Base directory for persisted certificate material."""
|
||||
|
||||
env = _env_path("BOREALIS_CERTIFICATES_ROOT") or _env_path("BOREALIS_CERT_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = project_root() / "Certificates"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
# Ensure expected subdirectories exist for agent and server material.
|
||||
try:
|
||||
(root / "Server").mkdir(parents=True, exist_ok=True)
|
||||
(root / "Agent").mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
return root
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def server_certificates_root() -> Path:
|
||||
"""Base directory for server certificate material."""
|
||||
|
||||
env = _env_path("BOREALIS_SERVER_CERT_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = certificates_root() / "Server"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def agent_certificates_root() -> Path:
|
||||
"""Base directory for agent certificate material."""
|
||||
|
||||
env = _env_path("BOREALIS_AGENT_CERT_ROOT")
|
||||
if env:
|
||||
env.mkdir(parents=True, exist_ok=True)
|
||||
return env
|
||||
|
||||
root = certificates_root() / "Agent"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def certificates_path(*parts: str) -> Path:
|
||||
"""Return a path under the certificates root."""
|
||||
|
||||
return certificates_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_certificates_dir(*parts: str) -> Path:
|
||||
"""Create (if required) and return a certificates subdirectory."""
|
||||
|
||||
path = certificates_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def server_certificates_path(*parts: str) -> Path:
|
||||
"""Return a path under the server certificates root."""
|
||||
|
||||
return server_certificates_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_server_certificates_dir(*parts: str) -> Path:
|
||||
"""Create (if required) and return a server certificates subdirectory."""
|
||||
|
||||
path = server_certificates_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def agent_certificates_path(*parts: str) -> Path:
|
||||
"""Return a path under the agent certificates root."""
|
||||
|
||||
return agent_certificates_root().joinpath(*parts)
|
||||
|
||||
|
||||
def ensure_agent_certificates_dir(*parts: str) -> Path:
|
||||
"""Create (if required) and return an agent certificates subdirectory."""
|
||||
|
||||
path = agent_certificates_path(*parts)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable
|
||||
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
||||
|
||||
|
||||
def register(
|
||||
app,
|
||||
*,
|
||||
db_conn_factory: Callable[[], sqlite3.Connection],
|
||||
jwt_service,
|
||||
dpop_validator: DPoPValidator,
|
||||
) -> None:
|
||||
blueprint = Blueprint("tokens", __name__)
|
||||
|
||||
def _hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
def _iso_now() -> str:
|
||||
return datetime.now(tz=timezone.utc).isoformat()
|
||||
|
||||
def _parse_iso(ts: str) -> datetime:
|
||||
return datetime.fromisoformat(ts)
|
||||
|
||||
@blueprint.route("/api/agent/token/refresh", methods=["POST"])
|
||||
def refresh():
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
guid = str(payload.get("guid") or "").strip()
|
||||
refresh_token = str(payload.get("refresh_token") or "").strip()
|
||||
|
||||
if not guid or not refresh_token:
|
||||
return jsonify({"error": "invalid_request"}), 400
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at
|
||||
FROM refresh_tokens
|
||||
WHERE guid = ?
|
||||
AND token_hash = ?
|
||||
""",
|
||||
(guid, _hash_token(refresh_token)),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return jsonify({"error": "invalid_refresh_token"}), 401
|
||||
|
||||
record_id, row_guid, _token_hash, stored_jkt, created_at, expires_at, revoked_at = row
|
||||
if row_guid != guid:
|
||||
return jsonify({"error": "invalid_refresh_token"}), 401
|
||||
if revoked_at:
|
||||
return jsonify({"error": "refresh_token_revoked"}), 401
|
||||
if expires_at:
|
||||
try:
|
||||
if _parse_iso(expires_at) <= datetime.now(tz=timezone.utc):
|
||||
return jsonify({"error": "refresh_token_expired"}), 401
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(guid,),
|
||||
)
|
||||
device_row = cur.fetchone()
|
||||
if not device_row:
|
||||
return jsonify({"error": "device_not_found"}), 404
|
||||
|
||||
device_guid, fingerprint, token_version, status = device_row
|
||||
status_norm = (status or "active").strip().lower()
|
||||
if status_norm in {"revoked", "decommissioned"}:
|
||||
return jsonify({"error": "device_revoked"}), 403
|
||||
|
||||
dpop_proof = request.headers.get("DPoP")
|
||||
jkt = stored_jkt or ""
|
||||
if dpop_proof:
|
||||
try:
|
||||
jkt = dpop_validator.verify(request.method, request.url, dpop_proof, access_token=None)
|
||||
except DPoPReplayError:
|
||||
return jsonify({"error": "dpop_replayed"}), 400
|
||||
except DPoPVerificationError:
|
||||
return jsonify({"error": "dpop_invalid"}), 400
|
||||
elif stored_jkt:
|
||||
# The agent does not yet emit DPoP proofs; allow recovery by clearing
|
||||
# the stored binding so refreshes can succeed. This preserves
|
||||
# backward compatibility while the client gains full DPoP support.
|
||||
try:
|
||||
app.logger.warning(
|
||||
"Clearing stored DPoP binding for guid=%s due to missing proof",
|
||||
guid,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
cur.execute(
|
||||
"UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?",
|
||||
(record_id,),
|
||||
)
|
||||
|
||||
new_access_token = jwt_service.issue_access_token(
|
||||
guid,
|
||||
fingerprint or "",
|
||||
token_version or 1,
|
||||
)
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET last_used_at = ?,
|
||||
dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt)
|
||||
WHERE id = ?
|
||||
""",
|
||||
(_iso_now(), jkt, record_id),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"access_token": new_access_token,
|
||||
"expires_in": 900,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
)
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
Reference in New Issue
Block a user