Serialise agent enrollment and guard heartbeat auth context

This commit is contained in:
2025-10-18 00:03:13 -06:00
parent 7b0e2f48e1
commit 91e7a6de88
2 changed files with 38 additions and 8 deletions

View File

@@ -20,6 +20,7 @@ import datetime
import shutil import shutil
import string import string
import ssl import ssl
import threading
from typing import Any, Dict, Optional, List from typing import Any, Dict, Optional, List
import requests import requests
@@ -494,6 +495,7 @@ class AgentHttpClient:
self.access_token: Optional[str] = self.key_store.load_access_token() self.access_token: Optional[str] = self.key_store.load_access_token()
self.refresh_token: Optional[str] = self.key_store.load_refresh_token() self.refresh_token: Optional[str] = self.key_store.load_refresh_token()
self.access_expires_at: Optional[int] = self.key_store.get_access_expiry() self.access_expires_at: Optional[int] = self.key_store.get_access_expiry()
self._auth_lock = threading.RLock()
self.refresh_base_url() self.refresh_base_url()
self._configure_verify() self._configure_verify()
if self.access_token: if self.access_token:
@@ -570,11 +572,15 @@ class AgentHttpClient:
# Enrollment & token management # Enrollment & token management
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def ensure_authenticated(self) -> None: def ensure_authenticated(self) -> None:
with self._auth_lock:
self._ensure_authenticated_locked()
def _ensure_authenticated_locked(self) -> None:
self.refresh_base_url() self.refresh_base_url()
if not self.guid or not self.refresh_token: if not self.guid or not self.refresh_token:
self.perform_enrollment() self._perform_enrollment_locked()
if not self.access_token or self._token_expiring_soon(): if not self.access_token or self._token_expiring_soon():
self.refresh_access_token() self._refresh_access_token_locked()
def _token_expiring_soon(self) -> bool: def _token_expiring_soon(self) -> bool:
if not self.access_token: if not self.access_token:
@@ -584,6 +590,12 @@ class AgentHttpClient:
return (self.access_expires_at - time.time()) < 60 return (self.access_expires_at - time.time()) < 60
def perform_enrollment(self) -> None: def perform_enrollment(self) -> None:
with self._auth_lock:
self._perform_enrollment_locked()
def _perform_enrollment_locked(self) -> None:
if self.guid and self.refresh_token:
return
code = self._resolve_installer_code() code = self._resolve_installer_code()
if not code: if not code:
raise RuntimeError( raise RuntimeError(
@@ -682,9 +694,13 @@ class AgentHttpClient:
_log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log") _log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log")
def refresh_access_token(self) -> None: def refresh_access_token(self) -> None:
with self._auth_lock:
self._refresh_access_token_locked()
def _refresh_access_token_locked(self) -> None:
if not self.refresh_token or not self.guid: if not self.refresh_token or not self.guid:
self.clear_tokens() self._clear_tokens_locked()
self.perform_enrollment() self._perform_enrollment_locked()
return return
payload = {"guid": self.guid, "refresh_token": self.refresh_token} payload = {"guid": self.guid, "refresh_token": self.refresh_token}
resp = self.session.post( resp = self.session.post(
@@ -695,8 +711,8 @@ class AgentHttpClient:
) )
if resp.status_code in (401, 403): if resp.status_code in (401, 403):
_log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log") _log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log")
self.clear_tokens() self._clear_tokens_locked()
self.perform_enrollment() self._perform_enrollment_locked()
return return
resp.raise_for_status() resp.raise_for_status()
data = resp.json() data = resp.json()
@@ -712,6 +728,10 @@ class AgentHttpClient:
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"}) self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
def clear_tokens(self) -> None: def clear_tokens(self) -> None:
with self._auth_lock:
self._clear_tokens_locked()
def _clear_tokens_locked(self) -> None:
self.key_store.clear_tokens() self.key_store.clear_tokens()
self.access_token = None self.access_token = None
self.refresh_token = None self.refresh_token = None

View File

@@ -28,10 +28,18 @@ def register(
except Exception: except Exception:
return None return None
def _auth_context():
ctx = getattr(g, "device_auth", None)
if ctx is None:
log("server", f"device auth context missing for {request.path}")
return ctx
@blueprint.route("/api/agent/heartbeat", methods=["POST"]) @blueprint.route("/api/agent/heartbeat", methods=["POST"])
@require_device_auth(auth_manager) @require_device_auth(auth_manager)
def heartbeat(): def heartbeat():
ctx = getattr(g, "device_auth") ctx = _auth_context()
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
payload = request.get_json(force=True, silent=True) or {} payload = request.get_json(force=True, silent=True) or {}
now_ts = int(time.time()) now_ts = int(time.time())
@@ -90,7 +98,9 @@ def register(
@blueprint.route("/api/agent/script/request", methods=["POST"]) @blueprint.route("/api/agent/script/request", methods=["POST"])
@require_device_auth(auth_manager) @require_device_auth(auth_manager)
def script_request(): def script_request():
ctx = getattr(g, "device_auth") ctx = _auth_context()
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
if ctx.status != "active": if ctx.status != "active":
return jsonify( return jsonify(
{ {