Merge pull request #126 from bunny-lab-io:codex/review-and-resolve-enrollment-implementation-issues

Fix agent keystore bootstrap ordering and DPAPI fallback
This commit is contained in:
2025-10-17 22:01:32 -06:00
committed by GitHub
3 changed files with 136 additions and 41 deletions

View File

@@ -125,6 +125,24 @@ def _agent_guid_path() -> str:
return os.path.abspath(os.path.join(os.path.dirname(__file__), 'agent_GUID'))
def _settings_dir():
try:
return os.path.join(_find_project_root(), 'Agent', 'Borealis', 'Settings')
except Exception:
return os.path.abspath(os.path.join(os.path.dirname(__file__), 'Settings'))
_KEY_STORE_INSTANCE = None
def _key_store() -> AgentKeyStore:
global _KEY_STORE_INSTANCE
if _KEY_STORE_INSTANCE is None:
scope = 'SYSTEM' if SYSTEM_SERVICE_MODE else 'CURRENTUSER'
_KEY_STORE_INSTANCE = AgentKeyStore(_settings_dir(), scope=scope)
return _KEY_STORE_INSTANCE
def _persist_agent_guid_local(guid: str):
guid = _normalize_agent_guid(guid)
if not guid:
@@ -515,21 +533,22 @@ class AgentHttpClient:
return {"Authorization": f"Bearer {self.access_token}"}
return {}
def websocket_kwargs(self) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {}
verify = getattr(self.session, "verify", True)
if isinstance(verify, str) and os.path.isfile(verify):
try:
ctx = ssl.create_default_context(cafile=verify)
kwargs["ssl"] = ctx
except Exception:
pass
elif verify is False:
try:
kwargs["ssl"] = ssl._create_unverified_context()
except Exception:
pass
return kwargs
def configure_socketio(self, client: "socketio.AsyncClient") -> None:
"""Align the Socket.IO engine's TLS verification with the REST client."""
try:
verify = getattr(self.session, "verify", True)
engine = getattr(client, "eio", None)
if engine is None:
return
# python-engineio accepts bool, path, or ssl.SSLContext for ssl_verify
if isinstance(verify, str) and os.path.isfile(verify):
engine.ssl_verify = verify
elif verify is False:
engine.ssl_verify = False
else:
engine.ssl_verify = True
except Exception:
pass
# ------------------------------------------------------------------
# Enrollment & token management
@@ -1028,24 +1047,6 @@ def _collect_heartbeat_metrics() -> Dict[str, Any]:
return metrics
def _settings_dir():
try:
return os.path.join(_find_project_root(), 'Agent', 'Borealis', 'Settings')
except Exception:
return os.path.abspath(os.path.join(os.path.dirname(__file__), 'Settings'))
_KEY_STORE_INSTANCE = None
def _key_store() -> AgentKeyStore:
global _KEY_STORE_INSTANCE
if _KEY_STORE_INSTANCE is None:
scope = 'SYSTEM' if SYSTEM_SERVICE_MODE else 'CURRENTUSER'
_KEY_STORE_INSTANCE = AgentKeyStore(_settings_dir(), scope=scope)
return _KEY_STORE_INSTANCE
SERVER_CERT_PATH = _key_store().server_certificate_path()
@@ -2036,6 +2037,7 @@ async def connect_loop():
while True:
try:
client.ensure_authenticated()
client.configure_socketio(sio)
url = client.websocket_base_url()
print(f"[INFO] Connecting Agent to {url}...")
_log_agent(f'Connecting to {url}...')
@@ -2043,7 +2045,6 @@ async def connect_loop():
url,
transports=['websocket'],
headers=client.auth_headers(),
ssl_verify=client.session.verify,
)
break
except Exception as e:

View File

@@ -39,17 +39,41 @@ def _restrict_permissions(path: str) -> None:
def _protect(data: bytes, *, scope_system: bool) -> bytes:
if not IS_WINDOWS or not win32crypt:
return data
flags = win32crypt.CRYPTPROTECT_LOCAL_MACHINE if scope_system else 0
protected = win32crypt.CryptProtectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
return protected[1]
flags = 0
if scope_system:
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
try:
protected = win32crypt.CryptProtectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
except Exception:
return data
blob = protected[1]
if isinstance(blob, memoryview):
return blob.tobytes()
if isinstance(blob, bytearray):
return bytes(blob)
if isinstance(blob, bytes):
return blob
return data
def _unprotect(data: bytes, *, scope_system: bool) -> bytes:
if not IS_WINDOWS or not win32crypt:
return data
flags = win32crypt.CRYPTPROTECT_LOCAL_MACHINE if scope_system else 0
unwrapped = win32crypt.CryptUnprotectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
return unwrapped[1]
flags = 0
if scope_system:
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
try:
unwrapped = win32crypt.CryptUnprotectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
except Exception:
return data
blob = unwrapped[1]
if isinstance(blob, memoryview):
return blob.tobytes()
if isinstance(blob, bytearray):
return bytes(blob)
if isinstance(blob, bytes):
return blob
return data
def _fingerprint_der(public_der: bytes) -> str: