Serialize agent enrollment across processes

This commit is contained in:
2025-10-18 01:38:06 -06:00
parent 07a9cfeb65
commit df16b22a5e

View File

@@ -21,7 +21,9 @@ import shutil
import string import string
import ssl import ssl
import threading import threading
from typing import Any, Dict, Optional, List import contextlib
import errno
from typing import Any, Dict, Optional, List, Callable
import requests import requests
try: try:
@@ -133,6 +135,133 @@ def _settings_dir():
return os.path.abspath(os.path.join(os.path.dirname(__file__), 'Settings')) return os.path.abspath(os.path.join(os.path.dirname(__file__), 'Settings'))
class _CrossProcessFileLock:
def __init__(self, path: str) -> None:
self.path = path
self._handle = None
def acquire(
self,
*,
timeout: float = 120.0,
poll_interval: float = 0.5,
on_wait: Optional[Callable[[], None]] = None,
) -> None:
directory = os.path.dirname(self.path)
if directory:
os.makedirs(directory, exist_ok=True)
deadline = time.time() + timeout if timeout else None
last_wait_log = 0.0
while True:
handle = open(self.path, 'a+b')
try:
self._try_lock(handle)
self._handle = handle
try:
handle.seek(0)
handle.truncate(0)
handle.write(f"pid={os.getpid()} ts={int(time.time())}\n".encode('utf-8'))
handle.flush()
except Exception:
pass
return
except OSError as exc:
handle.close()
if not self._is_lock_unavailable(exc):
raise
now = time.time()
if on_wait and (now - last_wait_log) >= 2.0:
try:
on_wait()
except Exception:
pass
last_wait_log = now
if deadline and now >= deadline:
raise TimeoutError('Timed out waiting for enrollment lock')
time.sleep(poll_interval)
except Exception:
handle.close()
raise
def release(self) -> None:
handle = self._handle
if not handle:
return
try:
self._unlock(handle)
finally:
try:
handle.close()
except Exception:
pass
self._handle = None
@staticmethod
def _is_lock_unavailable(exc: OSError) -> bool:
err = exc.errno
winerr = getattr(exc, 'winerror', None)
unavailable = {errno.EACCES, errno.EAGAIN, getattr(errno, 'EWOULDBLOCK', errno.EAGAIN)}
if err in unavailable:
return True
if winerr in (32, 33):
return True
return False
@staticmethod
def _try_lock(handle) -> None:
handle.seek(0, os.SEEK_END)
if handle.tell() == 0:
try:
handle.write(b'0')
handle.flush()
except Exception:
pass
handle.seek(0)
if os.name == 'nt':
import msvcrt # type: ignore
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1)
except OSError:
raise
else:
import fcntl # type: ignore
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
@staticmethod
def _unlock(handle) -> None:
if os.name == 'nt':
import msvcrt # type: ignore
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1)
except OSError:
pass
else:
import fcntl # type: ignore
try:
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
except OSError:
pass
_ENROLLMENT_FILE_LOCK: Optional[_CrossProcessFileLock] = None
@contextlib.contextmanager
def _acquire_enrollment_lock(*, timeout: float = 180.0, on_wait: Optional[Callable[[], None]] = None):
global _ENROLLMENT_FILE_LOCK
if _ENROLLMENT_FILE_LOCK is None:
_ENROLLMENT_FILE_LOCK = _CrossProcessFileLock(os.path.join(_settings_dir(), 'enrollment.lock'))
_ENROLLMENT_FILE_LOCK.acquire(timeout=timeout, on_wait=on_wait)
try:
yield
finally:
_ENROLLMENT_FILE_LOCK.release()
_KEY_STORE_INSTANCE = None _KEY_STORE_INSTANCE = None
@@ -503,15 +632,14 @@ class AgentHttpClient:
self.identity = IDENTITY self.identity = IDENTITY
self.session = requests.Session() self.session = requests.Session()
self.base_url: Optional[str] = None self.base_url: Optional[str] = None
self.guid: Optional[str] = self.key_store.load_guid() self.guid: Optional[str] = None
self.access_token: Optional[str] = self.key_store.load_access_token() self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = self.key_store.load_refresh_token() self.refresh_token: Optional[str] = None
self.access_expires_at: Optional[int] = self.key_store.get_access_expiry() self.access_expires_at: Optional[int] = None
self._auth_lock = threading.RLock() self._auth_lock = threading.RLock()
self.refresh_base_url() self.refresh_base_url()
self._configure_verify() self._configure_verify()
if self.access_token: self._reload_tokens_from_disk()
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
self.session.headers.setdefault("User-Agent", "Borealis-Agent/secure") self.session.headers.setdefault("User-Agent", "Borealis-Agent/secure")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -542,6 +670,20 @@ class AgentHttpClient:
except Exception: except Exception:
pass pass
def _reload_tokens_from_disk(self) -> None:
guid = self.key_store.load_guid()
access_token = self.key_store.load_access_token()
refresh_token = self.key_store.load_refresh_token()
access_expiry = self.key_store.get_access_expiry()
self.guid = guid if guid else None
self.access_token = access_token if access_token else None
self.refresh_token = refresh_token if refresh_token else None
self.access_expires_at = access_expiry if access_expiry else None
if self.access_token:
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
else:
self.session.headers.pop("Authorization", None)
def auth_headers(self) -> Dict[str, str]: def auth_headers(self) -> Dict[str, str]:
if self.access_token: if self.access_token:
return {"Authorization": f"Bearer {self.access_token}"} return {"Authorization": f"Bearer {self.access_token}"}
@@ -606,6 +748,7 @@ class AgentHttpClient:
self._perform_enrollment_locked() self._perform_enrollment_locked()
def _perform_enrollment_locked(self) -> None: def _perform_enrollment_locked(self) -> None:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token: if self.guid and self.refresh_token:
return return
code = self._resolve_installer_code() code = self._resolve_installer_code()
@@ -614,6 +757,34 @@ class AgentHttpClient:
"Installer code is required for enrollment. " "Installer code is required for enrollment. "
"Set BOREALIS_INSTALLER_CODE, pass --installer-code, or update agent_settings.json." "Set BOREALIS_INSTALLER_CODE, pass --installer-code, or update agent_settings.json."
) )
wait_state = {"count": 0, "tokens_seen": False}
def _on_lock_wait() -> None:
wait_state["count"] += 1
_log_agent(
f"Enrollment waiting for shared lock scope={SERVICE_MODE} attempt={wait_state['count']}",
fname="agent.log",
)
if not wait_state["tokens_seen"]:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
wait_state["tokens_seen"] = True
_log_agent(
"Enrollment credentials detected while waiting for lock; will reuse when available",
fname="agent.log",
)
try:
with _acquire_enrollment_lock(timeout=180.0, on_wait=_on_lock_wait):
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
_log_agent(
"Enrollment skipped after acquiring lock; credentials already present",
fname="agent.log",
)
return
self.refresh_base_url() self.refresh_base_url()
base_url = self.base_url or "https://localhost:5000" base_url = self.base_url or "https://localhost:5000"
code_masked = _mask_sensitive(code) code_masked = _mask_sensitive(code)
@@ -743,6 +914,15 @@ class AgentHttpClient:
self._finalize_enrollment(poll_data) self._finalize_enrollment(poll_data)
break break
raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}") raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}")
except TimeoutError:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
_log_agent(
"Enrollment lock wait timed out but credentials materialized; reusing existing tokens",
fname="agent.log",
)
return
raise
def _finalize_enrollment(self, payload: Dict[str, Any]) -> None: def _finalize_enrollment(self, payload: Dict[str, Any]) -> None:
server_cert = payload.get("server_certificate") server_cert = payload.get("server_certificate")