Removed Legacy Server Codebase

This commit is contained in:
2025-11-01 03:58:43 -06:00
parent bec43418c1
commit da37098d91
106 changed files with 6 additions and 36891 deletions

View File

@@ -1 +0,0 @@

View File

@@ -1 +0,0 @@

View File

@@ -1,496 +0,0 @@
from __future__ import annotations
import secrets
import sqlite3
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Callable, Dict, List, Optional
from flask import Blueprint, jsonify, request
from Modules.guid_utils import normalize_guid
VALID_TTL_HOURS = {1, 3, 6, 12, 24}
def register(
app,
*,
db_conn_factory: Callable[[], sqlite3.Connection],
require_admin: Callable[[], Optional[Any]],
current_user: Callable[[], Optional[Dict[str, str]]],
log: Callable[[str, str, Optional[str]], None],
) -> None:
blueprint = Blueprint("admin", __name__)
def _now() -> datetime:
return datetime.now(tz=timezone.utc)
def _iso(dt: datetime) -> str:
return dt.isoformat()
def _lookup_user_id(cur: sqlite3.Cursor, username: str) -> Optional[str]:
if not username:
return None
cur.execute(
"SELECT id FROM users WHERE LOWER(username) = LOWER(?)",
(username,),
)
row = cur.fetchone()
if row:
return str(row[0])
return None
def _hostname_conflict(
cur: sqlite3.Cursor,
hostname: Optional[str],
pending_guid: Optional[str],
) -> Optional[Dict[str, Any]]:
if not hostname:
return None
cur.execute(
"""
SELECT d.guid, d.ssl_key_fingerprint, ds.site_id, s.name
FROM devices d
LEFT JOIN device_sites ds ON ds.device_hostname = d.hostname
LEFT JOIN sites s ON s.id = ds.site_id
WHERE d.hostname = ?
""",
(hostname,),
)
row = cur.fetchone()
if not row:
return None
existing_guid = normalize_guid(row[0])
existing_fingerprint = (row[1] or "").strip().lower()
pending_norm = normalize_guid(pending_guid)
if existing_guid and pending_norm and existing_guid == pending_norm:
return None
site_id_raw = row[2]
site_id = None
if site_id_raw is not None:
try:
site_id = int(site_id_raw)
except (TypeError, ValueError):
site_id = None
site_name = row[3] or ""
return {
"guid": existing_guid or None,
"ssl_key_fingerprint": existing_fingerprint or None,
"site_id": site_id,
"site_name": site_name,
}
def _suggest_alternate_hostname(
cur: sqlite3.Cursor,
hostname: Optional[str],
pending_guid: Optional[str],
) -> Optional[str]:
base = (hostname or "").strip()
if not base:
return None
base = base[:253]
candidate = base
pending_norm = normalize_guid(pending_guid)
suffix = 1
while True:
cur.execute(
"SELECT guid FROM devices WHERE hostname = ?",
(candidate,),
)
row = cur.fetchone()
if not row:
return candidate
existing_guid = normalize_guid(row[0])
if pending_norm and existing_guid == pending_norm:
return candidate
candidate = f"{base}-{suffix}"
suffix += 1
if suffix > 50:
return pending_norm or candidate
@blueprint.before_request
def _check_admin():
result = require_admin()
if result is not None:
return result
return None
@blueprint.route("/api/admin/enrollment-codes", methods=["GET"])
def list_enrollment_codes():
status_filter = request.args.get("status")
conn = db_conn_factory()
try:
cur = conn.cursor()
sql = """
SELECT id,
code,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
"""
params: List[str] = []
now_iso = _iso(_now())
if status_filter == "active":
sql += " WHERE use_count < max_uses AND expires_at > ?"
params.append(now_iso)
elif status_filter == "expired":
sql += " WHERE use_count < max_uses AND expires_at <= ?"
params.append(now_iso)
elif status_filter == "used":
sql += " WHERE use_count >= max_uses"
sql += " ORDER BY expires_at ASC"
cur.execute(sql, params)
rows = cur.fetchall()
finally:
conn.close()
records = []
for row in rows:
records.append(
{
"id": row[0],
"code": row[1],
"expires_at": row[2],
"created_by_user_id": row[3],
"used_at": row[4],
"used_by_guid": row[5],
"max_uses": row[6],
"use_count": row[7],
"last_used_at": row[8],
}
)
return jsonify({"codes": records})
@blueprint.route("/api/admin/enrollment-codes", methods=["POST"])
def create_enrollment_code():
payload = request.get_json(force=True, silent=True) or {}
ttl_hours = int(payload.get("ttl_hours") or 1)
if ttl_hours not in VALID_TTL_HOURS:
return jsonify({"error": "invalid_ttl"}), 400
max_uses_value = payload.get("max_uses")
if max_uses_value is None:
max_uses_value = payload.get("allowed_uses")
try:
max_uses = int(max_uses_value)
except Exception:
max_uses = 2
if max_uses < 1:
max_uses = 1
if max_uses > 10:
max_uses = 10
user = current_user() or {}
username = user.get("username") or ""
conn = db_conn_factory()
try:
cur = conn.cursor()
created_by = _lookup_user_id(cur, username) or username or "system"
code_value = _generate_install_code()
issued_at = _now()
expires_at = issued_at + timedelta(hours=ttl_hours)
record_id = str(uuid.uuid4())
cur.execute(
"""
INSERT INTO enrollment_install_codes (
id, code, expires_at, created_by_user_id, max_uses, use_count
)
VALUES (?, ?, ?, ?, ?, 0)
""",
(record_id, code_value, _iso(expires_at), created_by, max_uses),
)
cur.execute(
"""
INSERT INTO enrollment_install_codes_persistent (
id,
code,
created_at,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
last_known_use_count,
last_used_at,
is_active,
archived_at,
consumed_at
)
VALUES (?, ?, ?, ?, ?, NULL, NULL, ?, 0, NULL, 1, NULL, NULL)
ON CONFLICT(id) DO UPDATE
SET code = excluded.code,
created_at = excluded.created_at,
expires_at = excluded.expires_at,
created_by_user_id = excluded.created_by_user_id,
max_uses = excluded.max_uses,
last_known_use_count = 0,
used_at = NULL,
used_by_guid = NULL,
last_used_at = NULL,
is_active = 1,
archived_at = NULL,
consumed_at = NULL
""",
(record_id, code_value, _iso(issued_at), _iso(expires_at), created_by, max_uses),
)
conn.commit()
finally:
conn.close()
log(
"server",
f"installer code created id={record_id} by={username} ttl={ttl_hours}h max_uses={max_uses}",
)
return jsonify(
{
"id": record_id,
"code": code_value,
"expires_at": _iso(expires_at),
"max_uses": max_uses,
"use_count": 0,
"last_used_at": None,
}
)
@blueprint.route("/api/admin/enrollment-codes/<code_id>", methods=["DELETE"])
def delete_enrollment_code(code_id: str):
conn = db_conn_factory()
try:
cur = conn.cursor()
cur.execute(
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
(code_id,),
)
deleted = cur.rowcount
if deleted:
archive_ts = _iso(_now())
cur.execute(
"""
UPDATE enrollment_install_codes_persistent
SET is_active = 0,
archived_at = COALESCE(archived_at, ?)
WHERE id = ?
""",
(archive_ts, code_id),
)
conn.commit()
finally:
conn.close()
if not deleted:
return jsonify({"error": "not_found"}), 404
log("server", f"installer code deleted id={code_id}")
return jsonify({"status": "deleted"})
@blueprint.route("/api/admin/device-approvals", methods=["GET"])
def list_device_approvals():
status_raw = request.args.get("status")
status = (status_raw or "").strip().lower()
approvals: List[Dict[str, Any]] = []
conn = db_conn_factory()
try:
cur = conn.cursor()
params: List[str] = []
sql = """
SELECT
da.id,
da.approval_reference,
da.guid,
da.hostname_claimed,
da.ssl_key_fingerprint_claimed,
da.enrollment_code_id,
da.status,
da.client_nonce,
da.server_nonce,
da.created_at,
da.updated_at,
da.approved_by_user_id,
u.username AS approved_by_username
FROM device_approvals AS da
LEFT JOIN users AS u
ON (
CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT)
OR LOWER(da.approved_by_user_id) = LOWER(u.username)
)
"""
if status and status != "all":
sql += " WHERE LOWER(da.status) = ?"
params.append(status)
sql += " ORDER BY da.created_at ASC"
cur.execute(sql, params)
rows = cur.fetchall()
for row in rows:
record_guid = row[2]
hostname = row[3]
fingerprint_claimed = row[4]
claimed_fp_norm = (fingerprint_claimed or "").strip().lower()
conflict_raw = _hostname_conflict(cur, hostname, record_guid)
fingerprint_match = False
requires_prompt = False
conflict = None
if conflict_raw:
conflict_fp = (conflict_raw.get("ssl_key_fingerprint") or "").strip().lower()
fingerprint_match = bool(conflict_fp and claimed_fp_norm) and conflict_fp == claimed_fp_norm
requires_prompt = not fingerprint_match
conflict = {
**conflict_raw,
"fingerprint_match": fingerprint_match,
"requires_prompt": requires_prompt,
}
alternate_hostname = (
_suggest_alternate_hostname(cur, hostname, record_guid)
if conflict_raw and requires_prompt
else None
)
approvals.append(
{
"id": row[0],
"approval_reference": row[1],
"guid": record_guid,
"hostname_claimed": hostname,
"ssl_key_fingerprint_claimed": fingerprint_claimed,
"enrollment_code_id": row[5],
"status": row[6],
"client_nonce": row[7],
"server_nonce": row[8],
"created_at": row[9],
"updated_at": row[10],
"approved_by_user_id": row[11],
"hostname_conflict": conflict,
"alternate_hostname": alternate_hostname,
"conflict_requires_prompt": requires_prompt,
"fingerprint_match": fingerprint_match,
"approved_by_username": row[12],
}
)
finally:
conn.close()
return jsonify({"approvals": approvals})
def _set_approval_status(
approval_id: str,
status: str,
*,
guid: Optional[str] = None,
resolution: Optional[str] = None,
):
user = current_user() or {}
username = user.get("username") or ""
conn = db_conn_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT status,
guid,
hostname_claimed,
ssl_key_fingerprint_claimed
FROM device_approvals
WHERE id = ?
""",
(approval_id,),
)
row = cur.fetchone()
if not row:
return {"error": "not_found"}, 404
existing_status = (row[0] or "").strip().lower()
if existing_status != "pending":
return {"error": "approval_not_pending"}, 409
stored_guid = row[1]
hostname_claimed = row[2]
fingerprint_claimed = (row[3] or "").strip().lower()
guid_effective = normalize_guid(guid) if guid else normalize_guid(stored_guid)
resolution_effective = (resolution.strip().lower() if isinstance(resolution, str) else None)
conflict = None
if status == "approved":
conflict = _hostname_conflict(cur, hostname_claimed, guid_effective)
if conflict:
conflict_fp = (conflict.get("ssl_key_fingerprint") or "").strip().lower()
fingerprint_match = bool(conflict_fp and fingerprint_claimed) and conflict_fp == fingerprint_claimed
if fingerprint_match:
guid_effective = conflict.get("guid") or guid_effective
if not resolution_effective:
resolution_effective = "auto_merge_fingerprint"
elif resolution_effective == "overwrite":
guid_effective = conflict.get("guid") or guid_effective
elif resolution_effective == "coexist":
pass
else:
return {
"error": "conflict_resolution_required",
"hostname": hostname_claimed,
}, 409
guid_to_store = guid_effective or normalize_guid(stored_guid) or None
approved_by = _lookup_user_id(cur, username) or username or "system"
cur.execute(
"""
UPDATE device_approvals
SET status = ?,
guid = ?,
approved_by_user_id = ?,
updated_at = ?
WHERE id = ?
""",
(
status,
guid_to_store,
approved_by,
_iso(_now()),
approval_id,
),
)
conn.commit()
finally:
conn.close()
resolution_note = f" ({resolution_effective})" if resolution_effective else ""
log("server", f"device approval {approval_id} -> {status}{resolution_note} by {username}")
payload: Dict[str, Any] = {"status": status}
if resolution_effective:
payload["conflict_resolution"] = resolution_effective
return payload, 200
@blueprint.route("/api/admin/device-approvals/<approval_id>/approve", methods=["POST"])
def approve_device(approval_id: str):
payload = request.get_json(force=True, silent=True) or {}
guid = payload.get("guid")
if guid:
guid = str(guid).strip()
resolution_val = payload.get("conflict_resolution")
resolution = None
if isinstance(resolution_val, str):
cleaned = resolution_val.strip().lower()
if cleaned:
resolution = cleaned
result, status_code = _set_approval_status(
approval_id,
"approved",
guid=guid,
resolution=resolution,
)
return jsonify(result), status_code
@blueprint.route("/api/admin/device-approvals/<approval_id>/deny", methods=["POST"])
def deny_device(approval_id: str):
result, status_code = _set_approval_status(approval_id, "denied")
return jsonify(result), status_code
app.register_blueprint(blueprint)
def _generate_install_code() -> str:
raw = secrets.token_hex(16).upper()
return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))

View File

@@ -1 +0,0 @@

View File

@@ -1,218 +0,0 @@
from __future__ import annotations
import json
import time
import sqlite3
from typing import Any, Callable, Dict, Optional
from flask import Blueprint, jsonify, request, g
from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
from Modules.crypto.signing import ScriptSigner
from Modules.guid_utils import normalize_guid
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
def register(
app,
*,
db_conn_factory: Callable[[], Any],
auth_manager: DeviceAuthManager,
log: Callable[[str, str, Optional[str]], None],
script_signer: ScriptSigner,
) -> None:
blueprint = Blueprint("agents", __name__)
def _json_or_none(value) -> Optional[str]:
if value is None:
return None
try:
return json.dumps(value)
except Exception:
return None
def _context_hint(ctx=None) -> Optional[str]:
if ctx is not None and getattr(ctx, "service_mode", None):
return _canonical_context(getattr(ctx, "service_mode", None))
return _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
def _auth_context():
ctx = getattr(g, "device_auth", None)
if ctx is None:
log("server", f"device auth context missing for {request.path}", _context_hint())
return ctx
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
@require_device_auth(auth_manager)
def heartbeat():
ctx = _auth_context()
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
payload = request.get_json(force=True, silent=True) or {}
context_label = _context_hint(ctx)
now_ts = int(time.time())
updates: Dict[str, Optional[str]] = {"last_seen": now_ts}
hostname = payload.get("hostname")
if isinstance(hostname, str) and hostname.strip():
updates["hostname"] = hostname.strip()
inventory = payload.get("inventory") if isinstance(payload.get("inventory"), dict) else {}
for key in ("memory", "network", "software", "storage", "cpu"):
if key in inventory and inventory[key] is not None:
encoded = _json_or_none(inventory[key])
if encoded is not None:
updates[key] = encoded
metrics = payload.get("metrics") if isinstance(payload.get("metrics"), dict) else {}
def _maybe_str(field: str) -> Optional[str]:
val = metrics.get(field)
if isinstance(val, str):
return val.strip()
return None
if "last_user" in metrics and metrics["last_user"]:
updates["last_user"] = str(metrics["last_user"])
if "operating_system" in metrics and metrics["operating_system"]:
updates["operating_system"] = str(metrics["operating_system"])
if "uptime" in metrics and metrics["uptime"] is not None:
try:
updates["uptime"] = int(metrics["uptime"])
except Exception:
pass
for field in ("external_ip", "internal_ip", "device_type"):
if field in payload and payload[field]:
updates[field] = str(payload[field])
conn = db_conn_factory()
try:
cur = conn.cursor()
def _apply_updates() -> int:
if not updates:
return 0
columns = ", ".join(f"{col} = ?" for col in updates.keys())
values = list(updates.values())
normalized_guid = normalize_guid(ctx.guid)
selected_guid: Optional[str] = None
if normalized_guid:
cur.execute(
"SELECT guid FROM devices WHERE UPPER(guid) = ?",
(normalized_guid,),
)
rows = cur.fetchall()
for (stored_guid,) in rows or []:
if stored_guid == ctx.guid:
selected_guid = stored_guid
break
if not selected_guid and rows:
selected_guid = rows[0][0]
target_guid = selected_guid or ctx.guid
cur.execute(
f"UPDATE devices SET {columns} WHERE guid = ?",
values + [target_guid],
)
updated = cur.rowcount
if updated > 0 and normalized_guid and target_guid != normalized_guid:
try:
cur.execute(
"UPDATE devices SET guid = ? WHERE guid = ?",
(normalized_guid, target_guid),
)
except sqlite3.IntegrityError:
pass
return updated
try:
rowcount = _apply_updates()
except sqlite3.IntegrityError as exc:
if "devices.hostname" in str(exc) and "UNIQUE" in str(exc).upper():
# Another device already claims this hostname; keep the existing
# canonical hostname assigned during enrollment to avoid breaking
# the unique constraint and continue updating the remaining fields.
existing_guid_for_hostname: Optional[str] = None
if "hostname" in updates:
try:
cur.execute(
"SELECT guid FROM devices WHERE hostname = ?",
(updates["hostname"],),
)
row = cur.fetchone()
if row and row[0]:
existing_guid_for_hostname = normalize_guid(row[0])
except Exception:
existing_guid_for_hostname = None
if "hostname" in updates:
updates.pop("hostname", None)
try:
rowcount = _apply_updates()
except sqlite3.IntegrityError:
raise
else:
try:
current_guid = normalize_guid(ctx.guid)
except Exception:
current_guid = ctx.guid
if (
existing_guid_for_hostname
and current_guid
and existing_guid_for_hostname == current_guid
):
pass # Same device contexts; no log needed.
else:
log(
"server",
"heartbeat hostname collision ignored for guid="
f"{ctx.guid}",
context_label,
)
else:
raise
if rowcount == 0:
log("server", f"heartbeat missing device record guid={ctx.guid}", context_label)
return jsonify({"error": "device_not_registered"}), 404
conn.commit()
finally:
conn.close()
return jsonify({"status": "ok", "poll_after_ms": 15000})
@blueprint.route("/api/agent/script/request", methods=["POST"])
@require_device_auth(auth_manager)
def script_request():
ctx = _auth_context()
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
if ctx.status != "active":
return jsonify(
{
"status": "quarantined",
"poll_after_ms": 60000,
"sig_alg": "ed25519",
"signing_key": script_signer.public_base64_spki(),
}
)
# Placeholder: actual dispatch logic will integrate with job scheduler.
return jsonify(
{
"status": "idle",
"poll_after_ms": 30000,
"sig_alg": "ed25519",
"signing_key": script_signer.public_base64_spki(),
}
)
app.register_blueprint(blueprint)

View File

@@ -1 +0,0 @@

View File

@@ -1,310 +0,0 @@
from __future__ import annotations
import functools
import sqlite3
import time
from contextlib import closing
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Optional
import jwt
from flask import g, jsonify, request
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
from Modules.auth.rate_limit import SlidingWindowRateLimiter
from Modules.guid_utils import normalize_guid
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
@dataclass
class DeviceAuthContext:
guid: str
ssl_key_fingerprint: str
token_version: int
access_token: str
claims: Dict[str, Any]
dpop_jkt: Optional[str]
status: str
service_mode: Optional[str]
class DeviceAuthError(Exception):
status_code = 401
error_code = "unauthorized"
def __init__(
self,
message: str = "unauthorized",
*,
status_code: Optional[int] = None,
retry_after: Optional[float] = None,
):
super().__init__(message)
if status_code is not None:
self.status_code = status_code
self.message = message
self.retry_after = retry_after
class DeviceAuthManager:
def __init__(
self,
*,
db_conn_factory: Callable[[], Any],
jwt_service,
dpop_validator: Optional[DPoPValidator],
log: Callable[[str, str, Optional[str]], None],
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
) -> None:
self._db_conn_factory = db_conn_factory
self._jwt_service = jwt_service
self._dpop_validator = dpop_validator
self._log = log
self._rate_limiter = rate_limiter
def authenticate(self) -> DeviceAuthContext:
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise DeviceAuthError("missing_authorization")
token = auth_header[len("Bearer ") :].strip()
if not token:
raise DeviceAuthError("missing_authorization")
try:
claims = self._jwt_service.decode(token)
except jwt.ExpiredSignatureError:
raise DeviceAuthError("token_expired")
except Exception:
raise DeviceAuthError("invalid_token")
raw_guid = str(claims.get("guid") or "").strip()
guid = normalize_guid(raw_guid)
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
token_version = int(claims.get("token_version") or 0)
if not guid or not fingerprint or token_version <= 0:
raise DeviceAuthError("invalid_claims")
if self._rate_limiter:
decision = self._rate_limiter.check(f"fp:{fingerprint}", 60, 60.0)
if not decision.allowed:
raise DeviceAuthError(
"rate_limited",
status_code=429,
retry_after=decision.retry_after,
)
context_label = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
with closing(self._db_conn_factory()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE UPPER(guid) = ?
""",
(guid,),
)
rows = cur.fetchall()
row = None
for candidate in rows or []:
candidate_guid = normalize_guid(candidate[0])
if candidate_guid == guid:
row = candidate
break
if row is None and rows:
row = rows[0]
if not row:
row = self._recover_device_record(
conn, guid, fingerprint, token_version, context_label
)
if not row:
raise DeviceAuthError("device_not_found", status_code=403)
db_guid, db_fp, db_token_version, status = row
db_guid_normalized = normalize_guid(db_guid)
if not db_guid_normalized or db_guid_normalized != guid:
raise DeviceAuthError("device_guid_mismatch", status_code=403)
db_fp = (db_fp or "").lower().strip()
if db_fp and db_fp != fingerprint:
raise DeviceAuthError("fingerprint_mismatch", status_code=403)
if db_token_version and db_token_version > token_version:
raise DeviceAuthError("token_version_revoked", status_code=401)
status_normalized = (status or "active").strip().lower()
allowed_statuses = {"active", "quarantined"}
if status_normalized not in allowed_statuses:
raise DeviceAuthError("device_revoked", status_code=403)
if status_normalized == "quarantined":
self._log(
"server",
f"device {guid} is quarantined; limited access for {request.path}",
context_label,
)
dpop_jkt: Optional[str] = None
dpop_proof = request.headers.get("DPoP")
if dpop_proof:
if not self._dpop_validator:
raise DeviceAuthError("dpop_not_supported", status_code=400)
try:
htu = request.url
dpop_jkt = self._dpop_validator.verify(request.method, htu, dpop_proof, token)
except DPoPReplayError:
raise DeviceAuthError("dpop_replayed", status_code=400)
except DPoPVerificationError:
raise DeviceAuthError("dpop_invalid", status_code=400)
ctx = DeviceAuthContext(
guid=guid,
ssl_key_fingerprint=fingerprint,
token_version=token_version,
access_token=token,
claims=claims,
dpop_jkt=dpop_jkt,
status=status_normalized,
service_mode=context_label,
)
return ctx
def _recover_device_record(
self,
conn: sqlite3.Connection,
guid: str,
fingerprint: str,
token_version: int,
context_label: Optional[str],
) -> Optional[tuple]:
"""Attempt to recreate a missing device row for an authenticated token."""
guid = normalize_guid(guid)
fingerprint = (fingerprint or "").strip()
if not guid or not fingerprint:
return None
cur = conn.cursor()
now_ts = int(time.time())
try:
now_iso = datetime.now(tz=timezone.utc).isoformat()
except Exception:
now_iso = datetime.utcnow().isoformat() # pragma: no cover
base_hostname = f"RECOVERED-{guid[:12].upper()}" if guid else "RECOVERED"
for attempt in range(6):
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
try:
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
token_version,
status,
key_added_at
)
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
""",
(
guid,
hostname,
now_ts,
now_ts,
fingerprint,
max(token_version or 1, 1),
now_iso,
),
)
except sqlite3.IntegrityError as exc:
# Hostname collision try again with a suffixed placeholder.
message = str(exc).lower()
if "hostname" in message and "unique" in message:
continue
self._log(
"server",
f"device auth failed to recover guid={guid} due to integrity error: {exc}",
context_label,
)
conn.rollback()
return None
except Exception as exc: # pragma: no cover - defensive logging
self._log(
"server",
f"device auth unexpected error recovering guid={guid}: {exc}",
context_label,
)
conn.rollback()
return None
else:
conn.commit()
break
else:
# Exhausted attempts because of hostname collisions.
self._log(
"server",
f"device auth could not recover guid={guid}; hostname collisions persisted",
context_label,
)
conn.rollback()
return None
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE guid = ?
""",
(guid,),
)
row = cur.fetchone()
if not row:
self._log(
"server",
f"device auth recovery for guid={guid} committed but row still missing",
context_label,
)
return row
def require_device_auth(manager: DeviceAuthManager):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
ctx = manager.authenticate()
except DeviceAuthError as exc:
response = jsonify({"error": exc.message})
response.status_code = exc.status_code
retry_after = getattr(exc, "retry_after", None)
if retry_after:
try:
response.headers["Retry-After"] = str(max(1, int(retry_after)))
except Exception:
response.headers["Retry-After"] = "1"
return response
g.device_auth = ctx
return func(*args, **kwargs)
return wrapper
return decorator

