mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-27 01:41:58 -06:00
Merge pull request #127 from bunny-lab-io:codex/implement-security-features-for-borealis
Harden agent WebSocket TLS verification
This commit is contained in:
@@ -20,7 +20,10 @@ import datetime
|
|||||||
import shutil
|
import shutil
|
||||||
import string
|
import string
|
||||||
import ssl
|
import ssl
|
||||||
from typing import Any, Dict, Optional, List
|
import threading
|
||||||
|
import contextlib
|
||||||
|
import errno
|
||||||
|
from typing import Any, Dict, Optional, List, Callable
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
try:
|
try:
|
||||||
@@ -132,6 +135,133 @@ def _settings_dir():
|
|||||||
return os.path.abspath(os.path.join(os.path.dirname(__file__), 'Settings'))
|
return os.path.abspath(os.path.join(os.path.dirname(__file__), 'Settings'))
|
||||||
|
|
||||||
|
|
||||||
|
class _CrossProcessFileLock:
|
||||||
|
def __init__(self, path: str) -> None:
|
||||||
|
self.path = path
|
||||||
|
self._handle = None
|
||||||
|
|
||||||
|
def acquire(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
timeout: float = 120.0,
|
||||||
|
poll_interval: float = 0.5,
|
||||||
|
on_wait: Optional[Callable[[], None]] = None,
|
||||||
|
) -> None:
|
||||||
|
directory = os.path.dirname(self.path)
|
||||||
|
if directory:
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
deadline = time.time() + timeout if timeout else None
|
||||||
|
last_wait_log = 0.0
|
||||||
|
while True:
|
||||||
|
handle = open(self.path, 'a+b')
|
||||||
|
try:
|
||||||
|
self._try_lock(handle)
|
||||||
|
self._handle = handle
|
||||||
|
try:
|
||||||
|
handle.seek(0)
|
||||||
|
handle.truncate(0)
|
||||||
|
handle.write(f"pid={os.getpid()} ts={int(time.time())}\n".encode('utf-8'))
|
||||||
|
handle.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
except OSError as exc:
|
||||||
|
handle.close()
|
||||||
|
if not self._is_lock_unavailable(exc):
|
||||||
|
raise
|
||||||
|
now = time.time()
|
||||||
|
if on_wait and (now - last_wait_log) >= 2.0:
|
||||||
|
try:
|
||||||
|
on_wait()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
last_wait_log = now
|
||||||
|
if deadline and now >= deadline:
|
||||||
|
raise TimeoutError('Timed out waiting for enrollment lock')
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
except Exception:
|
||||||
|
handle.close()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def release(self) -> None:
|
||||||
|
handle = self._handle
|
||||||
|
if not handle:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self._unlock(handle)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
handle.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._handle = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_lock_unavailable(exc: OSError) -> bool:
|
||||||
|
err = exc.errno
|
||||||
|
winerr = getattr(exc, 'winerror', None)
|
||||||
|
unavailable = {errno.EACCES, errno.EAGAIN, getattr(errno, 'EWOULDBLOCK', errno.EAGAIN)}
|
||||||
|
if err in unavailable:
|
||||||
|
return True
|
||||||
|
if winerr in (32, 33):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _try_lock(handle) -> None:
|
||||||
|
handle.seek(0, os.SEEK_END)
|
||||||
|
if handle.tell() == 0:
|
||||||
|
try:
|
||||||
|
handle.write(b'0')
|
||||||
|
handle.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
handle.seek(0)
|
||||||
|
if os.name == 'nt':
|
||||||
|
import msvcrt # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1)
|
||||||
|
except OSError:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
import fcntl # type: ignore
|
||||||
|
|
||||||
|
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unlock(handle) -> None:
|
||||||
|
if os.name == 'nt':
|
||||||
|
import msvcrt # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
import fcntl # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_ENROLLMENT_FILE_LOCK: Optional[_CrossProcessFileLock] = None
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _acquire_enrollment_lock(*, timeout: float = 180.0, on_wait: Optional[Callable[[], None]] = None):
|
||||||
|
global _ENROLLMENT_FILE_LOCK
|
||||||
|
if _ENROLLMENT_FILE_LOCK is None:
|
||||||
|
_ENROLLMENT_FILE_LOCK = _CrossProcessFileLock(os.path.join(_settings_dir(), 'enrollment.lock'))
|
||||||
|
_ENROLLMENT_FILE_LOCK.acquire(timeout=timeout, on_wait=on_wait)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_ENROLLMENT_FILE_LOCK.release()
|
||||||
|
|
||||||
|
|
||||||
_KEY_STORE_INSTANCE = None
|
_KEY_STORE_INSTANCE = None
|
||||||
|
|
||||||
|
|
||||||
@@ -242,6 +372,18 @@ def _log_agent(message: str, fname: str = 'agent.log'):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_sensitive(value: str, *, prefix: int = 4, suffix: int = 4) -> str:
|
||||||
|
try:
|
||||||
|
if not value:
|
||||||
|
return ''
|
||||||
|
trimmed = value.strip()
|
||||||
|
if len(trimmed) <= prefix + suffix:
|
||||||
|
return '*' * len(trimmed)
|
||||||
|
return f"{trimmed[:prefix]}***{trimmed[-suffix:]}"
|
||||||
|
except Exception:
|
||||||
|
return '***'
|
||||||
|
|
||||||
|
|
||||||
def _decode_base64_text(value):
|
def _decode_base64_text(value):
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
return None
|
return None
|
||||||
@@ -490,14 +632,15 @@ class AgentHttpClient:
|
|||||||
self.identity = IDENTITY
|
self.identity = IDENTITY
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
self.base_url: Optional[str] = None
|
self.base_url: Optional[str] = None
|
||||||
self.guid: Optional[str] = self.key_store.load_guid()
|
self.guid: Optional[str] = None
|
||||||
self.access_token: Optional[str] = self.key_store.load_access_token()
|
self.access_token: Optional[str] = None
|
||||||
self.refresh_token: Optional[str] = self.key_store.load_refresh_token()
|
self.refresh_token: Optional[str] = None
|
||||||
self.access_expires_at: Optional[int] = self.key_store.get_access_expiry()
|
self.access_expires_at: Optional[int] = None
|
||||||
|
self._auth_lock = threading.RLock()
|
||||||
|
self._active_installer_code: Optional[str] = None
|
||||||
self.refresh_base_url()
|
self.refresh_base_url()
|
||||||
self._configure_verify()
|
self._configure_verify()
|
||||||
if self.access_token:
|
self._reload_tokens_from_disk()
|
||||||
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
|
|
||||||
self.session.headers.setdefault("User-Agent", "Borealis-Agent/secure")
|
self.session.headers.setdefault("User-Agent", "Borealis-Agent/secure")
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -528,6 +671,31 @@ class AgentHttpClient:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _reload_tokens_from_disk(self) -> None:
|
||||||
|
guid = self.key_store.load_guid()
|
||||||
|
access_token = self.key_store.load_access_token()
|
||||||
|
refresh_token = self.key_store.load_refresh_token()
|
||||||
|
access_expiry = self.key_store.get_access_expiry()
|
||||||
|
self.guid = guid if guid else None
|
||||||
|
self.access_token = access_token if access_token else None
|
||||||
|
self.refresh_token = refresh_token if refresh_token else None
|
||||||
|
self.access_expires_at = access_expiry if access_expiry else None
|
||||||
|
if self.access_token:
|
||||||
|
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
|
||||||
|
else:
|
||||||
|
self.session.headers.pop("Authorization", None)
|
||||||
|
try:
|
||||||
|
_log_agent(
|
||||||
|
"Reloaded tokens from disk "
|
||||||
|
f"guid={'yes' if self.guid else 'no'} "
|
||||||
|
f"access={'yes' if self.access_token else 'no'} "
|
||||||
|
f"refresh={'yes' if self.refresh_token else 'no'} "
|
||||||
|
f"expiry={self.access_expires_at}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def auth_headers(self) -> Dict[str, str]:
|
def auth_headers(self) -> Dict[str, str]:
|
||||||
if self.access_token:
|
if self.access_token:
|
||||||
return {"Authorization": f"Bearer {self.access_token}"}
|
return {"Authorization": f"Bearer {self.access_token}"}
|
||||||
@@ -540,12 +708,28 @@ class AgentHttpClient:
|
|||||||
engine = getattr(client, "eio", None)
|
engine = getattr(client, "eio", None)
|
||||||
if engine is None:
|
if engine is None:
|
||||||
return
|
return
|
||||||
# python-engineio accepts bool, path, or ssl.SSLContext for ssl_verify
|
|
||||||
|
# python-engineio accepts either a boolean or an ``ssl.SSLContext``
|
||||||
|
# for TLS verification. When we have a pinned certificate bundle
|
||||||
|
# on disk, prefer constructing a dedicated context that trusts that
|
||||||
|
# bundle so WebSocket connections succeed even with private CAs.
|
||||||
if isinstance(verify, str) and os.path.isfile(verify):
|
if isinstance(verify, str) and os.path.isfile(verify):
|
||||||
|
try:
|
||||||
|
context = ssl.create_default_context(cafile=verify)
|
||||||
|
context.check_hostname = False
|
||||||
|
except Exception:
|
||||||
|
context = None
|
||||||
|
if context is not None:
|
||||||
|
engine.ssl_context = context
|
||||||
|
engine.ssl_verify = True
|
||||||
|
else:
|
||||||
|
engine.ssl_context = None
|
||||||
engine.ssl_verify = verify
|
engine.ssl_verify = verify
|
||||||
elif verify is False:
|
elif verify is False:
|
||||||
|
engine.ssl_context = None
|
||||||
engine.ssl_verify = False
|
engine.ssl_verify = False
|
||||||
else:
|
else:
|
||||||
|
engine.ssl_context = None
|
||||||
engine.ssl_verify = True
|
engine.ssl_verify = True
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -554,11 +738,15 @@ class AgentHttpClient:
|
|||||||
# Enrollment & token management
|
# Enrollment & token management
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
def ensure_authenticated(self) -> None:
|
def ensure_authenticated(self) -> None:
|
||||||
|
with self._auth_lock:
|
||||||
|
self._ensure_authenticated_locked()
|
||||||
|
|
||||||
|
def _ensure_authenticated_locked(self) -> None:
|
||||||
self.refresh_base_url()
|
self.refresh_base_url()
|
||||||
if not self.guid or not self.refresh_token:
|
if not self.guid or not self.refresh_token:
|
||||||
self.perform_enrollment()
|
self._perform_enrollment_locked()
|
||||||
if not self.access_token or self._token_expiring_soon():
|
if not self.access_token or self._token_expiring_soon():
|
||||||
self.refresh_access_token()
|
self._refresh_access_token_locked()
|
||||||
|
|
||||||
def _token_expiring_soon(self) -> bool:
|
def _token_expiring_soon(self) -> bool:
|
||||||
if not self.access_token:
|
if not self.access_token:
|
||||||
@@ -568,13 +756,57 @@ class AgentHttpClient:
|
|||||||
return (self.access_expires_at - time.time()) < 60
|
return (self.access_expires_at - time.time()) < 60
|
||||||
|
|
||||||
def perform_enrollment(self) -> None:
|
def perform_enrollment(self) -> None:
|
||||||
|
with self._auth_lock:
|
||||||
|
self._perform_enrollment_locked()
|
||||||
|
|
||||||
|
def _perform_enrollment_locked(self) -> None:
|
||||||
|
self._reload_tokens_from_disk()
|
||||||
|
if self.guid and self.refresh_token:
|
||||||
|
return
|
||||||
code = self._resolve_installer_code()
|
code = self._resolve_installer_code()
|
||||||
if not code:
|
if not code:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Installer code is required for enrollment. "
|
"Installer code is required for enrollment. "
|
||||||
"Set BOREALIS_INSTALLER_CODE, pass --installer-code, or update agent_settings.json."
|
"Set BOREALIS_INSTALLER_CODE, pass --installer-code, or update agent_settings.json."
|
||||||
)
|
)
|
||||||
|
self._active_installer_code = code
|
||||||
|
|
||||||
|
wait_state = {"count": 0, "tokens_seen": False}
|
||||||
|
|
||||||
|
def _on_lock_wait() -> None:
|
||||||
|
wait_state["count"] += 1
|
||||||
|
_log_agent(
|
||||||
|
f"Enrollment waiting for shared lock scope={SERVICE_MODE} attempt={wait_state['count']}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
|
if not wait_state["tokens_seen"]:
|
||||||
|
self._reload_tokens_from_disk()
|
||||||
|
if self.guid and self.refresh_token:
|
||||||
|
wait_state["tokens_seen"] = True
|
||||||
|
_log_agent(
|
||||||
|
"Enrollment credentials detected while waiting for lock; will reuse when available",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with _acquire_enrollment_lock(timeout=180.0, on_wait=_on_lock_wait):
|
||||||
|
self._reload_tokens_from_disk()
|
||||||
|
if self.guid and self.refresh_token:
|
||||||
|
_log_agent(
|
||||||
|
"Enrollment skipped after acquiring lock; credentials already present",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
self.refresh_base_url()
|
self.refresh_base_url()
|
||||||
|
base_url = self.base_url or "https://localhost:5000"
|
||||||
|
code_masked = _mask_sensitive(code)
|
||||||
|
_log_agent(
|
||||||
|
"Enrollment handshake starting "
|
||||||
|
f"base_url={base_url} scope={SERVICE_MODE} "
|
||||||
|
f"fingerprint={SSL_KEY_FINGERPRINT[:16]} installer_code={code_masked}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
client_nonce = os.urandom(32)
|
client_nonce = os.urandom(32)
|
||||||
payload = {
|
payload = {
|
||||||
"hostname": socket.gethostname(),
|
"hostname": socket.gethostname(),
|
||||||
@@ -583,10 +815,47 @@ class AgentHttpClient:
|
|||||||
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
|
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
|
||||||
}
|
}
|
||||||
request_url = f"{self.base_url}/api/agent/enroll/request"
|
request_url = f"{self.base_url}/api/agent/enroll/request"
|
||||||
_log_agent("Starting enrollment request...", fname="agent.log")
|
_log_agent(
|
||||||
|
"Starting enrollment request... "
|
||||||
|
f"url={request_url} hostname={payload['hostname']} pubkey_prefix={PUBLIC_KEY_B64[:24]}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
resp = self.session.post(request_url, json=payload, timeout=30)
|
resp = self.session.post(request_url, json=payload, timeout=30)
|
||||||
|
_log_agent(
|
||||||
|
f"Enrollment request HTTP status={resp.status_code} retry_after={resp.headers.get('Retry-After')}"
|
||||||
|
f" body_len={len(resp.content)}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
|
try:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
except requests.HTTPError:
|
||||||
|
snippet = resp.text[:512] if hasattr(resp, "text") else ""
|
||||||
|
_log_agent(
|
||||||
|
f"Enrollment request failed status={resp.status_code} body_snippet={snippet}",
|
||||||
|
fname="agent.error.log",
|
||||||
|
)
|
||||||
|
if resp.status_code == 400:
|
||||||
|
try:
|
||||||
|
err_payload = resp.json()
|
||||||
|
except Exception:
|
||||||
|
err_payload = {}
|
||||||
|
if (err_payload or {}).get("error") in {"invalid_enrollment_code", "code_consumed"}:
|
||||||
|
self._reload_tokens_from_disk()
|
||||||
|
if self.guid and self.refresh_token:
|
||||||
|
_log_agent(
|
||||||
|
"Enrollment code rejected but existing credentials are present; skipping re-enrollment",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
raise
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
_log_agent(
|
||||||
|
"Enrollment request accepted "
|
||||||
|
f"status={data.get('status')} approval_ref={data.get('approval_reference')} "
|
||||||
|
f"poll_after_ms={data.get('poll_after_ms')}"
|
||||||
|
f" server_cert={'yes' if data.get('server_certificate') else 'no'}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
if data.get("server_certificate"):
|
if data.get("server_certificate"):
|
||||||
self.key_store.save_server_certificate(data["server_certificate"])
|
self.key_store.save_server_certificate(data["server_certificate"])
|
||||||
self._configure_verify()
|
self._configure_verify()
|
||||||
@@ -604,6 +873,7 @@ class AgentHttpClient:
|
|||||||
raise RuntimeError("Enrollment response missing approval_reference or server_nonce")
|
raise RuntimeError("Enrollment response missing approval_reference or server_nonce")
|
||||||
server_nonce = base64.b64decode(server_nonce_b64)
|
server_nonce = base64.b64decode(server_nonce_b64)
|
||||||
poll_delay = max(int(data.get("poll_after_ms", 3000)) / 1000, 1)
|
poll_delay = max(int(data.get("poll_after_ms", 3000)) / 1000, 1)
|
||||||
|
attempt = 1
|
||||||
while True:
|
while True:
|
||||||
time.sleep(min(poll_delay, 15))
|
time.sleep(min(poll_delay, 15))
|
||||||
signature = self.identity.sign(server_nonce + approval_reference.encode("utf-8") + client_nonce)
|
signature = self.identity.sign(server_nonce + approval_reference.encode("utf-8") + client_nonce)
|
||||||
@@ -612,25 +882,73 @@ class AgentHttpClient:
|
|||||||
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
|
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
|
||||||
"proof_sig": base64.b64encode(signature).decode("ascii"),
|
"proof_sig": base64.b64encode(signature).decode("ascii"),
|
||||||
}
|
}
|
||||||
|
_log_agent(
|
||||||
|
f"Enrollment poll attempt={attempt} ref={approval_reference} delay={poll_delay}s",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
poll_resp = self.session.post(
|
poll_resp = self.session.post(
|
||||||
f"{self.base_url}/api/agent/enroll/poll",
|
f"{self.base_url}/api/agent/enroll/poll",
|
||||||
json=poll_payload,
|
json=poll_payload,
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
_log_agent(
|
||||||
|
"Enrollment poll response "
|
||||||
|
f"status_code={poll_resp.status_code} retry_after={poll_resp.headers.get('Retry-After')}"
|
||||||
|
f" body_len={len(poll_resp.content)}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
|
try:
|
||||||
poll_resp.raise_for_status()
|
poll_resp.raise_for_status()
|
||||||
|
except requests.HTTPError:
|
||||||
|
snippet = poll_resp.text[:512] if hasattr(poll_resp, "text") else ""
|
||||||
|
_log_agent(
|
||||||
|
f"Enrollment poll failed attempt={attempt} status={poll_resp.status_code} "
|
||||||
|
f"body_snippet={snippet}",
|
||||||
|
fname="agent.error.log",
|
||||||
|
)
|
||||||
|
raise
|
||||||
poll_data = poll_resp.json()
|
poll_data = poll_resp.json()
|
||||||
|
_log_agent(
|
||||||
|
f"Enrollment poll decoded attempt={attempt} status={poll_data.get('status')}"
|
||||||
|
f" next_delay={poll_data.get('poll_after_ms')}"
|
||||||
|
f" guid_hint={poll_data.get('guid')}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
status = poll_data.get("status")
|
status = poll_data.get("status")
|
||||||
if status == "pending":
|
if status == "pending":
|
||||||
poll_delay = max(int(poll_data.get("poll_after_ms", 5000)) / 1000, 1)
|
poll_delay = max(int(poll_data.get("poll_after_ms", 5000)) / 1000, 1)
|
||||||
|
_log_agent(
|
||||||
|
f"Enrollment still pending attempt={attempt} new_delay={poll_delay}s",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
|
attempt += 1
|
||||||
continue
|
continue
|
||||||
if status == "denied":
|
if status == "denied":
|
||||||
|
_log_agent("Enrollment denied by operator", fname="agent.error.log")
|
||||||
raise RuntimeError("Enrollment denied by operator")
|
raise RuntimeError("Enrollment denied by operator")
|
||||||
if status in ("expired", "unknown"):
|
if status in ("expired", "unknown"):
|
||||||
|
_log_agent(
|
||||||
|
f"Enrollment failed status={status} attempt={attempt}",
|
||||||
|
fname="agent.error.log",
|
||||||
|
)
|
||||||
raise RuntimeError(f"Enrollment failed with status={status}")
|
raise RuntimeError(f"Enrollment failed with status={status}")
|
||||||
if status in ("approved", "completed"):
|
if status in ("approved", "completed"):
|
||||||
|
_log_agent(
|
||||||
|
f"Enrollment approved attempt={attempt} ref={approval_reference}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
self._finalize_enrollment(poll_data)
|
self._finalize_enrollment(poll_data)
|
||||||
break
|
break
|
||||||
raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}")
|
raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}")
|
||||||
|
except TimeoutError:
|
||||||
|
self._reload_tokens_from_disk()
|
||||||
|
if self.guid and self.refresh_token:
|
||||||
|
_log_agent(
|
||||||
|
"Enrollment lock wait timed out but credentials materialized; reusing existing tokens",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
raise
|
||||||
|
|
||||||
def _finalize_enrollment(self, payload: Dict[str, Any]) -> None:
|
def _finalize_enrollment(self, payload: Dict[str, Any]) -> None:
|
||||||
server_cert = payload.get("server_certificate")
|
server_cert = payload.get("server_certificate")
|
||||||
@@ -649,6 +967,12 @@ class AgentHttpClient:
|
|||||||
expires_in = int(payload.get("expires_in") or 900)
|
expires_in = int(payload.get("expires_in") or 900)
|
||||||
if not (guid and access_token and refresh_token):
|
if not (guid and access_token and refresh_token):
|
||||||
raise RuntimeError("Enrollment approval response missing tokens or guid")
|
raise RuntimeError("Enrollment approval response missing tokens or guid")
|
||||||
|
_log_agent(
|
||||||
|
"Enrollment approval payload received "
|
||||||
|
f"guid={guid} access_token_len={len(access_token)} refresh_token_len={len(refresh_token)} "
|
||||||
|
f"expires_in={expires_in}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
self.guid = str(guid).strip()
|
self.guid = str(guid).strip()
|
||||||
self.access_token = access_token.strip()
|
self.access_token = access_token.strip()
|
||||||
self.refresh_token = refresh_token.strip()
|
self.refresh_token = refresh_token.strip()
|
||||||
@@ -663,12 +987,17 @@ class AgentHttpClient:
|
|||||||
_update_agent_id_for_guid(self.guid)
|
_update_agent_id_for_guid(self.guid)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
_log_agent(f"Failed to update agent id after enrollment: {exc}", fname="agent.error.log")
|
_log_agent(f"Failed to update agent id after enrollment: {exc}", fname="agent.error.log")
|
||||||
|
self._consume_installer_code()
|
||||||
_log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log")
|
_log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log")
|
||||||
|
|
||||||
def refresh_access_token(self) -> None:
|
def refresh_access_token(self) -> None:
|
||||||
|
with self._auth_lock:
|
||||||
|
self._refresh_access_token_locked()
|
||||||
|
|
||||||
|
def _refresh_access_token_locked(self) -> None:
|
||||||
if not self.refresh_token or not self.guid:
|
if not self.refresh_token or not self.guid:
|
||||||
self.clear_tokens()
|
self._clear_tokens_locked()
|
||||||
self.perform_enrollment()
|
self._perform_enrollment_locked()
|
||||||
return
|
return
|
||||||
payload = {"guid": self.guid, "refresh_token": self.refresh_token}
|
payload = {"guid": self.guid, "refresh_token": self.refresh_token}
|
||||||
resp = self.session.post(
|
resp = self.session.post(
|
||||||
@@ -679,8 +1008,8 @@ class AgentHttpClient:
|
|||||||
)
|
)
|
||||||
if resp.status_code in (401, 403):
|
if resp.status_code in (401, 403):
|
||||||
_log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log")
|
_log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log")
|
||||||
self.clear_tokens()
|
self._clear_tokens_locked()
|
||||||
self.perform_enrollment()
|
self._perform_enrollment_locked()
|
||||||
return
|
return
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
@@ -696,6 +1025,10 @@ class AgentHttpClient:
|
|||||||
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
|
self.session.headers.update({"Authorization": f"Bearer {self.access_token}"})
|
||||||
|
|
||||||
def clear_tokens(self) -> None:
|
def clear_tokens(self) -> None:
|
||||||
|
with self._auth_lock:
|
||||||
|
self._clear_tokens_locked()
|
||||||
|
|
||||||
|
def _clear_tokens_locked(self) -> None:
|
||||||
self.key_store.clear_tokens()
|
self.key_store.clear_tokens()
|
||||||
self.access_token = None
|
self.access_token = None
|
||||||
self.refresh_token = None
|
self.refresh_token = None
|
||||||
@@ -712,6 +1045,19 @@ class AgentHttpClient:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
def _consume_installer_code(self) -> None:
|
||||||
|
# Avoid clearing explicit CLI/env overrides; only mutate persisted config.
|
||||||
|
self._active_installer_code = None
|
||||||
|
if INSTALLER_CODE_OVERRIDE:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if CONFIG.data.get("installer_code"):
|
||||||
|
CONFIG.data["installer_code"] = ""
|
||||||
|
CONFIG._write()
|
||||||
|
_log_agent("Cleared persisted installer code after successful enrollment", fname="agent.log")
|
||||||
|
except Exception as exc:
|
||||||
|
_log_agent(f"Failed to clear installer code after enrollment: {exc}", fname="agent.error.log")
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# HTTP helpers
|
# HTTP helpers
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -722,6 +1068,16 @@ class AgentHttpClient:
|
|||||||
headers = self.auth_headers()
|
headers = self.auth_headers()
|
||||||
response = self.session.post(url, json=payload, headers=headers, timeout=30)
|
response = self.session.post(url, json=payload, headers=headers, timeout=30)
|
||||||
if response.status_code in (401, 403) and require_auth:
|
if response.status_code in (401, 403) and require_auth:
|
||||||
|
snippet = ""
|
||||||
|
try:
|
||||||
|
snippet = response.text[:256]
|
||||||
|
except Exception:
|
||||||
|
snippet = "<unavailable>"
|
||||||
|
_log_agent(
|
||||||
|
"Authenticated request rejected "
|
||||||
|
f"path={path} status={response.status_code} body_snippet={snippet}",
|
||||||
|
fname="agent.error.log",
|
||||||
|
)
|
||||||
self.clear_tokens()
|
self.clear_tokens()
|
||||||
self.ensure_authenticated()
|
self.ensure_authenticated()
|
||||||
headers = self.auth_headers()
|
headers = self.auth_headers()
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import contextlib
|
||||||
|
import errno
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -23,6 +25,16 @@ try:
|
|||||||
except Exception: # pragma: no cover - win32crypt missing
|
except Exception: # pragma: no cover - win32crypt missing
|
||||||
win32crypt = None # type: ignore
|
win32crypt = None # type: ignore
|
||||||
|
|
||||||
|
try: # pragma: no cover - Windows only
|
||||||
|
import msvcrt # type: ignore
|
||||||
|
except Exception: # pragma: no cover - non-Windows
|
||||||
|
msvcrt = None # type: ignore
|
||||||
|
|
||||||
|
try: # pragma: no cover - POSIX only
|
||||||
|
import fcntl # type: ignore
|
||||||
|
except Exception: # pragma: no cover - Windows
|
||||||
|
fcntl = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def _ensure_dir(path: str) -> None:
|
def _ensure_dir(path: str) -> None:
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
@@ -36,16 +48,122 @@ def _restrict_permissions(path: str) -> None:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _FileLock:
|
||||||
|
def __init__(self, path: str) -> None:
|
||||||
|
self.path = path
|
||||||
|
self._handle = None
|
||||||
|
|
||||||
|
def acquire(self, *, timeout: float = 60.0, poll_interval: float = 0.2) -> None:
|
||||||
|
directory = os.path.dirname(self.path)
|
||||||
|
if directory:
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
deadline = time.time() + timeout if timeout else None
|
||||||
|
while True:
|
||||||
|
handle = open(self.path, "a+b")
|
||||||
|
try:
|
||||||
|
self._try_lock(handle)
|
||||||
|
except OSError as exc:
|
||||||
|
handle.close()
|
||||||
|
if not self._is_lock_unavailable(exc):
|
||||||
|
raise
|
||||||
|
if deadline and time.time() >= deadline:
|
||||||
|
raise TimeoutError("Timed out waiting for file lock")
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
handle.close()
|
||||||
|
raise
|
||||||
|
|
||||||
|
self._handle = handle
|
||||||
|
try:
|
||||||
|
handle.seek(0)
|
||||||
|
handle.truncate(0)
|
||||||
|
payload = f"pid={os.getpid()} ts={int(time.time())}\n".encode("utf-8")
|
||||||
|
handle.write(payload)
|
||||||
|
handle.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
|
def release(self) -> None:
|
||||||
|
handle = self._handle
|
||||||
|
if not handle:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self._unlock(handle)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
handle.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._handle = None
|
||||||
|
|
||||||
|
def _try_lock(self, handle):
|
||||||
|
if IS_WINDOWS:
|
||||||
|
if msvcrt is None:
|
||||||
|
raise OSError(errno.EINVAL, "msvcrt unavailable for locking")
|
||||||
|
try:
|
||||||
|
msvcrt.locking(handle.fileno(), msvcrt.LK_NBLCK, 1) # type: ignore[attr-defined]
|
||||||
|
except OSError as exc: # pragma: no cover - platform specific
|
||||||
|
raise exc
|
||||||
|
else:
|
||||||
|
if fcntl is None:
|
||||||
|
raise OSError(errno.EINVAL, "fcntl unavailable for locking")
|
||||||
|
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
def _unlock(self, handle):
|
||||||
|
if IS_WINDOWS:
|
||||||
|
if msvcrt is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1) # type: ignore[attr-defined]
|
||||||
|
except OSError: # pragma: no cover - platform specific
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if fcntl is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
fcntl.flock(handle.fileno(), fcntl.LOCK_UN) # type: ignore[attr-defined]
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_lock_unavailable(exc: OSError) -> bool:
|
||||||
|
if hasattr(exc, "winerror"):
|
||||||
|
return exc.winerror in (32, 33) # type: ignore[attr-defined]
|
||||||
|
err = exc.errno if hasattr(exc, "errno") else None
|
||||||
|
return err in (errno.EACCES, errno.EAGAIN, errno.EWOULDBLOCK)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _locked_file(path: str, *, timeout: float = 60.0):
|
||||||
|
lock = _FileLock(path)
|
||||||
|
lock.acquire(timeout=timeout)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
|
||||||
def _protect(data: bytes, *, scope_system: bool) -> bytes:
|
def _protect(data: bytes, *, scope_system: bool) -> bytes:
|
||||||
if not IS_WINDOWS or not win32crypt:
|
if not IS_WINDOWS or not win32crypt:
|
||||||
return data
|
return data
|
||||||
flags = 0
|
scopes = [scope_system]
|
||||||
|
# Always include the alternate scope so we can fall back if the preferred
|
||||||
|
# protection attempt fails (e.g., running under a limited account that
|
||||||
|
# lacks access to the desired DPAPI scope).
|
||||||
if scope_system:
|
if scope_system:
|
||||||
|
scopes.append(False)
|
||||||
|
else:
|
||||||
|
scopes.append(True)
|
||||||
|
for scope in scopes:
|
||||||
|
flags = 0
|
||||||
|
if scope:
|
||||||
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
|
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
|
||||||
try:
|
try:
|
||||||
protected = win32crypt.CryptProtectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
|
protected = win32crypt.CryptProtectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
|
||||||
except Exception:
|
except Exception:
|
||||||
return data
|
continue
|
||||||
blob = protected[1]
|
blob = protected[1]
|
||||||
if isinstance(blob, memoryview):
|
if isinstance(blob, memoryview):
|
||||||
return blob.tobytes()
|
return blob.tobytes()
|
||||||
@@ -59,13 +177,19 @@ def _protect(data: bytes, *, scope_system: bool) -> bytes:
|
|||||||
def _unprotect(data: bytes, *, scope_system: bool) -> bytes:
|
def _unprotect(data: bytes, *, scope_system: bool) -> bytes:
|
||||||
if not IS_WINDOWS or not win32crypt:
|
if not IS_WINDOWS or not win32crypt:
|
||||||
return data
|
return data
|
||||||
flags = 0
|
scopes = [scope_system]
|
||||||
if scope_system:
|
if scope_system:
|
||||||
|
scopes.append(False)
|
||||||
|
else:
|
||||||
|
scopes.append(True)
|
||||||
|
for scope in scopes:
|
||||||
|
flags = 0
|
||||||
|
if scope:
|
||||||
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
|
flags = getattr(win32crypt, "CRYPTPROTECT_LOCAL_MACHINE", 0x4)
|
||||||
try:
|
try:
|
||||||
unwrapped = win32crypt.CryptUnprotectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
|
unwrapped = win32crypt.CryptUnprotectData(data, None, None, None, None, flags) # type: ignore[attr-defined]
|
||||||
except Exception:
|
except Exception:
|
||||||
return data
|
continue
|
||||||
blob = unwrapped[1]
|
blob = unwrapped[1]
|
||||||
if isinstance(blob, memoryview):
|
if isinstance(blob, memoryview):
|
||||||
return blob.tobytes()
|
return blob.tobytes()
|
||||||
@@ -105,15 +229,19 @@ class AgentKeyStore:
|
|||||||
self._token_meta_path = os.path.join(self.settings_dir, "access.meta.json")
|
self._token_meta_path = os.path.join(self.settings_dir, "access.meta.json")
|
||||||
self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem")
|
self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem")
|
||||||
self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub")
|
self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub")
|
||||||
|
self._identity_lock_path = os.path.join(self.settings_dir, "identity.lock")
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Identity management
|
# Identity management
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
def load_or_create_identity(self) -> AgentIdentity:
|
def load_or_create_identity(self) -> AgentIdentity:
|
||||||
|
with _locked_file(self._identity_lock_path, timeout=120.0):
|
||||||
if os.path.isfile(self._private_path) and os.path.isfile(self._public_path):
|
if os.path.isfile(self._private_path) and os.path.isfile(self._public_path):
|
||||||
try:
|
try:
|
||||||
return self._load_identity()
|
return self._load_identity()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# If loading fails, fall back to regenerating the identity while
|
||||||
|
# holding the lock so concurrent agents do not thrash the key files.
|
||||||
pass
|
pass
|
||||||
return self._create_identity()
|
return self._create_identity()
|
||||||
|
|
||||||
@@ -212,11 +340,23 @@ class AgentKeyStore:
|
|||||||
try:
|
try:
|
||||||
with open(self._refresh_token_path, "rb") as fh:
|
with open(self._refresh_token_path, "rb") as fh:
|
||||||
protected = fh.read()
|
protected = fh.read()
|
||||||
raw = _unprotect(protected, scope_system=self.scope_system)
|
|
||||||
return raw.decode("utf-8")
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Try both scopes (preferred first) and decode once a UTF-8 payload is recovered.
|
||||||
|
for scope_first in (self.scope_system, not self.scope_system):
|
||||||
|
try:
|
||||||
|
candidate = _unprotect(protected, scope_system=scope_first)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if not candidate:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
return candidate.decode("utf-8")
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
def clear_tokens(self) -> None:
|
def clear_tokens(self) -> None:
|
||||||
for path in (self._access_token_path, self._refresh_token_path, self._token_meta_path):
|
for path in (self._access_token_path, self._refresh_token_path, self._token_meta_path):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
import sqlite3
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
from flask import Blueprint, jsonify, request, g
|
from flask import Blueprint, jsonify, request, g
|
||||||
@@ -28,10 +29,18 @@ def register(
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _auth_context():
|
||||||
|
ctx = getattr(g, "device_auth", None)
|
||||||
|
if ctx is None:
|
||||||
|
log("server", f"device auth context missing for {request.path}")
|
||||||
|
return ctx
|
||||||
|
|
||||||
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
|
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
|
||||||
@require_device_auth(auth_manager)
|
@require_device_auth(auth_manager)
|
||||||
def heartbeat():
|
def heartbeat():
|
||||||
ctx = getattr(g, "device_auth")
|
ctx = _auth_context()
|
||||||
|
if ctx is None:
|
||||||
|
return jsonify({"error": "auth_context_missing"}), 500
|
||||||
payload = request.get_json(force=True, silent=True) or {}
|
payload = request.get_json(force=True, silent=True) or {}
|
||||||
|
|
||||||
now_ts = int(time.time())
|
now_ts = int(time.time())
|
||||||
@@ -71,6 +80,10 @@ def register(
|
|||||||
conn = db_conn_factory()
|
conn = db_conn_factory()
|
||||||
try:
|
try:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
def _apply_updates() -> int:
|
||||||
|
if not updates:
|
||||||
|
return 0
|
||||||
columns = ", ".join(f"{col} = ?" for col in updates.keys())
|
columns = ", ".join(f"{col} = ?" for col in updates.keys())
|
||||||
params = list(updates.values())
|
params = list(updates.values())
|
||||||
params.append(ctx.guid)
|
params.append(ctx.guid)
|
||||||
@@ -78,7 +91,31 @@ def register(
|
|||||||
f"UPDATE devices SET {columns} WHERE guid = ?",
|
f"UPDATE devices SET {columns} WHERE guid = ?",
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
if cur.rowcount == 0:
|
return cur.rowcount
|
||||||
|
|
||||||
|
try:
|
||||||
|
rowcount = _apply_updates()
|
||||||
|
except sqlite3.IntegrityError as exc:
|
||||||
|
if "devices.hostname" in str(exc) and "UNIQUE" in str(exc).upper():
|
||||||
|
# Another device already claims this hostname; keep the existing
|
||||||
|
# canonical hostname assigned during enrollment to avoid breaking
|
||||||
|
# the unique constraint and continue updating the remaining fields.
|
||||||
|
if "hostname" in updates:
|
||||||
|
updates.pop("hostname", None)
|
||||||
|
try:
|
||||||
|
rowcount = _apply_updates()
|
||||||
|
except sqlite3.IntegrityError:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
log(
|
||||||
|
"server",
|
||||||
|
"heartbeat hostname collision ignored for guid="
|
||||||
|
f"{ctx.guid}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
if rowcount == 0:
|
||||||
log("server", f"heartbeat missing device record guid={ctx.guid}")
|
log("server", f"heartbeat missing device record guid={ctx.guid}")
|
||||||
return jsonify({"error": "device_not_registered"}), 404
|
return jsonify({"error": "device_not_registered"}), 404
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@@ -90,7 +127,9 @@ def register(
|
|||||||
@blueprint.route("/api/agent/script/request", methods=["POST"])
|
@blueprint.route("/api/agent/script/request", methods=["POST"])
|
||||||
@require_device_auth(auth_manager)
|
@require_device_auth(auth_manager)
|
||||||
def script_request():
|
def script_request():
|
||||||
ctx = getattr(g, "device_auth")
|
ctx = _auth_context()
|
||||||
|
if ctx is None:
|
||||||
|
return jsonify({"error": "auth_context_missing"}), 500
|
||||||
if ctx.status != "active":
|
if ctx.status != "active":
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -54,6 +54,10 @@ def register(
|
|||||||
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
|
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
|
||||||
decision = limiter.check(key, limit, window_s)
|
decision = limiter.check(key, limit, window_s)
|
||||||
if not decision.allowed:
|
if not decision.allowed:
|
||||||
|
log(
|
||||||
|
"server",
|
||||||
|
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
|
||||||
|
)
|
||||||
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
|
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
|
||||||
response.status_code = 429
|
response.status_code = 429
|
||||||
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
|
response.headers["Retry-After"] = f"{int(decision.retry_after) or 1}"
|
||||||
@@ -128,19 +132,63 @@ def register(
|
|||||||
|
|
||||||
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
|
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT guid, hostname, token_version, status, ssl_key_fingerprint FROM devices WHERE guid = ?",
|
"""
|
||||||
|
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
|
||||||
|
FROM devices
|
||||||
|
WHERE guid = ?
|
||||||
|
""",
|
||||||
(guid,),
|
(guid,),
|
||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if row:
|
if row:
|
||||||
keys = ["guid", "hostname", "token_version", "status", "ssl_key_fingerprint"]
|
keys = [
|
||||||
|
"guid",
|
||||||
|
"hostname",
|
||||||
|
"token_version",
|
||||||
|
"status",
|
||||||
|
"ssl_key_fingerprint",
|
||||||
|
"key_added_at",
|
||||||
|
]
|
||||||
record = dict(zip(keys, row))
|
record = dict(zip(keys, row))
|
||||||
if not record.get("ssl_key_fingerprint"):
|
stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower()
|
||||||
|
new_fp = (fingerprint or "").strip().lower()
|
||||||
|
if not stored_fp and new_fp:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
||||||
(fingerprint, _iso(_now()), guid),
|
(fingerprint, _iso(_now()), guid),
|
||||||
)
|
)
|
||||||
record["ssl_key_fingerprint"] = fingerprint
|
record["ssl_key_fingerprint"] = fingerprint
|
||||||
|
elif new_fp and stored_fp != new_fp:
|
||||||
|
now_iso = _iso(_now())
|
||||||
|
try:
|
||||||
|
current_version = int(record.get("token_version") or 1)
|
||||||
|
except Exception:
|
||||||
|
current_version = 1
|
||||||
|
new_version = max(current_version + 1, 1)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
UPDATE devices
|
||||||
|
SET ssl_key_fingerprint = ?,
|
||||||
|
key_added_at = ?,
|
||||||
|
token_version = ?,
|
||||||
|
status = 'active'
|
||||||
|
WHERE guid = ?
|
||||||
|
""",
|
||||||
|
(fingerprint, now_iso, new_version, guid),
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
UPDATE refresh_tokens
|
||||||
|
SET revoked_at = ?
|
||||||
|
WHERE guid = ?
|
||||||
|
AND revoked_at IS NULL
|
||||||
|
""",
|
||||||
|
(now_iso, guid),
|
||||||
|
)
|
||||||
|
record["ssl_key_fingerprint"] = fingerprint
|
||||||
|
record["token_version"] = new_version
|
||||||
|
record["status"] = "active"
|
||||||
|
record["key_added_at"] = now_iso
|
||||||
return record
|
return record
|
||||||
|
|
||||||
resolved_hostname = _normalize_host(hostname, guid, cur)
|
resolved_hostname = _normalize_host(hostname, guid, cur)
|
||||||
@@ -169,6 +217,7 @@ def register(
|
|||||||
"token_version": 1,
|
"token_version": 1,
|
||||||
"status": "active",
|
"status": "active",
|
||||||
"ssl_key_fingerprint": fingerprint,
|
"ssl_key_fingerprint": fingerprint,
|
||||||
|
"key_added_at": key_added_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _hash_refresh_token(token: str) -> str:
|
def _hash_refresh_token(token: str) -> str:
|
||||||
@@ -198,7 +247,7 @@ def register(
|
|||||||
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
|
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
|
||||||
def enrollment_request():
|
def enrollment_request():
|
||||||
remote = _remote_addr()
|
remote = _remote_addr()
|
||||||
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 10, 60.0)
|
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0)
|
||||||
if rate_error:
|
if rate_error:
|
||||||
return rate_error
|
return rate_error
|
||||||
|
|
||||||
@@ -208,32 +257,47 @@ def register(
|
|||||||
agent_pubkey_b64 = payload.get("agent_pubkey")
|
agent_pubkey_b64 = payload.get("agent_pubkey")
|
||||||
client_nonce_b64 = payload.get("client_nonce")
|
client_nonce_b64 = payload.get("client_nonce")
|
||||||
|
|
||||||
|
log(
|
||||||
|
"server",
|
||||||
|
"enrollment request received "
|
||||||
|
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
|
||||||
|
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
|
||||||
|
)
|
||||||
|
|
||||||
if not hostname:
|
if not hostname:
|
||||||
|
log("server", f"enrollment rejected missing_hostname ip={remote}")
|
||||||
return jsonify({"error": "hostname_required"}), 400
|
return jsonify({"error": "hostname_required"}), 400
|
||||||
if not enrollment_code:
|
if not enrollment_code:
|
||||||
|
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}")
|
||||||
return jsonify({"error": "enrollment_code_required"}), 400
|
return jsonify({"error": "enrollment_code_required"}), 400
|
||||||
if not isinstance(agent_pubkey_b64, str):
|
if not isinstance(agent_pubkey_b64, str):
|
||||||
|
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}")
|
||||||
return jsonify({"error": "agent_pubkey_required"}), 400
|
return jsonify({"error": "agent_pubkey_required"}), 400
|
||||||
if not isinstance(client_nonce_b64, str):
|
if not isinstance(client_nonce_b64, str):
|
||||||
|
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}")
|
||||||
return jsonify({"error": "client_nonce_required"}), 400
|
return jsonify({"error": "client_nonce_required"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
|
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}")
|
||||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||||
|
|
||||||
if len(agent_pubkey_der) < 10:
|
if len(agent_pubkey_der) < 10:
|
||||||
|
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}")
|
||||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}")
|
||||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||||
if len(client_nonce_bytes) < 16:
|
if len(client_nonce_bytes) < 16:
|
||||||
|
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}")
|
||||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||||
|
|
||||||
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
||||||
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 3, 60.0)
|
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0)
|
||||||
if rate_error:
|
if rate_error:
|
||||||
return rate_error
|
return rate_error
|
||||||
|
|
||||||
@@ -333,21 +397,33 @@ def register(
|
|||||||
client_nonce_b64 = payload.get("client_nonce")
|
client_nonce_b64 = payload.get("client_nonce")
|
||||||
proof_sig_b64 = payload.get("proof_sig")
|
proof_sig_b64 = payload.get("proof_sig")
|
||||||
|
|
||||||
|
log(
|
||||||
|
"server",
|
||||||
|
"enrollment poll received "
|
||||||
|
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
|
||||||
|
f" proof_sig_len={len(proof_sig_b64 or '')}",
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(approval_reference, str) or not approval_reference:
|
if not isinstance(approval_reference, str) or not approval_reference:
|
||||||
|
log("server", "enrollment poll rejected missing_reference")
|
||||||
return jsonify({"error": "approval_reference_required"}), 400
|
return jsonify({"error": "approval_reference_required"}), 400
|
||||||
if not isinstance(client_nonce_b64, str):
|
if not isinstance(client_nonce_b64, str):
|
||||||
|
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}")
|
||||||
return jsonify({"error": "client_nonce_required"}), 400
|
return jsonify({"error": "client_nonce_required"}), 400
|
||||||
if not isinstance(proof_sig_b64, str):
|
if not isinstance(proof_sig_b64, str):
|
||||||
|
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}")
|
||||||
return jsonify({"error": "proof_sig_required"}), 400
|
return jsonify({"error": "proof_sig_required"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}")
|
||||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
|
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
log("server", f"enrollment poll invalid_sig ref={approval_reference}")
|
||||||
return jsonify({"error": "invalid_proof_sig"}), 400
|
return jsonify({"error": "invalid_proof_sig"}), 400
|
||||||
|
|
||||||
conn = db_conn_factory()
|
conn = db_conn_factory()
|
||||||
@@ -365,6 +441,7 @@ def register(
|
|||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
|
log("server", f"enrollment poll unknown_reference ref={approval_reference}")
|
||||||
return jsonify({"status": "unknown"}), 404
|
return jsonify({"status": "unknown"}), 404
|
||||||
|
|
||||||
(
|
(
|
||||||
@@ -383,11 +460,13 @@ def register(
|
|||||||
) = row
|
) = row
|
||||||
|
|
||||||
if client_nonce_stored != client_nonce_b64:
|
if client_nonce_stored != client_nonce_b64:
|
||||||
|
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}")
|
||||||
return jsonify({"error": "nonce_mismatch"}), 400
|
return jsonify({"error": "nonce_mismatch"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
|
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}")
|
||||||
return jsonify({"error": "server_nonce_invalid"}), 400
|
return jsonify({"error": "server_nonce_invalid"}), 400
|
||||||
|
|
||||||
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
|
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
|
||||||
@@ -395,30 +474,58 @@ def register(
|
|||||||
try:
|
try:
|
||||||
public_key = serialization.load_der_public_key(agent_pubkey_der)
|
public_key = serialization.load_der_public_key(agent_pubkey_der)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}")
|
||||||
public_key = None
|
public_key = None
|
||||||
|
|
||||||
if public_key is None:
|
if public_key is None:
|
||||||
|
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}")
|
||||||
return jsonify({"error": "agent_pubkey_invalid"}), 400
|
return jsonify({"error": "agent_pubkey_invalid"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
public_key.verify(proof_sig, message)
|
public_key.verify(proof_sig, message)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
log("server", f"enrollment poll invalid_proof ref={approval_reference}")
|
||||||
return jsonify({"error": "invalid_proof"}), 400
|
return jsonify({"error": "invalid_proof"}), 400
|
||||||
|
|
||||||
if status == "pending":
|
if status == "pending":
|
||||||
|
log(
|
||||||
|
"server",
|
||||||
|
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
|
||||||
|
f" fingerprint={fingerprint[:12]}",
|
||||||
|
)
|
||||||
return jsonify({"status": "pending", "poll_after_ms": 5000})
|
return jsonify({"status": "pending", "poll_after_ms": 5000})
|
||||||
if status == "denied":
|
if status == "denied":
|
||||||
|
log(
|
||||||
|
"server",
|
||||||
|
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
|
||||||
|
)
|
||||||
return jsonify({"status": "denied", "reason": "operator_denied"})
|
return jsonify({"status": "denied", "reason": "operator_denied"})
|
||||||
if status == "expired":
|
if status == "expired":
|
||||||
|
log(
|
||||||
|
"server",
|
||||||
|
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
|
||||||
|
)
|
||||||
return jsonify({"status": "expired"})
|
return jsonify({"status": "expired"})
|
||||||
if status == "completed":
|
if status == "completed":
|
||||||
|
log(
|
||||||
|
"server",
|
||||||
|
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
|
||||||
|
)
|
||||||
return jsonify({"status": "approved", "detail": "finalized"})
|
return jsonify({"status": "approved", "detail": "finalized"})
|
||||||
|
|
||||||
if status != "approved":
|
if status != "approved":
|
||||||
|
log(
|
||||||
|
"server",
|
||||||
|
f"enrollment poll unexpected_status={status} ref={approval_reference}",
|
||||||
|
)
|
||||||
return jsonify({"status": status or "unknown"}), 400
|
return jsonify({"status": status or "unknown"}), 400
|
||||||
|
|
||||||
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
|
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
|
||||||
if not nonce_cache.consume(nonce_key):
|
if not nonce_cache.consume(nonce_key):
|
||||||
|
log(
|
||||||
|
"server",
|
||||||
|
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
|
||||||
|
)
|
||||||
return jsonify({"error": "proof_replayed"}), 409
|
return jsonify({"error": "proof_replayed"}), 409
|
||||||
|
|
||||||
# Finalize enrollment
|
# Finalize enrollment
|
||||||
@@ -489,3 +596,12 @@ def _load_tls_bundle(path: str) -> str:
|
|||||||
return fh.read()
|
return fh.read()
|
||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_code(code: str) -> str:
|
||||||
|
if not code:
|
||||||
|
return "<missing>"
|
||||||
|
trimmed = str(code).strip()
|
||||||
|
if len(trimmed) <= 6:
|
||||||
|
return "***"
|
||||||
|
return f"{trimmed[:3]}***{trimmed[-3:]}"
|
||||||
|
|||||||
Reference in New Issue
Block a user