Merge pull request #127 from bunny-lab-io:codex/implement-security-features-for-borealis

Harden agent WebSocket TLS verification
This commit is contained in:
2025-10-18 02:31:18 -06:00
committed by GitHub
4 changed files with 771 additions and 120 deletions

View File

@@ -20,7 +20,10 @@ import datetime
import shutil
import string
import ssl
from typing import Any, Dict, Optional, List
import threading
import contextlib
import errno
from typing import Any, Dict, Optional, List, Callable
import requests
try:
@@ -132,6 +135,133 @@ def _settings_dir():
return os.path.abspath(os.path.join(os.path.dirname(__file__), 'Settings'))
class _CrossProcessFileLock:
def __init__(self, path: str) -> None:
self.path = path
self._handle = None
def acquire(
self,
*,
timeout: float = 120.0,
poll_interval: float = 0.5,
on_wait: Optional[Callable[[], None]] = None,
) -> None:
directory = os.path.dirname(self.path)
if directory:
os.makedirs(directory, exist_ok=True)
deadline = time.time() + timeout if timeout else None
last_wait_log = 0.0
while True:
handle = open(self.path, 'a+b')
try:
self._try_lock(handle)
self._handle = handle
try:
handle.seek(0)
handle.truncate(0)
handle.write(f"pid={os.getpid()} ts={int(time.time())}\n".encode('utf-8'))
handle.flush()
except Exception:
pass
return
except OSError as exc:
handle.close()
if not self._is_lock_unavailable(exc):
raise
now = time.time()
if on_wait and (now - last_wait_log) >= 2.0:
try:
on_wait()
except Exception:
pass
last_wait_log = now
if deadline and now >= deadline:
raise TimeoutError('Timed out waiting for enrollment lock')
time.sleep(poll_interval)
except Exception:
handle.close()
raise
def release(self) -> None:
handle = self._handle
if not handle:
return
try:
self._unlock(handle)
finally:
try:
handle.close()
except Exception:
pass
self._handle = None
@staticmethod
def _is_lock_unavailable(exc: OSError) -> bool:
err = exc.errno
winerr = getattr(exc, 'winerror', None)
unavailable = {errno.EACCES, errno.EAGAIN, getattr(errno, 'EWOULDBLOCK', errno.EAGAIN)}
if err in unavailable:
return True
if winerr in (32, 33):
return True
return False
@staticmethod
def _try_lock(handle) -> None:
handle.seek(0, os.SEEK_END)
if handle.tell() == 0:
try:
handle.write(b'0')
handle.flush()
except Exception:
pass
handle.seek(0)
if os.name == 'nt':
import msvcrt # type: ignore
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1)
except OSError:
raise
else:
import fcntl # type: ignore
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
@staticmethod
def _unlock(handle) -> None:
if os.name == 'nt':
import msvcrt # type: ignore
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1)
except OSError:
pass
else:
import fcntl # type: ignore
try:
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
except OSError:
pass
_ENROLLMENT_FILE_LOCK: Optional[_CrossProcessFileLock] = None
@contextlib.contextmanager
def _acquire_enrollment_lock(*, timeout: float = 180.0, on_wait: Optional[Callable[[], None]] = None):
global _ENROLLMENT_FILE_LOCK
if _ENROLLMENT_FILE_LOCK is None:
_ENROLLMENT_FILE_LOCK = _CrossProcessFileLock(os.path.join(_settings_dir(), 'enrollment.lock'))
_ENROLLMENT_FILE_LOCK.acquire(timeout=timeout, on_wait=on_wait)
try:
yield
finally:
_ENROLLMENT_FILE_LOCK.release()
_KEY_STORE_INSTANCE = None
@@ -242,6 +372,18 @@ def _log_agent(message: str, fname: str = 'agent.log'):
pass
def _mask_sensitive(value: str, *, prefix: int = 4, suffix: int = 4) -> str:
try:
if not value:
return ''
trimmed = value.strip()
if len(trimmed) <= prefix + suffix:
return '*' * len(trimmed)
return f"{trimmed[:prefix]}***{trimmed[-suffix:]}"
except Exception:
return '***'
def _decode_base64_text(value):
if not isinstance(value, str):
return None
@@ -490,14 +632,15 @@ class AgentHttpClient:
self.identity = IDENTITY
self.session = requests.Session()
self.base_url: Optional[str] = None
self.guid: Optional[str] = self.key_store.load_guid()
self.access_token: Optional[str] = self.key_store.load_access_token()
self.refresh_token: Optional[str] = self.key_store.load_refresh_token()
self.access_expires_at: Optional[int] = self.key_store.get_access_expiry()
self.guid: Optional[str] = None
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self.access_expires_at: Optional[int] = None
self._auth_lock = threading.RLock()
self._active_installer_code: Optional[str] = None
self.refresh_base_url()
self._configure_verify()
if self.access_token:
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
self._reload_tokens_from_disk()
self.session.headers.setdefault("User-Agent", "Borealis-Agent/secure")
# ------------------------------------------------------------------
@@ -528,6 +671,31 @@ class AgentHttpClient:
except Exception:
pass
def _reload_tokens_from_disk(self) -> None:
guid = self.key_store.load_guid()
access_token = self.key_store.load_access_token()
refresh_token = self.key_store.load_refresh_token()
access_expiry = self.key_store.get_access_expiry()
self.guid = guid if guid else None
self.access_token = access_token if access_token else None
self.refresh_token = refresh_token if refresh_token else None
self.access_expires_at = access_expiry if access_expiry else None
if self.access_token:
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
else:
self.session.headers.pop("Authorization", None)
try:
_log_agent(
"Reloaded tokens from disk "
f"guid={'yes' if self.guid else 'no'} "
f"access={'yes' if self.access_token else 'no'} "
f"refresh={'yes' if self.refresh_token else 'no'} "
f"expiry={self.access_expires_at}",
fname="agent.log",
)
except Exception:
pass
def auth_headers(self) -> Dict[str, str]:
if self.access_token:
return {"Authorization": f"Bearer {self.access_token}"}
@@ -540,12 +708,28 @@ class AgentHttpClient:
engine = getattr(client, "eio", None)
if engine is None:
return
# python-engineio accepts bool, path, or ssl.SSLContext for ssl_verify
# python-engineio accepts either a boolean or an ``ssl.SSLContext``
# for TLS verification. When we have a pinned certificate bundle
# on disk, prefer constructing a dedicated context that trusts that
# bundle so WebSocket connections succeed even with private CAs.
if isinstance(verify, str) and os.path.isfile(verify):
engine.ssl_verify = verify
try:
context = ssl.create_default_context(cafile=verify)
context.check_hostname = False
except Exception:
context = None
if context is not None:
engine.ssl_context = context
engine.ssl_verify = True
else:
engine.ssl_context = None
engine.ssl_verify = verify
elif verify is False:
engine.ssl_context = None
engine.ssl_verify = False
else:
engine.ssl_context = None
engine.ssl_verify = True
except Exception:
pass
@@ -554,11 +738,15 @@ class AgentHttpClient:
# Enrollment & token management
# ------------------------------------------------------------------
def ensure_authenticated(self) -> None:
with self._auth_lock:
self._ensure_authenticated_locked()
def _ensure_authenticated_locked(self) -> None:
self.refresh_base_url()
if not self.guid or not self.refresh_token:
self.perform_enrollment()
self._perform_enrollment_locked()
if not self.access_token or self._token_expiring_soon():
self.refresh_access_token()
self._refresh_access_token_locked()
def _token_expiring_soon(self) -> bool:
if not self.access_token:
@@ -568,69 +756,199 @@ class AgentHttpClient:
return (self.access_expires_at - time.time()) < 60
def perform_enrollment(self) -> None:
with self._auth_lock:
self._perform_enrollment_locked()
def _perform_enrollment_locked(self) -> None:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
return
code = self._resolve_installer_code()
if not code:
raise RuntimeError(
"Installer code is required for enrollment. "
"Set BOREALIS_INSTALLER_CODE, pass --installer-code, or update agent_settings.json."
)
self.refresh_base_url()
client_nonce = os.urandom(32)
payload = {
"hostname": socket.gethostname(),
"enrollment_code": code,
"agent_pubkey": PUBLIC_KEY_B64,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
}
request_url = f"{self.base_url}/api/agent/enroll/request"
_log_agent("Starting enrollment request...", fname="agent.log")
resp = self.session.post(request_url, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
if data.get("server_certificate"):
self.key_store.save_server_certificate(data["server_certificate"])
self._configure_verify()
signing_key = data.get("signing_key")
if signing_key:
try:
self.store_server_signing_key(signing_key)
except Exception as exc:
_log_agent(f'Unable to persist signing key from enrollment handshake: {exc}', fname='agent.error.log')
if data.get("status") != "pending":
raise RuntimeError(f"Unexpected enrollment status: {data}")
approval_reference = data.get("approval_reference")
server_nonce_b64 = data.get("server_nonce")
if not approval_reference or not server_nonce_b64:
raise RuntimeError("Enrollment response missing approval_reference or server_nonce")
server_nonce = base64.b64decode(server_nonce_b64)
poll_delay = max(int(data.get("poll_after_ms", 3000)) / 1000, 1)
while True:
time.sleep(min(poll_delay, 15))
signature = self.identity.sign(server_nonce + approval_reference.encode("utf-8") + client_nonce)
poll_payload = {
"approval_reference": approval_reference,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
"proof_sig": base64.b64encode(signature).decode("ascii"),
}
poll_resp = self.session.post(
f"{self.base_url}/api/agent/enroll/poll",
json=poll_payload,
timeout=30,
self._active_installer_code = code
wait_state = {"count": 0, "tokens_seen": False}
def _on_lock_wait() -> None:
wait_state["count"] += 1
_log_agent(
f"Enrollment waiting for shared lock scope={SERVICE_MODE} attempt={wait_state['count']}",
fname="agent.log",
)
poll_resp.raise_for_status()
poll_data = poll_resp.json()
status = poll_data.get("status")
if status == "pending":
poll_delay = max(int(poll_data.get("poll_after_ms", 5000)) / 1000, 1)
continue
if status == "denied":
raise RuntimeError("Enrollment denied by operator")
if status in ("expired", "unknown"):
raise RuntimeError(f"Enrollment failed with status={status}")
if status in ("approved", "completed"):
self._finalize_enrollment(poll_data)
break
raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}")
if not wait_state["tokens_seen"]:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
wait_state["tokens_seen"] = True
_log_agent(
"Enrollment credentials detected while waiting for lock; will reuse when available",
fname="agent.log",
)
try:
with _acquire_enrollment_lock(timeout=180.0, on_wait=_on_lock_wait):
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
_log_agent(
"Enrollment skipped after acquiring lock; credentials already present",
fname="agent.log",
)
return
self.refresh_base_url()
base_url = self.base_url or "https://localhost:5000"
code_masked = _mask_sensitive(code)
_log_agent(
"Enrollment handshake starting "
f"base_url={base_url} scope={SERVICE_MODE} "
f"fingerprint={SSL_KEY_FINGERPRINT[:16]} installer_code={code_masked}",
fname="agent.log",
)
client_nonce = os.urandom(32)
payload = {
"hostname": socket.gethostname(),
"enrollment_code": code,
"agent_pubkey": PUBLIC_KEY_B64,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
}
request_url = f"{self.base_url}/api/agent/enroll/request"
_log_agent(
"Starting enrollment request... "
f"url={request_url} hostname={payload['hostname']} pubkey_prefix={PUBLIC_KEY_B64[:24]}",
fname="agent.log",
)
resp = self.session.post(request_url, json=payload, timeout=30)
_log_agent(
f"Enrollment request HTTP status={resp.status_code} retry_after={resp.headers.get('Retry-After')}"
f" body_len={len(resp.content)}",
fname="agent.log",
)
try:
resp.raise_for_status()
except requests.HTTPError:
snippet = resp.text[:512] if hasattr(resp, "text") else ""
_log_agent(
f"Enrollment request failed status={resp.status_code} body_snippet={snippet}",
fname="agent.error.log",
)
if resp.status_code == 400:
try:
err_payload = resp.json()
except Exception:
err_payload = {}
if (err_payload or {}).get("error") in {"invalid_enrollment_code", "code_consumed"}:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
_log_agent(
"Enrollment code rejected but existing credentials are present; skipping re-enrollment",
fname="agent.log",
)
return
raise
data = resp.json()
_log_agent(
"Enrollment request accepted "
f"status={data.get('status')} approval_ref={data.get('approval_reference')} "
f"poll_after_ms={data.get('poll_after_ms')}"
f" server_cert={'yes' if data.get('server_certificate') else 'no'}",
fname="agent.log",
)
if data.get("server_certificate"):
self.key_store.save_server_certificate(data["server_certificate"])
self._configure_verify()
signing_key = data.get("signing_key")
if signing_key:
try:
self.store_server_signing_key(signing_key)
except Exception as exc:
_log_agent(f'Unable to persist signing key from enrollment handshake: {exc}', fname='agent.error.log')
if data.get("status") != "pending":
raise RuntimeError(f"Unexpected enrollment status: {data}")
approval_reference = data.get("approval_reference")
server_nonce_b64 = data.get("server_nonce")
if not approval_reference or not server_nonce_b64:
raise RuntimeError("Enrollment response missing approval_reference or server_nonce")
server_nonce = base64.b64decode(server_nonce_b64)
poll_delay = max(int(data.get("poll_after_ms", 3000)) / 1000, 1)
attempt = 1
while True:
time.sleep(min(poll_delay, 15))
signature = self.identity.sign(server_nonce + approval_reference.encode("utf-8") + client_nonce)
poll_payload = {
"approval_reference": approval_reference,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
"proof_sig": base64.b64encode(signature).decode("ascii"),
}
_log_agent(
f"Enrollment poll attempt={attempt} ref={approval_reference} delay={poll_delay}s",
fname="agent.log",
)
poll_resp = self.session.post(
f"{self.base_url}/api/agent/enroll/poll",
json=poll_payload,
timeout=30,
)
_log_agent(
"Enrollment poll response "
f"status_code={poll_resp.status_code} retry_after={poll_resp.headers.get('Retry-After')}"
f" body_len={len(poll_resp.content)}",
fname="agent.log",
)
try:
poll_resp.raise_for_status()
except requests.HTTPError:
snippet = poll_resp.text[:512] if hasattr(poll_resp, "text") else ""
_log_agent(
f"Enrollment poll failed attempt={attempt} status={poll_resp.status_code} "
f"body_snippet={snippet}",
fname="agent.error.log",
)
raise
poll_data = poll_resp.json()
_log_agent(
f"Enrollment poll decoded attempt={attempt} status={poll_data.get('status')}"
f" next_delay={poll_data.get('poll_after_ms')}"
f" guid_hint={poll_data.get('guid')}",
fname="agent.log",
)
status = poll_data.get("status")
if status == "pending":
poll_delay = max(int(poll_data.get("poll_after_ms", 5000)) / 1000, 1)
_log_agent(
f"Enrollment still pending attempt={attempt} new_delay={poll_delay}s",
fname="agent.log",
)
attempt += 1
continue
if status == "denied":
_log_agent("Enrollment denied by operator", fname="agent.error.log")
raise RuntimeError("Enrollment denied by operator")
if status in ("expired", "unknown"):
_log_agent(
f"Enrollment failed status={status} attempt={attempt}",
fname="agent.error.log",
)
raise RuntimeError(f"Enrollment failed with status={status}")
if status in ("approved", "completed"):
_log_agent(
f"Enrollment approved attempt={attempt} ref={approval_reference}",
fname="agent.log",
)
self._finalize_enrollment(poll_data)
break
raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}")
except TimeoutError:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
_log_agent(
"Enrollment lock wait timed out but credentials materialized; reusing existing tokens",
fname="agent.log",
)
return
raise
def _finalize_enrollment(self, payload: Dict[str, Any]) -> None:
server_cert = payload.get("server_certificate")
@@ -649,6 +967,12 @@ class AgentHttpClient:
expires_in = int(payload.get("expires_in") or 900)
if not (guid and access_token and refresh_token):
raise RuntimeError("Enrollment approval response missing tokens or guid")
_log_agent(
"Enrollment approval payload received "
f"guid={guid} access_token_len={len(access_token)} refresh_token_len={len(refresh_token)} "
f"expires_in={expires_in}",
fname="agent.log",
)
self.guid = str(guid).strip()
self.access_token = access_token.strip()
self.refresh_token = refresh_token.strip()
@@ -663,12 +987,17 @@ class AgentHttpClient:
_update_agent_id_for_guid(self.guid)
except Exception as exc:
_log_agent(f"Failed to update agent id after enrollment: {exc}", fname="agent.error.log")
self._consume_installer_code()
_log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log")
def refresh_access_token(self) -> None:
with self._auth_lock:
self._refresh_access_token_locked()
def _refresh_access_token_locked(self) -> None:
if not self.refresh_token or not self.guid:
self.clear_tokens()
self.perform_enrollment()
self._clear_tokens_locked()
self._perform_enrollment_locked()
return
payload = {"guid": self.guid, "refresh_token": self.refresh_token}
resp = self.session.post(
@@ -679,8 +1008,8 @@ class AgentHttpClient:
)
if resp.status_code in (401, 403):
_log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log")
self.clear_tokens()
self.perform_enrollment()
self._clear_tokens_locked()
self._perform_enrollment_locked()
return
resp.raise_for_status()
data = resp.json()
@@ -696,6 +1025,10 @@ class AgentHttpClient:
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
def clear_tokens(self) -> None:
with self._auth_lock:
self._clear_tokens_locked()
def _clear_tokens_locked(self) -> None:
self.key_store.clear_tokens()
self.access_token = None
self.refresh_token = None
@@ -712,6 +1045,19 @@ class AgentHttpClient:
except Exception:
return ""
def _consume_installer_code(self) -> None:
# Avoid clearing explicit CLI/env overrides; only mutate persisted config.
self._active_installer_code = None
if INSTALLER_CODE_OVERRIDE:
return
try:
if CONFIG.data.get("installer_code"):
CONFIG.data["installer_code"] = ""
CONFIG._write()
_log_agent("Cleared persisted installer code after successful enrollment", fname="agent.log")
except Exception as exc:
_log_agent(f"Failed to clear installer code after enrollment: {exc}", fname="agent.error.log")
# ------------------------------------------------------------------
# HTTP helpers
# ------------------------------------------------------------------
@@ -722,6 +1068,16 @@ class AgentHttpClient:
headers = self.auth_headers()
response = self.session.post(url, json=payload, headers=headers, timeout=30)
if response.status_code in (401, 403) and require_auth:
snippet = ""
try:
snippet = response.text[:256]
except Exception:
snippet = "<unavailable>"
_log_agent(
"Authenticated request rejected "
f"path={path} status={response.status_code} body_snippet={snippet}",
fname="agent.error.log",
)
self.clear_tokens()
self.ensure_authenticated()
headers = self.auth_headers()

