Serialise agent enrollment and guard heartbeat auth context

This commit is contained in:
2025-10-18 00:03:13 -06:00
parent 7b0e2f48e1
commit 91e7a6de88
2 changed files with 38 additions and 8 deletions

View File

@@ -20,6 +20,7 @@ import datetime
import shutil
import string
import ssl
import threading
from typing import Any, Dict, Optional, List
import requests
@@ -494,6 +495,7 @@ class AgentHttpClient:
self.access_token: Optional[str] = self.key_store.load_access_token()
self.refresh_token: Optional[str] = self.key_store.load_refresh_token()
self.access_expires_at: Optional[int] = self.key_store.get_access_expiry()
self._auth_lock = threading.RLock()
self.refresh_base_url()
self._configure_verify()
if self.access_token:
@@ -570,11 +572,15 @@ class AgentHttpClient:
# Enrollment & token management
# ------------------------------------------------------------------
def ensure_authenticated(self) -> None:
with self._auth_lock:
self._ensure_authenticated_locked()
def _ensure_authenticated_locked(self) -> None:
self.refresh_base_url()
if not self.guid or not self.refresh_token:
self.perform_enrollment()
self._perform_enrollment_locked()
if not self.access_token or self._token_expiring_soon():
self.refresh_access_token()
self._refresh_access_token_locked()
def _token_expiring_soon(self) -> bool:
if not self.access_token:
@@ -584,6 +590,12 @@ class AgentHttpClient:
return (self.access_expires_at - time.time()) < 60
def perform_enrollment(self) -> None:
with self._auth_lock:
self._perform_enrollment_locked()
def _perform_enrollment_locked(self) -> None:
if self.guid and self.refresh_token:
return
code = self._resolve_installer_code()
if not code:
raise RuntimeError(
@@ -682,9 +694,13 @@ class AgentHttpClient:
_log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log")
def refresh_access_token(self) -> None:
with self._auth_lock:
self._refresh_access_token_locked()
def _refresh_access_token_locked(self) -> None:
if not self.refresh_token or not self.guid:
self.clear_tokens()
self.perform_enrollment()
self._clear_tokens_locked()
self._perform_enrollment_locked()
return
payload = {"guid": self.guid, "refresh_token": self.refresh_token}
resp = self.session.post(
@@ -695,8 +711,8 @@ class AgentHttpClient:
)
if resp.status_code in (401, 403):
_log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log")
self.clear_tokens()
self.perform_enrollment()
self._clear_tokens_locked()
self._perform_enrollment_locked()
return
resp.raise_for_status()
data = resp.json()
@@ -712,6 +728,10 @@ class AgentHttpClient:
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
def clear_tokens(self) -> None:
with self._auth_lock:
self._clear_tokens_locked()
def _clear_tokens_locked(self) -> None:
self.key_store.clear_tokens()
self.access_token = None
self.refresh_token = None

View File

@@ -28,10 +28,18 @@ def register(
except Exception:
return None
def _auth_context():
ctx = getattr(g, "device_auth", None)
if ctx is None:
log("server", f"device auth context missing for {request.path}")
return ctx
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
@require_device_auth(auth_manager)
def heartbeat():
ctx = getattr(g, "device_auth")
ctx = _auth_context()
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
payload = request.get_json(force=True, silent=True) or {}
now_ts = int(time.time())
@@ -90,7 +98,9 @@ def register(
@blueprint.route("/api/agent/script/request", methods=["POST"])
@require_device_auth(auth_manager)
def script_request():
ctx = getattr(g, "device_auth")
ctx = _auth_context()
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
if ctx.status != "active":
return jsonify(
{