Files
Borealis-Github-Replica/Data/Server/Modules/enrollment/routes.py

608 lines
22 KiB
Python

from __future__ import annotations
import base64
import secrets
import sqlite3
import uuid
from datetime import datetime, timezone, timedelta
import time
from typing import Any, Callable, Dict, Optional
from flask import Blueprint, jsonify, request
from Modules.auth.rate_limit import SlidingWindowRateLimiter
from Modules.crypto import keys as crypto_keys
from Modules.enrollment.nonce_store import NonceCache
from cryptography.hazmat.primitives import serialization
def register(
app,
*,
db_conn_factory: Callable[[], sqlite3.Connection],
log: Callable[[str, str], None],
jwt_service,
tls_bundle_path: str,
ip_rate_limiter: SlidingWindowRateLimiter,
fp_rate_limiter: SlidingWindowRateLimiter,
nonce_cache: NonceCache,
script_signer,
) -> None:
blueprint = Blueprint("enrollment", __name__)
def _now() -> datetime:
return datetime.now(tz=timezone.utc)
def _iso(dt: datetime) -> str:
return dt.isoformat()
def _remote_addr() -> str:
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
addr = request.remote_addr or "unknown"
return addr.strip()
def _signing_key_b64() -> str:
if not script_signer:
return ""
try:
return script_signer.public_base64_spki()
except Exception:
return ""
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
decision = limiter.check(key, limit, window_s)
if not decision.allowed:
log(
"server",
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
)
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
response.status_code = 429
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
return response
return None
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
cur.execute(
"SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?",
(code_value,),
)
row = cur.fetchone()
if not row:
return None
keys = ["id", "code", "expires_at", "used_at"]
record = dict(zip(keys, row))
return record
def _install_code_valid(record: Dict[str, Any]) -> bool:
if not record:
return False
expires_at = record.get("expires_at")
if not isinstance(expires_at, str):
return False
try:
expiry = datetime.fromisoformat(expires_at)
except Exception:
return False
if expiry <= _now():
return False
if record.get("used_at"):
return False
return True
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
base = (hostname or "").strip() or guid
base = base[:253]
candidate = base
suffix = 1
while True:
cur.execute(
"SELECT guid FROM devices WHERE hostname = ? AND guid != ?",
(candidate, guid),
)
row = cur.fetchone()
if not row:
return candidate
candidate = f"{base}-{suffix}"
suffix += 1
if suffix > 50:
return f"{guid}"
def _store_device_key(cur: sqlite3.Cursor, guid: str, fingerprint: str) -> None:
added_at = _iso(_now())
cur.execute(
"""
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
VALUES (?, ?, ?, ?)
""",
(str(uuid.uuid4()), guid, fingerprint, added_at),
)
cur.execute(
"""
UPDATE device_keys
SET retired_at = ?
WHERE guid = ?
AND ssl_key_fingerprint != ?
AND retired_at IS NULL
""",
(_iso(_now()), guid, fingerprint),
)
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
cur.execute(
"""
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
FROM devices
WHERE guid = ?
""",
(guid,),
)
row = cur.fetchone()
if row:
keys = [
"guid",
"hostname",
"token_version",
"status",
"ssl_key_fingerprint",
"key_added_at",
]
record = dict(zip(keys, row))
stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower()
new_fp = (fingerprint or "").strip().lower()
if not stored_fp and new_fp:
cur.execute(
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
(fingerprint, _iso(_now()), guid),
)
record["ssl_key_fingerprint"] = fingerprint
elif new_fp and stored_fp != new_fp:
now_iso = _iso(_now())
try:
current_version = int(record.get("token_version") or 1)
except Exception:
current_version = 1
new_version = max(current_version + 1, 1)
cur.execute(
"""
UPDATE devices
SET ssl_key_fingerprint = ?,
key_added_at = ?,
token_version = ?,
status = 'active'
WHERE guid = ?
""",
(fingerprint, now_iso, new_version, guid),
)
cur.execute(
"""
UPDATE refresh_tokens
SET revoked_at = ?
WHERE guid = ?
AND revoked_at IS NULL
""",
(now_iso, guid),
)
record["ssl_key_fingerprint"] = fingerprint
record["token_version"] = new_version
record["status"] = "active"
record["key_added_at"] = now_iso
return record
resolved_hostname = _normalize_host(hostname, guid, cur)
created_at = int(time.time())
key_added_at = _iso(_now())
cur.execute(
"""
INSERT INTO devices (
guid, hostname, created_at, last_seen, ssl_key_fingerprint,
token_version, status, key_added_at
)
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
""",
(
guid,
resolved_hostname,
created_at,
created_at,
fingerprint,
key_added_at,
),
)
return {
"guid": guid,
"hostname": resolved_hostname,
"token_version": 1,
"status": "active",
"ssl_key_fingerprint": fingerprint,
"key_added_at": key_added_at,
}
def _hash_refresh_token(token: str) -> str:
import hashlib
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _issue_refresh_token(cur: sqlite3.Cursor, guid: str) -> Dict[str, Any]:
token = secrets.token_urlsafe(48)
now = _now()
expires_at = now.replace(microsecond=0) + timedelta(days=30)
cur.execute(
"""
INSERT INTO refresh_tokens (id, guid, token_hash, created_at, expires_at)
VALUES (?, ?, ?, ?, ?)
""",
(
str(uuid.uuid4()),
guid,
_hash_refresh_token(token),
_iso(now),
_iso(expires_at),
),
)
return {"token": token, "expires_at": expires_at}
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
def enrollment_request():
remote = _remote_addr()
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0)
if rate_error:
return rate_error
payload = request.get_json(force=True, silent=True) or {}
hostname = str(payload.get("hostname") or "").strip()
enrollment_code = str(payload.get("enrollment_code") or "").strip()
agent_pubkey_b64 = payload.get("agent_pubkey")
client_nonce_b64 = payload.get("client_nonce")
log(
"server",
"enrollment request received "
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
)
if not hostname:
log("server", f"enrollment rejected missing_hostname ip={remote}")
return jsonify({"error": "hostname_required"}), 400
if not enrollment_code:
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}")
return jsonify({"error": "enrollment_code_required"}), 400
if not isinstance(agent_pubkey_b64, str):
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}")
return jsonify({"error": "agent_pubkey_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}")
return jsonify({"error": "client_nonce_required"}), 400
try:
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
except Exception:
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}")
return jsonify({"error": "invalid_agent_pubkey"}), 400
if len(agent_pubkey_der) < 10:
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}")
return jsonify({"error": "invalid_agent_pubkey"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}")
return jsonify({"error": "invalid_client_nonce"}), 400
if len(client_nonce_bytes) < 16:
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}")
return jsonify({"error": "invalid_client_nonce"}), 400
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0)
if rate_error:
return rate_error
conn = db_conn_factory()
try:
cur = conn.cursor()
install_code = _load_install_code(cur, enrollment_code)
if not _install_code_valid(install_code):
return jsonify({"error": "invalid_enrollment_code"}), 400
approval_reference: str
record_id: str
server_nonce_bytes = secrets.token_bytes(32)
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
now = _iso(_now())
cur.execute(
"""
SELECT id, approval_reference
FROM device_approvals
WHERE ssl_key_fingerprint_claimed = ?
AND status = 'pending'
""",
(fingerprint,),
)
existing = cur.fetchone()
if existing:
record_id = existing[0]
approval_reference = existing[1]
cur.execute(
"""
UPDATE device_approvals
SET hostname_claimed = ?,
enrollment_code_id = ?,
client_nonce = ?,
server_nonce = ?,
agent_pubkey_der = ?,
updated_at = ?
WHERE id = ?
""",
(
hostname,
install_code["id"],
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
now,
record_id,
),
)
else:
record_id = str(uuid.uuid4())
approval_reference = str(uuid.uuid4())
cur.execute(
"""
INSERT INTO device_approvals (
id, approval_reference, guid, hostname_claimed,
ssl_key_fingerprint_claimed, enrollment_code_id,
status, client_nonce, server_nonce, agent_pubkey_der,
created_at, updated_at
)
VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
""",
(
record_id,
approval_reference,
hostname,
fingerprint,
install_code["id"],
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
now,
now,
),
)
conn.commit()
finally:
conn.close()
response = {
"status": "pending",
"approval_reference": approval_reference,
"server_nonce": server_nonce_b64,
"poll_after_ms": 3000,
"server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(),
}
log("server", f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}")
return jsonify(response)
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
def enrollment_poll():
payload = request.get_json(force=True, silent=True) or {}
approval_reference = payload.get("approval_reference")
client_nonce_b64 = payload.get("client_nonce")
proof_sig_b64 = payload.get("proof_sig")
log(
"server",
"enrollment poll received "
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
f" proof_sig_len={len(proof_sig_b64 or '')}",
)
if not isinstance(approval_reference, str) or not approval_reference:
log("server", "enrollment poll rejected missing_reference")
return jsonify({"error": "approval_reference_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}")
return jsonify({"error": "client_nonce_required"}), 400
if not isinstance(proof_sig_b64, str):
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}")
return jsonify({"error": "proof_sig_required"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}")
return jsonify({"error": "invalid_client_nonce"}), 400
try:
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_sig ref={approval_reference}")
return jsonify({"error": "invalid_proof_sig"}), 400
conn = db_conn_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed,
enrollment_code_id, status, client_nonce, server_nonce,
agent_pubkey_der, created_at, updated_at, approved_by_user_id
FROM device_approvals
WHERE approval_reference = ?
""",
(approval_reference,),
)
row = cur.fetchone()
if not row:
log("server", f"enrollment poll unknown_reference ref={approval_reference}")
return jsonify({"status": "unknown"}), 404
(
record_id,
guid,
hostname_claimed,
fingerprint,
enrollment_code_id,
status,
client_nonce_stored,
server_nonce_b64,
agent_pubkey_der,
created_at,
updated_at,
approved_by,
) = row
if client_nonce_stored != client_nonce_b64:
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}")
return jsonify({"error": "nonce_mismatch"}), 400
try:
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}")
return jsonify({"error": "server_nonce_invalid"}), 400
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
try:
public_key = serialization.load_der_public_key(agent_pubkey_der)
except Exception:
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}")
public_key = None
if public_key is None:
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}")
return jsonify({"error": "agent_pubkey_invalid"}), 400
try:
public_key.verify(proof_sig, message)
except Exception:
log("server", f"enrollment poll invalid_proof ref={approval_reference}")
return jsonify({"error": "invalid_proof"}), 400
if status == "pending":
log(
"server",
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
f" fingerprint={fingerprint[:12]}",
)
return jsonify({"status": "pending", "poll_after_ms": 5000})
if status == "denied":
log(
"server",
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
)
return jsonify({"status": "denied", "reason": "operator_denied"})
if status == "expired":
log(
"server",
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
)
return jsonify({"status": "expired"})
if status == "completed":
log(
"server",
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
)
return jsonify({"status": "approved", "detail": "finalized"})
if status != "approved":
log(
"server",
f"enrollment poll unexpected_status={status} ref={approval_reference}",
)
return jsonify({"status": status or "unknown"}), 400
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
if not nonce_cache.consume(nonce_key):
log(
"server",
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
)
return jsonify({"error": "proof_replayed"}), 409
# Finalize enrollment
effective_guid = guid or str(uuid.uuid4())
now_iso = _iso(_now())
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
_store_device_key(cur, effective_guid, fingerprint)
# Mark install code used
if enrollment_code_id:
cur.execute(
"""
UPDATE enrollment_install_codes
SET used_at = ?, used_by_guid = ?
WHERE id = ?
AND used_at IS NULL
""",
(now_iso, effective_guid, enrollment_code_id),
)
# Update approval record with final state
cur.execute(
"""
UPDATE device_approvals
SET guid = ?,
status = 'completed',
updated_at = ?
WHERE id = ?
""",
(effective_guid, now_iso, record_id),
)
refresh_info = _issue_refresh_token(cur, effective_guid)
access_token = jwt_service.issue_access_token(
effective_guid,
fingerprint,
device_record.get("token_version") or 1,
)
conn.commit()
finally:
conn.close()
log(
"server",
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
)
return jsonify(
{
"status": "approved",
"guid": effective_guid,
"access_token": access_token,
"expires_in": 900,
"refresh_token": refresh_info["token"],
"token_type": "Bearer",
"server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(),
}
)
app.register_blueprint(blueprint)
def _load_tls_bundle(path: str) -> str:
try:
with open(path, "r", encoding="utf-8") as fh:
return fh.read()
except Exception:
return ""
def _mask_code(code: str) -> str:
if not code:
return "<missing>"
trimmed = str(code).strip()
if len(trimmed) <= 6:
return "***"
return f"{trimmed[:3]}***{trimmed[-3:]}"