View File

@@ -1,109 +0,0 @@
"""
DPoP proof verification helpers.
"""
from __future__ import annotations
import hashlib
import time
from threading import Lock
from typing import Dict, Optional
import jwt
_DP0P_MAX_SKEW = 300.0 # seconds
class DPoPVerificationError(Exception):
pass
class DPoPReplayError(DPoPVerificationError):
pass
class DPoPValidator:
def __init__(self) -> None:
self._observed_jti: Dict[str, float] = {}
self._lock = Lock()
def verify(
self,
method: str,
htu: str,
proof: str,
access_token: Optional[str] = None,
) -> str:
"""
Verify the presented DPoP proof. Returns the JWK thumbprint on success.
"""
if not proof:
raise DPoPVerificationError("DPoP proof missing")
try:
header = jwt.get_unverified_header(proof)
except Exception as exc:
raise DPoPVerificationError("invalid DPoP header") from exc
jwk = header.get("jwk")
alg = header.get("alg")
if not jwk or not isinstance(jwk, dict):
raise DPoPVerificationError("missing jwk in DPoP header")
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
try:
key = jwt.PyJWK(jwk)
public_key = key.key
except Exception as exc:
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
try:
claims = jwt.decode(
proof,
public_key,
algorithms=[alg],
options={"require": ["htm", "htu", "jti", "iat"]},
)
except Exception as exc:
raise DPoPVerificationError("invalid DPoP signature") from exc
htm = claims.get("htm")
proof_htu = claims.get("htu")
jti = claims.get("jti")
iat = claims.get("iat")
ath = claims.get("ath")
if not isinstance(htm, str) or htm.lower() != method.lower():
raise DPoPVerificationError("DPoP htm mismatch")
if not isinstance(proof_htu, str) or proof_htu != htu:
raise DPoPVerificationError("DPoP htu mismatch")
if not isinstance(jti, str):
raise DPoPVerificationError("DPoP jti missing")
if not isinstance(iat, (int, float)):
raise DPoPVerificationError("DPoP iat missing")
now = time.time()
if abs(now - float(iat)) > _DP0P_MAX_SKEW:
raise DPoPVerificationError("DPoP proof outside allowed skew")
if ath and access_token:
expected_ath = jwt.utils.base64url_encode(
hashlib.sha256(access_token.encode("utf-8")).digest()
).decode("ascii")
if expected_ath != ath:
raise DPoPVerificationError("DPoP ath mismatch")
with self._lock:
expiry = self._observed_jti.get(jti)
if expiry and expiry > now:
raise DPoPReplayError("DPoP proof replay detected")
self._observed_jti[jti] = now + _DP0P_MAX_SKEW
# Opportunistic cleanup
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
for key in stale:
self._observed_jti.pop(key, None)
thumbprint = jwt.PyJWK(jwk).thumbprint()
return thumbprint.decode("ascii")