View File

@@ -3,6 +3,8 @@
from __future__ import annotations
import base64
import contextlib
import errno
import hashlib
import json
import os
@@ -23,6 +25,16 @@ try:
except Exception: # pragma: no cover - win32crypt missing
win32crypt = None # type: ignore
try: # pragma: no cover - Windows only
import msvcrt # type: ignore
except Exception: # pragma: no cover - non-Windows
msvcrt = None # type: ignore
try: # pragma: no cover - POSIX only
import fcntl # type: ignore
except Exception: # pragma: no cover - Windows
fcntl = None # type: ignore
def _ensure_dir(path: str) -> None:
os.makedirs(path, exist_ok=True)
@@ -36,43 +48,155 @@ def _restrict_permissions(path: str) -> None:
pass
class _FileLock:
def __init__(self, path: str) -> None:
self.path = path
self._handle = None
def acquire(self, *, timeout: float = 60.0, poll_interval: float = 0.2) -> None:
directory = os.path.dirname(self.path)
if directory:
os.makedirs(directory, exist_ok=True)
deadline = time.time() + timeout if timeout else None
while True:
handle = open(self.path, "a+b")
try:
self._try_lock(handle)
except OSError as exc:
handle.close()
if not self._is_lock_unavailable(exc):
raise
if deadline and time.time() >= deadline:
raise TimeoutError("Timed out waiting for file lock")
time.sleep(poll_interval)
continue
except Exception:
handle.close()
raise
self._handle = handle
try:
handle.seek(0)
handle.truncate(0)
payload = f"pid={os.getpid()} ts={int(time.time())}\n".encode("utf-8")
handle.write(payload)
handle.flush()
except Exception:
pass
return
def release(self) -> None:
handle = self._handle
if not handle:
return
try:
self._unlock(handle)
finally:
try:
handle.close()
except Exception:
pass
self._handle = None
def _try_lock(self, handle):
if IS_WINDOWS:
if msvcrt is None:
raise OSError(errno.EINVAL, "msvcrt unavailable for locking")
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1) # type: ignore[attr-defined]
except OSError as exc: # pragma: no cover - platform specific
raise exc
else:
if fcntl is None:
raise OSError(errno.EINVAL, "fcntl unavailable for locking")
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) # type: ignore[attr-defined]
def _unlock(self, handle):
if IS_WINDOWS:
if msvcrt is None:
return
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1) # type: ignore[attr-defined]
except OSError: # pragma: no cover - platform specific
pass
else:
if fcntl is None:
return
try:
fcntl.flock(handle.fileno(), fcntl.LOCK_UN) # type: ignore[attr-defined]
except OSError:
pass
@staticmethod
def _is_lock_unavailable(exc: OSError) -> bool:
if hasattr(exc, "winerror"):
return exc.winerror in (32, 33) # type: ignore[attr-defined]
err = exc.errno if hasattr(exc, "errno") else None
return err in (errno.EACCES, errno.EAGAIN, errno.EWOULDBLOCK)
@contextlib.contextmanager
def _locked_file(path: str, *, timeout: float = 60.0):
lock = _FileLock(path)
lock.acquire(timeout=timeout)
try:
yield
finally:
lock.release()
def _protect(data: bytes, *, scope_system: bool) -> bytes:
if not IS_WINDOWS or not win32crypt:
return data
flags = 0
scopes = [scope_system]
# Always include the alternate scope so we can fall back if the preferred
# protection attempt fails (e.g., running under a limited account that
# lacks access to the desired DPAPI scope).
if scope_system:
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
try:
protected = win32crypt.CryptProtectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
except Exception:
return data
blob = protected[1]
if isinstance(blob, memoryview):
return blob.tobytes()
if isinstance(blob, bytearray):
return bytes(blob)
if isinstance(blob, bytes):
return blob
scopes.append(False)
else:
scopes.append(True)
for scope in scopes:
flags = 0
if scope:
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
try:
protected = win32crypt.CryptProtectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
except Exception:
continue
blob = protected[1]
if isinstance(blob, memoryview):
return blob.tobytes()
if isinstance(blob, bytearray):
return bytes(blob)
if isinstance(blob, bytes):
return blob
return data
def _unprotect(data: bytes, *, scope_system: bool) -> bytes:
if not IS_WINDOWS or not win32crypt:
return data
flags = 0
scopes = [scope_system]
if scope_system:
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
try:
unwrapped = win32crypt.CryptUnprotectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
except Exception:
return data
blob = unwrapped[1]
if isinstance(blob, memoryview):
return blob.tobytes()
if isinstance(blob, bytearray):
return bytes(blob)
if isinstance(blob, bytes):
return blob
scopes.append(False)
else:
scopes.append(True)
for scope in scopes:
flags = 0
if scope:
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
try:
unwrapped = win32crypt.CryptUnprotectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
except Exception:
continue
blob = unwrapped[1]
if isinstance(blob, memoryview):
return blob.tobytes()
if isinstance(blob, bytearray):
return bytes(blob)
if isinstance(blob, bytes):
return blob
return data
@@ -105,17 +229,21 @@ class AgentKeyStore:
self._token_meta_path = os.path.join(self.settings_dir, "access.meta.json")
self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem")
self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub")
self._identity_lock_path = os.path.join(self.settings_dir, "identity.lock")
# ------------------------------------------------------------------
# Identity management
# ------------------------------------------------------------------
def load_or_create_identity(self) -> AgentIdentity:
if os.path.isfile(self._private_path) and os.path.isfile(self._public_path):
try:
return self._load_identity()
except Exception:
pass
return self._create_identity()
with _locked_file(self._identity_lock_path, timeout=120.0):
if os.path.isfile(self._private_path) and os.path.isfile(self._public_path):
try:
return self._load_identity()
except Exception:
# If loading fails, fall back to regenerating the identity while
# holding the lock so concurrent agents do not thrash the key files.
pass
return self._create_identity()
def _load_identity(self) -> AgentIdentity:
with open(self._private_path, "rb") as fh:
@@ -212,11 +340,23 @@ class AgentKeyStore:
try:
with open(self._refresh_token_path, "rb") as fh:
protected = fh.read()
raw = _unprotect(protected, scope_system=self.scope_system)
return raw.decode("utf-8")
except Exception:
return None
# Try both scopes (preferred first) and decode once a UTF-8 payload is recovered.
for scope_first in (self.scope_system, not self.scope_system):
try:
candidate = _unprotect(protected, scope_system=scope_first)
except Exception:
continue
if not candidate:
continue
try:
return candidate.decode("utf-8")
except Exception:
continue
return None
def clear_tokens(self) -> None:
for path in (self._access_token_path, self._refresh_token_path, self._token_meta_path):
try:

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import json
import time
import sqlite3
from typing import Any, Callable, Dict, Optional
from flask import Blueprint, jsonify, request, g
@@ -28,10 +29,18 @@ def register(
except Exception:
return None
def _auth_context():
ctx = getattr(g, "device_auth", None)
if ctx is None:
log("server", f"device auth context missing for {request.path}")
return ctx
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
@require_device_auth(auth_manager)
def heartbeat():
ctx = getattr(g, "device_auth")
ctx = _auth_context()
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
payload = request.get_json(force=True, silent=True) or {}
now_ts = int(time.time())
@@ -71,14 +80,42 @@ def register(
conn = db_conn_factory()
try:
cur = conn.cursor()
columns = ", ".join(f"{col} = ?" for col in updates.keys())
params = list(updates.values())
params.append(ctx.guid)
cur.execute(
f"UPDATE devices SET {columns} WHERE guid = ?",
params,
)
if cur.rowcount == 0:
def _apply_updates() -> int:
if not updates:
return 0
columns = ", ".join(f"{col} = ?" for col in updates.keys())
params = list(updates.values())
params.append(ctx.guid)
cur.execute(
f"UPDATE devices SET {columns} WHERE guid = ?",
params,
)
return cur.rowcount
try:
rowcount = _apply_updates()
except sqlite3.IntegrityError as exc:
if "devices.hostname" in str(exc) and "UNIQUE" in str(exc).upper():
# Another device already claims this hostname; keep the existing
# canonical hostname assigned during enrollment to avoid breaking
# the unique constraint and continue updating the remaining fields.
if "hostname" in updates:
updates.pop("hostname", None)
try:
rowcount = _apply_updates()
except sqlite3.IntegrityError:
raise
else:
log(
"server",
"heartbeat hostname collision ignored for guid="
f"{ctx.guid}",
)
else:
raise
if rowcount == 0:
log("server", f"heartbeat missing device record guid={ctx.guid}")
return jsonify({"error": "device_not_registered"}), 404
conn.commit()
@@ -90,7 +127,9 @@ def register(
@blueprint.route("/api/agent/script/request", methods=["POST"])
@require_device_auth(auth_manager)
def script_request():
ctx = getattr(g, "device_auth")
ctx = _auth_context()
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
if ctx.status != "active":
return jsonify(
{

View File

@@ -54,6 +54,10 @@ def register(
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
decision = limiter.check(key, limit, window_s)
if not decision.allowed:
log(
"server",
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
)
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
response.status_code = 429
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
@@ -128,19 +132,63 @@ def register(
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
cur.execute(
"SELECT guid, hostname, token_version, status, ssl_key_fingerprint FROM devices WHERE guid = ?",
"""
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
FROM devices
WHERE guid = ?
""",
(guid,),
)
row = cur.fetchone()
if row:
keys = ["guid", "hostname", "token_version", "status", "ssl_key_fingerprint"]
keys = [
"guid",
"hostname",
"token_version",
"status",
"ssl_key_fingerprint",
"key_added_at",
]
record = dict(zip(keys, row))
if not record.get("ssl_key_fingerprint"):
stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower()
new_fp = (fingerprint or "").strip().lower()
if not stored_fp and new_fp:
cur.execute(
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
(fingerprint, _iso(_now()), guid),
)
record["ssl_key_fingerprint"] = fingerprint
elif new_fp and stored_fp != new_fp:
now_iso = _iso(_now())
try:
current_version = int(record.get("token_version") or 1)
except Exception:
current_version = 1
new_version = max(current_version + 1, 1)
cur.execute(
"""
UPDATE devices
SET ssl_key_fingerprint = ?,
key_added_at = ?,
token_version = ?,
status = 'active'
WHERE guid = ?
""",
(fingerprint, now_iso, new_version, guid),
)
cur.execute(
"""
UPDATE refresh_tokens
SET revoked_at = ?
WHERE guid = ?
AND revoked_at IS NULL
""",
(now_iso, guid),
)
record["ssl_key_fingerprint"] = fingerprint
record["token_version"] = new_version
record["status"] = "active"
record["key_added_at"] = now_iso
return record
resolved_hostname = _normalize_host(hostname, guid, cur)
@@ -169,6 +217,7 @@ def register(
"token_version": 1,
"status": "active",
"ssl_key_fingerprint": fingerprint,
"key_added_at": key_added_at,
}
def _hash_refresh_token(token: str) -> str:
@@ -198,7 +247,7 @@ def register(
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
def enrollment_request():
remote = _remote_addr()
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 10, 60.0)
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0)
if rate_error:
return rate_error
@@ -208,32 +257,47 @@ def register(
agent_pubkey_b64 = payload.get("agent_pubkey")
client_nonce_b64 = payload.get("client_nonce")
log(
"server",
"enrollment request received "
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
)
if not hostname:
log("server", f"enrollment rejected missing_hostname ip={remote}")
return jsonify({"error": "hostname_required"}), 400
if not enrollment_code:
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}")
return jsonify({"error": "enrollment_code_required"}), 400
if not isinstance(agent_pubkey_b64, str):
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}")
return jsonify({"error": "agent_pubkey_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}")
return jsonify({"error": "client_nonce_required"}), 400
try:
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
except Exception:
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}")
return jsonify({"error": "invalid_agent_pubkey"}), 400
if len(agent_pubkey_der) < 10:
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}")
return jsonify({"error": "invalid_agent_pubkey"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}")
return jsonify({"error": "invalid_client_nonce"}), 400
if len(client_nonce_bytes) < 16:
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}")
return jsonify({"error": "invalid_client_nonce"}), 400
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 3, 60.0)
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0)
if rate_error:
return rate_error
@@ -333,21 +397,33 @@ def register(
client_nonce_b64 = payload.get("client_nonce")
proof_sig_b64 = payload.get("proof_sig")
log(
"server",
"enrollment poll received "
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
f" proof_sig_len={len(proof_sig_b64 or '')}",
)
if not isinstance(approval_reference, str) or not approval_reference:
log("server", "enrollment poll rejected missing_reference")
return jsonify({"error": "approval_reference_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}")
return jsonify({"error": "client_nonce_required"}), 400
if not isinstance(proof_sig_b64, str):
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}")
return jsonify({"error": "proof_sig_required"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}")
return jsonify({"error": "invalid_client_nonce"}), 400
try:
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_sig ref={approval_reference}")
return jsonify({"error": "invalid_proof_sig"}), 400
conn = db_conn_factory()
@@ -365,6 +441,7 @@ def register(
)
row = cur.fetchone()
if not row:
log("server", f"enrollment poll unknown_reference ref={approval_reference}")
return jsonify({"status": "unknown"}), 404
(
@@ -383,11 +460,13 @@ def register(
) = row
if client_nonce_stored != client_nonce_b64:
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}")
return jsonify({"error": "nonce_mismatch"}), 400
try:
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}")
return jsonify({"error": "server_nonce_invalid"}), 400
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
@@ -395,30 +474,58 @@ def register(
try:
public_key = serialization.load_der_public_key(agent_pubkey_der)
except Exception:
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}")
public_key = None
if public_key is None:
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}")
return jsonify({"error": "agent_pubkey_invalid"}), 400
try:
public_key.verify(proof_sig, message)
except Exception:
log("server", f"enrollment poll invalid_proof ref={approval_reference}")
return jsonify({"error": "invalid_proof"}), 400
if status == "pending":
log(
"server",
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
f" fingerprint={fingerprint[:12]}",
)
return jsonify({"status": "pending", "poll_after_ms": 5000})
if status == "denied":
log(
"server",
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
)
return jsonify({"status": "denied", "reason": "operator_denied"})
if status == "expired":
log(
"server",
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
)
return jsonify({"status": "expired"})
if status == "completed":
log(
"server",
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
)
return jsonify({"status": "approved", "detail": "finalized"})
if status != "approved":
log(
"server",
f"enrollment poll unexpected_status={status} ref={approval_reference}",
)
return jsonify({"status": status or "unknown"}), 400
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
if not nonce_cache.consume(nonce_key):
log(
"server",
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
)
return jsonify({"error": "proof_replayed"}), 409
# Finalize enrollment
@@ -489,3 +596,12 @@ def _load_tls_bundle(path: str) -> str:
return fh.read()
except Exception:
return ""
def _mask_code(code: str) -> str:
if not code:
return "<missing>"
trimmed = str(code).strip()
if len(trimmed) <= 6:
return "***"
return f"{trimmed[:3]}***{trimmed[-3:]}"