mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 19:21:58 -06:00
Serialise agent enrollment and guard heartbeat auth context
This commit is contained in:
@@ -20,6 +20,7 @@ import datetime
|
|||||||
import shutil
|
import shutil
|
||||||
import string
|
import string
|
||||||
import ssl
|
import ssl
|
||||||
|
import threading
|
||||||
from typing import Any, Dict, Optional, List
|
from typing import Any, Dict, Optional, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -494,6 +495,7 @@ class AgentHttpClient:
|
|||||||
self.access_token: Optional[str] = self.key_store.load_access_token()
|
self.access_token: Optional[str] = self.key_store.load_access_token()
|
||||||
self.refresh_token: Optional[str] = self.key_store.load_refresh_token()
|
self.refresh_token: Optional[str] = self.key_store.load_refresh_token()
|
||||||
self.access_expires_at: Optional[int] = self.key_store.get_access_expiry()
|
self.access_expires_at: Optional[int] = self.key_store.get_access_expiry()
|
||||||
|
self._auth_lock = threading.RLock()
|
||||||
self.refresh_base_url()
|
self.refresh_base_url()
|
||||||
self._configure_verify()
|
self._configure_verify()
|
||||||
if self.access_token:
|
if self.access_token:
|
||||||
@@ -570,11 +572,15 @@ class AgentHttpClient:
|
|||||||
# Enrollment & token management
|
# Enrollment & token management
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
def ensure_authenticated(self) -> None:
|
def ensure_authenticated(self) -> None:
|
||||||
|
with self._auth_lock:
|
||||||
|
self._ensure_authenticated_locked()
|
||||||
|
|
||||||
|
def _ensure_authenticated_locked(self) -> None:
|
||||||
self.refresh_base_url()
|
self.refresh_base_url()
|
||||||
if not self.guid or not self.refresh_token:
|
if not self.guid or not self.refresh_token:
|
||||||
self.perform_enrollment()
|
self._perform_enrollment_locked()
|
||||||
if not self.access_token or self._token_expiring_soon():
|
if not self.access_token or self._token_expiring_soon():
|
||||||
self.refresh_access_token()
|
self._refresh_access_token_locked()
|
||||||
|
|
||||||
def _token_expiring_soon(self) -> bool:
|
def _token_expiring_soon(self) -> bool:
|
||||||
if not self.access_token:
|
if not self.access_token:
|
||||||
@@ -584,6 +590,12 @@ class AgentHttpClient:
|
|||||||
return (self.access_expires_at - time.time()) < 60
|
return (self.access_expires_at - time.time()) < 60
|
||||||
|
|
||||||
def perform_enrollment(self) -> None:
|
def perform_enrollment(self) -> None:
|
||||||
|
with self._auth_lock:
|
||||||
|
self._perform_enrollment_locked()
|
||||||
|
|
||||||
|
def _perform_enrollment_locked(self) -> None:
|
||||||
|
if self.guid and self.refresh_token:
|
||||||
|
return
|
||||||
code = self._resolve_installer_code()
|
code = self._resolve_installer_code()
|
||||||
if not code:
|
if not code:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -682,9 +694,13 @@ class AgentHttpClient:
|
|||||||
_log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log")
|
_log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log")
|
||||||
|
|
||||||
def refresh_access_token(self) -> None:
|
def refresh_access_token(self) -> None:
|
||||||
|
with self._auth_lock:
|
||||||
|
self._refresh_access_token_locked()
|
||||||
|
|
||||||
|
def _refresh_access_token_locked(self) -> None:
|
||||||
if not self.refresh_token or not self.guid:
|
if not self.refresh_token or not self.guid:
|
||||||
self.clear_tokens()
|
self._clear_tokens_locked()
|
||||||
self.perform_enrollment()
|
self._perform_enrollment_locked()
|
||||||
return
|
return
|
||||||
payload = {"guid": self.guid, "refresh_token": self.refresh_token}
|
payload = {"guid": self.guid, "refresh_token": self.refresh_token}
|
||||||
resp = self.session.post(
|
resp = self.session.post(
|
||||||
@@ -695,8 +711,8 @@ class AgentHttpClient:
|
|||||||
)
|
)
|
||||||
if resp.status_code in (401, 403):
|
if resp.status_code in (401, 403):
|
||||||
_log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log")
|
_log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log")
|
||||||
self.clear_tokens()
|
self._clear_tokens_locked()
|
||||||
self.perform_enrollment()
|
self._perform_enrollment_locked()
|
||||||
return
|
return
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
@@ -712,6 +728,10 @@ class AgentHttpClient:
|
|||||||
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
|
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
|
||||||
|
|
||||||
def clear_tokens(self) -> None:
|
def clear_tokens(self) -> None:
|
||||||
|
with self._auth_lock:
|
||||||
|
self._clear_tokens_locked()
|
||||||
|
|
||||||
|
def _clear_tokens_locked(self) -> None:
|
||||||
self.key_store.clear_tokens()
|
self.key_store.clear_tokens()
|
||||||
self.access_token = None
|
self.access_token = None
|
||||||
self.refresh_token = None
|
self.refresh_token = None
|
||||||
|
|||||||
@@ -28,10 +28,18 @@ def register(
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _auth_context():
|
||||||
|
ctx = getattr(g, "device_auth", None)
|
||||||
|
if ctx is None:
|
||||||
|
log("server", f"device auth context missing for {request.path}")
|
||||||
|
return ctx
|
||||||
|
|
||||||
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
|
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
|
||||||
@require_device_auth(auth_manager)
|
@require_device_auth(auth_manager)
|
||||||
def heartbeat():
|
def heartbeat():
|
||||||
ctx = getattr(g, "device_auth")
|
ctx = _auth_context()
|
||||||
|
if ctx is None:
|
||||||
|
return jsonify({"error": "auth_context_missing"}), 500
|
||||||
payload = request.get_json(force=True, silent=True) or {}
|
payload = request.get_json(force=True, silent=True) or {}
|
||||||
|
|
||||||
now_ts = int(time.time())
|
now_ts = int(time.time())
|
||||||
@@ -90,7 +98,9 @@ def register(
|
|||||||
@blueprint.route("/api/agent/script/request", methods=["POST"])
|
@blueprint.route("/api/agent/script/request", methods=["POST"])
|
||||||
@require_device_auth(auth_manager)
|
@require_device_auth(auth_manager)
|
||||||
def script_request():
|
def script_request():
|
||||||
ctx = getattr(g, "device_auth")
|
ctx = _auth_context()
|
||||||
|
if ctx is None:
|
||||||
|
return jsonify({"error": "auth_context_missing"}), 500
|
||||||
if ctx.status != "active":
|
if ctx.status != "active":
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user