View File

@@ -1,140 +0,0 @@
"""
JWT access-token helpers backed by an Ed25519 signing key.
"""
from __future__ import annotations
import hashlib
import time
from datetime import datetime, timezone
from typing import Any, Dict, Optional
import jwt
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from Modules.runtime import ensure_runtime_dir, runtime_path
_KEY_DIR = runtime_path("auth_keys")
_KEY_FILE = _KEY_DIR / "borealis-jwt-ed25519.key"
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-jwt-ed25519.key"
class JWTService:
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str):
self._private_key = private_key
self._public_key = private_key.public_key()
self._key_id = key_id
@property
def key_id(self) -> str:
return self._key_id
def issue_access_token(
self,
guid: str,
ssl_key_fingerprint: str,
token_version: int,
expires_in: int = 900,
extra_claims: Optional[Dict[str, Any]] = None,
) -> str:
now = int(time.time())
payload: Dict[str, Any] = {
"sub": f"device:{guid}",
"guid": guid,
"ssl_key_fingerprint": ssl_key_fingerprint,
"token_version": int(token_version),
"iat": now,
"nbf": now,
"exp": now + int(expires_in),
}
if extra_claims:
payload.update(extra_claims)
token = jwt.encode(
payload,
self._private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
),
algorithm="EdDSA",
headers={"kid": self._key_id},
)
return token
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
options = {"require": ["exp", "iat", "sub"]}
public_pem = self._public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return jwt.decode(
token,
public_pem,
algorithms=["EdDSA"],
audience=audience,
options=options,
)
def public_jwk(self) -> Dict[str, Any]:
public_bytes = self._public_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
# PyJWT expects base64url without padding.
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
def load_service() -> JWTService:
private_key = _load_or_create_private_key()
public_bytes = private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
return JWTService(private_key, key_id)
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
ensure_runtime_dir("auth_keys")
_migrate_legacy_key_if_present()
if _KEY_FILE.exists():
with _KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
if _LEGACY_KEY_FILE.exists():
with _LEGACY_KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
private_key = ed25519.Ed25519PrivateKey.generate()
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
with _KEY_FILE.open("wb") as fh:
fh.write(pem)
try:
if _KEY_FILE.exists() and hasattr(_KEY_FILE, "chmod"):
_KEY_FILE.chmod(0o600)
except Exception:
pass
return private_key
def _migrate_legacy_key_if_present() -> None:
if not _LEGACY_KEY_FILE.exists() or _KEY_FILE.exists():
return
try:
ensure_runtime_dir("auth_keys")
try:
_LEGACY_KEY_FILE.replace(_KEY_FILE)
except Exception:
_KEY_FILE.write_bytes(_LEGACY_KEY_FILE.read_bytes())
except Exception:
return

View File

