Serialize agent enrollment across processes

This commit is contained in:
2025-10-18 01:38:06 -06:00
parent 07a9cfeb65
commit df16b22a5e

View File

@@ -21,7 +21,9 @@ import shutil
import string
import ssl
import threading
from typing import Any, Dict, Optional, List
import contextlib
import errno
from typing import Any, Dict, Optional, List, Callable
import requests
try:
@@ -133,6 +135,133 @@ def _settings_dir():
return os.path.abspath(os.path.join(os.path.dirname(__file__), 'Settings'))
class _CrossProcessFileLock:
def __init__(self, path: str) -> None:
self.path = path
self._handle = None
def acquire(
self,
*,
timeout: float = 120.0,
poll_interval: float = 0.5,
on_wait: Optional[Callable[[], None]] = None,
) -> None:
directory = os.path.dirname(self.path)
if directory:
os.makedirs(directory, exist_ok=True)
deadline = time.time() + timeout if timeout else None
last_wait_log = 0.0
while True:
handle = open(self.path, 'a+b')
try:
self._try_lock(handle)
self._handle = handle
try:
handle.seek(0)
handle.truncate(0)
handle.write(f"pid={os.getpid()} ts={int(time.time())}\n".encode('utf-8'))
handle.flush()
except Exception:
pass
return
except OSError as exc:
handle.close()
if not self._is_lock_unavailable(exc):
raise
now = time.time()
if on_wait and (now - last_wait_log) >= 2.0:
try:
on_wait()
except Exception:
pass
last_wait_log = now
if deadline and now >= deadline:
raise TimeoutError('Timed out waiting for enrollment lock')
time.sleep(poll_interval)
except Exception:
handle.close()
raise
def release(self) -> None:
handle = self._handle
if not handle:
return
try:
self._unlock(handle)
finally:
try:
handle.close()
except Exception:
pass
self._handle = None
@staticmethod
def _is_lock_unavailable(exc: OSError) -> bool:
err = exc.errno
winerr = getattr(exc, 'winerror', None)
unavailable = {errno.EACCES, errno.EAGAIN, getattr(errno, 'EWOULDBLOCK', errno.EAGAIN)}
if err in unavailable:
return True
if winerr in (32, 33):
return True
return False
@staticmethod
def _try_lock(handle) -> None:
handle.seek(0, os.SEEK_END)
if handle.tell() == 0:
try:
handle.write(b'0')
handle.flush()
except Exception:
pass
handle.seek(0)
if os.name == 'nt':
import msvcrt # type: ignore
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1)
except OSError:
raise
else:
import fcntl # type: ignore
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
@staticmethod
def _unlock(handle) -> None:
if os.name == 'nt':
import msvcrt # type: ignore
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1)
except OSError:
pass
else:
import fcntl # type: ignore
try:
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
except OSError:
pass
_ENROLLMENT_FILE_LOCK: Optional[_CrossProcessFileLock] = None
@contextlib.contextmanager
def _acquire_enrollment_lock(*, timeout: float = 180.0, on_wait: Optional[Callable[[], None]] = None):
global _ENROLLMENT_FILE_LOCK
if _ENROLLMENT_FILE_LOCK is None:
_ENROLLMENT_FILE_LOCK = _CrossProcessFileLock(os.path.join(_settings_dir(), 'enrollment.lock'))
_ENROLLMENT_FILE_LOCK.acquire(timeout=timeout, on_wait=on_wait)
try:
yield
finally:
_ENROLLMENT_FILE_LOCK.release()
_KEY_STORE_INSTANCE = None
@@ -503,15 +632,14 @@ class AgentHttpClient:
self.identity = IDENTITY
self.session = requests.Session()
self.base_url: Optional[str] = None
self.guid: Optional[str] = self.key_store.load_guid()
self.access_token: Optional[str] = self.key_store.load_access_token()
self.refresh_token: Optional[str] = self.key_store.load_refresh_token()
self.access_expires_at: Optional[int] = self.key_store.get_access_expiry()
self.guid: Optional[str] = None
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self.access_expires_at: Optional[int] = None
self._auth_lock = threading.RLock()
self.refresh_base_url()
self._configure_verify()
if self.access_token:
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
self._reload_tokens_from_disk()
self.session.headers.setdefault("User-Agent", "Borealis-Agent/secure")
# ------------------------------------------------------------------
@@ -542,6 +670,20 @@ class AgentHttpClient:
except Exception:
pass
def _reload_tokens_from_disk(self) -> None:
guid = self.key_store.load_guid()
access_token = self.key_store.load_access_token()
refresh_token = self.key_store.load_refresh_token()
access_expiry = self.key_store.get_access_expiry()
self.guid = guid if guid else None
self.access_token = access_token if access_token else None
self.refresh_token = refresh_token if refresh_token else None
self.access_expires_at = access_expiry if access_expiry else None
if self.access_token:
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
else:
self.session.headers.pop("Authorization", None)
def auth_headers(self) -> Dict[str, str]:
if self.access_token:
return {"Authorization": f"Bearer {self.access_token}"}
@@ -606,6 +748,7 @@ class AgentHttpClient:
self._perform_enrollment_locked()
def _perform_enrollment_locked(self) -> None:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
return
code = self._resolve_installer_code()
@@ -614,6 +757,34 @@ class AgentHttpClient:
"Installer code is required for enrollment. "
"Set BOREALIS_INSTALLER_CODE, pass --installer-code, or update agent_settings.json."
)
wait_state = {"count": 0, "tokens_seen": False}
def _on_lock_wait() -> None:
wait_state["count"] += 1
_log_agent(
f"Enrollment waiting for shared lock scope={SERVICE_MODE} attempt={wait_state['count']}",
fname="agent.log",
)
if not wait_state["tokens_seen"]:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
wait_state["tokens_seen"] = True
_log_agent(
"Enrollment credentials detected while waiting for lock; will reuse when available",
fname="agent.log",
)
try:
with _acquire_enrollment_lock(timeout=180.0, on_wait=_on_lock_wait):
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
_log_agent(
"Enrollment skipped after acquiring lock; credentials already present",
fname="agent.log",
)
return
self.refresh_base_url()
base_url = self.base_url or "https://localhost:5000"
code_masked = _mask_sensitive(code)
@@ -743,6 +914,15 @@ class AgentHttpClient:
self._finalize_enrollment(poll_data)
break
raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}")
except TimeoutError:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
_log_agent(
"Enrollment lock wait timed out but credentials materialized; reusing existing tokens",
fname="agent.log",
)
return
raise
def _finalize_enrollment(self, payload: Dict[str, Any]) -> None:
server_cert = payload.get("server_certificate")