mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 17:41:58 -06:00
Merge pull request #127 from bunny-lab-io:codex/implement-security-features-for-borealis
Harden agent WebSocket TLS verification
This commit is contained in:
@@ -20,7 +20,10 @@ import datetime
|
||||
import shutil
|
||||
import string
|
||||
import ssl
|
||||
from typing import Any, Dict, Optional, List
|
||||
import threading
|
||||
import contextlib
|
||||
import errno
|
||||
from typing import Any, Dict, Optional, List, Callable
|
||||
|
||||
import requests
|
||||
try:
|
||||
@@ -132,6 +135,133 @@ def _settings_dir():
|
||||
return os.path.abspath(os.path.join(os.path.dirname(__file__), 'Settings'))
|
||||
|
||||
|
||||
class _CrossProcessFileLock:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self._handle = None
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
*,
|
||||
timeout: float = 120.0,
|
||||
poll_interval: float = 0.5,
|
||||
on_wait: Optional[Callable[[], None]] = None,
|
||||
) -> None:
|
||||
directory = os.path.dirname(self.path)
|
||||
if directory:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
deadline = time.time() + timeout if timeout else None
|
||||
last_wait_log = 0.0
|
||||
while True:
|
||||
handle = open(self.path, 'a+b')
|
||||
try:
|
||||
self._try_lock(handle)
|
||||
self._handle = handle
|
||||
try:
|
||||
handle.seek(0)
|
||||
handle.truncate(0)
|
||||
handle.write(f"pid={os.getpid()} ts={int(time.time())}\n".encode('utf-8'))
|
||||
handle.flush()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
except OSError as exc:
|
||||
handle.close()
|
||||
if not self._is_lock_unavailable(exc):
|
||||
raise
|
||||
now = time.time()
|
||||
if on_wait and (now - last_wait_log) >= 2.0:
|
||||
try:
|
||||
on_wait()
|
||||
except Exception:
|
||||
pass
|
||||
last_wait_log = now
|
||||
if deadline and now >= deadline:
|
||||
raise TimeoutError('Timed out waiting for enrollment lock')
|
||||
time.sleep(poll_interval)
|
||||
except Exception:
|
||||
handle.close()
|
||||
raise
|
||||
|
||||
def release(self) -> None:
|
||||
handle = self._handle
|
||||
if not handle:
|
||||
return
|
||||
try:
|
||||
self._unlock(handle)
|
||||
finally:
|
||||
try:
|
||||
handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._handle = None
|
||||
|
||||
@staticmethod
|
||||
def _is_lock_unavailable(exc: OSError) -> bool:
|
||||
err = exc.errno
|
||||
winerr = getattr(exc, 'winerror', None)
|
||||
unavailable = {errno.EACCES, errno.EAGAIN, getattr(errno, 'EWOULDBLOCK', errno.EAGAIN)}
|
||||
if err in unavailable:
|
||||
return True
|
||||
if winerr in (32, 33):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _try_lock(handle) -> None:
|
||||
handle.seek(0, os.SEEK_END)
|
||||
if handle.tell() == 0:
|
||||
try:
|
||||
handle.write(b'0')
|
||||
handle.flush()
|
||||
except Exception:
|
||||
pass
|
||||
handle.seek(0)
|
||||
if os.name == 'nt':
|
||||
import msvcrt # type: ignore
|
||||
|
||||
try:
|
||||
msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1)
|
||||
except OSError:
|
||||
raise
|
||||
else:
|
||||
import fcntl # type: ignore
|
||||
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
|
||||
@staticmethod
|
||||
def _unlock(handle) -> None:
|
||||
if os.name == 'nt':
|
||||
import msvcrt # type: ignore
|
||||
|
||||
try:
|
||||
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
import fcntl # type: ignore
|
||||
|
||||
try:
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
_ENROLLMENT_FILE_LOCK: Optional[_CrossProcessFileLock] = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _acquire_enrollment_lock(*, timeout: float = 180.0, on_wait: Optional[Callable[[], None]] = None):
|
||||
global _ENROLLMENT_FILE_LOCK
|
||||
if _ENROLLMENT_FILE_LOCK is None:
|
||||
_ENROLLMENT_FILE_LOCK = _CrossProcessFileLock(os.path.join(_settings_dir(), 'enrollment.lock'))
|
||||
_ENROLLMENT_FILE_LOCK.acquire(timeout=timeout, on_wait=on_wait)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_ENROLLMENT_FILE_LOCK.release()
|
||||
|
||||
|
||||
_KEY_STORE_INSTANCE = None
|
||||
|
||||
|
||||
@@ -242,6 +372,18 @@ def _log_agent(message: str, fname: str = 'agent.log'):
|
||||
pass
|
||||
|
||||
|
||||
def _mask_sensitive(value: str, *, prefix: int = 4, suffix: int = 4) -> str:
|
||||
try:
|
||||
if not value:
|
||||
return ''
|
||||
trimmed = value.strip()
|
||||
if len(trimmed) <= prefix + suffix:
|
||||
return '*' * len(trimmed)
|
||||
return f"{trimmed[:prefix]}***{trimmed[-suffix:]}"
|
||||
except Exception:
|
||||
return '***'
|
||||
|
||||
|
||||
def _decode_base64_text(value):
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
@@ -490,14 +632,15 @@ class AgentHttpClient:
|
||||
self.identity = IDENTITY
|
||||
self.session = requests.Session()
|
||||
self.base_url: Optional[str] = None
|
||||
self.guid: Optional[str] = self.key_store.load_guid()
|
||||
self.access_token: Optional[str] = self.key_store.load_access_token()
|
||||
self.refresh_token: Optional[str] = self.key_store.load_refresh_token()
|
||||
self.access_expires_at: Optional[int] = self.key_store.get_access_expiry()
|
||||
self.guid: Optional[str] = None
|
||||
self.access_token: Optional[str] = None
|
||||
self.refresh_token: Optional[str] = None
|
||||
self.access_expires_at: Optional[int] = None
|
||||
self._auth_lock = threading.RLock()
|
||||
self._active_installer_code: Optional[str] = None
|
||||
self.refresh_base_url()
|
||||
self._configure_verify()
|
||||
if self.access_token:
|
||||
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
|
||||
self._reload_tokens_from_disk()
|
||||
self.session.headers.setdefault("User-Agent", "Borealis-Agent/secure")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -528,6 +671,31 @@ class AgentHttpClient:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _reload_tokens_from_disk(self) -> None:
|
||||
guid = self.key_store.load_guid()
|
||||
access_token = self.key_store.load_access_token()
|
||||
refresh_token = self.key_store.load_refresh_token()
|
||||
access_expiry = self.key_store.get_access_expiry()
|
||||
self.guid = guid if guid else None
|
||||
self.access_token = access_token if access_token else None
|
||||
self.refresh_token = refresh_token if refresh_token else None
|
||||
self.access_expires_at = access_expiry if access_expiry else None
|
||||
if self.access_token:
|
||||
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
|
||||
else:
|
||||
self.session.headers.pop("Authorization", None)
|
||||
try:
|
||||
_log_agent(
|
||||
"Reloaded tokens from disk "
|
||||
f"guid={'yes' if self.guid else 'no'} "
|
||||
f"access={'yes' if self.access_token else 'no'} "
|
||||
f"refresh={'yes' if self.refresh_token else 'no'} "
|
||||
f"expiry={self.access_expires_at}",
|
||||
fname="agent.log",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def auth_headers(self) -> Dict[str, str]:
|
||||
if self.access_token:
|
||||
return {"Authorization": f"Bearer {self.access_token}"}
|
||||
@@ -540,12 +708,28 @@ class AgentHttpClient:
|
||||
engine = getattr(client, "eio", None)
|
||||
if engine is None:
|
||||
return
|
||||
# python-engineio accepts bool, path, or ssl.SSLContext for ssl_verify
|
||||
|
||||
# python-engineio accepts either a boolean or an ``ssl.SSLContext``
|
||||
# for TLS verification. When we have a pinned certificate bundle
|
||||
# on disk, prefer constructing a dedicated context that trusts that
|
||||
# bundle so WebSocket connections succeed even with private CAs.
|
||||
if isinstance(verify, str) and os.path.isfile(verify):
|
||||
engine.ssl_verify = verify
|
||||
try:
|
||||
context = ssl.create_default_context(cafile=verify)
|
||||
context.check_hostname = False
|
||||
except Exception:
|
||||
context = None
|
||||
if context is not None:
|
||||
engine.ssl_context = context
|
||||
engine.ssl_verify = True
|
||||
else:
|
||||
engine.ssl_context = None
|
||||
engine.ssl_verify = verify
|
||||
elif verify is False:
|
||||
engine.ssl_context = None
|
||||
engine.ssl_verify = False
|
||||
else:
|
||||
engine.ssl_context = None
|
||||
engine.ssl_verify = True
|
||||
except Exception:
|
||||
pass
|
||||
@@ -554,11 +738,15 @@ class AgentHttpClient:
|
||||
# Enrollment & token management
|
||||
# ------------------------------------------------------------------
|
||||
def ensure_authenticated(self) -> None:
|
||||
with self._auth_lock:
|
||||
self._ensure_authenticated_locked()
|
||||
|
||||
def _ensure_authenticated_locked(self) -> None:
|
||||
self.refresh_base_url()
|
||||
if not self.guid or not self.refresh_token:
|
||||
self.perform_enrollment()
|
||||
self._perform_enrollment_locked()
|
||||
if not self.access_token or self._token_expiring_soon():
|
||||
self.refresh_access_token()
|
||||
self._refresh_access_token_locked()
|
||||
|
||||
def _token_expiring_soon(self) -> bool:
|
||||
if not self.access_token:
|
||||
@@ -568,69 +756,199 @@ class AgentHttpClient:
|
||||
return (self.access_expires_at - time.time()) < 60
|
||||
|
||||
def perform_enrollment(self) -> None:
|
||||
with self._auth_lock:
|
||||
self._perform_enrollment_locked()
|
||||
|
||||
def _perform_enrollment_locked(self) -> None:
|
||||
self._reload_tokens_from_disk()
|
||||
if self.guid and self.refresh_token:
|
||||
return
|
||||
code = self._resolve_installer_code()
|
||||
if not code:
|
||||
raise RuntimeError(
|
||||
"Installer code is required for enrollment. "
|
||||
"Set BOREALIS_INSTALLER_CODE, pass --installer-code, or update agent_settings.json."
|
||||
)
|
||||
self.refresh_base_url()
|
||||
client_nonce = os.urandom(32)
|
||||
payload = {
|
||||
"hostname": socket.gethostname(),
|
||||
"enrollment_code": code,
|
||||
"agent_pubkey": PUBLIC_KEY_B64,
|
||||
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
|
||||
}
|
||||
request_url = f"{self.base_url}/api/agent/enroll/request"
|
||||
_log_agent("Starting enrollment request...", fname="agent.log")
|
||||
resp = self.session.post(request_url, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if data.get("server_certificate"):
|
||||
self.key_store.save_server_certificate(data["server_certificate"])
|
||||
self._configure_verify()
|
||||
signing_key = data.get("signing_key")
|
||||
if signing_key:
|
||||
try:
|
||||
self.store_server_signing_key(signing_key)
|
||||
except Exception as exc:
|
||||
_log_agent(f'Unable to persist signing key from enrollment handshake: {exc}', fname='agent.error.log')
|
||||
if data.get("status") != "pending":
|
||||
raise RuntimeError(f"Unexpected enrollment status: {data}")
|
||||
approval_reference = data.get("approval_reference")
|
||||
server_nonce_b64 = data.get("server_nonce")
|
||||
if not approval_reference or not server_nonce_b64:
|
||||
raise RuntimeError("Enrollment response missing approval_reference or server_nonce")
|
||||
server_nonce = base64.b64decode(server_nonce_b64)
|
||||
poll_delay = max(int(data.get("poll_after_ms", 3000)) / 1000, 1)
|
||||
while True:
|
||||
time.sleep(min(poll_delay, 15))
|
||||
signature = self.identity.sign(server_nonce + approval_reference.encode("utf-8") + client_nonce)
|
||||
poll_payload = {
|
||||
"approval_reference": approval_reference,
|
||||
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
|
||||
"proof_sig": base64.b64encode(signature).decode("ascii"),
|
||||
}
|
||||
poll_resp = self.session.post(
|
||||
f"{self.base_url}/api/agent/enroll/poll",
|
||||
json=poll_payload,
|
||||
timeout=30,
|
||||
self._active_installer_code = code
|
||||
|
||||
wait_state = {"count": 0, "tokens_seen": False}
|
||||
|
||||
def _on_lock_wait() -> None:
|
||||
wait_state["count"] += 1
|
||||
_log_agent(
|
||||
f"Enrollment waiting for shared lock scope={SERVICE_MODE} attempt={wait_state['count']}",
|
||||
fname="agent.log",
|
||||
)
|
||||
poll_resp.raise_for_status()
|
||||
poll_data = poll_resp.json()
|
||||
status = poll_data.get("status")
|
||||
if status == "pending":
|
||||
poll_delay = max(int(poll_data.get("poll_after_ms", 5000)) / 1000, 1)
|
||||
continue
|
||||
if status == "denied":
|
||||
raise RuntimeError("Enrollment denied by operator")
|
||||
if status in ("expired", "unknown"):
|
||||
raise RuntimeError(f"Enrollment failed with status={status}")
|
||||
if status in ("approved", "completed"):
|
||||
self._finalize_enrollment(poll_data)
|
||||
break
|
||||
raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}")
|
||||
if not wait_state["tokens_seen"]:
|
||||
self._reload_tokens_from_disk()
|
||||
if self.guid and self.refresh_token:
|
||||
wait_state["tokens_seen"] = True
|
||||
_log_agent(
|
||||
"Enrollment credentials detected while waiting for lock; will reuse when available",
|
||||
fname="agent.log",
|
||||
)
|
||||
|
||||
try:
|
||||
with _acquire_enrollment_lock(timeout=180.0, on_wait=_on_lock_wait):
|
||||
self._reload_tokens_from_disk()
|
||||
if self.guid and self.refresh_token:
|
||||
_log_agent(
|
||||
"Enrollment skipped after acquiring lock; credentials already present",
|
||||
fname="agent.log",
|
||||
)
|
||||
return
|
||||
|
||||
self.refresh_base_url()
|
||||
base_url = self.base_url or "https://localhost:5000"
|
||||
code_masked = _mask_sensitive(code)
|
||||
_log_agent(
|
||||
"Enrollment handshake starting "
|
||||
f"base_url={base_url} scope={SERVICE_MODE} "
|
||||
f"fingerprint={SSL_KEY_FINGERPRINT[:16]} installer_code={code_masked}",
|
||||
fname="agent.log",
|
||||
)
|
||||
client_nonce = os.urandom(32)
|
||||
payload = {
|
||||
"hostname": socket.gethostname(),
|
||||
"enrollment_code": code,
|
||||
"agent_pubkey": PUBLIC_KEY_B64,
|
||||
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
|
||||
}
|
||||
request_url = f"{self.base_url}/api/agent/enroll/request"
|
||||
_log_agent(
|
||||
"Starting enrollment request... "
|
||||
f"url={request_url} hostname={payload['hostname']} pubkey_prefix={PUBLIC_KEY_B64[:24]}",
|
||||
fname="agent.log",
|
||||
)
|
||||
resp = self.session.post(request_url, json=payload, timeout=30)
|
||||
_log_agent(
|
||||
f"Enrollment request HTTP status={resp.status_code} retry_after={resp.headers.get('Retry-After')}"
|
||||
f" body_len={len(resp.content)}",
|
||||
fname="agent.log",
|
||||
)
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except requests.HTTPError:
|
||||
snippet = resp.text[:512] if hasattr(resp, "text") else ""
|
||||
_log_agent(
|
||||
f"Enrollment request failed status={resp.status_code} body_snippet={snippet}",
|
||||
fname="agent.error.log",
|
||||
)
|
||||
if resp.status_code == 400:
|
||||
try:
|
||||
err_payload = resp.json()
|
||||
except Exception:
|
||||
err_payload = {}
|
||||
if (err_payload or {}).get("error") in {"invalid_enrollment_code", "code_consumed"}:
|
||||
self._reload_tokens_from_disk()
|
||||
if self.guid and self.refresh_token:
|
||||
_log_agent(
|
||||
"Enrollment code rejected but existing credentials are present; skipping re-enrollment",
|
||||
fname="agent.log",
|
||||
)
|
||||
return
|
||||
raise
|
||||
data = resp.json()
|
||||
_log_agent(
|
||||
"Enrollment request accepted "
|
||||
f"status={data.get('status')} approval_ref={data.get('approval_reference')} "
|
||||
f"poll_after_ms={data.get('poll_after_ms')}"
|
||||
f" server_cert={'yes' if data.get('server_certificate') else 'no'}",
|
||||
fname="agent.log",
|
||||
)
|
||||
if data.get("server_certificate"):
|
||||
self.key_store.save_server_certificate(data["server_certificate"])
|
||||
self._configure_verify()
|
||||
signing_key = data.get("signing_key")
|
||||
if signing_key:
|
||||
try:
|
||||
self.store_server_signing_key(signing_key)
|
||||
except Exception as exc:
|
||||
_log_agent(f'Unable to persist signing key from enrollment handshake: {exc}', fname='agent.error.log')
|
||||
if data.get("status") != "pending":
|
||||
raise RuntimeError(f"Unexpected enrollment status: {data}")
|
||||
approval_reference = data.get("approval_reference")
|
||||
server_nonce_b64 = data.get("server_nonce")
|
||||
if not approval_reference or not server_nonce_b64:
|
||||
raise RuntimeError("Enrollment response missing approval_reference or server_nonce")
|
||||
server_nonce = base64.b64decode(server_nonce_b64)
|
||||
poll_delay = max(int(data.get("poll_after_ms", 3000)) / 1000, 1)
|
||||
attempt = 1
|
||||
while True:
|
||||
time.sleep(min(poll_delay, 15))
|
||||
signature = self.identity.sign(server_nonce + approval_reference.encode("utf-8") + client_nonce)
|
||||
poll_payload = {
|
||||
"approval_reference": approval_reference,
|
||||
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
|
||||
"proof_sig": base64.b64encode(signature).decode("ascii"),
|
||||
}
|
||||
_log_agent(
|
||||
f"Enrollment poll attempt={attempt} ref={approval_reference} delay={poll_delay}s",
|
||||
fname="agent.log",
|
||||
)
|
||||
poll_resp = self.session.post(
|
||||
f"{self.base_url}/api/agent/enroll/poll",
|
||||
json=poll_payload,
|
||||
timeout=30,
|
||||
)
|
||||
_log_agent(
|
||||
"Enrollment poll response "
|
||||
f"status_code={poll_resp.status_code} retry_after={poll_resp.headers.get('Retry-After')}"
|
||||
f" body_len={len(poll_resp.content)}",
|
||||
fname="agent.log",
|
||||
)
|
||||
try:
|
||||
poll_resp.raise_for_status()
|
||||
except requests.HTTPError:
|
||||
snippet = poll_resp.text[:512] if hasattr(poll_resp, "text") else ""
|
||||
_log_agent(
|
||||
f"Enrollment poll failed attempt={attempt} status={poll_resp.status_code} "
|
||||
f"body_snippet={snippet}",
|
||||
fname="agent.error.log",
|
||||
)
|
||||
raise
|
||||
poll_data = poll_resp.json()
|
||||
_log_agent(
|
||||
f"Enrollment poll decoded attempt={attempt} status={poll_data.get('status')}"
|
||||
f" next_delay={poll_data.get('poll_after_ms')}"
|
||||
f" guid_hint={poll_data.get('guid')}",
|
||||
fname="agent.log",
|
||||
)
|
||||
status = poll_data.get("status")
|
||||
if status == "pending":
|
||||
poll_delay = max(int(poll_data.get("poll_after_ms", 5000)) / 1000, 1)
|
||||
_log_agent(
|
||||
f"Enrollment still pending attempt={attempt} new_delay={poll_delay}s",
|
||||
fname="agent.log",
|
||||
)
|
||||
attempt += 1
|
||||
continue
|
||||
if status == "denied":
|
||||
_log_agent("Enrollment denied by operator", fname="agent.error.log")
|
||||
raise RuntimeError("Enrollment denied by operator")
|
||||
if status in ("expired", "unknown"):
|
||||
_log_agent(
|
||||
f"Enrollment failed status={status} attempt={attempt}",
|
||||
fname="agent.error.log",
|
||||
)
|
||||
raise RuntimeError(f"Enrollment failed with status={status}")
|
||||
if status in ("approved", "completed"):
|
||||
_log_agent(
|
||||
f"Enrollment approved attempt={attempt} ref={approval_reference}",
|
||||
fname="agent.log",
|
||||
)
|
||||
self._finalize_enrollment(poll_data)
|
||||
break
|
||||
raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}")
|
||||
except TimeoutError:
|
||||
self._reload_tokens_from_disk()
|
||||
if self.guid and self.refresh_token:
|
||||
_log_agent(
|
||||
"Enrollment lock wait timed out but credentials materialized; reusing existing tokens",
|
||||
fname="agent.log",
|
||||
)
|
||||
return
|
||||
raise
|
||||
|
||||
def _finalize_enrollment(self, payload: Dict[str, Any]) -> None:
|
||||
server_cert = payload.get("server_certificate")
|
||||
@@ -649,6 +967,12 @@ class AgentHttpClient:
|
||||
expires_in = int(payload.get("expires_in") or 900)
|
||||
if not (guid and access_token and refresh_token):
|
||||
raise RuntimeError("Enrollment approval response missing tokens or guid")
|
||||
_log_agent(
|
||||
"Enrollment approval payload received "
|
||||
f"guid={guid} access_token_len={len(access_token)} refresh_token_len={len(refresh_token)} "
|
||||
f"expires_in={expires_in}",
|
||||
fname="agent.log",
|
||||
)
|
||||
self.guid = str(guid).strip()
|
||||
self.access_token = access_token.strip()
|
||||
self.refresh_token = refresh_token.strip()
|
||||
@@ -663,12 +987,17 @@ class AgentHttpClient:
|
||||
_update_agent_id_for_guid(self.guid)
|
||||
except Exception as exc:
|
||||
_log_agent(f"Failed to update agent id after enrollment: {exc}", fname="agent.error.log")
|
||||
self._consume_installer_code()
|
||||
_log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log")
|
||||
|
||||
def refresh_access_token(self) -> None:
|
||||
with self._auth_lock:
|
||||
self._refresh_access_token_locked()
|
||||
|
||||
def _refresh_access_token_locked(self) -> None:
|
||||
if not self.refresh_token or not self.guid:
|
||||
self.clear_tokens()
|
||||
self.perform_enrollment()
|
||||
self._clear_tokens_locked()
|
||||
self._perform_enrollment_locked()
|
||||
return
|
||||
payload = {"guid": self.guid, "refresh_token": self.refresh_token}
|
||||
resp = self.session.post(
|
||||
@@ -679,8 +1008,8 @@ class AgentHttpClient:
|
||||
)
|
||||
if resp.status_code in (401, 403):
|
||||
_log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log")
|
||||
self.clear_tokens()
|
||||
self.perform_enrollment()
|
||||
self._clear_tokens_locked()
|
||||
self._perform_enrollment_locked()
|
||||
return
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
@@ -696,6 +1025,10 @@ class AgentHttpClient:
|
||||
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
with self._auth_lock:
|
||||
self._clear_tokens_locked()
|
||||
|
||||
def _clear_tokens_locked(self) -> None:
|
||||
self.key_store.clear_tokens()
|
||||
self.access_token = None
|
||||
self.refresh_token = None
|
||||
@@ -712,6 +1045,19 @@ class AgentHttpClient:
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _consume_installer_code(self) -> None:
|
||||
# Avoid clearing explicit CLI/env overrides; only mutate persisted config.
|
||||
self._active_installer_code = None
|
||||
if INSTALLER_CODE_OVERRIDE:
|
||||
return
|
||||
try:
|
||||
if CONFIG.data.get("installer_code"):
|
||||
CONFIG.data["installer_code"] = ""
|
||||
CONFIG._write()
|
||||
_log_agent("Cleared persisted installer code after successful enrollment", fname="agent.log")
|
||||
except Exception as exc:
|
||||
_log_agent(f"Failed to clear installer code after enrollment: {exc}", fname="agent.error.log")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP helpers
|
||||
# ------------------------------------------------------------------
|
||||
@@ -722,6 +1068,16 @@ class AgentHttpClient:
|
||||
headers = self.auth_headers()
|
||||
response = self.session.post(url, json=payload, headers=headers, timeout=30)
|
||||
if response.status_code in (401, 403) and require_auth:
|
||||
snippet = ""
|
||||
try:
|
||||
snippet = response.text[:256]
|
||||
except Exception:
|
||||
snippet = "<unavailable>"
|
||||
_log_agent(
|
||||
"Authenticated request rejected "
|
||||
f"path={path} status={response.status_code} body_snippet={snippet}",
|
||||
fname="agent.error.log",
|
||||
)
|
||||
self.clear_tokens()
|
||||
self.ensure_authenticated()
|
||||
headers = self.auth_headers()
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import errno
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
@@ -23,6 +25,16 @@ try:
|
||||
except Exception: # pragma: no cover - win32crypt missing
|
||||
win32crypt = None # type: ignore
|
||||
|
||||
try: # pragma: no cover - Windows only
|
||||
import msvcrt # type: ignore
|
||||
except Exception: # pragma: no cover - non-Windows
|
||||
msvcrt = None # type: ignore
|
||||
|
||||
try: # pragma: no cover - POSIX only
|
||||
import fcntl # type: ignore
|
||||
except Exception: # pragma: no cover - Windows
|
||||
fcntl = None # type: ignore
|
||||
|
||||
|
||||
def _ensure_dir(path: str) -> None:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
@@ -36,43 +48,155 @@ def _restrict_permissions(path: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _FileLock:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self._handle = None
|
||||
|
||||
def acquire(self, *, timeout: float = 60.0, poll_interval: float = 0.2) -> None:
|
||||
directory = os.path.dirname(self.path)
|
||||
if directory:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
deadline = time.time() + timeout if timeout else None
|
||||
while True:
|
||||
handle = open(self.path, "a+b")
|
||||
try:
|
||||
self._try_lock(handle)
|
||||
except OSError as exc:
|
||||
handle.close()
|
||||
if not self._is_lock_unavailable(exc):
|
||||
raise
|
||||
if deadline and time.time() >= deadline:
|
||||
raise TimeoutError("Timed out waiting for file lock")
|
||||
time.sleep(poll_interval)
|
||||
continue
|
||||
except Exception:
|
||||
handle.close()
|
||||
raise
|
||||
|
||||
self._handle = handle
|
||||
try:
|
||||
handle.seek(0)
|
||||
handle.truncate(0)
|
||||
payload = f"pid={os.getpid()} ts={int(time.time())}\n".encode("utf-8")
|
||||
handle.write(payload)
|
||||
handle.flush()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
def release(self) -> None:
|
||||
handle = self._handle
|
||||
if not handle:
|
||||
return
|
||||
try:
|
||||
self._unlock(handle)
|
||||
finally:
|
||||
try:
|
||||
handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._handle = None
|
||||
|
||||
def _try_lock(self, handle):
|
||||
if IS_WINDOWS:
|
||||
if msvcrt is None:
|
||||
raise OSError(errno.EINVAL, "msvcrt unavailable for locking")
|
||||
try:
|
||||
msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1) # type: ignore[attr-defined]
|
||||
except OSError as exc: # pragma: no cover - platform specific
|
||||
raise exc
|
||||
else:
|
||||
if fcntl is None:
|
||||
raise OSError(errno.EINVAL, "fcntl unavailable for locking")
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) # type: ignore[attr-defined]
|
||||
|
||||
def _unlock(self, handle):
|
||||
if IS_WINDOWS:
|
||||
if msvcrt is None:
|
||||
return
|
||||
try:
|
||||
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1) # type: ignore[attr-defined]
|
||||
except OSError: # pragma: no cover - platform specific
|
||||
pass
|
||||
else:
|
||||
if fcntl is None:
|
||||
return
|
||||
try:
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_UN) # type: ignore[attr-defined]
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_lock_unavailable(exc: OSError) -> bool:
|
||||
if hasattr(exc, "winerror"):
|
||||
return exc.winerror in (32, 33) # type: ignore[attr-defined]
|
||||
err = exc.errno if hasattr(exc, "errno") else None
|
||||
return err in (errno.EACCES, errno.EAGAIN, errno.EWOULDBLOCK)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _locked_file(path: str, *, timeout: float = 60.0):
|
||||
lock = _FileLock(path)
|
||||
lock.acquire(timeout=timeout)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
def _protect(data: bytes, *, scope_system: bool) -> bytes:
|
||||
if not IS_WINDOWS or not win32crypt:
|
||||
return data
|
||||
flags = 0
|
||||
scopes = [scope_system]
|
||||
# Always include the alternate scope so we can fall back if the preferred
|
||||
# protection attempt fails (e.g., running under a limited account that
|
||||
# lacks access to the desired DPAPI scope).
|
||||
if scope_system:
|
||||
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
|
||||
try:
|
||||
protected = win32crypt.CryptProtectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return data
|
||||
blob = protected[1]
|
||||
if isinstance(blob, memoryview):
|
||||
return blob.tobytes()
|
||||
if isinstance(blob, bytearray):
|
||||
return bytes(blob)
|
||||
if isinstance(blob, bytes):
|
||||
return blob
|
||||
scopes.append(False)
|
||||
else:
|
||||
scopes.append(True)
|
||||
for scope in scopes:
|
||||
flags = 0
|
||||
if scope:
|
||||
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
|
||||
try:
|
||||
protected = win32crypt.CryptProtectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
continue
|
||||
blob = protected[1]
|
||||
if isinstance(blob, memoryview):
|
||||
return blob.tobytes()
|
||||
if isinstance(blob, bytearray):
|
||||
return bytes(blob)
|
||||
if isinstance(blob, bytes):
|
||||
return blob
|
||||
return data
|
||||
|
||||
|
||||
def _unprotect(data: bytes, *, scope_system: bool) -> bytes:
|
||||
if not IS_WINDOWS or not win32crypt:
|
||||
return data
|
||||
flags = 0
|
||||
scopes = [scope_system]
|
||||
if scope_system:
|
||||
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
|
||||
try:
|
||||
unwrapped = win32crypt.CryptUnprotectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return data
|
||||
blob = unwrapped[1]
|
||||
if isinstance(blob, memoryview):
|
||||
return blob.tobytes()
|
||||
if isinstance(blob, bytearray):
|
||||
return bytes(blob)
|
||||
if isinstance(blob, bytes):
|
||||
return blob
|
||||
scopes.append(False)
|
||||
else:
|
||||
scopes.append(True)
|
||||
for scope in scopes:
|
||||
flags = 0
|
||||
if scope:
|
||||
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
|
||||
try:
|
||||
unwrapped = win32crypt.CryptUnprotectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
continue
|
||||
blob = unwrapped[1]
|
||||
if isinstance(blob, memoryview):
|
||||
return blob.tobytes()
|
||||
if isinstance(blob, bytearray):
|
||||
return bytes(blob)
|
||||
if isinstance(blob, bytes):
|
||||
return blob
|
||||
return data
|
||||
|
||||
|
||||
@@ -105,17 +229,21 @@ class AgentKeyStore:
|
||||
self._token_meta_path = os.path.join(self.settings_dir, "access.meta.json")
|
||||
self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem")
|
||||
self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub")
|
||||
self._identity_lock_path = os.path.join(self.settings_dir, "identity.lock")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Identity management
|
||||
# ------------------------------------------------------------------
|
||||
def load_or_create_identity(self) -> AgentIdentity:
|
||||
if os.path.isfile(self._private_path) and os.path.isfile(self._public_path):
|
||||
try:
|
||||
return self._load_identity()
|
||||
except Exception:
|
||||
pass
|
||||
return self._create_identity()
|
||||
with _locked_file(self._identity_lock_path, timeout=120.0):
|
||||
if os.path.isfile(self._private_path) and os.path.isfile(self._public_path):
|
||||
try:
|
||||
return self._load_identity()
|
||||
except Exception:
|
||||
# If loading fails, fall back to regenerating the identity while
|
||||
# holding the lock so concurrent agents do not thrash the key files.
|
||||
pass
|
||||
return self._create_identity()
|
||||
|
||||
def _load_identity(self) -> AgentIdentity:
|
||||
with open(self._private_path, "rb") as fh:
|
||||
@@ -212,11 +340,23 @@ class AgentKeyStore:
|
||||
try:
|
||||
with open(self._refresh_token_path, "rb") as fh:
|
||||
protected = fh.read()
|
||||
raw = _unprotect(protected, scope_system=self.scope_system)
|
||||
return raw.decode("utf-8")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Try both scopes (preferred first) and decode once a UTF-8 payload is recovered.
|
||||
for scope_first in (self.scope_system, not self.scope_system):
|
||||
try:
|
||||
candidate = _unprotect(protected, scope_system=scope_first)
|
||||
except Exception:
|
||||
continue
|
||||
if not candidate:
|
||||
continue
|
||||
try:
|
||||
return candidate.decode("utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
for path in (self._access_token_path, self._refresh_token_path, self._token_meta_path):
|
||||
try:
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import sqlite3
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from flask import Blueprint, jsonify, request, g
|
||||
@@ -28,10 +29,18 @@ def register(
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _auth_context():
|
||||
ctx = getattr(g, "device_auth", None)
|
||||
if ctx is None:
|
||||
log("server", f"device auth context missing for {request.path}")
|
||||
return ctx
|
||||
|
||||
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
|
||||
@require_device_auth(auth_manager)
|
||||
def heartbeat():
|
||||
ctx = getattr(g, "device_auth")
|
||||
ctx = _auth_context()
|
||||
if ctx is None:
|
||||
return jsonify({"error": "auth_context_missing"}), 500
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
|
||||
now_ts = int(time.time())
|
||||
@@ -71,14 +80,42 @@ def register(
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
columns = ", ".join(f"{col} = ?" for col in updates.keys())
|
||||
params = list(updates.values())
|
||||
params.append(ctx.guid)
|
||||
cur.execute(
|
||||
f"UPDATE devices SET {columns} WHERE guid = ?",
|
||||
params,
|
||||
)
|
||||
if cur.rowcount == 0:
|
||||
|
||||
def _apply_updates() -> int:
|
||||
if not updates:
|
||||
return 0
|
||||
columns = ", ".join(f"{col} = ?" for col in updates.keys())
|
||||
params = list(updates.values())
|
||||
params.append(ctx.guid)
|
||||
cur.execute(
|
||||
f"UPDATE devices SET {columns} WHERE guid = ?",
|
||||
params,
|
||||
)
|
||||
return cur.rowcount
|
||||
|
||||
try:
|
||||
rowcount = _apply_updates()
|
||||
except sqlite3.IntegrityError as exc:
|
||||
if "devices.hostname" in str(exc) and "UNIQUE" in str(exc).upper():
|
||||
# Another device already claims this hostname; keep the existing
|
||||
# canonical hostname assigned during enrollment to avoid breaking
|
||||
# the unique constraint and continue updating the remaining fields.
|
||||
if "hostname" in updates:
|
||||
updates.pop("hostname", None)
|
||||
try:
|
||||
rowcount = _apply_updates()
|
||||
except sqlite3.IntegrityError:
|
||||
raise
|
||||
else:
|
||||
log(
|
||||
"server",
|
||||
"heartbeat hostname collision ignored for guid="
|
||||
f"{ctx.guid}",
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
if rowcount == 0:
|
||||
log("server", f"heartbeat missing device record guid={ctx.guid}")
|
||||
return jsonify({"error": "device_not_registered"}), 404
|
||||
conn.commit()
|
||||
@@ -90,7 +127,9 @@ def register(
|
||||
@blueprint.route("/api/agent/script/request", methods=["POST"])
|
||||
@require_device_auth(auth_manager)
|
||||
def script_request():
|
||||
ctx = getattr(g, "device_auth")
|
||||
ctx = _auth_context()
|
||||
if ctx is None:
|
||||
return jsonify({"error": "auth_context_missing"}), 500
|
||||
if ctx.status != "active":
|
||||
return jsonify(
|
||||
{
|
||||
|
||||
@@ -54,6 +54,10 @@ def register(
|
||||
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
|
||||
decision = limiter.check(key, limit, window_s)
|
||||
if not decision.allowed:
|
||||
log(
|
||||
"server",
|
||||
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
|
||||
)
|
||||
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
|
||||
response.status_code = 429
|
||||
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
|
||||
@@ -128,19 +132,63 @@ def register(
|
||||
|
||||
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
|
||||
cur.execute(
|
||||
"SELECT guid, hostname, token_version, status, ssl_key_fingerprint FROM devices WHERE guid = ?",
|
||||
"""
|
||||
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
|
||||
FROM devices
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(guid,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
keys = ["guid", "hostname", "token_version", "status", "ssl_key_fingerprint"]
|
||||
keys = [
|
||||
"guid",
|
||||
"hostname",
|
||||
"token_version",
|
||||
"status",
|
||||
"ssl_key_fingerprint",
|
||||
"key_added_at",
|
||||
]
|
||||
record = dict(zip(keys, row))
|
||||
if not record.get("ssl_key_fingerprint"):
|
||||
stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower()
|
||||
new_fp = (fingerprint or "").strip().lower()
|
||||
if not stored_fp and new_fp:
|
||||
cur.execute(
|
||||
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
||||
(fingerprint, _iso(_now()), guid),
|
||||
)
|
||||
record["ssl_key_fingerprint"] = fingerprint
|
||||
elif new_fp and stored_fp != new_fp:
|
||||
now_iso = _iso(_now())
|
||||
try:
|
||||
current_version = int(record.get("token_version") or 1)
|
||||
except Exception:
|
||||
current_version = 1
|
||||
new_version = max(current_version + 1, 1)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE devices
|
||||
SET ssl_key_fingerprint = ?,
|
||||
key_added_at = ?,
|
||||
token_version = ?,
|
||||
status = 'active'
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(fingerprint, now_iso, new_version, guid),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = ?
|
||||
WHERE guid = ?
|
||||
AND revoked_at IS NULL
|
||||
""",
|
||||
(now_iso, guid),
|
||||
)
|
||||
record["ssl_key_fingerprint"] = fingerprint
|
||||
record["token_version"] = new_version
|
||||
record["status"] = "active"
|
||||
record["key_added_at"] = now_iso
|
||||
return record
|
||||
|
||||
resolved_hostname = _normalize_host(hostname, guid, cur)
|
||||
@@ -169,6 +217,7 @@ def register(
|
||||
"token_version": 1,
|
||||
"status": "active",
|
||||
"ssl_key_fingerprint": fingerprint,
|
||||
"key_added_at": key_added_at,
|
||||
}
|
||||
|
||||
def _hash_refresh_token(token: str) -> str:
|
||||
@@ -198,7 +247,7 @@ def register(
|
||||
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
|
||||
def enrollment_request():
|
||||
remote = _remote_addr()
|
||||
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 10, 60.0)
|
||||
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0)
|
||||
if rate_error:
|
||||
return rate_error
|
||||
|
||||
@@ -208,32 +257,47 @@ def register(
|
||||
agent_pubkey_b64 = payload.get("agent_pubkey")
|
||||
client_nonce_b64 = payload.get("client_nonce")
|
||||
|
||||
log(
|
||||
"server",
|
||||
"enrollment request received "
|
||||
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
|
||||
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
|
||||
)
|
||||
|
||||
if not hostname:
|
||||
log("server", f"enrollment rejected missing_hostname ip={remote}")
|
||||
return jsonify({"error": "hostname_required"}), 400
|
||||
if not enrollment_code:
|
||||
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}")
|
||||
return jsonify({"error": "enrollment_code_required"}), 400
|
||||
if not isinstance(agent_pubkey_b64, str):
|
||||
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}")
|
||||
return jsonify({"error": "agent_pubkey_required"}), 400
|
||||
if not isinstance(client_nonce_b64, str):
|
||||
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}")
|
||||
return jsonify({"error": "client_nonce_required"}), 400
|
||||
|
||||
try:
|
||||
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
|
||||
except Exception:
|
||||
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}")
|
||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||
|
||||
if len(agent_pubkey_der) < 10:
|
||||
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}")
|
||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}")
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
if len(client_nonce_bytes) < 16:
|
||||
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}")
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
|
||||
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
||||
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 3, 60.0)
|
||||
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0)
|
||||
if rate_error:
|
||||
return rate_error
|
||||
|
||||
@@ -333,21 +397,33 @@ def register(
|
||||
client_nonce_b64 = payload.get("client_nonce")
|
||||
proof_sig_b64 = payload.get("proof_sig")
|
||||
|
||||
log(
|
||||
"server",
|
||||
"enrollment poll received "
|
||||
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
|
||||
f" proof_sig_len={len(proof_sig_b64 or '')}",
|
||||
)
|
||||
|
||||
if not isinstance(approval_reference, str) or not approval_reference:
|
||||
log("server", "enrollment poll rejected missing_reference")
|
||||
return jsonify({"error": "approval_reference_required"}), 400
|
||||
if not isinstance(client_nonce_b64, str):
|
||||
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}")
|
||||
return jsonify({"error": "client_nonce_required"}), 400
|
||||
if not isinstance(proof_sig_b64, str):
|
||||
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}")
|
||||
return jsonify({"error": "proof_sig_required"}), 400
|
||||
|
||||
try:
|
||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}")
|
||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||
|
||||
try:
|
||||
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_sig ref={approval_reference}")
|
||||
return jsonify({"error": "invalid_proof_sig"}), 400
|
||||
|
||||
conn = db_conn_factory()
|
||||
@@ -365,6 +441,7 @@ def register(
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
log("server", f"enrollment poll unknown_reference ref={approval_reference}")
|
||||
return jsonify({"status": "unknown"}), 404
|
||||
|
||||
(
|
||||
@@ -383,11 +460,13 @@ def register(
|
||||
) = row
|
||||
|
||||
if client_nonce_stored != client_nonce_b64:
|
||||
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}")
|
||||
return jsonify({"error": "nonce_mismatch"}), 400
|
||||
|
||||
try:
|
||||
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}")
|
||||
return jsonify({"error": "server_nonce_invalid"}), 400
|
||||
|
||||
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
|
||||
@@ -395,30 +474,58 @@ def register(
|
||||
try:
|
||||
public_key = serialization.load_der_public_key(agent_pubkey_der)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}")
|
||||
public_key = None
|
||||
|
||||
if public_key is None:
|
||||
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}")
|
||||
return jsonify({"error": "agent_pubkey_invalid"}), 400
|
||||
|
||||
try:
|
||||
public_key.verify(proof_sig, message)
|
||||
except Exception:
|
||||
log("server", f"enrollment poll invalid_proof ref={approval_reference}")
|
||||
return jsonify({"error": "invalid_proof"}), 400
|
||||
|
||||
if status == "pending":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
|
||||
f" fingerprint={fingerprint[:12]}",
|
||||
)
|
||||
return jsonify({"status": "pending", "poll_after_ms": 5000})
|
||||
if status == "denied":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
|
||||
)
|
||||
return jsonify({"status": "denied", "reason": "operator_denied"})
|
||||
if status == "expired":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
|
||||
)
|
||||
return jsonify({"status": "expired"})
|
||||
if status == "completed":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
|
||||
)
|
||||
return jsonify({"status": "approved", "detail": "finalized"})
|
||||
|
||||
if status != "approved":
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll unexpected_status={status} ref={approval_reference}",
|
||||
)
|
||||
return jsonify({"status": status or "unknown"}), 400
|
||||
|
||||
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
|
||||
if not nonce_cache.consume(nonce_key):
|
||||
log(
|
||||
"server",
|
||||
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
|
||||
)
|
||||
return jsonify({"error": "proof_replayed"}), 409
|
||||
|
||||
# Finalize enrollment
|
||||
@@ -489,3 +596,12 @@ def _load_tls_bundle(path: str) -> str:
|
||||
return fh.read()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _mask_code(code: str) -> str:
|
||||
if not code:
|
||||
return "<missing>"
|
||||
trimmed = str(code).strip()
|
||||
if len(trimmed) <= 6:
|
||||
return "***"
|
||||
return f"{trimmed[:3]}***{trimmed[-3:]}"
|
||||
|
||||
Reference in New Issue
Block a user