Guard identity creation with cross-process lock

This commit is contained in:
2025-10-18 02:24:19 -06:00
parent cf82474e07
commit f4902cf5b8
2 changed files with 129 additions and 6 deletions

View File

@@ -3,6 +3,8 @@
from __future__ import annotations
import base64
import contextlib
import errno
import hashlib
import json
import os
@@ -23,6 +25,16 @@ try:
except Exception: # pragma: no cover - win32crypt missing
win32crypt = None # type: ignore
try: # pragma: no cover - Windows only
import msvcrt # type: ignore
except Exception: # pragma: no cover - non-Windows
msvcrt = None # type: ignore
try: # pragma: no cover - POSIX only
import fcntl # type: ignore
except Exception: # pragma: no cover - Windows
fcntl = None # type: ignore
def _ensure_dir(path: str) -> None:
os.makedirs(path, exist_ok=True)
@@ -36,6 +48,103 @@ def _restrict_permissions(path: str) -> None:
pass
class _FileLock:
def __init__(self, path: str) -> None:
self.path = path
self._handle = None
def acquire(self, *, timeout: float = 60.0, poll_interval: float = 0.2) -> None:
directory = os.path.dirname(self.path)
if directory:
os.makedirs(directory, exist_ok=True)
deadline = time.time() + timeout if timeout else None
while True:
handle = open(self.path, "a+b")
try:
self._try_lock(handle)
except OSError as exc:
handle.close()
if not self._is_lock_unavailable(exc):
raise
if deadline and time.time() >= deadline:
raise TimeoutError("Timed out waiting for file lock")
time.sleep(poll_interval)
continue
except Exception:
handle.close()
raise
self._handle = handle
try:
handle.seek(0)
handle.truncate(0)
payload = f"pid={os.getpid()} ts={int(time.time())}\n".encode("utf-8")
handle.write(payload)
handle.flush()
except Exception:
pass
return
def release(self) -> None:
handle = self._handle
if not handle:
return
try:
self._unlock(handle)
finally:
try:
handle.close()
except Exception:
pass
self._handle = None
def _try_lock(self, handle):
if IS_WINDOWS:
if msvcrt is None:
raise OSError(errno.EINVAL, "msvcrt unavailable for locking")
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1) # type: ignore[attr-defined]
except OSError as exc: # pragma: no cover - platform specific
raise exc
else:
if fcntl is None:
raise OSError(errno.EINVAL, "fcntl unavailable for locking")
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) # type: ignore[attr-defined]
def _unlock(self, handle):
if IS_WINDOWS:
if msvcrt is None:
return
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1) # type: ignore[attr-defined]
except OSError: # pragma: no cover - platform specific
pass
else:
if fcntl is None:
return
try:
fcntl.flock(handle.fileno(), fcntl.LOCK_UN) # type: ignore[attr-defined]
except OSError:
pass
@staticmethod
def _is_lock_unavailable(exc: OSError) -> bool:
if hasattr(exc, "winerror"):
return exc.winerror in (32, 33) # type: ignore[attr-defined]
err = exc.errno if hasattr(exc, "errno") else None
return err in (errno.EACCES, errno.EAGAIN, errno.EWOULDBLOCK)
@contextlib.contextmanager
def _locked_file(path: str, *, timeout: float = 60.0):
lock = _FileLock(path)
lock.acquire(timeout=timeout)
try:
yield
finally:
lock.release()
def _protect(data: bytes, *, scope_system: bool) -> bytes:
if not IS_WINDOWS or not win32crypt:
return data
@@ -120,17 +229,21 @@ class AgentKeyStore:
self._token_meta_path = os.path.join(self.settings_dir, "access.meta.json")
self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem")
self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub")
self._identity_lock_path = os.path.join(self.settings_dir, "identity.lock")
# ------------------------------------------------------------------
# Identity management
# ------------------------------------------------------------------
def load_or_create_identity(self) -> AgentIdentity:
if os.path.isfile(self._private_path) and os.path.isfile(self._public_path):
try:
return self._load_identity()
except Exception:
pass
return self._create_identity()
with _locked_file(self._identity_lock_path, timeout=120.0):
if os.path.isfile(self._private_path) and os.path.isfile(self._public_path):
try:
return self._load_identity()
except Exception:
# If loading fails, fall back to regenerating the identity while
# holding the lock so concurrent agents do not thrash the key files.
pass
return self._create_identity()
def _load_identity(self) -> AgentIdentity:
with open(self._private_path, "rb") as fh: