Merge pull request #127 from bunny-lab-io:codex/implement-security-features-for-borealis

Harden agent WebSocket TLS verification
This commit is contained in:
2025-10-18 02:31:18 -06:00
committed by GitHub
4 changed files with 771 additions and 120 deletions

View File

@@ -20,7 +20,10 @@ import datetime
import shutil import shutil
import string import string
import ssl import ssl
from typing import Any, Dict, Optional, List import threading
import contextlib
import errno
from typing import Any, Dict, Optional, List, Callable
import requests import requests
try: try:
@@ -132,6 +135,133 @@ def _settings_dir():
return os.path.abspath(os.path.join(os.path.dirname(__file__), 'Settings')) return os.path.abspath(os.path.join(os.path.dirname(__file__), 'Settings'))
class _CrossProcessFileLock:
def __init__(self, path: str) -> None:
self.path = path
self._handle = None
def acquire(
self,
*,
timeout: float = 120.0,
poll_interval: float = 0.5,
on_wait: Optional[Callable[[], None]] = None,
) -> None:
directory = os.path.dirname(self.path)
if directory:
os.makedirs(directory, exist_ok=True)
deadline = time.time() + timeout if timeout else None
last_wait_log = 0.0
while True:
handle = open(self.path, 'a+b')
try:
self._try_lock(handle)
self._handle = handle
try:
handle.seek(0)
handle.truncate(0)
handle.write(f"pid={os.getpid()} ts={int(time.time())}\n".encode('utf-8'))
handle.flush()
except Exception:
pass
return
except OSError as exc:
handle.close()
if not self._is_lock_unavailable(exc):
raise
now = time.time()
if on_wait and (now - last_wait_log) >= 2.0:
try:
on_wait()
except Exception:
pass
last_wait_log = now
if deadline and now >= deadline:
raise TimeoutError('Timed out waiting for enrollment lock')
time.sleep(poll_interval)
except Exception:
handle.close()
raise
def release(self) -> None:
handle = self._handle
if not handle:
return
try:
self._unlock(handle)
finally:
try:
handle.close()
except Exception:
pass
self._handle = None
@staticmethod
def _is_lock_unavailable(exc: OSError) -> bool:
err = exc.errno
winerr = getattr(exc, 'winerror', None)
unavailable = {errno.EACCES, errno.EAGAIN, getattr(errno, 'EWOULDBLOCK', errno.EAGAIN)}
if err in unavailable:
return True
if winerr in (32, 33):
return True
return False
@staticmethod
def _try_lock(handle) -> None:
handle.seek(0, os.SEEK_END)
if handle.tell() == 0:
try:
handle.write(b'0')
handle.flush()
except Exception:
pass
handle.seek(0)
if os.name == 'nt':
import msvcrt # type: ignore
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1)
except OSError:
raise
else:
import fcntl # type: ignore
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
@staticmethod
def _unlock(handle) -> None:
if os.name == 'nt':
import msvcrt # type: ignore
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1)
except OSError:
pass
else:
import fcntl # type: ignore
try:
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
except OSError:
pass
_ENROLLMENT_FILE_LOCK: Optional[_CrossProcessFileLock] = None
@contextlib.contextmanager
def _acquire_enrollment_lock(*, timeout: float = 180.0, on_wait: Optional[Callable[[], None]] = None):
global _ENROLLMENT_FILE_LOCK
if _ENROLLMENT_FILE_LOCK is None:
_ENROLLMENT_FILE_LOCK = _CrossProcessFileLock(os.path.join(_settings_dir(), 'enrollment.lock'))
_ENROLLMENT_FILE_LOCK.acquire(timeout=timeout, on_wait=on_wait)
try:
yield
finally:
_ENROLLMENT_FILE_LOCK.release()
_KEY_STORE_INSTANCE = None _KEY_STORE_INSTANCE = None
@@ -242,6 +372,18 @@ def _log_agent(message: str, fname: str = 'agent.log'):
pass pass
def _mask_sensitive(value: str, *, prefix: int = 4, suffix: int = 4) -> str:
try:
if not value:
return ''
trimmed = value.strip()
if len(trimmed) <= prefix + suffix:
return '*' * len(trimmed)
return f"{trimmed[:prefix]}***{trimmed[-suffix:]}"
except Exception:
return '***'
def _decode_base64_text(value): def _decode_base64_text(value):
if not isinstance(value, str): if not isinstance(value, str):
return None return None
@@ -490,14 +632,15 @@ class AgentHttpClient:
self.identity = IDENTITY self.identity = IDENTITY
self.session = requests.Session() self.session = requests.Session()
self.base_url: Optional[str] = None self.base_url: Optional[str] = None
self.guid: Optional[str] = self.key_store.load_guid() self.guid: Optional[str] = None
self.access_token: Optional[str] = self.key_store.load_access_token() self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = self.key_store.load_refresh_token() self.refresh_token: Optional[str] = None
self.access_expires_at: Optional[int] = self.key_store.get_access_expiry() self.access_expires_at: Optional[int] = None
self._auth_lock = threading.RLock()
self._active_installer_code: Optional[str] = None
self.refresh_base_url() self.refresh_base_url()
self._configure_verify() self._configure_verify()
if self.access_token: self._reload_tokens_from_disk()
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
self.session.headers.setdefault("User-Agent", "Borealis-Agent/secure") self.session.headers.setdefault("User-Agent", "Borealis-Agent/secure")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -528,6 +671,31 @@ class AgentHttpClient:
except Exception: except Exception:
pass pass
def _reload_tokens_from_disk(self) -> None:
guid = self.key_store.load_guid()
access_token = self.key_store.load_access_token()
refresh_token = self.key_store.load_refresh_token()
access_expiry = self.key_store.get_access_expiry()
self.guid = guid if guid else None
self.access_token = access_token if access_token else None
self.refresh_token = refresh_token if refresh_token else None
self.access_expires_at = access_expiry if access_expiry else None
if self.access_token:
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
else:
self.session.headers.pop("Authorization", None)
try:
_log_agent(
"Reloaded tokens from disk "
f"guid={'yes' if self.guid else 'no'} "
f"access={'yes' if self.access_token else 'no'} "
f"refresh={'yes' if self.refresh_token else 'no'} "
f"expiry={self.access_expires_at}",
fname="agent.log",
)
except Exception:
pass
def auth_headers(self) -> Dict[str, str]: def auth_headers(self) -> Dict[str, str]:
if self.access_token: if self.access_token:
return {"Authorization": f"Bearer {self.access_token}"} return {"Authorization": f"Bearer {self.access_token}"}
@@ -540,12 +708,28 @@ class AgentHttpClient:
engine = getattr(client, "eio", None) engine = getattr(client, "eio", None)
if engine is None: if engine is None:
return return
# python-engineio accepts bool, path, or ssl.SSLContext for ssl_verify
# python-engineio accepts either a boolean or an ``ssl.SSLContext``
# for TLS verification. When we have a pinned certificate bundle
# on disk, prefer constructing a dedicated context that trusts that
# bundle so WebSocket connections succeed even with private CAs.
if isinstance(verify, str) and os.path.isfile(verify): if isinstance(verify, str) and os.path.isfile(verify):
engine.ssl_verify = verify try:
context = ssl.create_default_context(cafile=verify)
context.check_hostname = False
except Exception:
context = None
if context is not None:
engine.ssl_context = context
engine.ssl_verify = True
else:
engine.ssl_context = None
engine.ssl_verify = verify
elif verify is False: elif verify is False:
engine.ssl_context = None
engine.ssl_verify = False engine.ssl_verify = False
else: else:
engine.ssl_context = None
engine.ssl_verify = True engine.ssl_verify = True
except Exception: except Exception:
pass pass
@@ -554,11 +738,15 @@ class AgentHttpClient:
# Enrollment & token management # Enrollment & token management
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def ensure_authenticated(self) -> None: def ensure_authenticated(self) -> None:
with self._auth_lock:
self._ensure_authenticated_locked()
def _ensure_authenticated_locked(self) -> None:
self.refresh_base_url() self.refresh_base_url()
if not self.guid or not self.refresh_token: if not self.guid or not self.refresh_token:
self.perform_enrollment() self._perform_enrollment_locked()
if not self.access_token or self._token_expiring_soon(): if not self.access_token or self._token_expiring_soon():
self.refresh_access_token() self._refresh_access_token_locked()
def _token_expiring_soon(self) -> bool: def _token_expiring_soon(self) -> bool:
if not self.access_token: if not self.access_token:
@@ -568,69 +756,199 @@ class AgentHttpClient:
return (self.access_expires_at - time.time()) < 60 return (self.access_expires_at - time.time()) < 60
def perform_enrollment(self) -> None: def perform_enrollment(self) -> None:
with self._auth_lock:
self._perform_enrollment_locked()
def _perform_enrollment_locked(self) -> None:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
return
code = self._resolve_installer_code() code = self._resolve_installer_code()
if not code: if not code:
raise RuntimeError( raise RuntimeError(
"Installer code is required for enrollment. " "Installer code is required for enrollment. "
"Set BOREALIS_INSTALLER_CODE, pass --installer-code, or update agent_settings.json." "Set BOREALIS_INSTALLER_CODE, pass --installer-code, or update agent_settings.json."
) )
self.refresh_base_url() self._active_installer_code = code
client_nonce = os.urandom(32)
payload = { wait_state = {"count": 0, "tokens_seen": False}
"hostname": socket.gethostname(),
"enrollment_code": code, def _on_lock_wait() -> None:
"agent_pubkey": PUBLIC_KEY_B64, wait_state["count"] += 1
"client_nonce": base64.b64encode(client_nonce).decode("ascii"), _log_agent(
} f"Enrollment waiting for shared lock scope={SERVICE_MODE} attempt={wait_state['count']}",
request_url = f"{self.base_url}/api/agent/enroll/request" fname="agent.log",
_log_agent("Starting enrollment request...", fname="agent.log")
resp = self.session.post(request_url, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
if data.get("server_certificate"):
self.key_store.save_server_certificate(data["server_certificate"])
self._configure_verify()
signing_key = data.get("signing_key")
if signing_key:
try:
self.store_server_signing_key(signing_key)
except Exception as exc:
_log_agent(f'Unable to persist signing key from enrollment handshake: {exc}', fname='agent.error.log')
if data.get("status") != "pending":
raise RuntimeError(f"Unexpected enrollment status: {data}")
approval_reference = data.get("approval_reference")
server_nonce_b64 = data.get("server_nonce")
if not approval_reference or not server_nonce_b64:
raise RuntimeError("Enrollment response missing approval_reference or server_nonce")
server_nonce = base64.b64decode(server_nonce_b64)
poll_delay = max(int(data.get("poll_after_ms", 3000)) / 1000, 1)
while True:
time.sleep(min(poll_delay, 15))
signature = self.identity.sign(server_nonce + approval_reference.encode("utf-8") + client_nonce)
poll_payload = {
"approval_reference": approval_reference,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
"proof_sig": base64.b64encode(signature).decode("ascii"),
}
poll_resp = self.session.post(
f"{self.base_url}/api/agent/enroll/poll",
json=poll_payload,
timeout=30,
) )
poll_resp.raise_for_status() if not wait_state["tokens_seen"]:
poll_data = poll_resp.json() self._reload_tokens_from_disk()
status = poll_data.get("status") if self.guid and self.refresh_token:
if status == "pending": wait_state["tokens_seen"] = True
poll_delay = max(int(poll_data.get("poll_after_ms", 5000)) / 1000, 1) _log_agent(
continue "Enrollment credentials detected while waiting for lock; will reuse when available",
if status == "denied": fname="agent.log",
raise RuntimeError("Enrollment denied by operator") )
if status in ("expired", "unknown"):
raise RuntimeError(f"Enrollment failed with status={status}") try:
if status in ("approved", "completed"): with _acquire_enrollment_lock(timeout=180.0, on_wait=_on_lock_wait):
self._finalize_enrollment(poll_data) self._reload_tokens_from_disk()
break if self.guid and self.refresh_token:
raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}") _log_agent(
"Enrollment skipped after acquiring lock; credentials already present",
fname="agent.log",
)
return
self.refresh_base_url()
base_url = self.base_url or "https://localhost:5000"
code_masked = _mask_sensitive(code)
_log_agent(
"Enrollment handshake starting "
f"base_url={base_url} scope={SERVICE_MODE} "
f"fingerprint={SSL_KEY_FINGERPRINT[:16]} installer_code={code_masked}",
fname="agent.log",
)
client_nonce = os.urandom(32)
payload = {
"hostname": socket.gethostname(),
"enrollment_code": code,
"agent_pubkey": PUBLIC_KEY_B64,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
}
request_url = f"{self.base_url}/api/agent/enroll/request"
_log_agent(
"Starting enrollment request... "
f"url={request_url} hostname={payload['hostname']} pubkey_prefix={PUBLIC_KEY_B64[:24]}",
fname="agent.log",
)
resp = self.session.post(request_url, json=payload, timeout=30)
_log_agent(
f"Enrollment request HTTP status={resp.status_code} retry_after={resp.headers.get('Retry-After')}"
f" body_len={len(resp.content)}",
fname="agent.log",
)
try:
resp.raise_for_status()
except requests.HTTPError:
snippet = resp.text[:512] if hasattr(resp, "text") else ""
_log_agent(
f"Enrollment request failed status={resp.status_code} body_snippet={snippet}",
fname="agent.error.log",
)
if resp.status_code == 400:
try:
err_payload = resp.json()
except Exception:
err_payload = {}
if (err_payload or {}).get("error") in {"invalid_enrollment_code", "code_consumed"}:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
_log_agent(
"Enrollment code rejected but existing credentials are present; skipping re-enrollment",
fname="agent.log",
)
return
raise
data = resp.json()
_log_agent(
"Enrollment request accepted "
f"status={data.get('status')} approval_ref={data.get('approval_reference')} "
f"poll_after_ms={data.get('poll_after_ms')}"
f" server_cert={'yes' if data.get('server_certificate') else 'no'}",
fname="agent.log",
)
if data.get("server_certificate"):
self.key_store.save_server_certificate(data["server_certificate"])
self._configure_verify()
signing_key = data.get("signing_key")
if signing_key:
try:
self.store_server_signing_key(signing_key)
except Exception as exc:
_log_agent(f'Unable to persist signing key from enrollment handshake: {exc}', fname='agent.error.log')
if data.get("status") != "pending":
raise RuntimeError(f"Unexpected enrollment status: {data}")
approval_reference = data.get("approval_reference")
server_nonce_b64 = data.get("server_nonce")
if not approval_reference or not server_nonce_b64:
raise RuntimeError("Enrollment response missing approval_reference or server_nonce")
server_nonce = base64.b64decode(server_nonce_b64)
poll_delay = max(int(data.get("poll_after_ms", 3000)) / 1000, 1)
attempt = 1
while True:
time.sleep(min(poll_delay, 15))
signature = self.identity.sign(server_nonce + approval_reference.encode("utf-8") + client_nonce)
poll_payload = {
"approval_reference": approval_reference,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
"proof_sig": base64.b64encode(signature).decode("ascii"),
}
_log_agent(
f"Enrollment poll attempt={attempt} ref={approval_reference} delay={poll_delay}s",
fname="agent.log",
)
poll_resp = self.session.post(
f"{self.base_url}/api/agent/enroll/poll",
json=poll_payload,
timeout=30,
)
_log_agent(
"Enrollment poll response "
f"status_code={poll_resp.status_code} retry_after={poll_resp.headers.get('Retry-After')}"
f" body_len={len(poll_resp.content)}",
fname="agent.log",
)
try:
poll_resp.raise_for_status()
except requests.HTTPError:
snippet = poll_resp.text[:512] if hasattr(poll_resp, "text") else ""
_log_agent(
f"Enrollment poll failed attempt={attempt} status={poll_resp.status_code} "
f"body_snippet={snippet}",
fname="agent.error.log",
)
raise
poll_data = poll_resp.json()
_log_agent(
f"Enrollment poll decoded attempt={attempt} status={poll_data.get('status')}"
f" next_delay={poll_data.get('poll_after_ms')}"
f" guid_hint={poll_data.get('guid')}",
fname="agent.log",
)
status = poll_data.get("status")
if status == "pending":
poll_delay = max(int(poll_data.get("poll_after_ms", 5000)) / 1000, 1)
_log_agent(
f"Enrollment still pending attempt={attempt} new_delay={poll_delay}s",
fname="agent.log",
)
attempt += 1
continue
if status == "denied":
_log_agent("Enrollment denied by operator", fname="agent.error.log")
raise RuntimeError("Enrollment denied by operator")
if status in ("expired", "unknown"):
_log_agent(
f"Enrollment failed status={status} attempt={attempt}",
fname="agent.error.log",
)
raise RuntimeError(f"Enrollment failed with status={status}")
if status in ("approved", "completed"):
_log_agent(
f"Enrollment approved attempt={attempt} ref={approval_reference}",
fname="agent.log",
)
self._finalize_enrollment(poll_data)
break
raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}")
except TimeoutError:
self._reload_tokens_from_disk()
if self.guid and self.refresh_token:
_log_agent(
"Enrollment lock wait timed out but credentials materialized; reusing existing tokens",
fname="agent.log",
)
return
raise
def _finalize_enrollment(self, payload: Dict[str, Any]) -> None: def _finalize_enrollment(self, payload: Dict[str, Any]) -> None:
server_cert = payload.get("server_certificate") server_cert = payload.get("server_certificate")
@@ -649,6 +967,12 @@ class AgentHttpClient:
expires_in = int(payload.get("expires_in") or 900) expires_in = int(payload.get("expires_in") or 900)
if not (guid and access_token and refresh_token): if not (guid and access_token and refresh_token):
raise RuntimeError("Enrollment approval response missing tokens or guid") raise RuntimeError("Enrollment approval response missing tokens or guid")
_log_agent(
"Enrollment approval payload received "
f"guid={guid} access_token_len={len(access_token)} refresh_token_len={len(refresh_token)} "
f"expires_in={expires_in}",
fname="agent.log",
)
self.guid = str(guid).strip() self.guid = str(guid).strip()
self.access_token = access_token.strip() self.access_token = access_token.strip()
self.refresh_token = refresh_token.strip() self.refresh_token = refresh_token.strip()
@@ -663,12 +987,17 @@ class AgentHttpClient:
_update_agent_id_for_guid(self.guid) _update_agent_id_for_guid(self.guid)
except Exception as exc: except Exception as exc:
_log_agent(f"Failed to update agent id after enrollment: {exc}", fname="agent.error.log") _log_agent(f"Failed to update agent id after enrollment: {exc}", fname="agent.error.log")
self._consume_installer_code()
_log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log") _log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log")
def refresh_access_token(self) -> None: def refresh_access_token(self) -> None:
with self._auth_lock:
self._refresh_access_token_locked()
def _refresh_access_token_locked(self) -> None:
if not self.refresh_token or not self.guid: if not self.refresh_token or not self.guid:
self.clear_tokens() self._clear_tokens_locked()
self.perform_enrollment() self._perform_enrollment_locked()
return return
payload = {"guid": self.guid, "refresh_token": self.refresh_token} payload = {"guid": self.guid, "refresh_token": self.refresh_token}
resp = self.session.post( resp = self.session.post(
@@ -679,8 +1008,8 @@ class AgentHttpClient:
) )
if resp.status_code in (401, 403): if resp.status_code in (401, 403):
_log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log") _log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log")
self.clear_tokens() self._clear_tokens_locked()
self.perform_enrollment() self._perform_enrollment_locked()
return return
resp.raise_for_status() resp.raise_for_status()
data = resp.json() data = resp.json()
@@ -696,6 +1025,10 @@ class AgentHttpClient:
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"}) self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
def clear_tokens(self) -> None: def clear_tokens(self) -> None:
with self._auth_lock:
self._clear_tokens_locked()
def _clear_tokens_locked(self) -> None:
self.key_store.clear_tokens() self.key_store.clear_tokens()
self.access_token = None self.access_token = None
self.refresh_token = None self.refresh_token = None
@@ -712,6 +1045,19 @@ class AgentHttpClient:
except Exception: except Exception:
return "" return ""
def _consume_installer_code(self) -> None:
# Avoid clearing explicit CLI/env overrides; only mutate persisted config.
self._active_installer_code = None
if INSTALLER_CODE_OVERRIDE:
return
try:
if CONFIG.data.get("installer_code"):
CONFIG.data["installer_code"] = ""
CONFIG._write()
_log_agent("Cleared persisted installer code after successful enrollment", fname="agent.log")
except Exception as exc:
_log_agent(f"Failed to clear installer code after enrollment: {exc}", fname="agent.error.log")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# HTTP helpers # HTTP helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -722,6 +1068,16 @@ class AgentHttpClient:
headers = self.auth_headers() headers = self.auth_headers()
response = self.session.post(url, json=payload, headers=headers, timeout=30) response = self.session.post(url, json=payload, headers=headers, timeout=30)
if response.status_code in (401, 403) and require_auth: if response.status_code in (401, 403) and require_auth:
snippet = ""
try:
snippet = response.text[:256]
except Exception:
snippet = "<unavailable>"
_log_agent(
"Authenticated request rejected "
f"path={path} status={response.status_code} body_snippet={snippet}",
fname="agent.error.log",
)
self.clear_tokens() self.clear_tokens()
self.ensure_authenticated() self.ensure_authenticated()
headers = self.auth_headers() headers = self.auth_headers()

