Clarify agent and server log context labeling

This commit is contained in:
2025-10-18 04:04:02 -06:00
parent 64e0c05d66
commit afa429db3f
7 changed files with 227 additions and 44 deletions

View File

@@ -18,7 +18,7 @@ def register(
db_conn_factory: Callable[[], sqlite3.Connection],
require_admin: Callable[[], Optional[Any]],
current_user: Callable[[], Optional[Dict[str, str]]],
log: Callable[[str, str], None],
log: Callable[[str, str, Optional[str]], None],
) -> None:
blueprint = Blueprint("admin", __name__)

View File

@@ -10,13 +10,24 @@ from flask import Blueprint, jsonify, request, g
from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
from Modules.crypto.signing import ScriptSigner
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
def register(
app,
*,
db_conn_factory: Callable[[], Any],
auth_manager: DeviceAuthManager,
log: Callable[[str, str], None],
log: Callable[[str, str, Optional[str]], None],
script_signer: ScriptSigner,
) -> None:
blueprint = Blueprint("agents", __name__)
@@ -29,10 +40,15 @@ def register(
except Exception:
return None
def _context_hint(ctx=None) -> Optional[str]:
if ctx is not None and getattr(ctx, "service_mode", None):
return _canonical_context(getattr(ctx, "service_mode", None))
return _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
def _auth_context():
ctx = getattr(g, "device_auth", None)
if ctx is None:
log("server", f"device auth context missing for {request.path}")
log("server", f"device auth context missing for {request.path}", _context_hint())
return ctx
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
@@ -42,6 +58,7 @@ def register(
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
payload = request.get_json(force=True, silent=True) or {}
context_label = _context_hint(ctx)
now_ts = int(time.time())
updates: Dict[str, Optional[str]] = {"last_seen": now_ts}
@@ -111,12 +128,13 @@ def register(
"server",
"heartbeat hostname collision ignored for guid="
f"{ctx.guid}",
context_label,
)
else:
raise
if rowcount == 0:
log("server", f"heartbeat missing device record guid={ctx.guid}")
log("server", f"heartbeat missing device record guid={ctx.guid}", context_label)
return jsonify({"error": "device_not_registered"}), 404
conn.commit()
finally:

View File

@@ -13,6 +13,17 @@ from flask import g, jsonify, request
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
from Modules.auth.rate_limit import SlidingWindowRateLimiter
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
@dataclass
class DeviceAuthContext:
@@ -23,6 +34,7 @@ class DeviceAuthContext:
claims: Dict[str, Any]
dpop_jkt: Optional[str]
status: str
service_mode: Optional[str]
class DeviceAuthError(Exception):
@@ -50,7 +62,7 @@ class DeviceAuthManager:
db_conn_factory: Callable[[], Any],
jwt_service,
dpop_validator: Optional[DPoPValidator],
log: Callable[[str, str], None],
log: Callable[[str, str, Optional[str]], None],
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
) -> None:
self._db_conn_factory = db_conn_factory
@@ -89,6 +101,8 @@ class DeviceAuthManager:
retry_after=decision.retry_after,
)
context_label = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
conn = self._db_conn_factory()
try:
cur = conn.cursor()
@@ -100,10 +114,10 @@ class DeviceAuthManager:
""",
(guid,),
)
row = cur.fetchone()
row = cur.fetchone()
if not row:
row = self._recover_device_record(conn, guid, fingerprint, token_version)
if not row:
row = self._recover_device_record(conn, guid, fingerprint, token_version, context_label)
finally:
conn.close()
@@ -127,7 +141,11 @@ class DeviceAuthManager:
if status_normalized not in allowed_statuses:
raise DeviceAuthError("device_revoked", status_code=403)
if status_normalized == "quarantined":
self._log("server", f"device {guid} is quarantined; limited access for {request.path}")
self._log(
"server",
f"device {guid} is quarantined; limited access for {request.path}",
context_label,
)
dpop_jkt: Optional[str] = None
dpop_proof = request.headers.get("DPoP")
@@ -150,6 +168,7 @@ class DeviceAuthManager:
claims=claims,
dpop_jkt=dpop_jkt,
status=status_normalized,
service_mode=context_label,
)
return ctx
@@ -159,6 +178,7 @@ class DeviceAuthManager:
guid: str,
fingerprint: str,
token_version: int,
context_label: Optional[str],
) -> Optional[tuple]:
"""Attempt to recreate a missing device row for an authenticated token."""
@@ -211,6 +231,7 @@ class DeviceAuthManager:
self._log(
"server",
f"device auth failed to recover guid={guid} due to integrity error: {exc}",
context_label,
)
conn.rollback()
return None
@@ -218,6 +239,7 @@ class DeviceAuthManager:
self._log(
"server",
f"device auth unexpected error recovering guid={guid}: {exc}",
context_label,
)
conn.rollback()
return None
@@ -229,6 +251,7 @@ class DeviceAuthManager:
self._log(
"server",
f"device auth could not recover guid={guid}; hostname collisions persisted",
context_label,
)
conn.rollback()
return None
@@ -246,6 +269,7 @@ class DeviceAuthManager:
self._log(
"server",
f"device auth recovery for guid={guid} committed but row still missing",
context_label,
)
return row

View File

@@ -8,6 +8,17 @@ from datetime import datetime, timezone, timedelta
import time
from typing import Any, Callable, Dict, Optional, Tuple
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
from flask import Blueprint, jsonify, request
from Modules.auth.rate_limit import SlidingWindowRateLimiter
@@ -20,7 +31,7 @@ def register(
app,
*,
db_conn_factory: Callable[[], sqlite3.Connection],
log: Callable[[str, str], None],
log: Callable[[str, str, Optional[str]], None],
jwt_service,
tls_bundle_path: str,
ip_rate_limiter: SlidingWindowRateLimiter,
@@ -51,12 +62,19 @@ def register(
except Exception:
return ""
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
def _rate_limited(
key: str,
limiter: SlidingWindowRateLimiter,
limit: int,
window_s: float,
context_hint: Optional[str],
):
decision = limiter.check(key, limit, window_s)
if not decision.allowed:
log(
"server",
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
context_hint,
)
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
response.status_code = 429
@@ -295,7 +313,9 @@ def register(
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
def enrollment_request():
remote = _remote_addr()
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0)
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0, context_hint)
if rate_error:
return rate_error
@@ -310,42 +330,43 @@ def register(
"enrollment request received "
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
context_hint,
)
if not hostname:
log("server", f"enrollment rejected missing_hostname ip={remote}")
log("server", f"enrollment rejected missing_hostname ip={remote}", context_hint)
return jsonify({"error": "hostname_required"}), 400
if not enrollment_code:
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}")
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
return jsonify({"error": "enrollment_code_required"}), 400
if not isinstance(agent_pubkey_b64, str):
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}")
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "agent_pubkey_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}")
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400
try:
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
except Exception:
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}")
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400
if len(agent_pubkey_der) < 10:
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}")
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}")
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
if len(client_nonce_bytes) < 16:
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}")
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0)
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0, context_hint)
if rate_error:
return rate_error
@@ -359,6 +380,7 @@ def register(
"server",
"enrollment request invalid_code "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
)
return jsonify({"error": "invalid_enrollment_code"}), 400
@@ -444,7 +466,11 @@ def register(
"server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(),
}
log("server", f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}")
log(
"server",
f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}",
context_hint,
)
return jsonify(response)
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
@@ -453,34 +479,36 @@ def register(
approval_reference = payload.get("approval_reference")
client_nonce_b64 = payload.get("client_nonce")
proof_sig_b64 = payload.get("proof_sig")
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
log(
"server",
"enrollment poll received "
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
f" proof_sig_len={len(proof_sig_b64 or '')}",
context_hint,
)
if not isinstance(approval_reference, str) or not approval_reference:
log("server", "enrollment poll rejected missing_reference")
log("server", "enrollment poll rejected missing_reference", context_hint)
return jsonify({"error": "approval_reference_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}")
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400
if not isinstance(proof_sig_b64, str):
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}")
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
return jsonify({"error": "proof_sig_required"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}")
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
try:
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_sig ref={approval_reference}")
log("server", f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof_sig"}), 400
conn = db_conn_factory()
@@ -498,7 +526,7 @@ def register(
)
row = cur.fetchone()
if not row:
log("server", f"enrollment poll unknown_reference ref={approval_reference}")
log("server", f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
return jsonify({"status": "unknown"}), 404
(
@@ -517,13 +545,13 @@ def register(
) = row
if client_nonce_stored != client_nonce_b64:
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}")
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
return jsonify({"error": "nonce_mismatch"}), 400
try:
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}")
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "server_nonce_invalid"}), 400
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
@@ -531,17 +559,17 @@ def register(
try:
public_key = serialization.load_der_public_key(agent_pubkey_der)
except Exception:
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}")
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
public_key = None
if public_key is None:
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}")
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
return jsonify({"error": "agent_pubkey_invalid"}), 400
try:
public_key.verify(proof_sig, message)
except Exception:
log("server", f"enrollment poll invalid_proof ref={approval_reference}")
log("server", f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof"}), 400
if status == "pending":
@@ -549,24 +577,28 @@ def register(
"server",
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
f" fingerprint={fingerprint[:12]}",
context_hint,
)
return jsonify({"status": "pending", "poll_after_ms": 5000})
if status == "denied":
log(
"server",
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "denied", "reason": "operator_denied"})
if status == "expired":
log(
"server",
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "expired"})
if status == "completed":
log(
"server",
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "approved", "detail": "finalized"})
@@ -574,6 +606,7 @@ def register(
log(
"server",
f"enrollment poll unexpected_status={status} ref={approval_reference}",
context_hint,
)
return jsonify({"status": status or "unknown"}), 400
@@ -582,6 +615,7 @@ def register(
log(
"server",
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
context_hint,
)
return jsonify({"error": "proof_replayed"}), 409
@@ -656,6 +690,7 @@ def register(
log(
"server",
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
context_hint,
)
return jsonify(
{

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Callable
from typing import Callable, Optional
import eventlet
from flask_socketio import SocketIO
@@ -11,7 +11,7 @@ def start_prune_job(
socketio: SocketIO,
*,
db_conn_factory: Callable[[], any],
log: Callable[[str, str], None],
log: Callable[[str, str, Optional[str]], None],
) -> None:
def _job_loop():
while True:
@@ -24,7 +24,7 @@ def start_prune_job(
socketio.start_background_task(_job_loop)
def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None]) -> None:
def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str, Optional[str]], None]) -> None:
now = datetime.now(tz=timezone.utc)
now_iso = now.isoformat()
stale_before = (now - timedelta(hours=24)).isoformat()