mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 17:41:58 -06:00
479 lines
17 KiB
Python
479 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import secrets
|
|
import sqlite3
|
|
import uuid
|
|
from datetime import datetime, timezone, timedelta
|
|
import time
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
from flask import Blueprint, jsonify, request
|
|
|
|
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
|
from Modules.crypto import keys as crypto_keys
|
|
from Modules.enrollment.nonce_store import NonceCache
|
|
from cryptography.hazmat.primitives import serialization
|
|
|
|
|
|
def register(
|
|
app,
|
|
*,
|
|
db_conn_factory: Callable[[], sqlite3.Connection],
|
|
log: Callable[[str, str], None],
|
|
jwt_service,
|
|
tls_bundle_path: str,
|
|
ip_rate_limiter: SlidingWindowRateLimiter,
|
|
fp_rate_limiter: SlidingWindowRateLimiter,
|
|
nonce_cache: NonceCache,
|
|
) -> None:
|
|
blueprint = Blueprint("enrollment", __name__)
|
|
|
|
def _now() -> datetime:
|
|
return datetime.now(tz=timezone.utc)
|
|
|
|
def _iso(dt: datetime) -> str:
|
|
return dt.isoformat()
|
|
|
|
def _remote_addr() -> str:
|
|
forwarded = request.headers.get("X-Forwarded-For")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
addr = request.remote_addr or "unknown"
|
|
return addr.strip()
|
|
|
|
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
|
|
decision = limiter.check(key, limit, window_s)
|
|
if not decision.allowed:
|
|
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
|
|
response.status_code = 429
|
|
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
|
|
return response
|
|
return None
|
|
|
|
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
|
|
cur.execute(
|
|
"SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?",
|
|
(code_value,),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
return None
|
|
keys = ["id", "code", "expires_at", "used_at"]
|
|
record = dict(zip(keys, row))
|
|
return record
|
|
|
|
def _install_code_valid(record: Dict[str, Any]) -> bool:
|
|
if not record:
|
|
return False
|
|
expires_at = record.get("expires_at")
|
|
if not isinstance(expires_at, str):
|
|
return False
|
|
try:
|
|
expiry = datetime.fromisoformat(expires_at)
|
|
except Exception:
|
|
return False
|
|
if expiry <= _now():
|
|
return False
|
|
if record.get("used_at"):
|
|
return False
|
|
return True
|
|
|
|
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
|
|
base = (hostname or "").strip() or guid
|
|
base = base[:253]
|
|
candidate = base
|
|
suffix = 1
|
|
while True:
|
|
cur.execute(
|
|
"SELECT guid FROM devices WHERE hostname = ? AND guid != ?",
|
|
(candidate, guid),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
return candidate
|
|
candidate = f"{base}-{suffix}"
|
|
suffix += 1
|
|
if suffix > 50:
|
|
return f"{guid}"
|
|
|
|
def _store_device_key(cur: sqlite3.Cursor, guid: str, fingerprint: str) -> None:
|
|
added_at = _iso(_now())
|
|
cur.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
|
|
VALUES (?, ?, ?, ?)
|
|
""",
|
|
(str(uuid.uuid4()), guid, fingerprint, added_at),
|
|
)
|
|
cur.execute(
|
|
"""
|
|
UPDATE device_keys
|
|
SET retired_at = ?
|
|
WHERE guid = ?
|
|
AND ssl_key_fingerprint != ?
|
|
AND retired_at IS NULL
|
|
""",
|
|
(_iso(_now()), guid, fingerprint),
|
|
)
|
|
|
|
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
|
|
cur.execute(
|
|
"SELECT guid, hostname, token_version, status, ssl_key_fingerprint FROM devices WHERE guid = ?",
|
|
(guid,),
|
|
)
|
|
row = cur.fetchone()
|
|
if row:
|
|
keys = ["guid", "hostname", "token_version", "status", "ssl_key_fingerprint"]
|
|
record = dict(zip(keys, row))
|
|
if not record.get("ssl_key_fingerprint"):
|
|
cur.execute(
|
|
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
|
(fingerprint, _iso(_now()), guid),
|
|
)
|
|
record["ssl_key_fingerprint"] = fingerprint
|
|
return record
|
|
|
|
resolved_hostname = _normalize_host(hostname, guid, cur)
|
|
created_at = int(time.time())
|
|
key_added_at = _iso(_now())
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO devices (
|
|
guid, hostname, created_at, last_seen, ssl_key_fingerprint,
|
|
token_version, status, key_added_at
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
|
|
""",
|
|
(
|
|
guid,
|
|
resolved_hostname,
|
|
created_at,
|
|
created_at,
|
|
fingerprint,
|
|
key_added_at,
|
|
),
|
|
)
|
|
return {
|
|
"guid": guid,
|
|
"hostname": resolved_hostname,
|
|
"token_version": 1,
|
|
"status": "active",
|
|
"ssl_key_fingerprint": fingerprint,
|
|
}
|
|
|
|
def _hash_refresh_token(token: str) -> str:
|
|
import hashlib
|
|
|
|
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
|
|
|
def _issue_refresh_token(cur: sqlite3.Cursor, guid: str) -> Dict[str, Any]:
|
|
token = secrets.token_urlsafe(48)
|
|
now = _now()
|
|
expires_at = now.replace(microsecond=0) + timedelta(days=30)
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO refresh_tokens (id, guid, token_hash, created_at, expires_at)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
str(uuid.uuid4()),
|
|
guid,
|
|
_hash_refresh_token(token),
|
|
_iso(now),
|
|
_iso(expires_at),
|
|
),
|
|
)
|
|
return {"token": token, "expires_at": expires_at}
|
|
|
|
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
|
|
def enrollment_request():
|
|
remote = _remote_addr()
|
|
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 10, 60.0)
|
|
if rate_error:
|
|
return rate_error
|
|
|
|
payload = request.get_json(force=True, silent=True) or {}
|
|
hostname = str(payload.get("hostname") or "").strip()
|
|
enrollment_code = str(payload.get("enrollment_code") or "").strip()
|
|
agent_pubkey_b64 = payload.get("agent_pubkey")
|
|
client_nonce_b64 = payload.get("client_nonce")
|
|
|
|
if not hostname:
|
|
return jsonify({"error": "hostname_required"}), 400
|
|
if not enrollment_code:
|
|
return jsonify({"error": "enrollment_code_required"}), 400
|
|
if not isinstance(agent_pubkey_b64, str):
|
|
return jsonify({"error": "agent_pubkey_required"}), 400
|
|
if not isinstance(client_nonce_b64, str):
|
|
return jsonify({"error": "client_nonce_required"}), 400
|
|
|
|
try:
|
|
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
|
|
except Exception:
|
|
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
|
|
|
if len(agent_pubkey_der) < 10:
|
|
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
|
|
|
try:
|
|
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
|
except Exception:
|
|
return jsonify({"error": "invalid_client_nonce"}), 400
|
|
if len(client_nonce_bytes) < 16:
|
|
return jsonify({"error": "invalid_client_nonce"}), 400
|
|
|
|
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
|
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 3, 60.0)
|
|
if rate_error:
|
|
return rate_error
|
|
|
|
conn = db_conn_factory()
|
|
try:
|
|
cur = conn.cursor()
|
|
install_code = _load_install_code(cur, enrollment_code)
|
|
if not _install_code_valid(install_code):
|
|
return jsonify({"error": "invalid_enrollment_code"}), 400
|
|
|
|
approval_reference: str
|
|
record_id: str
|
|
server_nonce_bytes = secrets.token_bytes(32)
|
|
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
|
|
now = _iso(_now())
|
|
|
|
cur.execute(
|
|
"""
|
|
SELECT id, approval_reference
|
|
FROM device_approvals
|
|
WHERE ssl_key_fingerprint_claimed = ?
|
|
AND status = 'pending'
|
|
""",
|
|
(fingerprint,),
|
|
)
|
|
existing = cur.fetchone()
|
|
if existing:
|
|
record_id = existing[0]
|
|
approval_reference = existing[1]
|
|
cur.execute(
|
|
"""
|
|
UPDATE device_approvals
|
|
SET hostname_claimed = ?,
|
|
enrollment_code_id = ?,
|
|
client_nonce = ?,
|
|
server_nonce = ?,
|
|
agent_pubkey_der = ?,
|
|
updated_at = ?
|
|
WHERE id = ?
|
|
""",
|
|
(
|
|
hostname,
|
|
install_code["id"],
|
|
client_nonce_b64,
|
|
server_nonce_b64,
|
|
agent_pubkey_der,
|
|
now,
|
|
record_id,
|
|
),
|
|
)
|
|
else:
|
|
record_id = str(uuid.uuid4())
|
|
approval_reference = str(uuid.uuid4())
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO device_approvals (
|
|
id, approval_reference, guid, hostname_claimed,
|
|
ssl_key_fingerprint_claimed, enrollment_code_id,
|
|
status, client_nonce, server_nonce, agent_pubkey_der,
|
|
created_at, updated_at
|
|
)
|
|
VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
record_id,
|
|
approval_reference,
|
|
hostname,
|
|
fingerprint,
|
|
install_code["id"],
|
|
client_nonce_b64,
|
|
server_nonce_b64,
|
|
agent_pubkey_der,
|
|
now,
|
|
now,
|
|
),
|
|
)
|
|
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
response = {
|
|
"status": "pending",
|
|
"approval_reference": approval_reference,
|
|
"server_nonce": server_nonce_b64,
|
|
"poll_after_ms": 3000,
|
|
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
|
}
|
|
log("server", f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}")
|
|
return jsonify(response)
|
|
|
|
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
|
|
def enrollment_poll():
|
|
payload = request.get_json(force=True, silent=True) or {}
|
|
approval_reference = payload.get("approval_reference")
|
|
client_nonce_b64 = payload.get("client_nonce")
|
|
proof_sig_b64 = payload.get("proof_sig")
|
|
|
|
if not isinstance(approval_reference, str) or not approval_reference:
|
|
return jsonify({"error": "approval_reference_required"}), 400
|
|
if not isinstance(client_nonce_b64, str):
|
|
return jsonify({"error": "client_nonce_required"}), 400
|
|
if not isinstance(proof_sig_b64, str):
|
|
return jsonify({"error": "proof_sig_required"}), 400
|
|
|
|
try:
|
|
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
|
except Exception:
|
|
return jsonify({"error": "invalid_client_nonce"}), 400
|
|
|
|
try:
|
|
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
|
|
except Exception:
|
|
return jsonify({"error": "invalid_proof_sig"}), 400
|
|
|
|
conn = db_conn_factory()
|
|
try:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed,
|
|
enrollment_code_id, status, client_nonce, server_nonce,
|
|
agent_pubkey_der, created_at, updated_at, approved_by_user_id
|
|
FROM device_approvals
|
|
WHERE approval_reference = ?
|
|
""",
|
|
(approval_reference,),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
return jsonify({"status": "unknown"}), 404
|
|
|
|
(
|
|
record_id,
|
|
guid,
|
|
hostname_claimed,
|
|
fingerprint,
|
|
enrollment_code_id,
|
|
status,
|
|
client_nonce_stored,
|
|
server_nonce_b64,
|
|
agent_pubkey_der,
|
|
created_at,
|
|
updated_at,
|
|
approved_by,
|
|
) = row
|
|
|
|
if client_nonce_stored != client_nonce_b64:
|
|
return jsonify({"error": "nonce_mismatch"}), 400
|
|
|
|
try:
|
|
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
|
|
except Exception:
|
|
return jsonify({"error": "server_nonce_invalid"}), 400
|
|
|
|
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
|
|
|
|
try:
|
|
public_key = serialization.load_der_public_key(agent_pubkey_der)
|
|
except Exception:
|
|
public_key = None
|
|
|
|
if public_key is None:
|
|
return jsonify({"error": "agent_pubkey_invalid"}), 400
|
|
|
|
try:
|
|
public_key.verify(proof_sig, message)
|
|
except Exception:
|
|
return jsonify({"error": "invalid_proof"}), 400
|
|
|
|
if status == "pending":
|
|
return jsonify({"status": "pending", "poll_after_ms": 5000})
|
|
if status == "denied":
|
|
return jsonify({"status": "denied", "reason": "operator_denied"})
|
|
if status == "expired":
|
|
return jsonify({"status": "expired"})
|
|
|
|
if status != "approved":
|
|
return jsonify({"status": status or "unknown"}), 400
|
|
|
|
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
|
|
if not nonce_cache.consume(nonce_key):
|
|
return jsonify({"error": "proof_replayed"}), 409
|
|
|
|
# Finalize enrollment
|
|
effective_guid = guid or str(uuid.uuid4())
|
|
now_iso = _iso(_now())
|
|
|
|
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
|
|
_store_device_key(cur, effective_guid, fingerprint)
|
|
|
|
# Mark install code used
|
|
if enrollment_code_id:
|
|
cur.execute(
|
|
"""
|
|
UPDATE enrollment_install_codes
|
|
SET used_at = ?, used_by_guid = ?
|
|
WHERE id = ?
|
|
AND used_at IS NULL
|
|
""",
|
|
(now_iso, effective_guid, enrollment_code_id),
|
|
)
|
|
|
|
# Update approval record with final state
|
|
cur.execute(
|
|
"""
|
|
UPDATE device_approvals
|
|
SET guid = ?,
|
|
status = 'completed',
|
|
updated_at = ?
|
|
WHERE id = ?
|
|
""",
|
|
(effective_guid, now_iso, record_id),
|
|
)
|
|
|
|
refresh_info = _issue_refresh_token(cur, effective_guid)
|
|
access_token = jwt_service.issue_access_token(
|
|
effective_guid,
|
|
fingerprint,
|
|
device_record.get("token_version") or 1,
|
|
)
|
|
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
log(
|
|
"server",
|
|
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
|
|
)
|
|
return jsonify(
|
|
{
|
|
"status": "approved",
|
|
"guid": effective_guid,
|
|
"access_token": access_token,
|
|
"expires_in": 900,
|
|
"refresh_token": refresh_info["token"],
|
|
"token_type": "Bearer",
|
|
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
|
}
|
|
)
|
|
|
|
app.register_blueprint(blueprint)
|
|
|
|
|
|
def _load_tls_bundle(path: str) -> str:
|
|
try:
|
|
with open(path, "r", encoding="utf-8") as fh:
|
|
return fh.read()
|
|
except Exception:
|
|
return ""
|