mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 15:21:57 -06:00
feat: add agent enrollment endpoints and nonce protections
This commit is contained in:
35
Data/Server/Modules/enrollment/nonce_store.py
Normal file
35
Data/Server/Modules/enrollment/nonce_store.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Short-lived nonce cache to defend against replay attacks during enrollment.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class NonceCache:
|
||||
def __init__(self, ttl_seconds: float = 300.0) -> None:
|
||||
self._ttl = ttl_seconds
|
||||
self._entries: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def consume(self, key: str) -> bool:
|
||||
"""
|
||||
Attempt to consume the nonce identified by `key`.
|
||||
|
||||
Returns True on first use within TTL, False if already consumed.
|
||||
"""
|
||||
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
expire_at = self._entries.get(key)
|
||||
if expire_at and expire_at > now:
|
||||
return False
|
||||
self._entries[key] = now + self._ttl
|
||||
# Opportunistic cleanup to keep the dict small
|
||||
stale = [nonce for nonce, expiry in self._entries.items() if expiry <= now]
|
||||
for nonce in stale:
|
||||
self._entries.pop(nonce, None)
|
||||
return True
|
||||
478
Data/Server/Modules/enrollment/routes.py
Normal file
478
Data/Server/Modules/enrollment/routes.py
Normal file
@@ -0,0 +1,478 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import secrets
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||
from Modules.crypto import keys as crypto_keys
|
||||
from Modules.enrollment.nonce_store import NonceCache
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
|
||||
def register(
|
||||
app,
|
||||
*,
|
||||
db_conn_factory: Callable[[], sqlite3.Connection],
|
||||
log: Callable[[str, str], None],
|
||||
jwt_service,
|
||||
tls_bundle_path: str,
|
||||
ip_rate_limiter: SlidingWindowRateLimiter,
|
||||
fp_rate_limiter: SlidingWindowRateLimiter,
|
||||
nonce_cache: NonceCache,
|
||||
) -> None:
|
||||
blueprint = Blueprint("enrollment", __name__)
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
def _iso(dt: datetime) -> str:
|
||||
return dt.isoformat()
|
||||
|
||||
def _remote_addr() -> str:
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
addr = request.remote_addr or "unknown"
|
||||
return addr.strip()
|
||||
|
||||
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
|
||||
decision = limiter.check(key, limit, window_s)
|
||||
if not decision.allowed:
|
||||
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
|
||||
response.status_code = 429
|
||||
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
|
||||
return response
|
||||
return None
|
||||
|
||||
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
|
||||
cur.execute(
|
||||
"SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?",
|
||||
(code_value,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
keys = ["id", "code", "expires_at", "used_at"]
|
||||
record = dict(zip(keys, row))
|
||||
return record
|
||||
|
||||
def _install_code_valid(record: Dict[str, Any]) -> bool:
|
||||
if not record:
|
||||
return False
|
||||
expires_at = record.get("expires_at")
|
||||
if not isinstance(expires_at, str):
|
||||
return False
|
||||
try:
|
||||
expiry = datetime.fromisoformat(expires_at)
|
||||
except Exception:
|
||||
return False
|
||||
if expiry <= _now():
|
||||
return False
|
||||
if record.get("used_at"):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
|
||||
base = (hostname or "").strip() or guid
|
||||
base = base[:253]
|
||||
candidate = base
|
||||
suffix = 1
|
||||
while True:
|
||||
cur.execute(
|
||||
"SELECT guid FROM devices WHERE hostname = ? AND guid != ?",
|
||||
(candidate, guid),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return candidate
|
||||
candidate = f"{base}-{suffix}"
|
||||
suffix += 1
|
||||
if suffix > 50:
|
||||
return f"{guid}"
|
||||
|
||||
def _store_device_key(cur: sqlite3.Cursor, guid: str, fingerprint: str) -> None:
|
||||
added_at = _iso(_now())
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(str(uuid.uuid4()), guid, fingerprint, added_at),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_keys
|
||||
SET retired_at = ?
|
||||
WHERE guid = ?
|
||||
AND ssl_key_fingerprint != ?
|
||||
AND retired_at IS NULL
|
||||
""",
|
||||
(_iso(_now()), guid, fingerprint),
|
||||
)
|
||||
|
||||
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
|
||||
cur.execute(
|
||||
"SELECT guid, hostname, token_version, status, ssl_key_fingerprint FROM devices WHERE guid = ?",
|
||||
(guid,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
keys = ["guid", "hostname", "token_version", "status", "ssl_key_fingerprint"]
|
||||
record = dict(zip(keys, row))
|
||||
if not record.get("ssl_key_fingerprint"):
|
||||
cur.execute(
|
||||
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
||||
(fingerprint, _iso(_now()), guid),
|
||||
)
|
||||
record["ssl_key_fingerprint"] = fingerprint
|
||||
return record
|
||||
|
||||
resolved_hostname = _normalize_host(hostname, guid, cur)
|
||||
created_at = int(time.time())
|
||||
key_added_at = _iso(_now())
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO devices (
|
||||
guid, hostname, created_at, last_seen, ssl_key_fingerprint,
|
||||
token_version, status, key_added_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
|
||||
""",
|
||||
(
|
||||
guid,
|
||||
resolved_hostname,
|
||||
created_at,
|
||||
created_at,
|
||||
fingerprint,
|
||||
key_added_at,
|
||||
),
|
||||
)
|
||||
return {
|
||||
"guid": guid,
|
||||
"hostname": resolved_hostname,
|
||||
"token_version": 1,
|
||||
"status": "active",
|
||||
"ssl_key_fingerprint": fingerprint,
|
||||
}
|
||||
|
||||
def _hash_refresh_token(token: str) -> str:
|
||||
import hashlib
|
||||
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
def _issue_refresh_token(cur: sqlite3.Cursor, guid: str) -> Dict[str, Any]:
|
||||
token = secrets.token_urlsafe(48)
|
||||
now = _now()
|
||||
expires_at = now.replace(microsecond=0) + timedelta(days=30)
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO refresh_tokens (id, guid, token_hash, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
guid,
|
||||
_hash_refresh_token(token),
|
||||
_iso(now),
|
||||
_iso(expires_at),
|
||||
),
|
||||
)
|
||||
return {"token": token, "expires_at": expires_at}
|
||||
|
||||
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
|
||||
def enrollment_request():
|
||||
remote = _remote_addr()
|
||||
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 10, 60.0)
|
||||
if rate_error:
|
||||
return rate_error
|
||||
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
hostname = str(payload.get("hostname") or "").strip()
|
||||
enrollment_code = str(payload.get("enrollment_code") or "").strip()
|
||||
agent_pubkey_b64 = payload.get("agent_pubkey")
|
||||
client_nonce_b64 = payload.get("client_nonce")
|
||||
|
||||
if not hostname:
|
||||
return jsonify({"error": "hostname_required"}), 400
|
||||
if not enrollment_code:
|
||||
return jsonify({"error": "enrollment_code_required"}), 400
|
||||
if not isinstance(agent_pubkey_b64, str):
|
||||
return jsonify({"error": "agent_pubkey_required"}), 400
|
||||
if not isinstance(client_nonce_b64, str):
|
||||
return jsonify({"error": "client_nonce_required"}), 400
|
||||
|
||||
try:
|
||||
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
|
||||
except Exception:
|
||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||
|
||||
if len(agent_pubkey_der) < 10:
|
||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
if len(client_nonce_bytes) < 16:
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
|
||||
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
||||
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 3, 60.0)
|
||||
if rate_error:
|
||||
return rate_error
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
install_code = _load_install_code(cur, enrollment_code)
|
||||
if not _install_code_valid(install_code):
|
||||
return jsonify({"error": "invalid_enrollment_code"}), 400
|
||||
|
||||
approval_reference: str
|
||||
record_id: str
|
||||
server_nonce_bytes = secrets.token_bytes(32)
|
||||
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
|
||||
now = _iso(_now())
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, approval_reference
|
||||
FROM device_approvals
|
||||
WHERE ssl_key_fingerprint_claimed = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(fingerprint,),
|
||||
)
|
||||
existing = cur.fetchone()
|
||||
if existing:
|
||||
record_id = existing[0]
|
||||
approval_reference = existing[1]
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET hostname_claimed = ?,
|
||||
enrollment_code_id = ?,
|
||||
client_nonce = ?,
|
||||
server_nonce = ?,
|
||||
agent_pubkey_der = ?,
|
||||
updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
hostname,
|
||||
install_code["id"],
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
now,
|
||||
record_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
record_id = str(uuid.uuid4())
|
||||
approval_reference = str(uuid.uuid4())
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO device_approvals (
|
||||
id, approval_reference, guid, hostname_claimed,
|
||||
ssl_key_fingerprint_claimed, enrollment_code_id,
|
||||
status, client_nonce, server_nonce, agent_pubkey_der,
|
||||
created_at, updated_at
|
||||
)
|
||||
VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
record_id,
|
||||
approval_reference,
|
||||
hostname,
|
||||
fingerprint,
|
||||
install_code["id"],
|
||||
client_nonce_b64,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
now,
|
||||
now,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
response = {
|
||||
"status": "pending",
|
||||
"approval_reference": approval_reference,
|
||||
"server_nonce": server_nonce_b64,
|
||||
"poll_after_ms": 3000,
|
||||
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
||||
}
|
||||
log("server", f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}")
|
||||
return jsonify(response)
|
||||
|
||||
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
|
||||
def enrollment_poll():
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
approval_reference = payload.get("approval_reference")
|
||||
client_nonce_b64 = payload.get("client_nonce")
|
||||
proof_sig_b64 = payload.get("proof_sig")
|
||||
|
||||
if not isinstance(approval_reference, str) or not approval_reference:
|
||||
return jsonify({"error": "approval_reference_required"}), 400
|
||||
if not isinstance(client_nonce_b64, str):
|
||||
return jsonify({"error": "client_nonce_required"}), 400
|
||||
if not isinstance(proof_sig_b64, str):
|
||||
return jsonify({"error": "proof_sig_required"}), 400
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
|
||||
try:
|
||||
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
|
||||
except Exception:
|
||||
return jsonify({"error": "invalid_proof_sig"}), 400
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed,
|
||||
enrollment_code_id, status, client_nonce, server_nonce,
|
||||
agent_pubkey_der, created_at, updated_at, approved_by_user_id
|
||||
FROM device_approvals
|
||||
WHERE approval_reference = ?
|
||||
""",
|
||||
(approval_reference,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return jsonify({"status": "unknown"}), 404
|
||||
|
||||
(
|
||||
record_id,
|
||||
guid,
|
||||
hostname_claimed,
|
||||
fingerprint,
|
||||
enrollment_code_id,
|
||||
status,
|
||||
client_nonce_stored,
|
||||
server_nonce_b64,
|
||||
agent_pubkey_der,
|
||||
created_at,
|
||||
updated_at,
|
||||
approved_by,
|
||||
) = row
|
||||
|
||||
if client_nonce_stored != client_nonce_b64:
|
||||
return jsonify({"error": "nonce_mismatch"}), 400
|
||||
|
||||
try:
|
||||
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
return jsonify({"error": "server_nonce_invalid"}), 400
|
||||
|
||||
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
|
||||
|
||||
try:
|
||||
public_key = serialization.load_der_public_key(agent_pubkey_der)
|
||||
except Exception:
|
||||
public_key = None
|
||||
|
||||
if public_key is None:
|
||||
return jsonify({"error": "agent_pubkey_invalid"}), 400
|
||||
|
||||
try:
|
||||
public_key.verify(proof_sig, message)
|
||||
except Exception:
|
||||
return jsonify({"error": "invalid_proof"}), 400
|
||||
|
||||
if status == "pending":
|
||||
return jsonify({"status": "pending", "poll_after_ms": 5000})
|
||||
if status == "denied":
|
||||
return jsonify({"status": "denied", "reason": "operator_denied"})
|
||||
if status == "expired":
|
||||
return jsonify({"status": "expired"})
|
||||
|
||||
if status != "approved":
|
||||
return jsonify({"status": status or "unknown"}), 400
|
||||
|
||||
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
|
||||
if not nonce_cache.consume(nonce_key):
|
||||
return jsonify({"error": "proof_replayed"}), 409
|
||||
|
||||
# Finalize enrollment
|
||||
effective_guid = guid or str(uuid.uuid4())
|
||||
now_iso = _iso(_now())
|
||||
|
||||
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
|
||||
_store_device_key(cur, effective_guid, fingerprint)
|
||||
|
||||
# Mark install code used
|
||||
if enrollment_code_id:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE enrollment_install_codes
|
||||
SET used_at = ?, used_by_guid = ?
|
||||
WHERE id = ?
|
||||
AND used_at IS NULL
|
||||
""",
|
||||
(now_iso, effective_guid, enrollment_code_id),
|
||||
)
|
||||
|
||||
# Update approval record with final state
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE device_approvals
|
||||
SET guid = ?,
|
||||
status = 'completed',
|
||||
updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(effective_guid, now_iso, record_id),
|
||||
)
|
||||
|
||||
refresh_info = _issue_refresh_token(cur, effective_guid)
|
||||
access_token = jwt_service.issue_access_token(
|
||||
effective_guid,
|
||||
fingerprint,
|
||||
device_record.get("token_version") or 1,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
log(
|
||||
"server",
|
||||
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
|
||||
)
|
||||
return jsonify(
|
||||
{
|
||||
"status": "approved",
|
||||
"guid": effective_guid,
|
||||
"access_token": access_token,
|
||||
"expires_in": 900,
|
||||
"refresh_token": refresh_info["token"],
|
||||
"token_type": "Bearer",
|
||||
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
||||
}
|
||||
)
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
|
||||
def _load_tls_bundle(path: str) -> str:
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as fh:
|
||||
return fh.read()
|
||||
except Exception:
|
||||
return ""
|
||||
@@ -50,7 +50,10 @@ from datetime import datetime, timezone
|
||||
|
||||
from Modules import db_migrations
|
||||
from Modules.auth import jwt_service as jwt_service_module
|
||||
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||
from Modules.crypto import certificates
|
||||
from Modules.enrollment import routes as enrollment_routes
|
||||
from Modules.enrollment.nonce_store import NonceCache
|
||||
|
||||
try:
|
||||
from cryptography.fernet import Fernet # type: ignore
|
||||
@@ -146,6 +149,9 @@ os.environ.setdefault("BOREALIS_TLS_KEY", TLS_KEY_PATH)
|
||||
os.environ.setdefault("BOREALIS_TLS_BUNDLE", TLS_BUNDLE_PATH)
|
||||
|
||||
JWT_SERVICE = jwt_service_module.load_service()
|
||||
IP_RATE_LIMITER = SlidingWindowRateLimiter()
|
||||
FP_RATE_LIMITER = SlidingWindowRateLimiter()
|
||||
ENROLLMENT_NONCE_CACHE = NonceCache()
|
||||
|
||||
|
||||
def _set_cached_github_token(token: Optional[str]) -> None:
|
||||
@@ -4819,6 +4825,17 @@ def init_db():
|
||||
|
||||
init_db()
|
||||
|
||||
enrollment_routes.register(
|
||||
app,
|
||||
db_conn_factory=_db_conn,
|
||||
log=_write_service_log,
|
||||
jwt_service=JWT_SERVICE,
|
||||
tls_bundle_path=TLS_BUNDLE_PATH,
|
||||
ip_rate_limiter=IP_RATE_LIMITER,
|
||||
fp_rate_limiter=FP_RATE_LIMITER,
|
||||
nonce_cache=ENROLLMENT_NONCE_CACHE,
|
||||
)
|
||||
|
||||
|
||||
def ensure_default_admin():
|
||||
"""Ensure at least one admin user exists.
|
||||
|
||||
Reference in New Issue
Block a user