mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 17:21:58 -06:00
Guard identity creation with cross-process lock
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import errno
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
@@ -23,6 +25,16 @@ try:
|
||||
except Exception: # pragma: no cover - win32crypt missing
|
||||
win32crypt = None # type: ignore
|
||||
|
||||
try: # pragma: no cover - Windows only
|
||||
import msvcrt # type: ignore
|
||||
except Exception: # pragma: no cover - non-Windows
|
||||
msvcrt = None # type: ignore
|
||||
|
||||
try: # pragma: no cover - POSIX only
|
||||
import fcntl # type: ignore
|
||||
except Exception: # pragma: no cover - Windows
|
||||
fcntl = None # type: ignore
|
||||
|
||||
|
||||
def _ensure_dir(path: str) -> None:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
@@ -36,6 +48,103 @@ def _restrict_permissions(path: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _FileLock:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self._handle = None
|
||||
|
||||
def acquire(self, *, timeout: float = 60.0, poll_interval: float = 0.2) -> None:
|
||||
directory = os.path.dirname(self.path)
|
||||
if directory:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
deadline = time.time() + timeout if timeout else None
|
||||
while True:
|
||||
handle = open(self.path, "a+b")
|
||||
try:
|
||||
self._try_lock(handle)
|
||||
except OSError as exc:
|
||||
handle.close()
|
||||
if not self._is_lock_unavailable(exc):
|
||||
raise
|
||||
if deadline and time.time() >= deadline:
|
||||
raise TimeoutError("Timed out waiting for file lock")
|
||||
time.sleep(poll_interval)
|
||||
continue
|
||||
except Exception:
|
||||
handle.close()
|
||||
raise
|
||||
|
||||
self._handle = handle
|
||||
try:
|
||||
handle.seek(0)
|
||||
handle.truncate(0)
|
||||
payload = f"pid={os.getpid()} ts={int(time.time())}\n".encode("utf-8")
|
||||
handle.write(payload)
|
||||
handle.flush()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
def release(self) -> None:
|
||||
handle = self._handle
|
||||
if not handle:
|
||||
return
|
||||
try:
|
||||
self._unlock(handle)
|
||||
finally:
|
||||
try:
|
||||
handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._handle = None
|
||||
|
||||
def _try_lock(self, handle):
|
||||
if IS_WINDOWS:
|
||||
if msvcrt is None:
|
||||
raise OSError(errno.EINVAL, "msvcrt unavailable for locking")
|
||||
try:
|
||||
msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1) # type: ignore[attr-defined]
|
||||
except OSError as exc: # pragma: no cover - platform specific
|
||||
raise exc
|
||||
else:
|
||||
if fcntl is None:
|
||||
raise OSError(errno.EINVAL, "fcntl unavailable for locking")
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) # type: ignore[attr-defined]
|
||||
|
||||
def _unlock(self, handle):
|
||||
if IS_WINDOWS:
|
||||
if msvcrt is None:
|
||||
return
|
||||
try:
|
||||
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1) # type: ignore[attr-defined]
|
||||
except OSError: # pragma: no cover - platform specific
|
||||
pass
|
||||
else:
|
||||
if fcntl is None:
|
||||
return
|
||||
try:
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_UN) # type: ignore[attr-defined]
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_lock_unavailable(exc: OSError) -> bool:
|
||||
if hasattr(exc, "winerror"):
|
||||
return exc.winerror in (32, 33) # type: ignore[attr-defined]
|
||||
err = exc.errno if hasattr(exc, "errno") else None
|
||||
return err in (errno.EACCES, errno.EAGAIN, errno.EWOULDBLOCK)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _locked_file(path: str, *, timeout: float = 60.0):
|
||||
lock = _FileLock(path)
|
||||
lock.acquire(timeout=timeout)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
def _protect(data: bytes, *, scope_system: bool) -> bytes:
|
||||
if not IS_WINDOWS or not win32crypt:
|
||||
return data
|
||||
@@ -120,17 +229,21 @@ class AgentKeyStore:
|
||||
self._token_meta_path = os.path.join(self.settings_dir, "access.meta.json")
|
||||
self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem")
|
||||
self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub")
|
||||
self._identity_lock_path = os.path.join(self.settings_dir, "identity.lock")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Identity management
|
||||
# ------------------------------------------------------------------
|
||||
def load_or_create_identity(self) -> AgentIdentity:
|
||||
if os.path.isfile(self._private_path) and os.path.isfile(self._public_path):
|
||||
try:
|
||||
return self._load_identity()
|
||||
except Exception:
|
||||
pass
|
||||
return self._create_identity()
|
||||
with _locked_file(self._identity_lock_path, timeout=120.0):
|
||||
if os.path.isfile(self._private_path) and os.path.isfile(self._public_path):
|
||||
try:
|
||||
return self._load_identity()
|
||||
except Exception:
|
||||
# If loading fails, fall back to regenerating the identity while
|
||||
# holding the lock so concurrent agents do not thrash the key files.
|
||||
pass
|
||||
return self._create_identity()
|
||||
|
||||
def _load_identity(self) -> AgentIdentity:
|
||||
with open(self._private_path, "rb") as fh:
|
||||
|
||||
Reference in New Issue
Block a user