View File

@@ -3,6 +3,8 @@
from __future__ import annotations from __future__ import annotations
import base64 import base64
import contextlib
import errno
import hashlib import hashlib
import json import json
import os import os
@@ -23,6 +25,16 @@ try:
except Exception: # pragma: no cover - win32crypt missing except Exception: # pragma: no cover - win32crypt missing
win32crypt = None # type: ignore win32crypt = None # type: ignore
try: # pragma: no cover - Windows only
import msvcrt # type: ignore
except Exception: # pragma: no cover - non-Windows
msvcrt = None # type: ignore
try: # pragma: no cover - POSIX only
import fcntl # type: ignore
except Exception: # pragma: no cover - Windows
fcntl = None # type: ignore
def _ensure_dir(path: str) -> None: def _ensure_dir(path: str) -> None:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
@@ -36,43 +48,155 @@ def _restrict_permissions(path: str) -> None:
pass pass
class _FileLock:
def __init__(self, path: str) -> None:
self.path = path
self._handle = None
def acquire(self, *, timeout: float = 60.0, poll_interval: float = 0.2) -> None:
directory = os.path.dirname(self.path)
if directory:
os.makedirs(directory, exist_ok=True)
deadline = time.time() + timeout if timeout else None
while True:
handle = open(self.path, "a+b")
try:
self._try_lock(handle)
except OSError as exc:
handle.close()
if not self._is_lock_unavailable(exc):
raise
if deadline and time.time() >= deadline:
raise TimeoutError("Timed out waiting for file lock")
time.sleep(poll_interval)
continue
except Exception:
handle.close()
raise
self._handle = handle
try:
handle.seek(0)
handle.truncate(0)
payload = f"pid={os.getpid()} ts={int(time.time())}\n".encode("utf-8")
handle.write(payload)
handle.flush()
except Exception:
pass
return
def release(self) -> None:
handle = self._handle
if not handle:
return
try:
self._unlock(handle)
finally:
try:
handle.close()
except Exception:
pass
self._handle = None
def _try_lock(self, handle):
if IS_WINDOWS:
if msvcrt is None:
raise OSError(errno.EINVAL, "msvcrt unavailable for locking")
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1) # type: ignore[attr-defined]
except OSError as exc: # pragma: no cover - platform specific
raise exc
else:
if fcntl is None:
raise OSError(errno.EINVAL, "fcntl unavailable for locking")
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) # type: ignore[attr-defined]
def _unlock(self, handle):
if IS_WINDOWS:
if msvcrt is None:
return
try:
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1) # type: ignore[attr-defined]
except OSError: # pragma: no cover - platform specific
pass
else:
if fcntl is None:
return
try:
fcntl.flock(handle.fileno(), fcntl.LOCK_UN) # type: ignore[attr-defined]
except OSError:
pass
@staticmethod
def _is_lock_unavailable(exc: OSError) -> bool:
if hasattr(exc, "winerror"):
return exc.winerror in (32, 33) # type: ignore[attr-defined]
err = exc.errno if hasattr(exc, "errno") else None
return err in (errno.EACCES, errno.EAGAIN, errno.EWOULDBLOCK)
@contextlib.contextmanager
def _locked_file(path: str, *, timeout: float = 60.0):
lock = _FileLock(path)
lock.acquire(timeout=timeout)
try:
yield
finally:
lock.release()
def _protect(data: bytes, *, scope_system: bool) -> bytes: def _protect(data: bytes, *, scope_system: bool) -> bytes:
if not IS_WINDOWS or not win32crypt: if not IS_WINDOWS or not win32crypt:
return data return data
flags = 0 scopes = [scope_system]
# Always include the alternate scope so we can fall back if the preferred
# protection attempt fails (e.g., running under a limited account that
# lacks access to the desired DPAPI scope).
if scope_system: if scope_system:
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4) scopes.append(False)
try: else:
protected = win32crypt.CryptProtectData(data, None, None, None, None, flags) # type: ignore[attr-defined] scopes.append(True)
except Exception: for scope in scopes:
return data flags = 0
blob = protected[1] if scope:
if isinstance(blob, memoryview): flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
return blob.tobytes() try:
if isinstance(blob, bytearray): protected = win32crypt.CryptProtectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
return bytes(blob) except Exception:
if isinstance(blob, bytes): continue
return blob blob = protected[1]
if isinstance(blob, memoryview):
return blob.tobytes()
if isinstance(blob, bytearray):
return bytes(blob)
if isinstance(blob, bytes):
return blob
return data return data
def _unprotect(data: bytes, *, scope_system: bool) -> bytes: def _unprotect(data: bytes, *, scope_system: bool) -> bytes:
if not IS_WINDOWS or not win32crypt: if not IS_WINDOWS or not win32crypt:
return data return data
flags = 0 scopes = [scope_system]
if scope_system: if scope_system:
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4) scopes.append(False)
try: else:
unwrapped = win32crypt.CryptUnprotectData(data, None, None, None, None, flags) # type: ignore[attr-defined] scopes.append(True)
except Exception: for scope in scopes:
return data flags = 0
blob = unwrapped[1] if scope:
if isinstance(blob, memoryview): flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
return blob.tobytes() try:
if isinstance(blob, bytearray): unwrapped = win32crypt.CryptUnprotectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
return bytes(blob) except Exception:
if isinstance(blob, bytes): continue
return blob blob = unwrapped[1]
if isinstance(blob, memoryview):
return blob.tobytes()
if isinstance(blob, bytearray):
return bytes(blob)
if isinstance(blob, bytes):
return blob
return data return data
@@ -105,17 +229,21 @@ class AgentKeyStore:
self._token_meta_path = os.path.join(self.settings_dir, "access.meta.json") self._token_meta_path = os.path.join(self.settings_dir, "access.meta.json")
self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem") self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem")
self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub") self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub")
self._identity_lock_path = os.path.join(self.settings_dir, "identity.lock")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Identity management # Identity management
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def load_or_create_identity(self) -> AgentIdentity: def load_or_create_identity(self) -> AgentIdentity:
if os.path.isfile(self._private_path) and os.path.isfile(self._public_path): with _locked_file(self._identity_lock_path, timeout=120.0):
try: if os.path.isfile(self._private_path) and os.path.isfile(self._public_path):
return self._load_identity() try:
except Exception: return self._load_identity()
pass except Exception:
return self._create_identity() # If loading fails, fall back to regenerating the identity while
# holding the lock so concurrent agents do not thrash the key files.
pass
return self._create_identity()
def _load_identity(self) -> AgentIdentity: def _load_identity(self) -> AgentIdentity:
with open(self._private_path, "rb") as fh: with open(self._private_path, "rb") as fh:
@@ -212,11 +340,23 @@ class AgentKeyStore:
try: try:
with open(self._refresh_token_path, "rb") as fh: with open(self._refresh_token_path, "rb") as fh:
protected = fh.read() protected = fh.read()
raw = _unprotect(protected, scope_system=self.scope_system)
return raw.decode("utf-8")
except Exception: except Exception:
return None return None
# Try both scopes (preferred first) and decode once a UTF-8 payload is recovered.
for scope_first in (self.scope_system, not self.scope_system):
try:
candidate = _unprotect(protected, scope_system=scope_first)
except Exception:
continue
if not candidate:
continue
try:
return candidate.decode("utf-8")
except Exception:
continue
return None
def clear_tokens(self) -> None: def clear_tokens(self) -> None:
for path in (self._access_token_path, self._refresh_token_path, self._token_meta_path): for path in (self._access_token_path, self._refresh_token_path, self._token_meta_path):
try: try:

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import json import json
import time import time
import sqlite3
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional
from flask import Blueprint, jsonify, request, g from flask import Blueprint, jsonify, request, g
@@ -28,10 +29,18 @@ def register(
except Exception: except Exception:
return None return None
def _auth_context():
ctx = getattr(g, "device_auth", None)
if ctx is None:
log("server", f"device auth context missing for {request.path}")
return ctx
@blueprint.route("/api/agent/heartbeat", methods=["POST"]) @blueprint.route("/api/agent/heartbeat", methods=["POST"])
@require_device_auth(auth_manager) @require_device_auth(auth_manager)
def heartbeat(): def heartbeat():
ctx = getattr(g, "device_auth") ctx = _auth_context()
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
payload = request.get_json(force=True, silent=True) or {} payload = request.get_json(force=True, silent=True) or {}
now_ts = int(time.time()) now_ts = int(time.time())
@@ -71,14 +80,42 @@ def register(
conn = db_conn_factory() conn = db_conn_factory()
try: try:
cur = conn.cursor() cur = conn.cursor()
columns = ", ".join(f"{col} = ?" for col in updates.keys())
params = list(updates.values()) def _apply_updates() -> int:
params.append(ctx.guid) if not updates:
cur.execute( return 0
f"UPDATE devices SET {columns} WHERE guid = ?", columns = ", ".join(f"{col} = ?" for col in updates.keys())
params, params = list(updates.values())
) params.append(ctx.guid)
if cur.rowcount == 0: cur.execute(
f"UPDATE devices SET {columns} WHERE guid = ?",
params,
)
return cur.rowcount
try:
rowcount = _apply_updates()
except sqlite3.IntegrityError as exc:
if "devices.hostname" in str(exc) and "UNIQUE" in str(exc).upper():
# Another device already claims this hostname; keep the existing
# canonical hostname assigned during enrollment to avoid breaking
# the unique constraint and continue updating the remaining fields.
if "hostname" in updates:
updates.pop("hostname", None)
try:
rowcount = _apply_updates()
except sqlite3.IntegrityError:
raise
else:
log(
"server",
"heartbeat hostname collision ignored for guid="
f"{ctx.guid}",
)
else:
raise
if rowcount == 0:
log("server", f"heartbeat missing device record guid={ctx.guid}") log("server", f"heartbeat missing device record guid={ctx.guid}")
return jsonify({"error": "device_not_registered"}), 404 return jsonify({"error": "device_not_registered"}), 404
conn.commit() conn.commit()
@@ -90,7 +127,9 @@ def register(
@blueprint.route("/api/agent/script/request", methods=["POST"]) @blueprint.route("/api/agent/script/request", methods=["POST"])
@require_device_auth(auth_manager) @require_device_auth(auth_manager)
def script_request(): def script_request():
ctx = getattr(g, "device_auth") ctx = _auth_context()
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
if ctx.status != "active": if ctx.status != "active":
return jsonify( return jsonify(
{ {

View File

@@ -54,6 +54,10 @@ def register(
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float): def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
decision = limiter.check(key, limit, window_s) decision = limiter.check(key, limit, window_s)
if not decision.allowed: if not decision.allowed:
log(
"server",
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
)
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after}) response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
response.status_code = 429 response.status_code = 429
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}" response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
@@ -128,19 +132,63 @@ def register(
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]: def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
cur.execute( cur.execute(
"SELECT guid, hostname, token_version, status, ssl_key_fingerprint FROM devices WHERE guid = ?", """
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
FROM devices
WHERE guid = ?
""",
(guid,), (guid,),
) )
row = cur.fetchone() row = cur.fetchone()
if row: if row:
keys = ["guid", "hostname", "token_version", "status", "ssl_key_fingerprint"] keys = [
"guid",
"hostname",
"token_version",
"status",
"ssl_key_fingerprint",
"key_added_at",
]
record = dict(zip(keys, row)) record = dict(zip(keys, row))
if not record.get("ssl_key_fingerprint"): stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower()
new_fp = (fingerprint or "").strip().lower()
if not stored_fp and new_fp:
cur.execute( cur.execute(
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?", "UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
(fingerprint, _iso(_now()), guid), (fingerprint, _iso(_now()), guid),
) )
record["ssl_key_fingerprint"] = fingerprint record["ssl_key_fingerprint"] = fingerprint
elif new_fp and stored_fp != new_fp:
now_iso = _iso(_now())
try:
current_version = int(record.get("token_version") or 1)
except Exception:
current_version = 1
new_version = max(current_version + 1, 1)
cur.execute(
"""
UPDATE devices
SET ssl_key_fingerprint = ?,
key_added_at = ?,
token_version = ?,
status = 'active'
WHERE guid = ?
""",
(fingerprint, now_iso, new_version, guid),
)
cur.execute(
"""
UPDATE refresh_tokens
SET revoked_at = ?
WHERE guid = ?
AND revoked_at IS NULL
""",
(now_iso, guid),
)
record["ssl_key_fingerprint"] = fingerprint
record["token_version"] = new_version
record["status"] = "active"
record["key_added_at"] = now_iso
return record return record
resolved_hostname = _normalize_host(hostname, guid, cur) resolved_hostname = _normalize_host(hostname, guid, cur)
@@ -169,6 +217,7 @@ def register(
"token_version": 1, "token_version": 1,
"status": "active", "status": "active",
"ssl_key_fingerprint": fingerprint, "ssl_key_fingerprint": fingerprint,
"key_added_at": key_added_at,
} }
def _hash_refresh_token(token: str) -> str: def _hash_refresh_token(token: str) -> str:
@@ -198,7 +247,7 @@ def register(
@blueprint.route("/api/agent/enroll/request", methods=["POST"]) @blueprint.route("/api/agent/enroll/request", methods=["POST"])
def enrollment_request(): def enrollment_request():
remote = _remote_addr() remote = _remote_addr()
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 10, 60.0) rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0)
if rate_error: if rate_error:
return rate_error return rate_error
@@ -208,32 +257,47 @@ def register(
agent_pubkey_b64 = payload.get("agent_pubkey") agent_pubkey_b64 = payload.get("agent_pubkey")
client_nonce_b64 = payload.get("client_nonce") client_nonce_b64 = payload.get("client_nonce")
log(
"server",
"enrollment request received "
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
)
if not hostname: if not hostname:
log("server", f"enrollment rejected missing_hostname ip={remote}")
return jsonify({"error": "hostname_required"}), 400 return jsonify({"error": "hostname_required"}), 400
if not enrollment_code: if not enrollment_code:
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}")
return jsonify({"error": "enrollment_code_required"}), 400 return jsonify({"error": "enrollment_code_required"}), 400
if not isinstance(agent_pubkey_b64, str): if not isinstance(agent_pubkey_b64, str):
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}")
return jsonify({"error": "agent_pubkey_required"}), 400 return jsonify({"error": "agent_pubkey_required"}), 400
if not isinstance(client_nonce_b64, str): if not isinstance(client_nonce_b64, str):
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}")
return jsonify({"error": "client_nonce_required"}), 400 return jsonify({"error": "client_nonce_required"}), 400
try: try:
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64) agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
except Exception: except Exception:
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}")
return jsonify({"error": "invalid_agent_pubkey"}), 400 return jsonify({"error": "invalid_agent_pubkey"}), 400
if len(agent_pubkey_der) < 10: if len(agent_pubkey_der) < 10:
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}")
return jsonify({"error": "invalid_agent_pubkey"}), 400 return jsonify({"error": "invalid_agent_pubkey"}), 400
try: try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True) client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception: except Exception:
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}")
return jsonify({"error": "invalid_client_nonce"}), 400 return jsonify({"error": "invalid_client_nonce"}), 400
if len(client_nonce_bytes) < 16: if len(client_nonce_bytes) < 16:
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}")
return jsonify({"error": "invalid_client_nonce"}), 400 return jsonify({"error": "invalid_client_nonce"}), 400
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der) fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 3, 60.0) rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0)
if rate_error: if rate_error:
return rate_error return rate_error
@@ -333,21 +397,33 @@ def register(
client_nonce_b64 = payload.get("client_nonce") client_nonce_b64 = payload.get("client_nonce")
proof_sig_b64 = payload.get("proof_sig") proof_sig_b64 = payload.get("proof_sig")
log(
"server",
"enrollment poll received "
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
f" proof_sig_len={len(proof_sig_b64 or '')}",
)
if not isinstance(approval_reference, str) or not approval_reference: if not isinstance(approval_reference, str) or not approval_reference:
log("server", "enrollment poll rejected missing_reference")
return jsonify({"error": "approval_reference_required"}), 400 return jsonify({"error": "approval_reference_required"}), 400
if not isinstance(client_nonce_b64, str): if not isinstance(client_nonce_b64, str):
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}")
return jsonify({"error": "client_nonce_required"}), 400 return jsonify({"error": "client_nonce_required"}), 400
if not isinstance(proof_sig_b64, str): if not isinstance(proof_sig_b64, str):
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}")
return jsonify({"error": "proof_sig_required"}), 400 return jsonify({"error": "proof_sig_required"}), 400
try: try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True) client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception: except Exception:
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}")
return jsonify({"error": "invalid_client_nonce"}), 400 return jsonify({"error": "invalid_client_nonce"}), 400
try: try:
proof_sig = base64.b64decode(proof_sig_b64, validate=True) proof_sig = base64.b64decode(proof_sig_b64, validate=True)
except Exception: except Exception:
log("server", f"enrollment poll invalid_sig ref={approval_reference}")
return jsonify({"error": "invalid_proof_sig"}), 400 return jsonify({"error": "invalid_proof_sig"}), 400
conn = db_conn_factory() conn = db_conn_factory()
@@ -365,6 +441,7 @@ def register(
) )
row = cur.fetchone() row = cur.fetchone()
if not row: if not row:
log("server", f"enrollment poll unknown_reference ref={approval_reference}")
return jsonify({"status": "unknown"}), 404 return jsonify({"status": "unknown"}), 404
( (
@@ -383,11 +460,13 @@ def register(
) = row ) = row
if client_nonce_stored != client_nonce_b64: if client_nonce_stored != client_nonce_b64:
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}")
return jsonify({"error": "nonce_mismatch"}), 400 return jsonify({"error": "nonce_mismatch"}), 400
try: try:
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True) server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
except Exception: except Exception:
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}")
return jsonify({"error": "server_nonce_invalid"}), 400 return jsonify({"error": "server_nonce_invalid"}), 400
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
@@ -395,30 +474,58 @@ def register(
try: try:
public_key = serialization.load_der_public_key(agent_pubkey_der) public_key = serialization.load_der_public_key(agent_pubkey_der)
except Exception: except Exception:
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}")
public_key = None public_key = None
if public_key is None: if public_key is None:
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}")
return jsonify({"error": "agent_pubkey_invalid"}), 400 return jsonify({"error": "agent_pubkey_invalid"}), 400
try: try:
public_key.verify(proof_sig, message) public_key.verify(proof_sig, message)
except Exception: except Exception:
log("server", f"enrollment poll invalid_proof ref={approval_reference}")
return jsonify({"error": "invalid_proof"}), 400 return jsonify({"error": "invalid_proof"}), 400
if status == "pending": if status == "pending":
log(
"server",
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
f" fingerprint={fingerprint[:12]}",
)
return jsonify({"status": "pending", "poll_after_ms": 5000}) return jsonify({"status": "pending", "poll_after_ms": 5000})
if status == "denied": if status == "denied":
log(
"server",
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
)
return jsonify({"status": "denied", "reason": "operator_denied"}) return jsonify({"status": "denied", "reason": "operator_denied"})
if status == "expired": if status == "expired":
log(
"server",
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
)
return jsonify({"status": "expired"}) return jsonify({"status": "expired"})
if status == "completed": if status == "completed":
log(
"server",
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
)
return jsonify({"status": "approved", "detail": "finalized"}) return jsonify({"status": "approved", "detail": "finalized"})
if status != "approved": if status != "approved":
log(
"server",
f"enrollment poll unexpected_status={status} ref={approval_reference}",
)
return jsonify({"status": status or "unknown"}), 400 return jsonify({"status": status or "unknown"}), 400
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}" nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
if not nonce_cache.consume(nonce_key): if not nonce_cache.consume(nonce_key):
log(
"server",
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
)
return jsonify({"error": "proof_replayed"}), 409 return jsonify({"error": "proof_replayed"}), 409
# Finalize enrollment # Finalize enrollment
@@ -489,3 +596,12 @@ def _load_tls_bundle(path: str) -> str:
return fh.read() return fh.read()
except Exception: except Exception:
return "" return ""
def _mask_code(code: str) -> str:
if not code:
return "<missing>"
trimmed = str(code).strip()
if len(trimmed) <= 6:
return "***"
return f"{trimmed[:3]}***{trimmed[-3:]}"