@@ -1,41 +0,0 @@
"""
Tiny in-memory rate limiter suitable for single-process development servers.
"""
from __future__ import annotations
import time
from collections import deque
from dataclasses import dataclass
from threading import Lock
from typing import Deque, Dict, Tuple
@dataclass
class RateLimitDecision:
allowed: bool
retry_after: float
class SlidingWindowRateLimiter:
def __init__(self) -> None:
self._buckets: Dict[str, Deque[float]] = {}
self._lock = Lock()
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
now = time.monotonic()
with self._lock:
bucket = self._buckets.get(key)
if bucket is None:
bucket = deque()
self._buckets[key] = bucket
while bucket and now - bucket[0] > window_seconds:
bucket.popleft()
if len(bucket) >= limit:
retry_after = max(0.0, window_seconds - (now - bucket[0]))
return RateLimitDecision(False, retry_after)
bucket.append(now)
return RateLimitDecision(True, 0.0)

View File

@@ -1 +0,0 @@

View File

@@ -1,372 +0,0 @@
"""
Server TLS certificate management.
Borealis now issues a dedicated root CA and a leaf server certificate so that
agents can pin the CA without requiring a re-enrollment every time the server
certificate is refreshed. The CA is persisted alongside the server key so that
existing deployments can be upgraded in-place.
"""
from __future__ import annotations
import os
import ssl
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional, Tuple
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
from Modules.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path
_CERT_DIR = server_certificates_path()
_CERT_FILE = _CERT_DIR / "borealis-server-cert.pem"
_KEY_FILE = _CERT_DIR / "borealis-server-key.pem"
_BUNDLE_FILE = _CERT_DIR / "borealis-server-bundle.pem"
_CA_KEY_FILE = _CERT_DIR / "borealis-root-ca-key.pem"
_CA_CERT_FILE = _CERT_DIR / "borealis-root-ca.pem"
_LEGACY_CERT_DIR = runtime_path("certs")
_LEGACY_CERT_FILE = _LEGACY_CERT_DIR / "borealis-server-cert.pem"
_LEGACY_KEY_FILE = _LEGACY_CERT_DIR / "borealis-server-key.pem"
_LEGACY_BUNDLE_FILE = _LEGACY_CERT_DIR / "borealis-server-bundle.pem"
_ROOT_COMMON_NAME = "Borealis Root CA"
_ORG_NAME = "Borealis"
_ROOT_VALIDITY = timedelta(days=365 * 100)
_SERVER_VALIDITY = timedelta(days=365 * 5)
def ensure_certificate(common_name: str = "Borealis Server") -> Tuple[Path, Path, Path]:
"""
Ensure the root CA, server certificate, and bundle exist on disk.
Returns (cert_path, key_path, bundle_path).
"""
ensure_server_certificates_dir()
_migrate_legacy_material_if_present()
ca_key, ca_cert, ca_regenerated = _ensure_root_ca()
server_cert = _load_certificate(_CERT_FILE)
needs_regen = ca_regenerated or _server_certificate_needs_regeneration(server_cert, ca_cert)
if needs_regen:
server_cert = _generate_server_certificate(common_name, ca_key, ca_cert)
if server_cert is None:
server_cert = _generate_server_certificate(common_name, ca_key, ca_cert)
_write_bundle(server_cert, ca_cert)
return _CERT_FILE, _KEY_FILE, _BUNDLE_FILE
def _migrate_legacy_material_if_present() -> None:
# Promote legacy runtime certificates (Server/Borealis/certs) into the new location.
if not _CERT_FILE.exists() or not _KEY_FILE.exists():
legacy_cert = _LEGACY_CERT_FILE
legacy_key = _LEGACY_KEY_FILE
if legacy_cert.exists() and legacy_key.exists():
try:
ensure_server_certificates_dir()
if not _CERT_FILE.exists():
_safe_copy(legacy_cert, _CERT_FILE)
if not _KEY_FILE.exists():
_safe_copy(legacy_key, _KEY_FILE)
except Exception:
pass
def _ensure_root_ca() -> Tuple[ec.EllipticCurvePrivateKey, x509.Certificate, bool]:
regenerated = False
ca_key: Optional[ec.EllipticCurvePrivateKey] = None
ca_cert: Optional[x509.Certificate] = None
if _CA_KEY_FILE.exists() and _CA_CERT_FILE.exists():
try:
ca_key = _load_private_key(_CA_KEY_FILE)
ca_cert = _load_certificate(_CA_CERT_FILE)
if ca_cert is not None and ca_key is not None:
expiry = _cert_not_after(ca_cert)
subject = ca_cert.subject
subject_cn = ""
try:
subject_cn = subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value # type: ignore[index]
except Exception:
subject_cn = ""
try:
basic = ca_cert.extensions.get_extension_for_class(x509.BasicConstraints).value # type: ignore[attr-defined]
is_ca = bool(basic.ca)
except Exception:
is_ca = False
if (
expiry <= datetime.now(tz=timezone.utc)
or not is_ca
or subject_cn != _ROOT_COMMON_NAME
):
regenerated = True
else:
regenerated = True
except Exception:
regenerated = True
else:
regenerated = True
if regenerated or ca_key is None or ca_cert is None:
ca_key = ec.generate_private_key(ec.SECP384R1())
public_key = ca_key.public_key()
now = datetime.now(tz=timezone.utc)
builder = (
x509.CertificateBuilder()
.subject_name(
x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, _ROOT_COMMON_NAME),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
]
)
)
.issuer_name(
x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, _ROOT_COMMON_NAME),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
]
)
)
.public_key(public_key)
.serial_number(x509.random_serial_number())
.not_valid_before(now - timedelta(minutes=5))
.not_valid_after(now + _ROOT_VALIDITY)
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(public_key),
critical=False,
)
)
builder = builder.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_public_key(public_key),
critical=False,
)
ca_cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA384())
_CA_KEY_FILE.write_bytes(
ca_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
)
_CA_CERT_FILE.write_bytes(ca_cert.public_bytes(serialization.Encoding.PEM))
_tighten_permissions(_CA_KEY_FILE)
_tighten_permissions(_CA_CERT_FILE)
else:
regenerated = False
return ca_key, ca_cert, regenerated
def _server_certificate_needs_regeneration(
server_cert: Optional[x509.Certificate],
ca_cert: x509.Certificate,
) -> bool:
if server_cert is None:
return True
try:
if server_cert.issuer != ca_cert.subject:
return True
except Exception:
return True
try:
expiry = _cert_not_after(server_cert)
if expiry <= datetime.now(tz=timezone.utc):
return True
except Exception:
return True
try:
basic = server_cert.extensions.get_extension_for_class(x509.BasicConstraints).value # type: ignore[attr-defined]
if basic.ca:
return True
except Exception:
return True
try:
eku = server_cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value # type: ignore[attr-defined]
if ExtendedKeyUsageOID.SERVER_AUTH not in eku:
return True
except Exception:
return True
return False
def _generate_server_certificate(
common_name: str,
ca_key: ec.EllipticCurvePrivateKey,
ca_cert: x509.Certificate,
) -> x509.Certificate:
private_key = ec.generate_private_key(ec.SECP384R1())
public_key = private_key.public_key()
now = datetime.now(tz=timezone.utc)
ca_expiry = _cert_not_after(ca_cert)
candidate_expiry = now + _SERVER_VALIDITY
not_after = min(ca_expiry - timedelta(days=1), candidate_expiry)
builder = (
x509.CertificateBuilder()
.subject_name(
x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
]
)
)
.issuer_name(ca_cert.subject)
.public_key(public_key)
.serial_number(x509.random_serial_number())
.not_valid_before(now - timedelta(minutes=5))
.not_valid_after(not_after)
.add_extension(
x509.SubjectAlternativeName(
[
x509.DNSName("localhost"),
x509.DNSName("127.0.0.1"),
x509.DNSName("::1"),
]
),
critical=False,
)
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]),
critical=False,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(public_key),
critical=False,
)
.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_key.public_key()),
critical=False,
)
)
certificate = builder.sign(private_key=ca_key, algorithm=hashes.SHA384())
_KEY_FILE.write_bytes(
private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
)
_CERT_FILE.write_bytes(certificate.public_bytes(serialization.Encoding.PEM))
_tighten_permissions(_KEY_FILE)
_tighten_permissions(_CERT_FILE)
return certificate
def _write_bundle(server_cert: x509.Certificate, ca_cert: x509.Certificate) -> None:
try:
server_pem = server_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip()
ca_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip()
except Exception:
return
bundle = f"{server_pem}\n{ca_pem}\n"
_BUNDLE_FILE.write_text(bundle, encoding="utf-8")
_tighten_permissions(_BUNDLE_FILE)
def _safe_copy(src: Path, dst: Path) -> None:
try:
dst.write_bytes(src.read_bytes())
except Exception:
pass
def _tighten_permissions(path: Path) -> None:
try:
if os.name == "posix":
path.chmod(0o600)
except Exception:
pass
def _load_private_key(path: Path) -> ec.EllipticCurvePrivateKey:
with path.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
def _load_certificate(path: Path) -> Optional[x509.Certificate]:
try:
return x509.load_pem_x509_certificate(path.read_bytes())
except Exception:
return None
def _cert_not_after(cert: x509.Certificate) -> datetime:
try:
return cert.not_valid_after_utc # type: ignore[attr-defined]
except AttributeError:
value = cert.not_valid_after
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value
def build_ssl_context() -> ssl.SSLContext:
cert_path, key_path, bundle_path = ensure_certificate()
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.minimum_version = ssl.TLSVersion.TLSv1_3
context.load_cert_chain(certfile=str(bundle_path), keyfile=str(key_path))
return context
def certificate_paths() -> Tuple[str, str, str]:
cert_path, key_path, bundle_path = ensure_certificate()
return str(cert_path), str(key_path), str(bundle_path)

View File

@@ -1,71 +0,0 @@
"""
Utility helpers for working with Ed25519 keys and fingerprints.
"""
from __future__ import annotations
import base64
import hashlib
import re
from typing import Tuple
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import load_der_public_key
from cryptography.hazmat.primitives.asymmetric import ed25519
def generate_ed25519_keypair() -> Tuple[ed25519.Ed25519PrivateKey, bytes]:
"""
Generate a new Ed25519 keypair.
Returns the private key object and the public key encoded as SubjectPublicKeyInfo DER bytes.
"""
private_key = ed25519.Ed25519PrivateKey.generate()
public_key = private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return private_key, public_key
def normalize_base64(data: str) -> str:
"""
Collapse whitespace and normalise URL-safe encodings so we can reliably decode.
"""
cleaned = re.sub(r"\\s+", "", data or "")
return cleaned.replace("-", "+").replace("_", "/")
def spki_der_from_base64(spki_b64: str) -> bytes:
return base64.b64decode(normalize_base64(spki_b64), validate=True)
def base64_from_spki_der(spki_der: bytes) -> str:
return base64.b64encode(spki_der).decode("ascii")
def fingerprint_from_spki_der(spki_der: bytes) -> str:
digest = hashlib.sha256(spki_der).hexdigest()
return digest.lower()
def fingerprint_from_base64_spki(spki_b64: str) -> str:
return fingerprint_from_spki_der(spki_der_from_base64(spki_b64))
def private_key_to_pem(private_key: ed25519.Ed25519PrivateKey) -> bytes:
return private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
def public_key_to_pem(public_spki_der: bytes) -> bytes:
public_key = load_der_public_key(public_spki_der)
return public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)

