Files
Borealis-Github-Replica/Data/Server/Modules/enrollment/routes.py

479 lines
17 KiB
Python

from __future__ import annotations
import base64
import secrets
import sqlite3
import uuid
from datetime import datetime, timezone, timedelta
import time
from typing import Any, Callable, Dict, Optional
from flask import Blueprint, jsonify, request
from Modules.auth.rate_limit import SlidingWindowRateLimiter
from Modules.crypto import keys as crypto_keys
from Modules.enrollment.nonce_store import NonceCache
from cryptography.hazmat.primitives import serialization
def register(
app,
*,
db_conn_factory: Callable[[], sqlite3.Connection],
log: Callable[[str, str], None],
jwt_service,
tls_bundle_path: str,
ip_rate_limiter: SlidingWindowRateLimiter,
fp_rate_limiter: SlidingWindowRateLimiter,
nonce_cache: NonceCache,
) -> None:
blueprint = Blueprint("enrollment", __name__)
def _now() -> datetime:
return datetime.now(tz=timezone.utc)
def _iso(dt: datetime) -> str:
return dt.isoformat()
def _remote_addr() -> str:
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
addr = request.remote_addr or "unknown"
return addr.strip()
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
decision = limiter.check(key, limit, window_s)
if not decision.allowed:
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
response.status_code = 429
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
return response
return None
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
cur.execute(
"SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?",
(code_value,),
)
row = cur.fetchone()
if not row:
return None
keys = ["id", "code", "expires_at", "used_at"]
record = dict(zip(keys, row))
return record
def _install_code_valid(record: Dict[str, Any]) -> bool:
if not record:
return False
expires_at = record.get("expires_at")
if not isinstance(expires_at, str):
return False
try:
expiry = datetime.fromisoformat(expires_at)
except Exception:
return False
if expiry <= _now():
return False
if record.get("used_at"):
return False
return True
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
base = (hostname or "").strip() or guid
base = base[:253]
candidate = base
suffix = 1
while True:
cur.execute(
"SELECT guid FROM devices WHERE hostname = ? AND guid != ?",
(candidate, guid),
)
row = cur.fetchone()
if not row:
return candidate
candidate = f"{base}-{suffix}"
suffix += 1
if suffix > 50:
return f"{guid}"
def _store_device_key(cur: sqlite3.Cursor, guid: str, fingerprint: str) -> None:
added_at = _iso(_now())
cur.execute(
"""
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
VALUES (?, ?, ?, ?)
""",
(str(uuid.uuid4()), guid, fingerprint, added_at),
)
cur.execute(
"""
UPDATE device_keys
SET retired_at = ?
WHERE guid = ?
AND ssl_key_fingerprint != ?
AND retired_at IS NULL
""",
(_iso(_now()), guid, fingerprint),
)
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
cur.execute(
"SELECT guid, hostname, token_version, status, ssl_key_fingerprint FROM devices WHERE guid = ?",
(guid,),
)
row = cur.fetchone()
if row:
keys = ["guid", "hostname", "token_version", "status", "ssl_key_fingerprint"]
record = dict(zip(keys, row))
if not record.get("ssl_key_fingerprint"):
cur.execute(
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
(fingerprint, _iso(_now()), guid),
)
record["ssl_key_fingerprint"] = fingerprint
return record
resolved_hostname = _normalize_host(hostname, guid, cur)
created_at = int(time.time())
key_added_at = _iso(_now())
cur.execute(
"""
INSERT INTO devices (
guid, hostname, created_at, last_seen, ssl_key_fingerprint,
token_version, status, key_added_at
)
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
""",
(
guid,
resolved_hostname,
created_at,
created_at,
fingerprint,
key_added_at,
),
)
return {
"guid": guid,
"hostname": resolved_hostname,
"token_version": 1,
"status": "active",
"ssl_key_fingerprint": fingerprint,
}
def _hash_refresh_token(token: str) -> str:
import hashlib
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _issue_refresh_token(cur: sqlite3.Cursor, guid: str) -> Dict[str, Any]:
token = secrets.token_urlsafe(48)
now = _now()
expires_at = now.replace(microsecond=0) + timedelta(days=30)
cur.execute(
"""
INSERT INTO refresh_tokens (id, guid, token_hash, created_at, expires_at)
VALUES (?, ?, ?, ?, ?)
""",
(
str(uuid.uuid4()),
guid,
_hash_refresh_token(token),
_iso(now),
_iso(expires_at),
),
)
return {"token": token, "expires_at": expires_at}
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
def enrollment_request():
remote = _remote_addr()
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 10, 60.0)
if rate_error:
return rate_error
payload = request.get_json(force=True, silent=True) or {}
hostname = str(payload.get("hostname") or "").strip()
enrollment_code = str(payload.get("enrollment_code") or "").strip()
agent_pubkey_b64 = payload.get("agent_pubkey")
client_nonce_b64 = payload.get("client_nonce")
if not hostname:
return jsonify({"error": "hostname_required"}), 400
if not enrollment_code:
return jsonify({"error": "enrollment_code_required"}), 400
if not isinstance(agent_pubkey_b64, str):
return jsonify({"error": "agent_pubkey_required"}), 400
if not isinstance(client_nonce_b64, str):
return jsonify({"error": "client_nonce_required"}), 400
try:
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
except Exception:
return jsonify({"error": "invalid_agent_pubkey"}), 400
if len(agent_pubkey_der) < 10:
return jsonify({"error": "invalid_agent_pubkey"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
return jsonify({"error": "invalid_client_nonce"}), 400
if len(client_nonce_bytes) < 16:
return jsonify({"error": "invalid_client_nonce"}), 400
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 3, 60.0)
if rate_error:
return rate_error
conn = db_conn_factory()
try:
cur = conn.cursor()
install_code = _load_install_code(cur, enrollment_code)
if not _install_code_valid(install_code):
return jsonify({"error": "invalid_enrollment_code"}), 400
approval_reference: str
record_id: str
server_nonce_bytes = secrets.token_bytes(32)
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
now = _iso(_now())
cur.execute(
"""
SELECT id, approval_reference
FROM device_approvals
WHERE ssl_key_fingerprint_claimed = ?
AND status = 'pending'
""",
(fingerprint,),
)
existing = cur.fetchone()
if existing:
record_id = existing[0]
approval_reference = existing[1]
cur.execute(
"""
UPDATE device_approvals
SET hostname_claimed = ?,
enrollment_code_id = ?,
client_nonce = ?,
server_nonce = ?,
agent_pubkey_der = ?,
updated_at = ?
WHERE id = ?
""",
(
hostname,
install_code["id"],
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
now,
record_id,
),
)
else:
record_id = str(uuid.uuid4())
approval_reference = str(uuid.uuid4())
cur.execute(
"""
INSERT INTO device_approvals (
id, approval_reference, guid, hostname_claimed,
ssl_key_fingerprint_claimed, enrollment_code_id,
status, client_nonce, server_nonce, agent_pubkey_der,
created_at, updated_at
)
VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
""",
(
record_id,
approval_reference,
hostname,
fingerprint,
install_code["id"],
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
now,
now,
),
)
conn.commit()
finally:
conn.close()
response = {
"status": "pending",
"approval_reference": approval_reference,
"server_nonce": server_nonce_b64,
"poll_after_ms": 3000,
"server_certificate": _load_tls_bundle(tls_bundle_path),
}
log("server", f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}")
return jsonify(response)
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
def enrollment_poll():
payload = request.get_json(force=True, silent=True) or {}
approval_reference = payload.get("approval_reference")
client_nonce_b64 = payload.get("client_nonce")
proof_sig_b64 = payload.get("proof_sig")
if not isinstance(approval_reference, str) or not approval_reference:
return jsonify({"error": "approval_reference_required"}), 400
if not isinstance(client_nonce_b64, str):
return jsonify({"error": "client_nonce_required"}), 400
if not isinstance(proof_sig_b64, str):
return jsonify({"error": "proof_sig_required"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
return jsonify({"error": "invalid_client_nonce"}), 400
try:
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
except Exception:
return jsonify({"error": "invalid_proof_sig"}), 400
conn = db_conn_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed,
enrollment_code_id, status, client_nonce, server_nonce,
agent_pubkey_der, created_at, updated_at, approved_by_user_id
FROM device_approvals
WHERE approval_reference = ?
""",
(approval_reference,),
)
row = cur.fetchone()
if not row:
return jsonify({"status": "unknown"}), 404
(
record_id,
guid,
hostname_claimed,
fingerprint,
enrollment_code_id,
status,
client_nonce_stored,
server_nonce_b64,
agent_pubkey_der,
created_at,
updated_at,
approved_by,
) = row
if client_nonce_stored != client_nonce_b64:
return jsonify({"error": "nonce_mismatch"}), 400
try:
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
except Exception:
return jsonify({"error": "server_nonce_invalid"}), 400
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
try:
public_key = serialization.load_der_public_key(agent_pubkey_der)
except Exception:
public_key = None
if public_key is None:
return jsonify({"error": "agent_pubkey_invalid"}), 400
try:
public_key.verify(proof_sig, message)
except Exception:
return jsonify({"error": "invalid_proof"}), 400
if status == "pending":
return jsonify({"status": "pending", "poll_after_ms": 5000})
if status == "denied":
return jsonify({"status": "denied", "reason": "operator_denied"})
if status == "expired":
return jsonify({"status": "expired"})
if status != "approved":
return jsonify({"status": status or "unknown"}), 400
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
if not nonce_cache.consume(nonce_key):
return jsonify({"error": "proof_replayed"}), 409
# Finalize enrollment
effective_guid = guid or str(uuid.uuid4())
now_iso = _iso(_now())
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
_store_device_key(cur, effective_guid, fingerprint)
# Mark install code used
if enrollment_code_id:
cur.execute(
"""
UPDATE enrollment_install_codes
SET used_at = ?, used_by_guid = ?
WHERE id = ?
AND used_at IS NULL
""",
(now_iso, effective_guid, enrollment_code_id),
)
# Update approval record with final state
cur.execute(
"""
UPDATE device_approvals
SET guid = ?,
status = 'completed',
updated_at = ?
WHERE id = ?
""",
(effective_guid, now_iso, record_id),
)
refresh_info = _issue_refresh_token(cur, effective_guid)
access_token = jwt_service.issue_access_token(
effective_guid,
fingerprint,
device_record.get("token_version") or 1,
)
conn.commit()
finally:
conn.close()
log(
"server",
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
)
return jsonify(
{
"status": "approved",
"guid": effective_guid,
"access_token": access_token,
"expires_in": 900,
"refresh_token": refresh_info["token"],
"token_type": "Bearer",
"server_certificate": _load_tls_bundle(tls_bundle_path),
}
)
app.register_blueprint(blueprint)
def _load_tls_bundle(path: str) -> str:
try:
with open(path, "r", encoding="utf-8") as fh:
return fh.read()
except Exception:
return ""