diff --git a/Data/Agent/agent.py b/Data/Agent/agent.py index 3a99f7b..4bd933f 100644 --- a/Data/Agent/agent.py +++ b/Data/Agent/agent.py @@ -20,7 +20,10 @@ import datetime import shutil import string import ssl -from typing import Any, Dict, Optional, List +import threading +import contextlib +import errno +from typing import Any, Dict, Optional, List, Callable import requests try: @@ -132,6 +135,133 @@ def _settings_dir(): return os.path.abspath(os.path.join(os.path.dirname(__file__), 'Settings')) +class _CrossProcessFileLock: + def __init__(self, path: str) -> None: + self.path = path + self._handle = None + + def acquire( + self, + *, + timeout: float = 120.0, + poll_interval: float = 0.5, + on_wait: Optional[Callable[[], None]] = None, + ) -> None: + directory = os.path.dirname(self.path) + if directory: + os.makedirs(directory, exist_ok=True) + deadline = time.time() + timeout if timeout else None + last_wait_log = 0.0 + while True: + handle = open(self.path, 'a+b') + try: + self._try_lock(handle) + self._handle = handle + try: + handle.seek(0) + handle.truncate(0) + handle.write(f"pid={os.getpid()} ts={int(time.time())}\n".encode('utf-8')) + handle.flush() + except Exception: + pass + return + except OSError as exc: + handle.close() + if not self._is_lock_unavailable(exc): + raise + now = time.time() + if on_wait and (now - last_wait_log) >= 2.0: + try: + on_wait() + except Exception: + pass + last_wait_log = now + if deadline and now >= deadline: + raise TimeoutError('Timed out waiting for enrollment lock') + time.sleep(poll_interval) + except Exception: + handle.close() + raise + + def release(self) -> None: + handle = self._handle + if not handle: + return + try: + self._unlock(handle) + finally: + try: + handle.close() + except Exception: + pass + self._handle = None + + @staticmethod + def _is_lock_unavailable(exc: OSError) -> bool: + err = exc.errno + winerr = getattr(exc, 'winerror', None) + unavailable = {errno.EACCES, errno.EAGAIN, getattr(errno, 'EWOULDBLOCK', errno.EAGAIN)} + if err in unavailable: + return True + if winerr in (32, 33): + return True + return False + + @staticmethod + def _try_lock(handle) -> None: + handle.seek(0, os.SEEK_END) + if handle.tell() == 0: + try: + handle.write(b'0') + handle.flush() + except Exception: + pass + handle.seek(0) + if os.name == 'nt': + import msvcrt # type: ignore + + try: + msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1) + except OSError: + raise + else: + import fcntl # type: ignore + + fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + + @staticmethod + def _unlock(handle) -> None: + if os.name == 'nt': + import msvcrt # type: ignore + + try: + msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1) + except OSError: + pass + else: + import fcntl # type: ignore + + try: + fcntl.flock(handle.fileno(), fcntl.LOCK_UN) + except OSError: + pass + + +_ENROLLMENT_FILE_LOCK: Optional[_CrossProcessFileLock] = None + + +@contextlib.contextmanager +def _acquire_enrollment_lock(*, timeout: float = 180.0, on_wait: Optional[Callable[[], None]] = None): + global _ENROLLMENT_FILE_LOCK + if _ENROLLMENT_FILE_LOCK is None: + _ENROLLMENT_FILE_LOCK = _CrossProcessFileLock(os.path.join(_settings_dir(), 'enrollment.lock')) + _ENROLLMENT_FILE_LOCK.acquire(timeout=timeout, on_wait=on_wait) + try: + yield + finally: + _ENROLLMENT_FILE_LOCK.release() + + _KEY_STORE_INSTANCE = None @@ -242,6 +372,18 @@ def _log_agent(message: str, fname: str = 'agent.log'): pass +def _mask_sensitive(value: str, *, prefix: int = 4, suffix: int = 4) -> str: + try: + if not value: + return '' + trimmed = value.strip() + if len(trimmed) <= prefix + suffix: + return '*' * len(trimmed) + return f"{trimmed[:prefix]}***{trimmed[-suffix:]}" + except Exception: + return '***' + + def _decode_base64_text(value): if not isinstance(value, str): return None @@ -490,14 +632,15 @@ class AgentHttpClient: self.identity = IDENTITY self.session = requests.Session() self.base_url: Optional[str] = None - self.guid: Optional[str] = self.key_store.load_guid() - self.access_token: Optional[str] = self.key_store.load_access_token() - self.refresh_token: Optional[str] = self.key_store.load_refresh_token() - self.access_expires_at: Optional[int] = self.key_store.get_access_expiry() + self.guid: Optional[str] = None + self.access_token: Optional[str] = None + self.refresh_token: Optional[str] = None + self.access_expires_at: Optional[int] = None + self._auth_lock = threading.RLock() + self._active_installer_code: Optional[str] = None self.refresh_base_url() self._configure_verify() - if self.access_token: - self.session.headers.update({"Authorization": f"Bearer {self.access_token}"}) + self._reload_tokens_from_disk() self.session.headers.setdefault("User-Agent", "Borealis-Agent/secure") # ------------------------------------------------------------------ @@ -528,6 +671,31 @@ class AgentHttpClient: except Exception: pass + def _reload_tokens_from_disk(self) -> None: + guid = self.key_store.load_guid() + access_token = self.key_store.load_access_token() + refresh_token = self.key_store.load_refresh_token() + access_expiry = self.key_store.get_access_expiry() + self.guid = guid if guid else None + self.access_token = access_token if access_token else None + self.refresh_token = refresh_token if refresh_token else None + self.access_expires_at = access_expiry if access_expiry else None + if self.access_token: + self.session.headers.update({"Authorization": f"Bearer {self.access_token}"}) + else: + self.session.headers.pop("Authorization", None) + try: + _log_agent( + "Reloaded tokens from disk " + f"guid={'yes' if self.guid else 'no'} " + f"access={'yes' if self.access_token else 'no'} " + f"refresh={'yes' if self.refresh_token else 'no'} " + f"expiry={self.access_expires_at}", + fname="agent.log", + ) + except Exception: + pass + def auth_headers(self) -> Dict[str, str]: if self.access_token: return {"Authorization": f"Bearer {self.access_token}"} @@ -540,12 +708,28 @@ class AgentHttpClient: engine = getattr(client, "eio", None) if engine is None: return - # python-engineio accepts bool, path, or ssl.SSLContext for ssl_verify + + # python-engineio accepts either a boolean or an ``ssl.SSLContext`` + # for TLS verification. When we have a pinned certificate bundle + # on disk, prefer constructing a dedicated context that trusts that + # bundle so WebSocket connections succeed even with private CAs. if isinstance(verify, str) and os.path.isfile(verify): - engine.ssl_verify = verify + try: + context = ssl.create_default_context(cafile=verify) + context.check_hostname = False + except Exception: + context = None + if context is not None: + engine.ssl_context = context + engine.ssl_verify = True + else: + engine.ssl_context = None + engine.ssl_verify = verify elif verify is False: + engine.ssl_context = None engine.ssl_verify = False else: + engine.ssl_context = None engine.ssl_verify = True except Exception: pass @@ -554,11 +738,15 @@ class AgentHttpClient: # Enrollment & token management # ------------------------------------------------------------------ def ensure_authenticated(self) -> None: + with self._auth_lock: + self._ensure_authenticated_locked() + + def _ensure_authenticated_locked(self) -> None: self.refresh_base_url() if not self.guid or not self.refresh_token: - self.perform_enrollment() + self._perform_enrollment_locked() if not self.access_token or self._token_expiring_soon(): - self.refresh_access_token() + self._refresh_access_token_locked() def _token_expiring_soon(self) -> bool: if not self.access_token: @@ -568,69 +756,199 @@ class AgentHttpClient: return (self.access_expires_at - time.time()) < 60 def perform_enrollment(self) -> None: + with self._auth_lock: + self._perform_enrollment_locked() + + def _perform_enrollment_locked(self) -> None: + self._reload_tokens_from_disk() + if self.guid and self.refresh_token: + return code = self._resolve_installer_code() if not code: raise RuntimeError( "Installer code is required for enrollment. " "Set BOREALIS_INSTALLER_CODE, pass --installer-code, or update agent_settings.json." ) - self.refresh_base_url() - client_nonce = os.urandom(32) - payload = { - "hostname": socket.gethostname(), - "enrollment_code": code, - "agent_pubkey": PUBLIC_KEY_B64, - "client_nonce": base64.b64encode(client_nonce).decode("ascii"), - } - request_url = f"{self.base_url}/api/agent/enroll/request" - _log_agent("Starting enrollment request...", fname="agent.log") - resp = self.session.post(request_url, json=payload, timeout=30) - resp.raise_for_status() - data = resp.json() - if data.get("server_certificate"): - self.key_store.save_server_certificate(data["server_certificate"]) - self._configure_verify() - signing_key = data.get("signing_key") - if signing_key: - try: - self.store_server_signing_key(signing_key) - except Exception as exc: - _log_agent(f'Unable to persist signing key from enrollment handshake: {exc}', fname='agent.error.log') - if data.get("status") != "pending": - raise RuntimeError(f"Unexpected enrollment status: {data}") - approval_reference = data.get("approval_reference") - server_nonce_b64 = data.get("server_nonce") - if not approval_reference or not server_nonce_b64: - raise RuntimeError("Enrollment response missing approval_reference or server_nonce") - server_nonce = base64.b64decode(server_nonce_b64) - poll_delay = max(int(data.get("poll_after_ms", 3000)) / 1000, 1) - while True: - time.sleep(min(poll_delay, 15)) - signature = self.identity.sign(server_nonce + approval_reference.encode("utf-8") + client_nonce) - poll_payload = { - "approval_reference": approval_reference, - "client_nonce": base64.b64encode(client_nonce).decode("ascii"), - "proof_sig": base64.b64encode(signature).decode("ascii"), - } - poll_resp = self.session.post( - f"{self.base_url}/api/agent/enroll/poll", - json=poll_payload, - timeout=30, + self._active_installer_code = code + + wait_state = {"count": 0, "tokens_seen": False} + + def _on_lock_wait() -> None: + wait_state["count"] += 1 + _log_agent( + f"Enrollment waiting for shared lock scope={SERVICE_MODE} attempt={wait_state['count']}", + fname="agent.log", ) - poll_resp.raise_for_status() - poll_data = poll_resp.json() - status = poll_data.get("status") - if status == "pending": - poll_delay = max(int(poll_data.get("poll_after_ms", 5000)) / 1000, 1) - continue - if status == "denied": - raise RuntimeError("Enrollment denied by operator") - if status in ("expired", "unknown"): - raise RuntimeError(f"Enrollment failed with status={status}") - if status in ("approved", "completed"): - self._finalize_enrollment(poll_data) - break - raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}") + if not wait_state["tokens_seen"]: + self._reload_tokens_from_disk() + if self.guid and self.refresh_token: + wait_state["tokens_seen"] = True + _log_agent( + "Enrollment credentials detected while waiting for lock; will reuse when available", + fname="agent.log", + ) + + try: + with _acquire_enrollment_lock(timeout=180.0, on_wait=_on_lock_wait): + self._reload_tokens_from_disk() + if self.guid and self.refresh_token: + _log_agent( + "Enrollment skipped after acquiring lock; credentials already present", + fname="agent.log", + ) + return + + self.refresh_base_url() + base_url = self.base_url or "https://localhost:5000" + code_masked = _mask_sensitive(code) + _log_agent( + "Enrollment handshake starting " + f"base_url={base_url} scope={SERVICE_MODE} " + f"fingerprint={SSL_KEY_FINGERPRINT[:16]} installer_code={code_masked}", + fname="agent.log", + ) + client_nonce = os.urandom(32) + payload = { + "hostname": socket.gethostname(), + "enrollment_code": code, + "agent_pubkey": PUBLIC_KEY_B64, + "client_nonce": base64.b64encode(client_nonce).decode("ascii"), + } + request_url = f"{self.base_url}/api/agent/enroll/request" + _log_agent( + "Starting enrollment request... " + f"url={request_url} hostname={payload['hostname']} pubkey_prefix={PUBLIC_KEY_B64[:24]}", + fname="agent.log", + ) + resp = self.session.post(request_url, json=payload, timeout=30) + _log_agent( + f"Enrollment request HTTP status={resp.status_code} retry_after={resp.headers.get('Retry-After')}" + f" body_len={len(resp.content)}", + fname="agent.log", + ) + try: + resp.raise_for_status() + except requests.HTTPError: + snippet = resp.text[:512] if hasattr(resp, "text") else "" + _log_agent( + f"Enrollment request failed status={resp.status_code} body_snippet={snippet}", + fname="agent.error.log", + ) + if resp.status_code == 400: + try: + err_payload = resp.json() + except Exception: + err_payload = {} + if (err_payload or {}).get("error") in {"invalid_enrollment_code", "code_consumed"}: + self._reload_tokens_from_disk() + if self.guid and self.refresh_token: + _log_agent( + "Enrollment code rejected but existing credentials are present; skipping re-enrollment", + fname="agent.log", + ) + return + raise + data = resp.json() + _log_agent( + "Enrollment request accepted " + f"status={data.get('status')} approval_ref={data.get('approval_reference')} " + f"poll_after_ms={data.get('poll_after_ms')}" + f" server_cert={'yes' if data.get('server_certificate') else 'no'}", + fname="agent.log", + ) + if data.get("server_certificate"): + self.key_store.save_server_certificate(data["server_certificate"]) + self._configure_verify() + signing_key = data.get("signing_key") + if signing_key: + try: + self.store_server_signing_key(signing_key) + except Exception as exc: + _log_agent(f'Unable to persist signing key from enrollment handshake: {exc}', fname='agent.error.log') + if data.get("status") != "pending": + raise RuntimeError(f"Unexpected enrollment status: {data}") + approval_reference = data.get("approval_reference") + server_nonce_b64 = data.get("server_nonce") + if not approval_reference or not server_nonce_b64: + raise RuntimeError("Enrollment response missing approval_reference or server_nonce") + server_nonce = base64.b64decode(server_nonce_b64) + poll_delay = max(int(data.get("poll_after_ms", 3000)) / 1000, 1) + attempt = 1 + while True: + time.sleep(min(poll_delay, 15)) + signature = self.identity.sign(server_nonce + approval_reference.encode("utf-8") + client_nonce) + poll_payload = { + "approval_reference": approval_reference, + "client_nonce": base64.b64encode(client_nonce).decode("ascii"), + "proof_sig": base64.b64encode(signature).decode("ascii"), + } + _log_agent( + f"Enrollment poll attempt={attempt} ref={approval_reference} delay={poll_delay}s", + fname="agent.log", + ) + poll_resp = self.session.post( + f"{self.base_url}/api/agent/enroll/poll", + json=poll_payload, + timeout=30, + ) + _log_agent( + "Enrollment poll response " + f"status_code={poll_resp.status_code} retry_after={poll_resp.headers.get('Retry-After')}" + f" body_len={len(poll_resp.content)}", + fname="agent.log", + ) + try: + poll_resp.raise_for_status() + except requests.HTTPError: + snippet = poll_resp.text[:512] if hasattr(poll_resp, "text") else "" + _log_agent( + f"Enrollment poll failed attempt={attempt} status={poll_resp.status_code} " + f"body_snippet={snippet}", + fname="agent.error.log", + ) + raise + poll_data = poll_resp.json() + _log_agent( + f"Enrollment poll decoded attempt={attempt} status={poll_data.get('status')}" + f" next_delay={poll_data.get('poll_after_ms')}" + f" guid_hint={poll_data.get('guid')}", + fname="agent.log", + ) + status = poll_data.get("status") + if status == "pending": + poll_delay = max(int(poll_data.get("poll_after_ms", 5000)) / 1000, 1) + _log_agent( + f"Enrollment still pending attempt={attempt} new_delay={poll_delay}s", + fname="agent.log", + ) + attempt += 1 + continue + if status == "denied": + _log_agent("Enrollment denied by operator", fname="agent.error.log") + raise RuntimeError("Enrollment denied by operator") + if status in ("expired", "unknown"): + _log_agent( + f"Enrollment failed status={status} attempt={attempt}", + fname="agent.error.log", + ) + raise RuntimeError(f"Enrollment failed with status={status}") + if status in ("approved", "completed"): + _log_agent( + f"Enrollment approved attempt={attempt} ref={approval_reference}", + fname="agent.log", + ) + self._finalize_enrollment(poll_data) + break + raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}") + except TimeoutError: + self._reload_tokens_from_disk() + if self.guid and self.refresh_token: + _log_agent( + "Enrollment lock wait timed out but credentials materialized; reusing existing tokens", + fname="agent.log", + ) + return + raise def _finalize_enrollment(self, payload: Dict[str, Any]) -> None: server_cert = payload.get("server_certificate") @@ -649,6 +967,12 @@ class AgentHttpClient: expires_in = int(payload.get("expires_in") or 900) if not (guid and access_token and refresh_token): raise RuntimeError("Enrollment approval response missing tokens or guid") + _log_agent( + "Enrollment approval payload received " + f"guid={guid} access_token_len={len(access_token)} refresh_token_len={len(refresh_token)} " + f"expires_in={expires_in}", + fname="agent.log", + ) self.guid = str(guid).strip() self.access_token = access_token.strip() self.refresh_token = refresh_token.strip() @@ -663,12 +987,17 @@ class AgentHttpClient: _update_agent_id_for_guid(self.guid) except Exception as exc: _log_agent(f"Failed to update agent id after enrollment: {exc}", fname="agent.error.log") + self._consume_installer_code() _log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log") def refresh_access_token(self) -> None: + with self._auth_lock: + self._refresh_access_token_locked() + + def _refresh_access_token_locked(self) -> None: if not self.refresh_token or not self.guid: - self.clear_tokens() - self.perform_enrollment() + self._clear_tokens_locked() + self._perform_enrollment_locked() return payload = {"guid": self.guid, "refresh_token": self.refresh_token} resp = self.session.post( @@ -679,8 +1008,8 @@ class AgentHttpClient: ) if resp.status_code in (401, 403): _log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log") - self.clear_tokens() - self.perform_enrollment() + self._clear_tokens_locked() + self._perform_enrollment_locked() return resp.raise_for_status() data = resp.json() @@ -696,6 +1025,10 @@ class AgentHttpClient: self.session.headers.update({"Authorization": f"Bearer {self.access_token}"}) def clear_tokens(self) -> None: + with self._auth_lock: + self._clear_tokens_locked() + + def _clear_tokens_locked(self) -> None: self.key_store.clear_tokens() self.access_token = None self.refresh_token = None @@ -712,6 +1045,19 @@ class AgentHttpClient: except Exception: return "" + def _consume_installer_code(self) -> None: + # Avoid clearing explicit CLI/env overrides; only mutate persisted config. + self._active_installer_code = None + if INSTALLER_CODE_OVERRIDE: + return + try: + if CONFIG.data.get("installer_code"): + CONFIG.data["installer_code"] = "" + CONFIG._write() + _log_agent("Cleared persisted installer code after successful enrollment", fname="agent.log") + except Exception as exc: + _log_agent(f"Failed to clear installer code after enrollment: {exc}", fname="agent.error.log") + # ------------------------------------------------------------------ # HTTP helpers # ------------------------------------------------------------------ @@ -722,6 +1068,16 @@ class AgentHttpClient: headers = self.auth_headers() response = self.session.post(url, json=payload, headers=headers, timeout=30) if response.status_code in (401, 403) and require_auth: + snippet = "" + try: + snippet = response.text[:256] + except Exception: + snippet = "" + _log_agent( + "Authenticated request rejected " + f"path={path} status={response.status_code} body_snippet={snippet}", + fname="agent.error.log", + ) self.clear_tokens() self.ensure_authenticated() headers = self.auth_headers() diff --git a/Data/Agent/security.py b/Data/Agent/security.py index 475d1ee..443a0dd 100644 --- a/Data/Agent/security.py +++ b/Data/Agent/security.py @@ -3,6 +3,8 @@ from __future__ import annotations import base64 +import contextlib +import errno import hashlib import json import os @@ -23,6 +25,16 @@ try: except Exception: # pragma: no cover - win32crypt missing win32crypt = None # type: ignore +try: # pragma: no cover - Windows only + import msvcrt # type: ignore +except Exception: # pragma: no cover - non-Windows + msvcrt = None # type: ignore + +try: # pragma: no cover - POSIX only + import fcntl # type: ignore +except Exception: # pragma: no cover - Windows + fcntl = None # type: ignore + def _ensure_dir(path: str) -> None: os.makedirs(path, exist_ok=True) @@ -36,43 +48,155 @@ def _restrict_permissions(path: str) -> None: pass +class _FileLock: + def __init__(self, path: str) -> None: + self.path = path + self._handle = None + + def acquire(self, *, timeout: float = 60.0, poll_interval: float = 0.2) -> None: + directory = os.path.dirname(self.path) + if directory: + os.makedirs(directory, exist_ok=True) + deadline = time.time() + timeout if timeout else None + while True: + handle = open(self.path, "a+b") + try: + self._try_lock(handle) + except OSError as exc: + handle.close() + if not self._is_lock_unavailable(exc): + raise + if deadline and time.time() >= deadline: + raise TimeoutError("Timed out waiting for file lock") + time.sleep(poll_interval) + continue + except Exception: + handle.close() + raise + + self._handle = handle + try: + handle.seek(0) + handle.truncate(0) + payload = f"pid={os.getpid()} ts={int(time.time())}\n".encode("utf-8") + handle.write(payload) + handle.flush() + except Exception: + pass + return + + def release(self) -> None: + handle = self._handle + if not handle: + return + try: + self._unlock(handle) + finally: + try: + handle.close() + except Exception: + pass + self._handle = None + + def _try_lock(self, handle): + if IS_WINDOWS: + if msvcrt is None: + raise OSError(errno.EINVAL, "msvcrt unavailable for locking") + try: + msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1) # type: ignore[attr-defined] + except OSError as exc: # pragma: no cover - platform specific + raise exc + else: + if fcntl is None: + raise OSError(errno.EINVAL, "fcntl unavailable for locking") + fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) # type: ignore[attr-defined] + + def _unlock(self, handle): + if IS_WINDOWS: + if msvcrt is None: + return + try: + msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1) # type: ignore[attr-defined] + except OSError: # pragma: no cover - platform specific + pass + else: + if fcntl is None: + return + try: + fcntl.flock(handle.fileno(), fcntl.LOCK_UN) # type: ignore[attr-defined] + except OSError: + pass + + @staticmethod + def _is_lock_unavailable(exc: OSError) -> bool: + if hasattr(exc, "winerror"): + return exc.winerror in (32, 33) # type: ignore[attr-defined] + err = exc.errno if hasattr(exc, "errno") else None + return err in (errno.EACCES, errno.EAGAIN, errno.EWOULDBLOCK) + + +@contextlib.contextmanager +def _locked_file(path: str, *, timeout: float = 60.0): + lock = _FileLock(path) + lock.acquire(timeout=timeout) + try: + yield + finally: + lock.release() + + def _protect(data: bytes, *, scope_system: bool) -> bytes: if not IS_WINDOWS or not win32crypt: return data - flags = 0 + scopes = [scope_system] + # Always include the alternate scope so we can fall back if the preferred + # protection attempt fails (e.g., running under a limited account that + # lacks access to the desired DPAPI scope). if scope_system: - flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4) - try: - protected = win32crypt.CryptProtectData(data, None, None, None, None, flags) # type: ignore[attr-defined] - except Exception: - return data - blob = protected[1] - if isinstance(blob, memoryview): - return blob.tobytes() - if isinstance(blob, bytearray): - return bytes(blob) - if isinstance(blob, bytes): - return blob + scopes.append(False) + else: + scopes.append(True) + for scope in scopes: + flags = 0 + if scope: + flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4) + try: + protected = win32crypt.CryptProtectData(data, None, None, None, None, flags) # type: ignore[attr-defined] + except Exception: + continue + blob = protected[1] + if isinstance(blob, memoryview): + return blob.tobytes() + if isinstance(blob, bytearray): + return bytes(blob) + if isinstance(blob, bytes): + return blob return data def _unprotect(data: bytes, *, scope_system: bool) -> bytes: if not IS_WINDOWS or not win32crypt: return data - flags = 0 + scopes = [scope_system] if scope_system: - flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4) - try: - unwrapped = win32crypt.CryptUnprotectData(data, None, None, None, None, flags) # type: ignore[attr-defined] - except Exception: - return data - blob = unwrapped[1] - if isinstance(blob, memoryview): - return blob.tobytes() - if isinstance(blob, bytearray): - return bytes(blob) - if isinstance(blob, bytes): - return blob + scopes.append(False) + else: + scopes.append(True) + for scope in scopes: + flags = 0 + if scope: + flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4) + try: + unwrapped = win32crypt.CryptUnprotectData(data, None, None, None, None, flags) # type: ignore[attr-defined] + except Exception: + continue + blob = unwrapped[1] + if isinstance(blob, memoryview): + return blob.tobytes() + if isinstance(blob, bytearray): + return bytes(blob) + if isinstance(blob, bytes): + return blob return data @@ -105,17 +229,21 @@ class AgentKeyStore: self._token_meta_path = os.path.join(self.settings_dir, "access.meta.json") self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem") self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub") + self._identity_lock_path = os.path.join(self.settings_dir, "identity.lock") # ------------------------------------------------------------------ # Identity management # ------------------------------------------------------------------ def load_or_create_identity(self) -> AgentIdentity: - if os.path.isfile(self._private_path) and os.path.isfile(self._public_path): - try: - return self._load_identity() - except Exception: - pass - return self._create_identity() + with _locked_file(self._identity_lock_path, timeout=120.0): + if os.path.isfile(self._private_path) and os.path.isfile(self._public_path): + try: + return self._load_identity() + except Exception: + # If loading fails, fall back to regenerating the identity while + # holding the lock so concurrent agents do not thrash the key files. + pass + return self._create_identity() def _load_identity(self) -> AgentIdentity: with open(self._private_path, "rb") as fh: @@ -212,11 +340,23 @@ class AgentKeyStore: try: with open(self._refresh_token_path, "rb") as fh: protected = fh.read() - raw = _unprotect(protected, scope_system=self.scope_system) - return raw.decode("utf-8") except Exception: return None + # Try both scopes (preferred first) and decode once a UTF-8 payload is recovered. + for scope_first in (self.scope_system, not self.scope_system): + try: + candidate = _unprotect(protected, scope_system=scope_first) + except Exception: + continue + if not candidate: + continue + try: + return candidate.decode("utf-8") + except Exception: + continue + return None + def clear_tokens(self) -> None: for path in (self._access_token_path, self._refresh_token_path, self._token_meta_path): try: diff --git a/Data/Server/Modules/agents/routes.py b/Data/Server/Modules/agents/routes.py index e312d6c..0b96bb2 100644 --- a/Data/Server/Modules/agents/routes.py +++ b/Data/Server/Modules/agents/routes.py @@ -2,6 +2,7 @@ from __future__ import annotations import json import time +import sqlite3 from typing import Any, Callable, Dict, Optional from flask import Blueprint, jsonify, request, g @@ -28,10 +29,18 @@ def register( except Exception: return None + def _auth_context(): + ctx = getattr(g, "device_auth", None) + if ctx is None: + log("server", f"device auth context missing for {request.path}") + return ctx + @blueprint.route("/api/agent/heartbeat", methods=["POST"]) @require_device_auth(auth_manager) def heartbeat(): - ctx = getattr(g, "device_auth") + ctx = _auth_context() + if ctx is None: + return jsonify({"error": "auth_context_missing"}), 500 payload = request.get_json(force=True, silent=True) or {} now_ts = int(time.time()) @@ -71,14 +80,42 @@ def register( conn = db_conn_factory() try: cur = conn.cursor() - columns = ", ".join(f"{col} = ?" for col in updates.keys()) - params = list(updates.values()) - params.append(ctx.guid) - cur.execute( - f"UPDATE devices SET {columns} WHERE guid = ?", - params, - ) - if cur.rowcount == 0: + + def _apply_updates() -> int: + if not updates: + return 0 + columns = ", ".join(f"{col} = ?" for col in updates.keys()) + params = list(updates.values()) + params.append(ctx.guid) + cur.execute( + f"UPDATE devices SET {columns} WHERE guid = ?", + params, + ) + return cur.rowcount + + try: + rowcount = _apply_updates() + except sqlite3.IntegrityError as exc: + if "devices.hostname" in str(exc) and "UNIQUE" in str(exc).upper(): + # Another device already claims this hostname; keep the existing + # canonical hostname assigned during enrollment to avoid breaking + # the unique constraint and continue updating the remaining fields. + if "hostname" in updates: + updates.pop("hostname", None) + try: + rowcount = _apply_updates() + except sqlite3.IntegrityError: + raise + else: + log( + "server", + "heartbeat hostname collision ignored for guid=" + f"{ctx.guid}", + ) + else: + raise + + if rowcount == 0: log("server", f"heartbeat missing device record guid={ctx.guid}") return jsonify({"error": "device_not_registered"}), 404 conn.commit() @@ -90,7 +127,9 @@ def register( @blueprint.route("/api/agent/script/request", methods=["POST"]) @require_device_auth(auth_manager) def script_request(): - ctx = getattr(g, "device_auth") + ctx = _auth_context() + if ctx is None: + return jsonify({"error": "auth_context_missing"}), 500 if ctx.status != "active": return jsonify( { diff --git a/Data/Server/Modules/enrollment/routes.py b/Data/Server/Modules/enrollment/routes.py index 7a883af..c408bcd 100644 --- a/Data/Server/Modules/enrollment/routes.py +++ b/Data/Server/Modules/enrollment/routes.py @@ -54,6 +54,10 @@ def register( def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float): decision = limiter.check(key, limit, window_s) if not decision.allowed: + log( + "server", + f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}", + ) response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after}) response.status_code = 429 response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}" @@ -128,19 +132,63 @@ def register( def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]: cur.execute( - "SELECT guid, hostname, token_version, status, ssl_key_fingerprint FROM devices WHERE guid = ?", + """ + SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at + FROM devices + WHERE guid = ? + """, (guid,), ) row = cur.fetchone() if row: - keys = ["guid", "hostname", "token_version", "status", "ssl_key_fingerprint"] + keys = [ + "guid", + "hostname", + "token_version", + "status", + "ssl_key_fingerprint", + "key_added_at", + ] record = dict(zip(keys, row)) - if not record.get("ssl_key_fingerprint"): + stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower() + new_fp = (fingerprint or "").strip().lower() + if not stored_fp and new_fp: cur.execute( "UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?", (fingerprint, _iso(_now()), guid), ) record["ssl_key_fingerprint"] = fingerprint + elif new_fp and stored_fp != new_fp: + now_iso = _iso(_now()) + try: + current_version = int(record.get("token_version") or 1) + except Exception: + current_version = 1 + new_version = max(current_version + 1, 1) + cur.execute( + """ + UPDATE devices + SET ssl_key_fingerprint = ?, + key_added_at = ?, + token_version = ?, + status = 'active' + WHERE guid = ? + """, + (fingerprint, now_iso, new_version, guid), + ) + cur.execute( + """ + UPDATE refresh_tokens + SET revoked_at = ? + WHERE guid = ? + AND revoked_at IS NULL + """, + (now_iso, guid), + ) + record["ssl_key_fingerprint"] = fingerprint + record["token_version"] = new_version + record["status"] = "active" + record["key_added_at"] = now_iso return record resolved_hostname = _normalize_host(hostname, guid, cur) @@ -169,6 +217,7 @@ def register( "token_version": 1, "status": "active", "ssl_key_fingerprint": fingerprint, + "key_added_at": key_added_at, } def _hash_refresh_token(token: str) -> str: @@ -198,7 +247,7 @@ def register( @blueprint.route("/api/agent/enroll/request", methods=["POST"]) def enrollment_request(): remote = _remote_addr() - rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 10, 60.0) + rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0) if rate_error: return rate_error @@ -208,32 +257,47 @@ def register( agent_pubkey_b64 = payload.get("agent_pubkey") client_nonce_b64 = payload.get("client_nonce") + log( + "server", + "enrollment request received " + f"ip={remote} hostname={hostname or ''} code_mask={_mask_code(enrollment_code)} " + f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}", + ) + if not hostname: + log("server", f"enrollment rejected missing_hostname ip={remote}") return jsonify({"error": "hostname_required"}), 400 if not enrollment_code: + log("server", f"enrollment rejected missing_code ip={remote} host={hostname}") return jsonify({"error": "enrollment_code_required"}), 400 if not isinstance(agent_pubkey_b64, str): + log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}") return jsonify({"error": "agent_pubkey_required"}), 400 if not isinstance(client_nonce_b64, str): + log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}") return jsonify({"error": "client_nonce_required"}), 400 try: agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64) except Exception: + log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}") return jsonify({"error": "invalid_agent_pubkey"}), 400 if len(agent_pubkey_der) < 10: + log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}") return jsonify({"error": "invalid_agent_pubkey"}), 400 try: client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True) except Exception: + log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}") return jsonify({"error": "invalid_client_nonce"}), 400 if len(client_nonce_bytes) < 16: + log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}") return jsonify({"error": "invalid_client_nonce"}), 400 fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der) - rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 3, 60.0) + rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0) if rate_error: return rate_error @@ -333,21 +397,33 @@ def register( client_nonce_b64 = payload.get("client_nonce") proof_sig_b64 = payload.get("proof_sig") + log( + "server", + "enrollment poll received " + f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}" + f" proof_sig_len={len(proof_sig_b64 or '')}", + ) + if not isinstance(approval_reference, str) or not approval_reference: + log("server", "enrollment poll rejected missing_reference") return jsonify({"error": "approval_reference_required"}), 400 if not isinstance(client_nonce_b64, str): + log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}") return jsonify({"error": "client_nonce_required"}), 400 if not isinstance(proof_sig_b64, str): + log("server", f"enrollment poll rejected missing_sig ref={approval_reference}") return jsonify({"error": "proof_sig_required"}), 400 try: client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True) except Exception: + log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}") return jsonify({"error": "invalid_client_nonce"}), 400 try: proof_sig = base64.b64decode(proof_sig_b64, validate=True) except Exception: + log("server", f"enrollment poll invalid_sig ref={approval_reference}") return jsonify({"error": "invalid_proof_sig"}), 400 conn = db_conn_factory() @@ -365,6 +441,7 @@ def register( ) row = cur.fetchone() if not row: + log("server", f"enrollment poll unknown_reference ref={approval_reference}") return jsonify({"status": "unknown"}), 404 ( @@ -383,11 +460,13 @@ def register( ) = row if client_nonce_stored != client_nonce_b64: + log("server", f"enrollment poll nonce_mismatch ref={approval_reference}") return jsonify({"error": "nonce_mismatch"}), 400 try: server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True) except Exception: + log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}") return jsonify({"error": "server_nonce_invalid"}), 400 message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes @@ -395,30 +474,58 @@ def register( try: public_key = serialization.load_der_public_key(agent_pubkey_der) except Exception: + log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}") public_key = None if public_key is None: + log("server", f"enrollment poll invalid_pubkey ref={approval_reference}") return jsonify({"error": "agent_pubkey_invalid"}), 400 try: public_key.verify(proof_sig, message) except Exception: + log("server", f"enrollment poll invalid_proof ref={approval_reference}") return jsonify({"error": "invalid_proof"}), 400 if status == "pending": + log( + "server", + f"enrollment poll pending ref={approval_reference} host={hostname_claimed}" + f" fingerprint={fingerprint[:12]}", + ) return jsonify({"status": "pending", "poll_after_ms": 5000}) if status == "denied": + log( + "server", + f"enrollment poll denied ref={approval_reference} host={hostname_claimed}", + ) return jsonify({"status": "denied", "reason": "operator_denied"}) if status == "expired": + log( + "server", + f"enrollment poll expired ref={approval_reference} host={hostname_claimed}", + ) return jsonify({"status": "expired"}) if status == "completed": + log( + "server", + f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}", + ) return jsonify({"status": "approved", "detail": "finalized"}) if status != "approved": + log( + "server", + f"enrollment poll unexpected_status={status} ref={approval_reference}", + ) return jsonify({"status": status or "unknown"}), 400 nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}" if not nonce_cache.consume(nonce_key): + log( + "server", + f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}", + ) return jsonify({"error": "proof_replayed"}), 409 # Finalize enrollment @@ -489,3 +596,12 @@ def _load_tls_bundle(path: str) -> str: return fh.read() except Exception: return "" + + +def _mask_code(code: str) -> str: + if not code: + return "" + trimmed = str(code).strip() + if len(trimmed) <= 6: + return "***" + return f"{trimmed[:3]}***{trimmed[-3:]}"