View File

@@ -1,125 +0,0 @@
"""
Code-signing helpers for delivering scripts to agents.
"""
from __future__ import annotations
from pathlib import Path
from typing import Tuple
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from Modules.runtime import (
ensure_server_certificates_dir,
server_certificates_path,
runtime_path,
)
from .keys import base64_from_spki_der
_KEY_DIR = server_certificates_path("Code-Signing")
_SIGNING_KEY_FILE = _KEY_DIR / "borealis-script-ed25519.key"
_SIGNING_PUB_FILE = _KEY_DIR / "borealis-script-ed25519.pub"
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-script-ed25519.key"
_LEGACY_PUB_FILE = runtime_path("keys") / "borealis-script-ed25519.pub"
_OLD_RUNTIME_KEY_DIR = runtime_path("script_signing_keys")
_OLD_RUNTIME_KEY_FILE = _OLD_RUNTIME_KEY_DIR / "borealis-script-ed25519.key"
_OLD_RUNTIME_PUB_FILE = _OLD_RUNTIME_KEY_DIR / "borealis-script-ed25519.pub"
class ScriptSigner:
def __init__(self, private_key: ed25519.Ed25519PrivateKey):
self._private = private_key
self._public = private_key.public_key()
def sign(self, payload: bytes) -> bytes:
return self._private.sign(payload)
def public_spki_der(self) -> bytes:
return self._public.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
def public_base64_spki(self) -> str:
return base64_from_spki_der(self.public_spki_der())
def load_signer() -> ScriptSigner:
private_key = _load_or_create()
return ScriptSigner(private_key)
def _load_or_create() -> ed25519.Ed25519PrivateKey:
ensure_server_certificates_dir("Code-Signing")
_migrate_legacy_material_if_present()
if _SIGNING_KEY_FILE.exists():
with _SIGNING_KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
if _LEGACY_KEY_FILE.exists():
with _LEGACY_KEY_FILE.open("rb") as fh:
return serialization.load_pem_private_key(fh.read(), password=None)
private_key = ed25519.Ed25519PrivateKey.generate()
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
with _SIGNING_KEY_FILE.open("wb") as fh:
fh.write(pem)
try:
if hasattr(_SIGNING_KEY_FILE, "chmod"):
_SIGNING_KEY_FILE.chmod(0o600)
except Exception:
pass
pub_der = private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
_SIGNING_PUB_FILE.write_bytes(pub_der)
return private_key
def _migrate_legacy_material_if_present() -> None:
if _SIGNING_KEY_FILE.exists():
return
# First migrate from legacy runtime path embedded in Server runtime.
try:
if _OLD_RUNTIME_KEY_FILE.exists() and not _SIGNING_KEY_FILE.exists():
ensure_server_certificates_dir("Code-Signing")
try:
_OLD_RUNTIME_KEY_FILE.replace(_SIGNING_KEY_FILE)
except Exception:
_SIGNING_KEY_FILE.write_bytes(_OLD_RUNTIME_KEY_FILE.read_bytes())
if _OLD_RUNTIME_PUB_FILE.exists() and not _SIGNING_PUB_FILE.exists():
try:
_OLD_RUNTIME_PUB_FILE.replace(_SIGNING_PUB_FILE)
except Exception:
_SIGNING_PUB_FILE.write_bytes(_OLD_RUNTIME_PUB_FILE.read_bytes())
except Exception:
pass
if not _LEGACY_KEY_FILE.exists() or _SIGNING_KEY_FILE.exists():
return
try:
ensure_server_certificates_dir("Code-Signing")
try:
_LEGACY_KEY_FILE.replace(_SIGNING_KEY_FILE)
except Exception:
_SIGNING_KEY_FILE.write_bytes(_LEGACY_KEY_FILE.read_bytes())
if _LEGACY_PUB_FILE.exists() and not _SIGNING_PUB_FILE.exists():
try:
_LEGACY_PUB_FILE.replace(_SIGNING_PUB_FILE)
except Exception:
_SIGNING_PUB_FILE.write_bytes(_LEGACY_PUB_FILE.read_bytes())
except Exception:
return

View File

@@ -1,488 +0,0 @@
"""
Database migration helpers for Borealis.
This module centralises schema evolution so the main server module can stay
focused on request handling. The migration functions are intentionally
idempotent — they can run repeatedly without changing state once the schema
matches the desired shape.
"""
from __future__ import annotations
import sqlite3
import uuid
from datetime import datetime, timezone
from typing import List, Optional, Sequence, Tuple
DEVICE_TABLE = "devices"
def apply_all(conn: sqlite3.Connection) -> None:
"""
Run all known schema migrations against the provided sqlite3 connection.
"""
_ensure_devices_table(conn)
_ensure_device_aux_tables(conn)
_ensure_refresh_token_table(conn)
_ensure_install_code_table(conn)
_ensure_install_code_persistence_table(conn)
_ensure_device_approval_table(conn)
conn.commit()
def _ensure_devices_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
if not _table_exists(cur, DEVICE_TABLE):
_create_devices_table(cur)
return
column_info = _table_info(cur, DEVICE_TABLE)
col_names = [c[1] for c in column_info]
pk_cols = [c[1] for c in column_info if c[5]]
needs_rebuild = pk_cols != ["guid"]
required_columns = {
"guid": "TEXT",
"hostname": "TEXT",
"description": "TEXT",
"created_at": "INTEGER",
"agent_hash": "TEXT",
"memory": "TEXT",
"network": "TEXT",
"software": "TEXT",
"storage": "TEXT",
"cpu": "TEXT",
"device_type": "TEXT",
"domain": "TEXT",
"external_ip": "TEXT",
"internal_ip": "TEXT",
"last_reboot": "TEXT",
"last_seen": "INTEGER",
"last_user": "TEXT",
"operating_system": "TEXT",
"uptime": "INTEGER",
"agent_id": "TEXT",
"ansible_ee_ver": "TEXT",
"connection_type": "TEXT",
"connection_endpoint": "TEXT",
"ssl_key_fingerprint": "TEXT",
"token_version": "INTEGER",
"status": "TEXT",
"key_added_at": "TEXT",
}
missing_columns = [col for col in required_columns if col not in col_names]
if missing_columns:
needs_rebuild = True
if needs_rebuild:
_rebuild_devices_table(conn, column_info)
else:
_ensure_column_defaults(cur)
_ensure_device_indexes(cur)
def _ensure_device_aux_tables(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS device_keys (
id TEXT PRIMARY KEY,
guid TEXT NOT NULL,
ssl_key_fingerprint TEXT NOT NULL,
added_at TEXT NOT NULL,
retired_at TEXT
)
"""
)
cur.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS uq_device_keys_guid_fingerprint
ON device_keys(guid, ssl_key_fingerprint)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_device_keys_guid
ON device_keys(guid)
"""
)
def _ensure_refresh_token_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS refresh_tokens (
id TEXT PRIMARY KEY,
guid TEXT NOT NULL,
token_hash TEXT NOT NULL,
dpop_jkt TEXT,
created_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
revoked_at TEXT,
last_used_at TEXT
)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_guid
ON refresh_tokens(guid)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at
ON refresh_tokens(expires_at)
"""
)
def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS enrollment_install_codes (
id TEXT PRIMARY KEY,
code TEXT NOT NULL UNIQUE,
expires_at TEXT NOT NULL,
created_by_user_id TEXT,
used_at TEXT,
used_by_guid TEXT,
max_uses INTEGER NOT NULL DEFAULT 1,
use_count INTEGER NOT NULL DEFAULT 0,
last_used_at TEXT
)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_eic_expires_at
ON enrollment_install_codes(expires_at)
"""
)
columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")}
if "max_uses" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1
"""
)
if "use_count" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0
"""
)
if "last_used_at" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN last_used_at TEXT
"""
)
def _ensure_install_code_persistence_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS enrollment_install_codes_persistent (
id TEXT PRIMARY KEY,
code TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
created_by_user_id TEXT,
used_at TEXT,
used_by_guid TEXT,
max_uses INTEGER NOT NULL DEFAULT 1,
last_known_use_count INTEGER NOT NULL DEFAULT 0,
last_used_at TEXT,
is_active INTEGER NOT NULL DEFAULT 1,
archived_at TEXT,
consumed_at TEXT
)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_eicp_active
ON enrollment_install_codes_persistent(is_active, expires_at)
"""
)
cur.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS uq_eicp_code
ON enrollment_install_codes_persistent(code)
"""
)
columns = {row[1] for row in _table_info(cur, "enrollment_install_codes_persistent")}
if "last_known_use_count" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes_persistent
ADD COLUMN last_known_use_count INTEGER NOT NULL DEFAULT 0
"""
)
if "archived_at" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes_persistent
ADD COLUMN archived_at TEXT
"""
)
if "consumed_at" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes_persistent
ADD COLUMN consumed_at TEXT
"""
)
if "is_active" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes_persistent
ADD COLUMN is_active INTEGER NOT NULL DEFAULT 1
"""
)
if "used_at" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes_persistent
ADD COLUMN used_at TEXT
"""
)
if "used_by_guid" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes_persistent
ADD COLUMN used_by_guid TEXT
"""
)
if "last_used_at" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes_persistent
ADD COLUMN last_used_at TEXT
"""
)
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS device_approvals (
id TEXT PRIMARY KEY,
approval_reference TEXT NOT NULL UNIQUE,
guid TEXT,
hostname_claimed TEXT NOT NULL,
ssl_key_fingerprint_claimed TEXT NOT NULL,
enrollment_code_id TEXT NOT NULL,
status TEXT NOT NULL,
client_nonce TEXT NOT NULL,
server_nonce TEXT NOT NULL,
agent_pubkey_der BLOB NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
approved_by_user_id TEXT
)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_da_status
ON device_approvals(status)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_da_fp_status
ON device_approvals(ssl_key_fingerprint_claimed, status)
"""
)
def _create_devices_table(cur: sqlite3.Cursor) -> None:
cur.execute(
"""
CREATE TABLE devices (
guid TEXT PRIMARY KEY,
hostname TEXT,
description TEXT,
created_at INTEGER,
agent_hash TEXT,
memory TEXT,
network TEXT,
software TEXT,
storage TEXT,
cpu TEXT,
device_type TEXT,
domain TEXT,
external_ip TEXT,
internal_ip TEXT,
last_reboot TEXT,
last_seen INTEGER,
last_user TEXT,
operating_system TEXT,
uptime INTEGER,
agent_id TEXT,
ansible_ee_ver TEXT,
connection_type TEXT,
connection_endpoint TEXT,
ssl_key_fingerprint TEXT,
token_version INTEGER DEFAULT 1,
status TEXT DEFAULT 'active',
key_added_at TEXT
)
"""
)
_ensure_device_indexes(cur)
def _ensure_device_indexes(cur: sqlite3.Cursor) -> None:
cur.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS uq_devices_hostname
ON devices(hostname)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_devices_ssl_key
ON devices(ssl_key_fingerprint)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_devices_status
ON devices(status)
"""
)
def _ensure_column_defaults(cur: sqlite3.Cursor) -> None:
cur.execute(
"""
UPDATE devices
SET token_version = COALESCE(token_version, 1)
WHERE token_version IS NULL
"""
)
cur.execute(
"""
UPDATE devices
SET status = COALESCE(status, 'active')
WHERE status IS NULL OR status = ''
"""
)
def _rebuild_devices_table(conn: sqlite3.Connection, column_info: Sequence[Tuple]) -> None:
cur = conn.cursor()
cur.execute("PRAGMA foreign_keys=OFF")
cur.execute("BEGIN IMMEDIATE")
cur.execute("ALTER TABLE devices RENAME TO devices_legacy")
_create_devices_table(cur)
legacy_columns = [c[1] for c in column_info]
cur.execute(f"SELECT {', '.join(legacy_columns)} FROM devices_legacy")
rows = cur.fetchall()
insert_sql = (
"""
INSERT OR REPLACE INTO devices (
guid, hostname, description, created_at, agent_hash, memory,
network, software, storage, cpu, device_type, domain, external_ip,
internal_ip, last_reboot, last_seen, last_user, operating_system,
uptime, agent_id, ansible_ee_ver, connection_type, connection_endpoint,
ssl_key_fingerprint, token_version, status, key_added_at
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
)
for row in rows:
record = dict(zip(legacy_columns, row))
guid = _normalized_guid(record.get("guid"))
if not guid:
guid = str(uuid.uuid4())
hostname = record.get("hostname")
created_at = record.get("created_at")
key_added_at = record.get("key_added_at")
if key_added_at is None:
key_added_at = _default_key_added_at(created_at)
params: Tuple = (
guid,
hostname,
record.get("description"),
created_at,
record.get("agent_hash"),
record.get("memory"),
record.get("network"),
record.get("software"),
record.get("storage"),
record.get("cpu"),
record.get("device_type"),
record.get("domain"),
record.get("external_ip"),
record.get("internal_ip"),
record.get("last_reboot"),
record.get("last_seen"),
record.get("last_user"),
record.get("operating_system"),
record.get("uptime"),
record.get("agent_id"),
record.get("ansible_ee_ver"),
record.get("connection_type"),
record.get("connection_endpoint"),
record.get("ssl_key_fingerprint"),
record.get("token_version") or 1,
record.get("status") or "active",
key_added_at,
)
cur.execute(insert_sql, params)
cur.execute("DROP TABLE devices_legacy")
cur.execute("COMMIT")
cur.execute("PRAGMA foreign_keys=ON")
def _default_key_added_at(created_at: Optional[int]) -> Optional[str]:
if created_at:
try:
dt = datetime.fromtimestamp(int(created_at), tz=timezone.utc)
return dt.isoformat()
except Exception:
pass
return datetime.now(tz=timezone.utc).isoformat()
def _table_exists(cur: sqlite3.Cursor, name: str) -> bool:
cur.execute(
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
(name,),
)
return cur.fetchone() is not None
def _table_info(cur: sqlite3.Cursor, name: str) -> List[Tuple]:
cur.execute(f"PRAGMA table_info({name})")
return cur.fetchall()
def _normalized_guid(value: Optional[str]) -> str:
if not value:
return ""
return str(value).strip()

View File

@@ -1 +0,0 @@

View File

@@ -1,35 +0,0 @@
"""
Short-lived nonce cache to defend against replay attacks during enrollment.
"""
from __future__ import annotations
import time
from threading import Lock
from typing import Dict
class NonceCache:
def __init__(self, ttl_seconds: float = 300.0) -> None:
self._ttl = ttl_seconds
self._entries: Dict[str, float] = {}
self._lock = Lock()
def consume(self, key: str) -> bool:
"""
Attempt to consume the nonce identified by `key`.
Returns True on first use within TTL, False if already consumed.
"""
now = time.monotonic()
with self._lock:
expire_at = self._entries.get(key)
if expire_at and expire_at > now:
return False
self._entries[key] = now + self._ttl
# Opportunistic cleanup to keep the dict small
stale = [nonce for nonce, expiry in self._entries.items() if expiry <= now]
for nonce in stale:
self._entries.pop(nonce, None)
return True

View File

@@ -1,759 +0,0 @@
from __future__ import annotations
import base64
import secrets
import sqlite3
import uuid
from datetime import datetime, timezone, timedelta
import time
from typing import Any, Callable, Dict, Optional, Tuple
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
from flask import Blueprint, jsonify, request
from Modules.auth.rate_limit import SlidingWindowRateLimiter
from Modules.crypto import keys as crypto_keys
from Modules.enrollment.nonce_store import NonceCache
from Modules.guid_utils import normalize_guid
from cryptography.hazmat.primitives import serialization
def register(
app,
*,
db_conn_factory: Callable[[], sqlite3.Connection],
log: Callable[[str, str, Optional[str]], None],
jwt_service,
tls_bundle_path: str,
ip_rate_limiter: SlidingWindowRateLimiter,
fp_rate_limiter: SlidingWindowRateLimiter,
nonce_cache: NonceCache,
script_signer,
) -> None:
blueprint = Blueprint("enrollment", __name__)
def _now() -> datetime:
return datetime.now(tz=timezone.utc)
def _iso(dt: datetime) -> str:
return dt.isoformat()
def _remote_addr() -> str:
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
addr = request.remote_addr or "unknown"
return addr.strip()
def _signing_key_b64() -> str:
if not script_signer:
return ""
try:
return script_signer.public_base64_spki()
except Exception:
return ""
def _rate_limited(
key: str,
limiter: SlidingWindowRateLimiter,
limit: int,
window_s: float,
context_hint: Optional[str],
):
decision = limiter.check(key, limit, window_s)
if not decision.allowed:
log(
"server",
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
context_hint,
)
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
response.status_code = 429
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
return response
return None
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
cur.execute(
"""
SELECT id,
code,
expires_at,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
WHERE code = ?
""",
(code_value,),
)
row = cur.fetchone()
if not row:
return None
keys = [
"id",
"code",
"expires_at",
"used_at",
"used_by_guid",
"max_uses",
"use_count",
"last_used_at",
]
record = dict(zip(keys, row))
return record
def _install_code_valid(
record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor
) -> Tuple[bool, Optional[str]]:
if not record:
return False, None
expires_at = record.get("expires_at")
if not isinstance(expires_at, str):
return False, None
try:
expiry = datetime.fromisoformat(expires_at)
except Exception:
return False, None
if expiry <= _now():
return False, None
try:
max_uses = int(record.get("max_uses") or 1)
except Exception:
max_uses = 1
if max_uses < 1:
max_uses = 1
try:
use_count = int(record.get("use_count") or 0)
except Exception:
use_count = 0
if use_count < max_uses:
return True, None
guid = normalize_guid(record.get("used_by_guid"))
if not guid:
return False, None
cur.execute(
"SELECT ssl_key_fingerprint FROM devices WHERE UPPER(guid) = ?",
(guid,),
)
row = cur.fetchone()
if not row:
return False, None
stored_fp = (row[0] or "").strip().lower()
if not stored_fp:
return False, None
if stored_fp == (fingerprint or "").strip().lower():
return True, guid
return False, None
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
guid_norm = normalize_guid(guid)
base = (hostname or "").strip() or guid_norm
base = base[:253]
candidate = base
suffix = 1
while True:
cur.execute(
"SELECT guid FROM devices WHERE hostname = ?",
(candidate,),
)
row = cur.fetchone()
if not row:
return candidate
existing_guid = normalize_guid(row[0])
if existing_guid == guid_norm:
return candidate
candidate = f"{base}-{suffix}"
suffix += 1
if suffix > 50:
return guid_norm
def _store_device_key(cur: sqlite3.Cursor, guid: str, fingerprint: str) -> None:
guid_norm = normalize_guid(guid)
added_at = _iso(_now())
cur.execute(
"""
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
VALUES (?, ?, ?, ?)
""",
(str(uuid.uuid4()), guid_norm, fingerprint, added_at),
)
cur.execute(
"""
UPDATE device_keys
SET retired_at = ?
WHERE guid = ?
AND ssl_key_fingerprint != ?
AND retired_at IS NULL
""",
(_iso(_now()), guid_norm, fingerprint),
)
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
guid_norm = normalize_guid(guid)
cur.execute(
"""
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
FROM devices
WHERE UPPER(guid) = ?
""",
(guid_norm,),
)
row = cur.fetchone()
if row:
keys = [
"guid",
"hostname",
"token_version",
"status",
"ssl_key_fingerprint",
"key_added_at",
]
record = dict(zip(keys, row))
record["guid"] = normalize_guid(record.get("guid"))
stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower()
new_fp = (fingerprint or "").strip().lower()
if not stored_fp and new_fp:
cur.execute(
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
(fingerprint, _iso(_now()), record["guid"]),
)
record["ssl_key_fingerprint"] = fingerprint
elif new_fp and stored_fp != new_fp:
now_iso = _iso(_now())
try:
current_version = int(record.get("token_version") or 1)
except Exception:
current_version = 1
new_version = max(current_version + 1, 1)
cur.execute(
"""
UPDATE devices
SET ssl_key_fingerprint = ?,
key_added_at = ?,
token_version = ?,
status = 'active'
WHERE guid = ?
""",
(fingerprint, now_iso, new_version, record["guid"]),
)
cur.execute(
"""
UPDATE refresh_tokens
SET revoked_at = ?
WHERE guid = ?
AND revoked_at IS NULL
""",
(now_iso, record["guid"]),
)
record["ssl_key_fingerprint"] = fingerprint
record["token_version"] = new_version
record["status"] = "active"
record["key_added_at"] = now_iso
return record
resolved_hostname = _normalize_host(hostname, guid_norm, cur)
created_at = int(time.time())
key_added_at = _iso(_now())
cur.execute(
"""
INSERT INTO devices (
guid, hostname, created_at, last_seen, ssl_key_fingerprint,
token_version, status, key_added_at
)
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
""",
(
guid_norm,
resolved_hostname,
created_at,
created_at,
fingerprint,
key_added_at,
),
)
return {
"guid": guid_norm,
"hostname": resolved_hostname,
"token_version": 1,
"status": "active",
"ssl_key_fingerprint": fingerprint,
"key_added_at": key_added_at,
}
def _hash_refresh_token(token: str) -> str:
import hashlib
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _issue_refresh_token(cur: sqlite3.Cursor, guid: str) -> Dict[str, Any]:
token = secrets.token_urlsafe(48)
now = _now()
expires_at = now.replace(microsecond=0) + timedelta(days=30)
cur.execute(
"""
INSERT INTO refresh_tokens (id, guid, token_hash, created_at, expires_at)
VALUES (?, ?, ?, ?, ?)
""",
(
str(uuid.uuid4()),
guid,
_hash_refresh_token(token),
_iso(now),
_iso(expires_at),
),
)
return {"token": token, "expires_at": expires_at}
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
def enrollment_request():
remote = _remote_addr()
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0, context_hint)
if rate_error:
return rate_error
payload = request.get_json(force=True, silent=True) or {}
hostname = str(payload.get("hostname") or "").strip()
enrollment_code = str(payload.get("enrollment_code") or "").strip()
agent_pubkey_b64 = payload.get("agent_pubkey")
client_nonce_b64 = payload.get("client_nonce")
log(
"server",
"enrollment request received "
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
context_hint,
)
if not hostname:
log("server", f"enrollment rejected missing_hostname ip={remote}", context_hint)
return jsonify({"error": "hostname_required"}), 400
if not enrollment_code:
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
return jsonify({"error": "enrollment_code_required"}), 400
if not isinstance(agent_pubkey_b64, str):
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "agent_pubkey_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400
try:
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
except Exception:
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400
if len(agent_pubkey_der) < 10:
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
if len(client_nonce_bytes) < 16:
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0, context_hint)
if rate_error:
return rate_error
conn = db_conn_factory()
try:
cur = conn.cursor()
install_code = _load_install_code(cur, enrollment_code)
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
if not valid_code:
log(
"server",
"enrollment request invalid_code "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
)
return jsonify({"error": "invalid_enrollment_code"}), 400
approval_reference: str
record_id: str
server_nonce_bytes = secrets.token_bytes(32)
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
now = _iso(_now())
cur.execute(
"""
SELECT id, approval_reference
FROM device_approvals
WHERE ssl_key_fingerprint_claimed = ?
AND status = 'pending'
""",
(fingerprint,),
)
existing = cur.fetchone()
if existing:
record_id = existing[0]
approval_reference = existing[1]
cur.execute(
"""
UPDATE device_approvals
SET hostname_claimed = ?,
guid = ?,
enrollment_code_id = ?,
client_nonce = ?,
server_nonce = ?,
agent_pubkey_der = ?,
updated_at = ?
WHERE id = ?
""",
(
hostname,
reuse_guid,
install_code["id"],
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
now,
record_id,
),
)
else:
record_id = str(uuid.uuid4())
approval_reference = str(uuid.uuid4())
cur.execute(
"""
INSERT INTO device_approvals (
id, approval_reference, guid, hostname_claimed,
ssl_key_fingerprint_claimed, enrollment_code_id,
status, client_nonce, server_nonce, agent_pubkey_der,
created_at, updated_at
)
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
""",
(
record_id,
approval_reference,
reuse_guid,
hostname,
fingerprint,
install_code["id"],
client_nonce_b64,
server_nonce_b64,
agent_pubkey_der,
now,
now,
),
)
conn.commit()
finally:
conn.close()
response = {
"status": "pending",
"approval_reference": approval_reference,
"server_nonce": server_nonce_b64,
"poll_after_ms": 3000,
"server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(),
}
log(
"server",
f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}",
context_hint,
)
return jsonify(response)
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
def enrollment_poll():
payload = request.get_json(force=True, silent=True) or {}
approval_reference = payload.get("approval_reference")
client_nonce_b64 = payload.get("client_nonce")
proof_sig_b64 = payload.get("proof_sig")
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
log(
"server",
"enrollment poll received "
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
f" proof_sig_len={len(proof_sig_b64 or '')}",
context_hint,
)
if not isinstance(approval_reference, str) or not approval_reference:
log("server", "enrollment poll rejected missing_reference", context_hint)
return jsonify({"error": "approval_reference_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400
if not isinstance(proof_sig_b64, str):
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
return jsonify({"error": "proof_sig_required"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
try:
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof_sig"}), 400
conn = db_conn_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT id, guid, hostname_claimed, ssl_key_fingerprint_claimed,
enrollment_code_id, status, client_nonce, server_nonce,
agent_pubkey_der, created_at, updated_at, approved_by_user_id
FROM device_approvals
WHERE approval_reference = ?
""",
(approval_reference,),
)
row = cur.fetchone()
if not row:
log("server", f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
return jsonify({"status": "unknown"}), 404
(
record_id,
guid,
hostname_claimed,
fingerprint,
enrollment_code_id,
status,
client_nonce_stored,
server_nonce_b64,
agent_pubkey_der,
created_at,
updated_at,
approved_by,
) = row
if client_nonce_stored != client_nonce_b64:
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
return jsonify({"error": "nonce_mismatch"}), 400
try:
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "server_nonce_invalid"}), 400
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
try:
public_key = serialization.load_der_public_key(agent_pubkey_der)
except Exception:
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
public_key = None
if public_key is None:
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
return jsonify({"error": "agent_pubkey_invalid"}), 400
try:
public_key.verify(proof_sig, message)
except Exception:
log("server", f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof"}), 400
if status == "pending":
log(
"server",
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
f" fingerprint={fingerprint[:12]}",
context_hint,
)
return jsonify({"status": "pending", "poll_after_ms": 5000})
if status == "denied":
log(
"server",
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "denied", "reason": "operator_denied"})
if status == "expired":
log(
"server",
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "expired"})
if status == "completed":
log(
"server",
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "approved", "detail": "finalized"})
if status != "approved":
log(
"server",
f"enrollment poll unexpected_status={status} ref={approval_reference}",
context_hint,
)
return jsonify({"status": status or "unknown"}), 400
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
if not nonce_cache.consume(nonce_key):
log(
"server",
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
context_hint,
)
return jsonify({"error": "proof_replayed"}), 409
# Finalize enrollment
effective_guid = normalize_guid(guid) if guid else normalize_guid(str(uuid.uuid4()))
now_iso = _iso(_now())
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
_store_device_key(cur, effective_guid, fingerprint)
# Mark install code used
if enrollment_code_id:
cur.execute(
"SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?",
(enrollment_code_id,),
)
usage_row = cur.fetchone()
try:
prior_count = int(usage_row[0]) if usage_row else 0
except Exception:
prior_count = 0
try:
allowed_uses = int(usage_row[1]) if usage_row else 1
except Exception:
allowed_uses = 1
if allowed_uses < 1:
allowed_uses = 1
new_count = prior_count + 1
consumed = new_count >= allowed_uses
cur.execute(
"""
UPDATE enrollment_install_codes
SET use_count = ?,
used_by_guid = ?,
last_used_at = ?,
used_at = CASE WHEN ? THEN ? ELSE used_at END
WHERE id = ?
""",
(
new_count,
effective_guid,
now_iso,
1 if consumed else 0,
now_iso,
enrollment_code_id,
),
)
cur.execute(
"""
UPDATE enrollment_install_codes_persistent
SET last_known_use_count = ?,
used_by_guid = ?,
last_used_at = ?,
used_at = CASE WHEN ? THEN ? ELSE used_at END,
is_active = CASE WHEN ? THEN 0 ELSE is_active END,
consumed_at = CASE WHEN ? THEN COALESCE(consumed_at, ?) ELSE consumed_at END,
archived_at = CASE WHEN ? THEN COALESCE(archived_at, ?) ELSE archived_at END
WHERE id = ?
""",
(
new_count,
effective_guid,
now_iso,
1 if consumed else 0,
now_iso,
1 if consumed else 0,
1 if consumed else 0,
now_iso,
1 if consumed else 0,
now_iso,
enrollment_code_id,
),
)
# Update approval record with final state
cur.execute(
"""
UPDATE device_approvals
SET guid = ?,
status = 'completed',
updated_at = ?
WHERE id = ?
""",
(effective_guid, now_iso, record_id),
)
refresh_info = _issue_refresh_token(cur, effective_guid)
access_token = jwt_service.issue_access_token(
effective_guid,
fingerprint,
device_record.get("token_version") or 1,
)
conn.commit()
finally:
conn.close()
log(
"server",
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
context_hint,
)
return jsonify(
{
"status": "approved",
"guid": effective_guid,
"access_token": access_token,
"expires_in": 900,
"refresh_token": refresh_info["token"],
"token_type": "Bearer",
"server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(),
}
)
app.register_blueprint(blueprint)
def _load_tls_bundle(path: str) -> str:
try:
with open(path, "r", encoding="utf-8") as fh:
return fh.read()
except Exception:
return ""
def _mask_code(code: str) -> str:
if not code:
return "<missing>"
trimmed = str(code).strip()
if len(trimmed) <= 6:
return "***"
return f"{trimmed[:3]}***{trimmed[-3:]}"

View File

@@ -1,26 +0,0 @@
from __future__ import annotations
import string
import uuid
from typing import Optional
def normalize_guid(value: Optional[str]) -> str:
"""
Canonicalize GUID strings so the server treats different casings/formats uniformly.
"""
candidate = (value or "").strip()
if not candidate:
return ""
candidate = candidate.strip("{}")
try:
return str(uuid.UUID(candidate)).upper()
except Exception:
cleaned = "".join(ch for ch in candidate if ch in string.hexdigits or ch == "-")
cleaned = cleaned.strip("-")
if cleaned:
try:
return str(uuid.UUID(cleaned)).upper()
except Exception:
pass
return candidate.upper()

View File

@@ -1 +0,0 @@

View File

@@ -1,110 +0,0 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Callable, List, Optional
import eventlet
from flask_socketio import SocketIO
def start_prune_job(
socketio: SocketIO,
*,
db_conn_factory: Callable[[], any],
log: Callable[[str, str, Optional[str]], None],
) -> None:
def _job_loop():
while True:
try:
_run_once(db_conn_factory, log)
except Exception as exc:
log("server", f"prune job failure: {exc}")
eventlet.sleep(24 * 60 * 60)
socketio.start_background_task(_job_loop)
def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str, Optional[str]], None]) -> None:
now = datetime.now(tz=timezone.utc)
now_iso = now.isoformat()
stale_before = (now - timedelta(hours=24)).isoformat()
conn = db_conn_factory()
try:
cur = conn.cursor()
persistent_table_exists = False
try:
cur.execute(
"SELECT 1 FROM sqlite_master WHERE type='table' AND name='enrollment_install_codes_persistent'"
)
persistent_table_exists = cur.fetchone() is not None
except Exception:
persistent_table_exists = False
expired_ids: List[str] = []
if persistent_table_exists:
cur.execute(
"""
SELECT id
FROM enrollment_install_codes
WHERE use_count = 0
AND expires_at < ?
""",
(now_iso,),
)
expired_ids = [str(row[0]) for row in cur.fetchall() if row and row[0]]
cur.execute(
"""
DELETE FROM enrollment_install_codes
WHERE use_count = 0
AND expires_at < ?
""",
(now_iso,),
)
codes_pruned = cur.rowcount or 0
if expired_ids:
placeholders = ",".join("?" for _ in expired_ids)
try:
cur.execute(
f"""
UPDATE enrollment_install_codes_persistent
SET is_active = 0,
archived_at = COALESCE(archived_at, ?)
WHERE id IN ({placeholders})
""",
(now_iso, *expired_ids),
)
except Exception:
# Best-effort archival; continue if the persistence table is absent.
pass
cur.execute(
"""
UPDATE device_approvals
SET status = 'expired',
updated_at = ?
WHERE status = 'pending'
AND (
EXISTS (
SELECT 1
FROM enrollment_install_codes c
WHERE c.id = device_approvals.enrollment_code_id
AND (
c.expires_at < ?
OR c.use_count >= c.max_uses
)
)
OR created_at < ?
)
""",
(now_iso, now_iso, stale_before),
)
approvals_marked = cur.rowcount or 0
conn.commit()
finally:
conn.close()
if codes_pruned:
log("server", f"prune job removed {codes_pruned} expired enrollment codes")
if approvals_marked:
log("server", f"prune job expired {approvals_marked} device approvals")

View File

@@ -1,168 +0,0 @@
"""Utility helpers for locating runtime storage paths.
The Borealis repository keeps the authoritative source code under ``Data/``
so that the bootstrap scripts can copy those assets into sibling ``Server/``
and ``Agent/`` directories for execution. Runtime artefacts such as TLS
certificates or signing keys must therefore live outside ``Data`` to avoid
polluting the template tree. This module centralises the path selection so
other modules can rely on a consistent location regardless of whether they
are executed from the copied runtime directory or directly from ``Data``
during development.
"""
from __future__ import annotations
import os
from functools import lru_cache
from pathlib import Path
from typing import Optional
def _env_path(name: str) -> Optional[Path]:
"""Return a resolved ``Path`` for the given environment variable."""
value = os.environ.get(name)
if not value:
return None
try:
return Path(value).expanduser().resolve()
except Exception:
return None
@lru_cache(maxsize=None)
def project_root() -> Path:
"""Best-effort detection of the repository root."""
env = _env_path("BOREALIS_PROJECT_ROOT")
if env:
return env
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "Borealis.ps1").exists() or (parent / ".git").is_dir():
return parent
# Fallback to the ancestor that corresponds to ``<repo>/`` when the module
# lives under ``Data/Server/Modules``.
try:
return current.parents[4]
except IndexError:
return current.parent
@lru_cache(maxsize=None)
def server_runtime_root() -> Path:
"""Location where the running server stores mutable artefacts."""
env = _env_path("BOREALIS_SERVER_ROOT")
if env:
return env
root = project_root()
runtime = root / "Server" / "Borealis"
return runtime
def runtime_path(*parts: str) -> Path:
"""Return a path relative to the server runtime root."""
return server_runtime_root().joinpath(*parts)
def ensure_runtime_dir(*parts: str) -> Path:
"""Create (if required) and return a runtime directory."""
path = runtime_path(*parts)
path.mkdir(parents=True, exist_ok=True)
return path
@lru_cache(maxsize=None)
def certificates_root() -> Path:
"""Base directory for persisted certificate material."""
env = _env_path("BOREALIS_CERTIFICATES_ROOT") or _env_path("BOREALIS_CERT_ROOT")
if env:
env.mkdir(parents=True, exist_ok=True)
return env
root = project_root() / "Certificates"
root.mkdir(parents=True, exist_ok=True)
# Ensure expected subdirectories exist for agent and server material.
try:
(root / "Server").mkdir(parents=True, exist_ok=True)
(root / "Agent").mkdir(parents=True, exist_ok=True)
except Exception:
pass
return root
@lru_cache(maxsize=None)
def server_certificates_root() -> Path:
"""Base directory for server certificate material."""
env = _env_path("BOREALIS_SERVER_CERT_ROOT")
if env:
env.mkdir(parents=True, exist_ok=True)
return env
root = certificates_root() / "Server"
root.mkdir(parents=True, exist_ok=True)
return root
@lru_cache(maxsize=None)
def agent_certificates_root() -> Path:
"""Base directory for agent certificate material."""
env = _env_path("BOREALIS_AGENT_CERT_ROOT")
if env:
env.mkdir(parents=True, exist_ok=True)
return env
root = certificates_root() / "Agent"
root.mkdir(parents=True, exist_ok=True)
return root
def certificates_path(*parts: str) -> Path:
"""Return a path under the certificates root."""
return certificates_root().joinpath(*parts)
def ensure_certificates_dir(*parts: str) -> Path:
"""Create (if required) and return a certificates subdirectory."""
path = certificates_path(*parts)
path.mkdir(parents=True, exist_ok=True)
return path
def server_certificates_path(*parts: str) -> Path:
"""Return a path under the server certificates root."""
return server_certificates_root().joinpath(*parts)
def ensure_server_certificates_dir(*parts: str) -> Path:
"""Create (if required) and return a server certificates subdirectory."""
path = server_certificates_path(*parts)
path.mkdir(parents=True, exist_ok=True)
return path
def agent_certificates_path(*parts: str) -> Path:
"""Return a path under the agent certificates root."""
return agent_certificates_root().joinpath(*parts)
def ensure_agent_certificates_dir(*parts: str) -> Path:
"""Create (if required) and return an agent certificates subdirectory."""
path = agent_certificates_path(*parts)
path.mkdir(parents=True, exist_ok=True)
return path

View File

@@ -1 +0,0 @@

View File

@@ -1,138 +0,0 @@
from __future__ import annotations
import hashlib
import sqlite3
from datetime import datetime, timezone
from typing import Callable
from flask import Blueprint, jsonify, request
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
def register(
app,
*,
db_conn_factory: Callable[[], sqlite3.Connection],
jwt_service,
dpop_validator: DPoPValidator,
) -> None:
blueprint = Blueprint("tokens", __name__)
def _hash_token(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _iso_now() -> str:
return datetime.now(tz=timezone.utc).isoformat()
def _parse_iso(ts: str) -> datetime:
return datetime.fromisoformat(ts)
@blueprint.route("/api/agent/token/refresh", methods=["POST"])
def refresh():
payload = request.get_json(force=True, silent=True) or {}
guid = str(payload.get("guid") or "").strip()
refresh_token = str(payload.get("refresh_token") or "").strip()
if not guid or not refresh_token:
return jsonify({"error": "invalid_request"}), 400
conn = db_conn_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at
FROM refresh_tokens
WHERE guid = ?
AND token_hash = ?
""",
(guid, _hash_token(refresh_token)),
)
row = cur.fetchone()
if not row:
return jsonify({"error": "invalid_refresh_token"}), 401
record_id, row_guid, _token_hash, stored_jkt, created_at, expires_at, revoked_at = row
if row_guid != guid:
return jsonify({"error": "invalid_refresh_token"}), 401
if revoked_at:
return jsonify({"error": "refresh_token_revoked"}), 401
if expires_at:
try:
if _parse_iso(expires_at) <= datetime.now(tz=timezone.utc):
return jsonify({"error": "refresh_token_expired"}), 401
except Exception:
pass
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE guid = ?
""",
(guid,),
)
device_row = cur.fetchone()
if not device_row:
return jsonify({"error": "device_not_found"}), 404
device_guid, fingerprint, token_version, status = device_row
status_norm = (status or "active").strip().lower()
if status_norm in {"revoked", "decommissioned"}:
return jsonify({"error": "device_revoked"}), 403
dpop_proof = request.headers.get("DPoP")
jkt = stored_jkt or ""
if dpop_proof:
try:
jkt = dpop_validator.verify(request.method, request.url, dpop_proof, access_token=None)
except DPoPReplayError:
return jsonify({"error": "dpop_replayed"}), 400
except DPoPVerificationError:
return jsonify({"error": "dpop_invalid"}), 400
elif stored_jkt:
# The agent does not yet emit DPoP proofs; allow recovery by clearing
# the stored binding so refreshes can succeed. This preserves
# backward compatibility while the client gains full DPoP support.
try:
app.logger.warning(
"Clearing stored DPoP binding for guid=%s due to missing proof",
guid,
)
except Exception:
pass
cur.execute(
"UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?",
(record_id,),
)
new_access_token = jwt_service.issue_access_token(
guid,
fingerprint or "",
token_version or 1,
)
cur.execute(
"""
UPDATE refresh_tokens
SET last_used_at = ?,
dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt)
WHERE id = ?
""",
(_iso_now(), jkt, record_id),
)
conn.commit()
finally:
conn.close()
return jsonify(
{
"access_token": new_access_token,
"expires_in": 900,
"token_type": "Bearer",
}
)
app.register_blueprint(blueprint)