Merge pull request #128 from bunny-lab-io:codex/implement-secure-agent-enrollment-features

Handle missing devices without re-enrollment loops
This commit is contained in:
2025-10-18 05:46:33 -06:00
committed by GitHub
13 changed files with 1401 additions and 127 deletions

View File

@@ -23,7 +23,8 @@ import ssl
import threading import threading
import contextlib import contextlib
import errno import errno
from typing import Any, Dict, Optional, List, Callable import re
from typing import Any, Dict, Optional, List, Callable, Tuple
import requests import requests
try: try:
@@ -66,21 +67,120 @@ def _rotate_daily(path: str):
# Early bootstrap logging (goes to agent.log) # Early bootstrap logging (goes to agent.log)
def _bootstrap_log(msg: str): _AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
_AGENT_SCOPE_PATTERN = re.compile(r"\\bscope=([A-Za-z0-9_-]+)", re.IGNORECASE)
def _canonical_scope_value(raw: Optional[str]) -> Optional[str]:
if not raw:
return None
value = "".join(ch for ch in str(raw) if ch.isalnum() or ch in ("_", "-"))
if not value:
return None
return value.upper()
def _agent_context_default() -> Optional[str]:
suffix = globals().get("CONFIG_SUFFIX_CANONICAL")
context = _canonical_scope_value(suffix)
if context:
return context
service = globals().get("SERVICE_MODE_CANONICAL")
context = _canonical_scope_value(service)
if context:
return context
return None
def _infer_agent_scope(message: str, provided_scope: Optional[str] = None) -> Optional[str]:
scope = _canonical_scope_value(provided_scope)
if scope:
return scope
match = _AGENT_SCOPE_PATTERN.search(message or "")
if match:
scope = _canonical_scope_value(match.group(1))
if scope:
return scope
return _agent_context_default()
def _format_agent_log_message(message: str, fname: str, scope: Optional[str] = None) -> str:
context = _infer_agent_scope(message, scope)
if fname == "agent.error.log":
prefix = "[ERROR]"
if context:
prefix = f"{prefix}[CONTEXT-{context}]"
return f"{prefix} {message}"
if context:
return f"[CONTEXT-{context}] {message}"
return f"[INFO] {message}"
def _bootstrap_log(msg: str, *, scope: Optional[str] = None):
try: try:
base = _agent_logs_root() base = _agent_logs_root()
os.makedirs(base, exist_ok=True) os.makedirs(base, exist_ok=True)
path = os.path.join(base, 'agent.log') path = os.path.join(base, 'agent.log')
_rotate_daily(path) _rotate_daily(path)
ts = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') ts = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
line = _format_agent_log_message(msg, 'agent.log', scope)
with open(path, 'a', encoding='utf-8') as fh: with open(path, 'a', encoding='utf-8') as fh:
fh.write(f'[{ts}] {msg}\n') fh.write(f'[{ts}] {line}\n')
except Exception:
pass
def _describe_exception(exc: BaseException) -> str:
try:
primary = f"{exc.__class__.__name__}: {exc}"
except Exception:
primary = repr(exc)
parts = [primary]
try:
cause = getattr(exc, "__cause__", None)
if cause and cause is not exc:
parts.append(f"cause={cause.__class__.__name__}: {cause}")
except Exception:
pass
try:
context = getattr(exc, "__context__", None)
if context and context is not exc and context is not getattr(exc, "__cause__", None):
parts.append(f"context={context.__class__.__name__}: {context}")
except Exception:
pass
try:
args = getattr(exc, "args", None)
if isinstance(args, tuple) and len(args) > 1:
parts.append(f"args={args!r}")
except Exception:
pass
try:
details = getattr(exc, "__dict__", None)
if isinstance(details, dict):
# Capture noteworthy nested attributes such as os_error/errno to help diagnose
# connection failures that collapse into generic ConnectionError wrappers.
for key in ("os_error", "errno", "code", "status"):
if key in details and details[key]:
parts.append(f"{key}={details[key]!r}")
except Exception:
pass
return "; ".join(part for part in parts if part)
def _log_exception_trace(prefix: str) -> None:
try:
tb = traceback.format_exc()
if not tb:
return
for line in tb.rstrip().splitlines():
_log_agent(f"{prefix} trace: {line}", fname="agent.error.log")
except Exception: except Exception:
pass pass
# Headless/service mode flag (skip Qt and interactive UI) # Headless/service mode flag (skip Qt and interactive UI)
SYSTEM_SERVICE_MODE = ('--system-service' in sys.argv) or (os.environ.get('BOREALIS_AGENT_MODE') == 'system') SYSTEM_SERVICE_MODE = ('--system-service' in sys.argv) or (os.environ.get('BOREALIS_AGENT_MODE') == 'system')
SERVICE_MODE = 'system' if SYSTEM_SERVICE_MODE else 'currentuser' SERVICE_MODE = 'system' if SYSTEM_SERVICE_MODE else 'currentuser'
SERVICE_MODE_CANONICAL = SERVICE_MODE.upper()
_bootstrap_log(f'agent.py loaded; SYSTEM_SERVICE_MODE={SYSTEM_SERVICE_MODE}; argv={sys.argv!r}') _bootstrap_log(f'agent.py loaded; SYSTEM_SERVICE_MODE={SYSTEM_SERVICE_MODE}; argv={sys.argv!r}')
def _argv_get(flag: str, default: str = None): def _argv_get(flag: str, default: str = None):
try: try:
@@ -359,15 +459,16 @@ def _find_project_root():
return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
# Simple file logger under Logs/Agent # Simple file logger under Logs/Agent
def _log_agent(message: str, fname: str = 'agent.log'): def _log_agent(message: str, fname: str = 'agent.log', *, scope: Optional[str] = None):
try: try:
log_dir = _agent_logs_root() log_dir = _agent_logs_root()
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
ts = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') ts = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
path = os.path.join(log_dir, fname) path = os.path.join(log_dir, fname)
_rotate_daily(path) _rotate_daily(path)
line = _format_agent_log_message(message, fname, scope)
with open(path, 'a', encoding='utf-8') as fh: with open(path, 'a', encoding='utf-8') as fh:
fh.write(f'[{ts}] {message}\n') fh.write(f'[{ts}] {line}\n')
except Exception: except Exception:
pass pass
@@ -384,6 +485,31 @@ def _mask_sensitive(value: str, *, prefix: int = 4, suffix: int = 4) -> str:
return '***' return '***'
def _format_debug_pairs(pairs: Dict[str, Any]) -> str:
try:
parts = []
for key, value in pairs.items():
parts.append(f"{key}={value!r}")
return ", ".join(parts)
except Exception:
return repr(pairs)
def _summarize_headers(headers: Dict[str, str]) -> str:
try:
rendered: List[str] = []
for key, value in headers.items():
lowered = key.lower()
display = value
if lowered == 'authorization':
token = value.split()[-1] if value and ' ' in value else value
display = f"Bearer {_mask_sensitive(token)}"
rendered.append(f"{key}={display}")
return ", ".join(rendered)
except Exception:
return '<unavailable>'
def _decode_base64_text(value): def _decode_base64_text(value):
if not isinstance(value, str): if not isinstance(value, str):
return None return None
@@ -571,6 +697,57 @@ DEFAULT_CONFIG = {
"installer_code": "" "installer_code": ""
} }
def _load_installer_code_from_file(path: str) -> str:
try:
with open(path, "r", encoding="utf-8") as fh:
data = json.load(fh)
except Exception:
return ""
value = data.get("installer_code") if isinstance(data, dict) else ""
if isinstance(value, str):
return value.strip()
return ""
def _fallback_installer_code(current_path: str) -> str:
settings_dir = os.path.dirname(current_path)
candidates: List[str] = []
suffix = CONFIG_SUFFIX_CANONICAL
sibling_map = {
"SYSTEM": "agent_settings_CURRENTUSER.json",
"CURRENTUSER": "agent_settings_SYSTEM.json",
}
sibling_name = sibling_map.get(suffix or "")
if sibling_name:
candidates.append(os.path.join(settings_dir, sibling_name))
# Prefer the shared/base config next
candidates.append(os.path.join(settings_dir, "agent_settings.json"))
# Legacy location fallback
try:
project_root = _find_project_root()
legacy_dir = os.path.join(project_root, "Agent", "Settings")
if sibling_name:
candidates.append(os.path.join(legacy_dir, sibling_name))
candidates.append(os.path.join(legacy_dir, "agent_settings.json"))
except Exception:
pass
current_abspath = os.path.abspath(current_path)
for candidate in candidates:
if not candidate:
continue
try:
candidate_path = os.path.abspath(candidate)
except Exception:
continue
if candidate_path == current_abspath or not os.path.isfile(candidate_path):
continue
code = _load_installer_code_from_file(candidate_path)
if code:
return code
return ""
class ConfigManager: class ConfigManager:
def __init__(self, path): def __init__(self, path):
self.path = path self.path = path
@@ -631,6 +808,9 @@ class AgentHttpClient:
self.key_store = _key_store() self.key_store = _key_store()
self.identity = IDENTITY self.identity = IDENTITY
self.session = requests.Session() self.session = requests.Session()
context_label = _agent_context_default()
if context_label:
self.session.headers.setdefault(_AGENT_CONTEXT_HEADER, context_label)
self.base_url: Optional[str] = None self.base_url: Optional[str] = None
self.guid: Optional[str] = None self.guid: Optional[str] = None
self.access_token: Optional[str] = None self.access_token: Optional[str] = None
@@ -697,42 +877,115 @@ class AgentHttpClient:
pass pass
def auth_headers(self) -> Dict[str, str]: def auth_headers(self) -> Dict[str, str]:
headers: Dict[str, str] = {}
if self.access_token: if self.access_token:
return {"Authorization": f"Bearer {self.access_token}"} headers["Authorization"] = f"Bearer {self.access_token}"
return {} context_label = _agent_context_default()
if context_label:
headers[_AGENT_CONTEXT_HEADER] = context_label
return headers
def configure_socketio(self, client: "socketio.AsyncClient") -> None: def configure_socketio(self, client: "socketio.AsyncClient") -> None:
"""Align the Socket.IO engine's TLS verification with the REST client.""" """Align the Socket.IO engine's TLS verification with the REST client."""
try: try:
verify = getattr(self.session, "verify", True) verify = getattr(self.session, "verify", True)
engine = getattr(client, "eio", None) engine = getattr(client, "eio", None)
if engine is None: if engine is None:
_log_agent(
"SocketIO TLS alignment skipped; AsyncClient.eio missing",
fname="agent.error.log",
)
return return
# python-engineio accepts either a boolean or an ``ssl.SSLContext`` http_iface = getattr(engine, "http", None)
# for TLS verification. When we have a pinned certificate bundle
# on disk, prefer constructing a dedicated context that trusts that debug_info = {
# bundle so WebSocket connections succeed even with private CAs. "verify_type": type(verify).__name__,
"verify_value": verify,
"engine_type": type(engine).__name__,
"http_iface_present": http_iface is not None,
}
_log_agent(
f"SocketIO TLS alignment start: {_format_debug_pairs(debug_info)}",
fname="agent.log",
)
def _set_attr(target: Any, name: str, value: Any) -> None:
if target is None:
return
try:
setattr(target, name, value)
except Exception:
pass
def _reset_cached_session() -> None:
if http_iface is None:
return
try:
if hasattr(http_iface, "session"):
setattr(http_iface, "session", None)
except Exception:
pass
context = None
if isinstance(verify, str) and os.path.isfile(verify): if isinstance(verify, str) and os.path.isfile(verify):
try: try:
context = ssl.create_default_context(cafile=verify) # Mirror Requests' certificate handling by starting from a
# default client context (which pre-loads the system
# certificate stores) and then layering the pinned
# certificate bundle on top. This matches the REST client
# behaviour and ensures self-signed leaf certificates work
# the same way for Socket.IO handshakes.
context = ssl.create_default_context()
context.check_hostname = False context.check_hostname = False
context.load_verify_locations(cafile=verify)
_log_agent(
f"SocketIO TLS alignment created SSLContext from cafile={verify}",
fname="agent.log",
)
except Exception: except Exception:
context = None context = None
if context is not None: _log_agent(
engine.ssl_context = context f"SocketIO TLS alignment failed to build context from cafile={verify}",
engine.ssl_verify = True fname="agent.error.log",
else: )
engine.ssl_context = None
engine.ssl_verify = verify if context is not None:
elif verify is False: _set_attr(engine, "ssl_context", context)
engine.ssl_context = None _set_attr(engine, "ssl_verify", True)
engine.ssl_verify = False _set_attr(engine, "verify_ssl", True)
else: _set_attr(http_iface, "ssl_context", context)
engine.ssl_context = None _set_attr(http_iface, "ssl_verify", True)
engine.ssl_verify = True _set_attr(http_iface, "verify_ssl", True)
_reset_cached_session()
_log_agent(
"SocketIO TLS alignment applied dedicated SSLContext to engine/http",
fname="agent.log",
)
return
# Fall back to boolean verification flags when we either do not
# have a pinned certificate bundle or failed to build a dedicated
# context for it.
verify_flag = False if verify is False else True
_set_attr(engine, "ssl_context", None)
_set_attr(engine, "ssl_verify", verify_flag)
_set_attr(engine, "verify_ssl", verify_flag)
_set_attr(http_iface, "ssl_context", None)
_set_attr(http_iface, "ssl_verify", verify_flag)
_set_attr(http_iface, "verify_ssl", verify_flag)
_reset_cached_session()
_log_agent(
f"SocketIO TLS alignment fallback verify_flag={verify_flag}",
fname="agent.log",
)
except Exception: except Exception:
pass _log_agent(
"SocketIO TLS alignment encountered unexpected error",
fname="agent.error.log",
)
_log_exception_trace("configure_socketio")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Enrollment & token management # Enrollment & token management
@@ -1007,10 +1260,22 @@ class AgentHttpClient:
timeout=20, timeout=20,
) )
if resp.status_code in (401, 403): if resp.status_code in (401, 403):
_log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log") error_code, snippet = self._error_details(resp)
self._clear_tokens_locked() if resp.status_code == 401 and self._should_retry_auth(resp.status_code, error_code):
self._perform_enrollment_locked() _log_agent(
return "Refresh token rejected; attempting re-enrollment"
f" error={error_code or '<unknown>'}",
fname="agent.error.log",
)
self._clear_tokens_locked()
self._perform_enrollment_locked()
return
_log_agent(
"Refresh token request forbidden "
f"status={resp.status_code} error={error_code or '<unknown>'}"
f" body_snippet={snippet}",
fname="agent.error.log",
)
resp.raise_for_status() resp.raise_for_status()
data = resp.json() data = resp.json()
access_token = data.get("access_token") access_token = data.get("access_token")
@@ -1036,14 +1301,79 @@ class AgentHttpClient:
self.guid = self.key_store.load_guid() self.guid = self.key_store.load_guid()
self.session.headers.pop("Authorization", None) self.session.headers.pop("Authorization", None)
def _error_details(self, response: requests.Response) -> Tuple[Optional[str], str]:
error_code: Optional[str] = None
snippet = ""
try:
snippet = response.text[:256]
except Exception:
snippet = "<unavailable>"
try:
data = response.json()
except Exception:
data = None
if isinstance(data, dict):
for key in ("error", "code", "status"):
value = data.get(key)
if isinstance(value, str) and value.strip():
error_code = value.strip()
break
return error_code, snippet
def _should_retry_auth(self, status_code: int, error_code: Optional[str]) -> bool:
if status_code == 401:
return True
retryable_forbidden = {"fingerprint_mismatch"}
if status_code == 403 and error_code in retryable_forbidden:
return True
return False
def _resolve_installer_code(self) -> str: def _resolve_installer_code(self) -> str:
if INSTALLER_CODE_OVERRIDE: if INSTALLER_CODE_OVERRIDE:
return INSTALLER_CODE_OVERRIDE code = INSTALLER_CODE_OVERRIDE.strip()
if code:
try:
self.key_store.cache_installer_code(code, consumer=SERVICE_MODE_CANONICAL)
except Exception:
pass
return code
code = ""
try: try:
code = (CONFIG.data.get("installer_code") or "").strip() code = (CONFIG.data.get("installer_code") or "").strip()
return code
except Exception: except Exception:
return "" code = ""
if code:
try:
self.key_store.cache_installer_code(code, consumer=SERVICE_MODE_CANONICAL)
except Exception:
pass
return code
try:
cached = self.key_store.load_cached_installer_code()
except Exception:
cached = None
if cached:
try:
self.key_store.cache_installer_code(cached, consumer=SERVICE_MODE_CANONICAL)
except Exception:
pass
return cached
fallback = _fallback_installer_code(CONFIG.path)
if fallback:
try:
CONFIG.data["installer_code"] = fallback
CONFIG._write()
_log_agent(
"Adopted installer code from sibling configuration", fname="agent.log"
)
except Exception:
pass
try:
self.key_store.cache_installer_code(fallback, consumer=SERVICE_MODE_CANONICAL)
except Exception:
pass
return fallback
return ""
def _consume_installer_code(self) -> None: def _consume_installer_code(self) -> None:
# Avoid clearing explicit CLI/env overrides; only mutate persisted config. # Avoid clearing explicit CLI/env overrides; only mutate persisted config.
@@ -1057,6 +1387,13 @@ class AgentHttpClient:
_log_agent("Cleared persisted installer code after successful enrollment", fname="agent.log") _log_agent("Cleared persisted installer code after successful enrollment", fname="agent.log")
except Exception as exc: except Exception as exc:
_log_agent(f"Failed to clear installer code after enrollment: {exc}", fname="agent.error.log") _log_agent(f"Failed to clear installer code after enrollment: {exc}", fname="agent.error.log")
try:
self.key_store.mark_installer_code_consumed(SERVICE_MODE_CANONICAL)
except Exception as exc:
_log_agent(
f"Failed to update shared installer code cache: {exc}",
fname="agent.error.log",
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# HTTP helpers # HTTP helpers
@@ -1068,20 +1405,19 @@ class AgentHttpClient:
headers = self.auth_headers() headers = self.auth_headers()
response = self.session.post(url, json=payload, headers=headers, timeout=30) response = self.session.post(url, json=payload, headers=headers, timeout=30)
if response.status_code in (401, 403) and require_auth: if response.status_code in (401, 403) and require_auth:
snippet = "" error_code, snippet = self._error_details(response)
try: if self._should_retry_auth(response.status_code, error_code):
snippet = response.text[:256] self.clear_tokens()
except Exception: self.ensure_authenticated()
snippet = "<unavailable>" headers = self.auth_headers()
_log_agent( response = self.session.post(url, json=payload, headers=headers, timeout=30)
"Authenticated request rejected " else:
f"path={path} status={response.status_code} body_snippet={snippet}", _log_agent(
fname="agent.error.log", "Authenticated request rejected "
) f"path={path} status={response.status_code} error={error_code or '<unknown>'}"
self.clear_tokens() f" body_snippet={snippet}",
self.ensure_authenticated() fname="agent.error.log",
headers = self.auth_headers() )
response = self.session.post(url, json=payload, headers=headers, timeout=30)
response.raise_for_status() response.raise_for_status()
if response.headers.get("Content-Type", "").lower().startswith("application/json"): if response.headers.get("Content-Type", "").lower().startswith("application/json"):
return response.json() return response.json()
@@ -2107,6 +2443,15 @@ async def send_agent_details_once():
async def connect(): async def connect():
print(f"[INFO] Successfully Connected to Borealis Server!") print(f"[INFO] Successfully Connected to Borealis Server!")
_log_agent('Connected to server.') _log_agent('Connected to server.')
try:
sid = getattr(sio, 'sid', None)
transport = getattr(sio, 'transport', None)
_log_agent(
f'WebSocket handshake established sid={sid!r} transport={transport!r}',
fname='agent.log',
)
except Exception:
pass
await sio.emit('connect_agent', {"agent_id": AGENT_ID, "service_mode": SERVICE_MODE}) await sio.emit('connect_agent', {"agent_id": AGENT_ID, "service_mode": SERVICE_MODE})
# Send an immediate heartbeat via authenticated REST call. # Send an immediate heartbeat via authenticated REST call.
@@ -2143,6 +2488,17 @@ async def connect():
except Exception: except Exception:
pass pass
@sio.event
async def connect_error(data):
try:
setattr(sio, "connection_error", data)
except Exception:
pass
try:
_log_agent(f'Socket connect_error event: {data!r}', fname='agent.error.log')
except Exception:
pass
@sio.event @sio.event
async def disconnect(): async def disconnect():
print("[WebSocket] Disconnected from Borealis server.") print("[WebSocket] Disconnected from Borealis server.")
@@ -2390,22 +2746,64 @@ if not SYSTEM_SERVICE_MODE:
async def connect_loop(): async def connect_loop():
retry = 5 retry = 5
client = http_client() client = http_client()
attempt = 0
while True: while True:
attempt += 1
try: try:
_log_agent(
f'connect_loop attempt={attempt} starting authentication phase',
fname='agent.log',
)
client.ensure_authenticated() client.ensure_authenticated()
auth_snapshot = {
'guid_present': bool(client.guid),
'access_token': bool(client.access_token),
'refresh_token': bool(client.refresh_token),
'access_expiry': client.access_expires_at,
}
_log_agent(
f"connect_loop attempt={attempt} auth snapshot: {_format_debug_pairs(auth_snapshot)}",
fname='agent.log',
)
client.configure_socketio(sio) client.configure_socketio(sio)
try:
setattr(sio, "connection_error", None)
except Exception:
pass
url = client.websocket_base_url() url = client.websocket_base_url()
headers = client.auth_headers()
header_summary = _summarize_headers(headers)
verify_value = getattr(client.session, 'verify', None)
_log_agent(
f"connect_loop attempt={attempt} dialing websocket url={url} transports=['websocket'] verify={verify_value!r} headers={header_summary}",
fname='agent.log',
)
print(f"[INFO] Connecting Agent to {url}...") print(f"[INFO] Connecting Agent to {url}...")
_log_agent(f'Connecting to {url}...')
await sio.connect( await sio.connect(
url, url,
transports=['websocket'], transports=['websocket'],
headers=client.auth_headers(), headers=headers,
)
_log_agent(
f'connect_loop attempt={attempt} sio.connect completed successfully',
fname='agent.log',
) )
break break
except Exception as e: except Exception as e:
print(f"[WebSocket] Server unavailable: {e}. Retrying in {retry}s...") detail = _describe_exception(e)
_log_agent(f'Server unavailable: {e}', fname='agent.error.log') try:
conn_err = getattr(sio, "connection_error", None)
except Exception:
conn_err = None
if conn_err:
detail = f"{detail}; connection_error={conn_err!r}"
message = (
f"connect_loop attempt={attempt} server unavailable: {detail}. "
f"Retrying in {retry}s..."
)
print(f"[WebSocket] {message}")
_log_agent(message, fname='agent.error.log')
_log_exception_trace(f'connect_loop attempt={attempt}')
await asyncio.sleep(retry) await asyncio.sleep(retry)
if __name__=='__main__': if __name__=='__main__':

View File

@@ -230,6 +230,7 @@ class AgentKeyStore:
self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem") self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem")
self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub") self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub")
self._identity_lock_path = os.path.join(self.settings_dir, "identity.lock") self._identity_lock_path = os.path.join(self.settings_dir, "identity.lock")
self._installer_cache_path = os.path.join(self.settings_dir, "installer_code.shared.json")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Identity management # Identity management
@@ -455,3 +456,107 @@ class AgentKeyStore:
if isinstance(value, str) and value.strip(): if isinstance(value, str) and value.strip():
return value.strip() return value.strip()
return None return None
# ------------------------------------------------------------------
# Installer code sharing helpers
# ------------------------------------------------------------------
def _load_installer_cache(self) -> dict:
if not os.path.isfile(self._installer_cache_path):
return {}
try:
with open(self._installer_cache_path, "r", encoding="utf-8") as fh:
data = json.load(fh)
if isinstance(data, dict):
return data
except Exception:
pass
return {}
def _store_installer_cache(self, payload: dict) -> None:
try:
with open(self._installer_cache_path, "w", encoding="utf-8") as fh:
json.dump(payload, fh, indent=2)
_restrict_permissions(self._installer_cache_path)
except Exception:
pass
def cache_installer_code(self, code: str, consumer: Optional[str] = None) -> None:
normalized = (code or "").strip()
if not normalized:
return
payload = self._load_installer_cache()
payload["code"] = normalized
consumers = set()
existing = payload.get("consumed")
if isinstance(existing, list):
consumers = {str(item).upper() for item in existing if isinstance(item, str)}
if consumer:
consumers.add(str(consumer).upper())
payload["consumed"] = sorted(consumers)
payload["updated_at"] = int(time.time())
self._store_installer_cache(payload)
def load_cached_installer_code(self) -> Optional[str]:
payload = self._load_installer_cache()
code = payload.get("code")
if isinstance(code, str):
stripped = code.strip()
if stripped:
return stripped
return None
def mark_installer_code_consumed(self, consumer: Optional[str] = None) -> None:
payload = self._load_installer_cache()
if not payload:
return
consumers = set()
existing = payload.get("consumed")
if isinstance(existing, list):
consumers = {str(item).upper() for item in existing if isinstance(item, str)}
if consumer:
consumers.add(str(consumer).upper())
payload["consumed"] = sorted(consumers)
payload["updated_at"] = int(time.time())
code_present = isinstance(payload.get("code"), str) and payload["code"].strip()
should_clear = False
if not code_present:
should_clear = True
else:
required_consumers = {"SYSTEM", "CURRENTUSER"}
if required_consumers.issubset(consumers):
should_clear = True
else:
remaining = required_consumers - consumers
if not remaining:
should_clear = True
else:
exists_other = False
for other in remaining:
if other == "SYSTEM":
cfg_name = "agent_settings_SYSTEM.json"
elif other == "CURRENTUSER":
cfg_name = "agent_settings_CURRENTUSER.json"
else:
cfg_name = None
if not cfg_name:
continue
path = os.path.join(self.settings_dir, cfg_name)
if os.path.isfile(path):
exists_other = True
break
if not exists_other:
should_clear = True
if should_clear:
payload.pop("code", None)
payload["consumed"] = []
if payload.get("code") or payload.get("consumed"):
self._store_installer_cache(payload)
else:
try:
if os.path.isfile(self._installer_cache_path):
os.remove(self._installer_cache_path)
except Exception:
pass

View File

@@ -18,7 +18,7 @@ def register(
db_conn_factory: Callable[[], sqlite3.Connection], db_conn_factory: Callable[[], sqlite3.Connection],
require_admin: Callable[[], Optional[Any]], require_admin: Callable[[], Optional[Any]],
current_user: Callable[[], Optional[Dict[str, str]]], current_user: Callable[[], Optional[Dict[str, str]]],
log: Callable[[str, str], None], log: Callable[[str, str, Optional[str]], None],
) -> None: ) -> None:
blueprint = Blueprint("admin", __name__) blueprint = Blueprint("admin", __name__)
@@ -54,18 +54,27 @@ def register(
try: try:
cur = conn.cursor() cur = conn.cursor()
sql = """ sql = """
SELECT id, code, expires_at, created_by_user_id, used_at, used_by_guid SELECT id,
code,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes FROM enrollment_install_codes
""" """
params: List[str] = [] params: List[str] = []
now_iso = _iso(_now())
if status_filter == "active": if status_filter == "active":
sql += " WHERE used_at IS NULL AND expires_at > ?" sql += " WHERE use_count < max_uses AND expires_at > ?"
params.append(_iso(_now())) params.append(now_iso)
elif status_filter == "expired": elif status_filter == "expired":
sql += " WHERE used_at IS NULL AND expires_at <= ?" sql += " WHERE use_count < max_uses AND expires_at <= ?"
params.append(_iso(_now())) params.append(now_iso)
elif status_filter == "used": elif status_filter == "used":
sql += " WHERE used_at IS NOT NULL" sql += " WHERE use_count >= max_uses"
sql += " ORDER BY expires_at ASC" sql += " ORDER BY expires_at ASC"
cur.execute(sql, params) cur.execute(sql, params)
rows = cur.fetchall() rows = cur.fetchall()
@@ -82,6 +91,9 @@ def register(
"created_by_user_id": row[3], "created_by_user_id": row[3],
"used_at": row[4], "used_at": row[4],
"used_by_guid": row[5], "used_by_guid": row[5],
"max_uses": row[6],
"use_count": row[7],
"last_used_at": row[8],
} }
) )
return jsonify({"codes": records}) return jsonify({"codes": records})
@@ -93,6 +105,18 @@ def register(
if ttl_hours not in VALID_TTL_HOURS: if ttl_hours not in VALID_TTL_HOURS:
return jsonify({"error": "invalid_ttl"}), 400 return jsonify({"error": "invalid_ttl"}), 400
max_uses_value = payload.get("max_uses")
if max_uses_value is None:
max_uses_value = payload.get("allowed_uses")
try:
max_uses = int(max_uses_value)
except Exception:
max_uses = 2
if max_uses < 1:
max_uses = 1
if max_uses > 10:
max_uses = 10
user = current_user() or {} user = current_user() or {}
username = user.get("username") or "" username = user.get("username") or ""
@@ -106,22 +130,28 @@ def register(
cur.execute( cur.execute(
""" """
INSERT INTO enrollment_install_codes ( INSERT INTO enrollment_install_codes (
id, code, expires_at, created_by_user_id id, code, expires_at, created_by_user_id, max_uses, use_count
) )
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, 0)
""", """,
(record_id, code_value, _iso(expires_at), created_by), (record_id, code_value, _iso(expires_at), created_by, max_uses),
) )
conn.commit() conn.commit()
finally: finally:
conn.close() conn.close()
log("server", f"installer code created id={record_id} by={username} ttl={ttl_hours}h") log(
"server",
f"installer code created id={record_id} by={username} ttl={ttl_hours}h max_uses={max_uses}",
)
return jsonify( return jsonify(
{ {
"id": record_id, "id": record_id,
"code": code_value, "code": code_value,
"expires_at": _iso(expires_at), "expires_at": _iso(expires_at),
"max_uses": max_uses,
"use_count": 0,
"last_used_at": None,
} }
) )
@@ -131,7 +161,7 @@ def register(
try: try:
cur = conn.cursor() cur = conn.cursor()
cur.execute( cur.execute(
"DELETE FROM enrollment_install_codes WHERE id = ? AND used_at IS NULL", "DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
(code_id,), (code_id,),
) )
deleted = cur.rowcount deleted = cur.rowcount

View File

@@ -10,13 +10,24 @@ from flask import Blueprint, jsonify, request, g
from Modules.auth.device_auth import DeviceAuthManager, require_device_auth from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
from Modules.crypto.signing import ScriptSigner from Modules.crypto.signing import ScriptSigner
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
def register( def register(
app, app,
*, *,
db_conn_factory: Callable[[], Any], db_conn_factory: Callable[[], Any],
auth_manager: DeviceAuthManager, auth_manager: DeviceAuthManager,
log: Callable[[str, str], None], log: Callable[[str, str, Optional[str]], None],
script_signer: ScriptSigner, script_signer: ScriptSigner,
) -> None: ) -> None:
blueprint = Blueprint("agents", __name__) blueprint = Blueprint("agents", __name__)
@@ -29,10 +40,15 @@ def register(
except Exception: except Exception:
return None return None
def _context_hint(ctx=None) -> Optional[str]:
if ctx is not None and getattr(ctx, "service_mode", None):
return _canonical_context(getattr(ctx, "service_mode", None))
return _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
def _auth_context(): def _auth_context():
ctx = getattr(g, "device_auth", None) ctx = getattr(g, "device_auth", None)
if ctx is None: if ctx is None:
log("server", f"device auth context missing for {request.path}") log("server", f"device auth context missing for {request.path}", _context_hint())
return ctx return ctx
@blueprint.route("/api/agent/heartbeat", methods=["POST"]) @blueprint.route("/api/agent/heartbeat", methods=["POST"])
@@ -42,6 +58,7 @@ def register(
if ctx is None: if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500 return jsonify({"error": "auth_context_missing"}), 500
payload = request.get_json(force=True, silent=True) or {} payload = request.get_json(force=True, silent=True) or {}
context_label = _context_hint(ctx)
now_ts = int(time.time()) now_ts = int(time.time())
updates: Dict[str, Optional[str]] = {"last_seen": now_ts} updates: Dict[str, Optional[str]] = {"last_seen": now_ts}
@@ -111,12 +128,13 @@ def register(
"server", "server",
"heartbeat hostname collision ignored for guid=" "heartbeat hostname collision ignored for guid="
f"{ctx.guid}", f"{ctx.guid}",
context_label,
) )
else: else:
raise raise
if rowcount == 0: if rowcount == 0:
log("server", f"heartbeat missing device record guid={ctx.guid}") log("server", f"heartbeat missing device record guid={ctx.guid}", context_label)
return jsonify({"error": "device_not_registered"}), 404 return jsonify({"error": "device_not_registered"}), 404
conn.commit() conn.commit()
finally: finally:

View File

@@ -1,7 +1,11 @@
from __future__ import annotations from __future__ import annotations
import functools import functools
import sqlite3
import time
from contextlib import closing
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional
import jwt import jwt
@@ -10,6 +14,17 @@ from flask import g, jsonify, request
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
from Modules.auth.rate_limit import SlidingWindowRateLimiter from Modules.auth.rate_limit import SlidingWindowRateLimiter
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
@dataclass @dataclass
class DeviceAuthContext: class DeviceAuthContext:
@@ -20,6 +35,7 @@ class DeviceAuthContext:
claims: Dict[str, Any] claims: Dict[str, Any]
dpop_jkt: Optional[str] dpop_jkt: Optional[str]
status: str status: str
service_mode: Optional[str]
class DeviceAuthError(Exception): class DeviceAuthError(Exception):
@@ -47,7 +63,7 @@ class DeviceAuthManager:
db_conn_factory: Callable[[], Any], db_conn_factory: Callable[[], Any],
jwt_service, jwt_service,
dpop_validator: Optional[DPoPValidator], dpop_validator: Optional[DPoPValidator],
log: Callable[[str, str], None], log: Callable[[str, str, Optional[str]], None],
rate_limiter: Optional[SlidingWindowRateLimiter] = None, rate_limiter: Optional[SlidingWindowRateLimiter] = None,
) -> None: ) -> None:
self._db_conn_factory = db_conn_factory self._db_conn_factory = db_conn_factory
@@ -86,8 +102,9 @@ class DeviceAuthManager:
retry_after=decision.retry_after, retry_after=decision.retry_after,
) )
conn = self._db_conn_factory() context_label = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
try:
with closing(self._db_conn_factory()) as conn:
cur = conn.cursor() cur = conn.cursor()
cur.execute( cur.execute(
""" """
@@ -98,8 +115,11 @@ class DeviceAuthManager:
(guid,), (guid,),
) )
row = cur.fetchone() row = cur.fetchone()
finally:
conn.close() if not row:
row = self._recover_device_record(
conn, guid, fingerprint, token_version, context_label
)
if not row: if not row:
raise DeviceAuthError("device_not_found", status_code=403) raise DeviceAuthError("device_not_found", status_code=403)
@@ -121,7 +141,11 @@ class DeviceAuthManager:
if status_normalized not in allowed_statuses: if status_normalized not in allowed_statuses:
raise DeviceAuthError("device_revoked", status_code=403) raise DeviceAuthError("device_revoked", status_code=403)
if status_normalized == "quarantined": if status_normalized == "quarantined":
self._log("server", f"device {guid} is quarantined; limited access for {request.path}") self._log(
"server",
f"device {guid} is quarantined; limited access for {request.path}",
context_label,
)
dpop_jkt: Optional[str] = None dpop_jkt: Optional[str] = None
dpop_proof = request.headers.get("DPoP") dpop_proof = request.headers.get("DPoP")
@@ -144,9 +168,111 @@ class DeviceAuthManager:
claims=claims, claims=claims,
dpop_jkt=dpop_jkt, dpop_jkt=dpop_jkt,
status=status_normalized, status=status_normalized,
service_mode=context_label,
) )
return ctx return ctx
def _recover_device_record(
self,
conn: sqlite3.Connection,
guid: str,
fingerprint: str,
token_version: int,
context_label: Optional[str],
) -> Optional[tuple]:
"""Attempt to recreate a missing device row for an authenticated token."""
guid = (guid or "").strip()
fingerprint = (fingerprint or "").strip()
if not guid or not fingerprint:
return None
cur = conn.cursor()
now_ts = int(time.time())
try:
now_iso = datetime.now(tz=timezone.utc).isoformat()
except Exception:
now_iso = datetime.utcnow().isoformat() # pragma: no cover
base_hostname = f"RECOVERED-{guid[:12].upper()}" if guid else "RECOVERED"
for attempt in range(6):
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
try:
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
token_version,
status,
key_added_at
)
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
""",
(
guid,
hostname,
now_ts,
now_ts,
fingerprint,
max(token_version or 1, 1),
now_iso,
),
)
except sqlite3.IntegrityError as exc:
# Hostname collision try again with a suffixed placeholder.
message = str(exc).lower()
if "hostname" in message and "unique" in message:
continue
self._log(
"server",
f"device auth failed to recover guid={guid} due to integrity error: {exc}",
context_label,
)
conn.rollback()
return None
except Exception as exc: # pragma: no cover - defensive logging
self._log(
"server",
f"device auth unexpected error recovering guid={guid}: {exc}",
context_label,
)
conn.rollback()
return None
else:
conn.commit()
break
else:
# Exhausted attempts because of hostname collisions.
self._log(
"server",
f"device auth could not recover guid={guid}; hostname collisions persisted",
context_label,
)
conn.rollback()
return None
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE guid = ?
""",
(guid,),
)
row = cur.fetchone()
if not row:
self._log(
"server",
f"device auth recovery for guid={guid} committed but row still missing",
context_label,
)
return row
def require_device_auth(manager: DeviceAuthManager): def require_device_auth(manager: DeviceAuthManager):
def decorator(func): def decorator(func):

View File

@@ -152,7 +152,10 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
expires_at TEXT NOT NULL, expires_at TEXT NOT NULL,
created_by_user_id TEXT, created_by_user_id TEXT,
used_at TEXT, used_at TEXT,
used_by_guid TEXT used_by_guid TEXT,
max_uses INTEGER NOT NULL DEFAULT 1,
use_count INTEGER NOT NULL DEFAULT 0,
last_used_at TEXT
) )
""" """
) )
@@ -163,6 +166,29 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
""" """
) )
columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")}
if "max_uses" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1
"""
)
if "use_count" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0
"""
)
if "last_used_at" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN last_used_at TEXT
"""
)
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None: def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor() cur = conn.cursor()

View File

@@ -6,7 +6,18 @@ import sqlite3
import uuid import uuid
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
import time import time
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional, Tuple
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
from flask import Blueprint, jsonify, request from flask import Blueprint, jsonify, request
@@ -20,7 +31,7 @@ def register(
app, app,
*, *,
db_conn_factory: Callable[[], sqlite3.Connection], db_conn_factory: Callable[[], sqlite3.Connection],
log: Callable[[str, str], None], log: Callable[[str, str, Optional[str]], None],
jwt_service, jwt_service,
tls_bundle_path: str, tls_bundle_path: str,
ip_rate_limiter: SlidingWindowRateLimiter, ip_rate_limiter: SlidingWindowRateLimiter,
@@ -51,12 +62,19 @@ def register(
except Exception: except Exception:
return "" return ""
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float): def _rate_limited(
key: str,
limiter: SlidingWindowRateLimiter,
limit: int,
window_s: float,
context_hint: Optional[str],
):
decision = limiter.check(key, limit, window_s) decision = limiter.check(key, limit, window_s)
if not decision.allowed: if not decision.allowed:
log( log(
"server", "server",
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}", f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
context_hint,
) )
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after}) response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
response.status_code = 429 response.status_code = 429
@@ -66,31 +84,79 @@ def register(
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]: def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
cur.execute( cur.execute(
"SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?", """
SELECT id,
code,
expires_at,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
WHERE code = ?
""",
(code_value,), (code_value,),
) )
row = cur.fetchone() row = cur.fetchone()
if not row: if not row:
return None return None
keys = ["id", "code", "expires_at", "used_at"] keys = [
"id",
"code",
"expires_at",
"used_at",
"used_by_guid",
"max_uses",
"use_count",
"last_used_at",
]
record = dict(zip(keys, row)) record = dict(zip(keys, row))
return record return record
def _install_code_valid(record: Dict[str, Any]) -> bool: def _install_code_valid(
record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor
) -> Tuple[bool, Optional[str]]:
if not record: if not record:
return False return False, None
expires_at = record.get("expires_at") expires_at = record.get("expires_at")
if not isinstance(expires_at, str): if not isinstance(expires_at, str):
return False return False, None
try: try:
expiry = datetime.fromisoformat(expires_at) expiry = datetime.fromisoformat(expires_at)
except Exception: except Exception:
return False return False, None
if expiry <= _now(): if expiry <= _now():
return False return False, None
if record.get("used_at"): try:
return False max_uses = int(record.get("max_uses") or 1)
return True except Exception:
max_uses = 1
if max_uses < 1:
max_uses = 1
try:
use_count = int(record.get("use_count") or 0)
except Exception:
use_count = 0
if use_count < max_uses:
return True, None
guid = str(record.get("used_by_guid") or "").strip()
if not guid:
return False, None
cur.execute(
"SELECT ssl_key_fingerprint FROM devices WHERE guid = ?",
(guid,),
)
row = cur.fetchone()
if not row:
return False, None
stored_fp = (row[0] or "").strip().lower()
if not stored_fp:
return False, None
if stored_fp == (fingerprint or "").strip().lower():
return True, guid
return False, None
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str: def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
base = (hostname or "").strip() or guid base = (hostname or "").strip() or guid
@@ -247,7 +313,9 @@ def register(
@blueprint.route("/api/agent/enroll/request", methods=["POST"]) @blueprint.route("/api/agent/enroll/request", methods=["POST"])
def enrollment_request(): def enrollment_request():
remote = _remote_addr() remote = _remote_addr()
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0) context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0, context_hint)
if rate_error: if rate_error:
return rate_error return rate_error
@@ -262,42 +330,43 @@ def register(
"enrollment request received " "enrollment request received "
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} " f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}", f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
context_hint,
) )
if not hostname: if not hostname:
log("server", f"enrollment rejected missing_hostname ip={remote}") log("server", f"enrollment rejected missing_hostname ip={remote}", context_hint)
return jsonify({"error": "hostname_required"}), 400 return jsonify({"error": "hostname_required"}), 400
if not enrollment_code: if not enrollment_code:
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}") log("server", f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
return jsonify({"error": "enrollment_code_required"}), 400 return jsonify({"error": "enrollment_code_required"}), 400
if not isinstance(agent_pubkey_b64, str): if not isinstance(agent_pubkey_b64, str):
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}") log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "agent_pubkey_required"}), 400 return jsonify({"error": "agent_pubkey_required"}), 400
if not isinstance(client_nonce_b64, str): if not isinstance(client_nonce_b64, str):
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}") log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400 return jsonify({"error": "client_nonce_required"}), 400
try: try:
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64) agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
except Exception: except Exception:
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}") log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400 return jsonify({"error": "invalid_agent_pubkey"}), 400
if len(agent_pubkey_der) < 10: if len(agent_pubkey_der) < 10:
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}") log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400 return jsonify({"error": "invalid_agent_pubkey"}), 400
try: try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True) client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception: except Exception:
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}") log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400 return jsonify({"error": "invalid_client_nonce"}), 400
if len(client_nonce_bytes) < 16: if len(client_nonce_bytes) < 16:
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}") log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400 return jsonify({"error": "invalid_client_nonce"}), 400
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der) fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0) rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0, context_hint)
if rate_error: if rate_error:
return rate_error return rate_error
@@ -305,7 +374,14 @@ def register(
try: try:
cur = conn.cursor() cur = conn.cursor()
install_code = _load_install_code(cur, enrollment_code) install_code = _load_install_code(cur, enrollment_code)
if not _install_code_valid(install_code): valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
if not valid_code:
log(
"server",
"enrollment request invalid_code "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
)
return jsonify({"error": "invalid_enrollment_code"}), 400 return jsonify({"error": "invalid_enrollment_code"}), 400
approval_reference: str approval_reference: str
@@ -331,6 +407,7 @@ def register(
""" """
UPDATE device_approvals UPDATE device_approvals
SET hostname_claimed = ?, SET hostname_claimed = ?,
guid = ?,
enrollment_code_id = ?, enrollment_code_id = ?,
client_nonce = ?, client_nonce = ?,
server_nonce = ?, server_nonce = ?,
@@ -340,6 +417,7 @@ def register(
""", """,
( (
hostname, hostname,
reuse_guid,
install_code["id"], install_code["id"],
client_nonce_b64, client_nonce_b64,
server_nonce_b64, server_nonce_b64,
@@ -359,11 +437,12 @@ def register(
status, client_nonce, server_nonce, agent_pubkey_der, status, client_nonce, server_nonce, agent_pubkey_der,
created_at, updated_at created_at, updated_at
) )
VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
""", """,
( (
record_id, record_id,
approval_reference, approval_reference,
reuse_guid,
hostname, hostname,
fingerprint, fingerprint,
install_code["id"], install_code["id"],
@@ -387,7 +466,11 @@ def register(
"server_certificate": _load_tls_bundle(tls_bundle_path), "server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(), "signing_key": _signing_key_b64(),
} }
log("server", f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}") log(
"server",
f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}",
context_hint,
)
return jsonify(response) return jsonify(response)
@blueprint.route("/api/agent/enroll/poll", methods=["POST"]) @blueprint.route("/api/agent/enroll/poll", methods=["POST"])
@@ -396,34 +479,36 @@ def register(
approval_reference = payload.get("approval_reference") approval_reference = payload.get("approval_reference")
client_nonce_b64 = payload.get("client_nonce") client_nonce_b64 = payload.get("client_nonce")
proof_sig_b64 = payload.get("proof_sig") proof_sig_b64 = payload.get("proof_sig")
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
log( log(
"server", "server",
"enrollment poll received " "enrollment poll received "
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}" f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
f" proof_sig_len={len(proof_sig_b64 or '')}", f" proof_sig_len={len(proof_sig_b64 or '')}",
context_hint,
) )
if not isinstance(approval_reference, str) or not approval_reference: if not isinstance(approval_reference, str) or not approval_reference:
log("server", "enrollment poll rejected missing_reference") log("server", "enrollment poll rejected missing_reference", context_hint)
return jsonify({"error": "approval_reference_required"}), 400 return jsonify({"error": "approval_reference_required"}), 400
if not isinstance(client_nonce_b64, str): if not isinstance(client_nonce_b64, str):
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}") log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400 return jsonify({"error": "client_nonce_required"}), 400
if not isinstance(proof_sig_b64, str): if not isinstance(proof_sig_b64, str):
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}") log("server", f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
return jsonify({"error": "proof_sig_required"}), 400 return jsonify({"error": "proof_sig_required"}), 400
try: try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True) client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception: except Exception:
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}") log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400 return jsonify({"error": "invalid_client_nonce"}), 400
try: try:
proof_sig = base64.b64decode(proof_sig_b64, validate=True) proof_sig = base64.b64decode(proof_sig_b64, validate=True)
except Exception: except Exception:
log("server", f"enrollment poll invalid_sig ref={approval_reference}") log("server", f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof_sig"}), 400 return jsonify({"error": "invalid_proof_sig"}), 400
conn = db_conn_factory() conn = db_conn_factory()
@@ -441,7 +526,7 @@ def register(
) )
row = cur.fetchone() row = cur.fetchone()
if not row: if not row:
log("server", f"enrollment poll unknown_reference ref={approval_reference}") log("server", f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
return jsonify({"status": "unknown"}), 404 return jsonify({"status": "unknown"}), 404
( (
@@ -460,13 +545,13 @@ def register(
) = row ) = row
if client_nonce_stored != client_nonce_b64: if client_nonce_stored != client_nonce_b64:
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}") log("server", f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
return jsonify({"error": "nonce_mismatch"}), 400 return jsonify({"error": "nonce_mismatch"}), 400
try: try:
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True) server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
except Exception: except Exception:
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}") log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "server_nonce_invalid"}), 400 return jsonify({"error": "server_nonce_invalid"}), 400
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
@@ -474,17 +559,17 @@ def register(
try: try:
public_key = serialization.load_der_public_key(agent_pubkey_der) public_key = serialization.load_der_public_key(agent_pubkey_der)
except Exception: except Exception:
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}") log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
public_key = None public_key = None
if public_key is None: if public_key is None:
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}") log("server", f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
return jsonify({"error": "agent_pubkey_invalid"}), 400 return jsonify({"error": "agent_pubkey_invalid"}), 400
try: try:
public_key.verify(proof_sig, message) public_key.verify(proof_sig, message)
except Exception: except Exception:
log("server", f"enrollment poll invalid_proof ref={approval_reference}") log("server", f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof"}), 400 return jsonify({"error": "invalid_proof"}), 400
if status == "pending": if status == "pending":
@@ -492,24 +577,28 @@ def register(
"server", "server",
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}" f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
f" fingerprint={fingerprint[:12]}", f" fingerprint={fingerprint[:12]}",
context_hint,
) )
return jsonify({"status": "pending", "poll_after_ms": 5000}) return jsonify({"status": "pending", "poll_after_ms": 5000})
if status == "denied": if status == "denied":
log( log(
"server", "server",
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}", f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
context_hint,
) )
return jsonify({"status": "denied", "reason": "operator_denied"}) return jsonify({"status": "denied", "reason": "operator_denied"})
if status == "expired": if status == "expired":
log( log(
"server", "server",
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}", f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
context_hint,
) )
return jsonify({"status": "expired"}) return jsonify({"status": "expired"})
if status == "completed": if status == "completed":
log( log(
"server", "server",
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}", f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
context_hint,
) )
return jsonify({"status": "approved", "detail": "finalized"}) return jsonify({"status": "approved", "detail": "finalized"})
@@ -517,6 +606,7 @@ def register(
log( log(
"server", "server",
f"enrollment poll unexpected_status={status} ref={approval_reference}", f"enrollment poll unexpected_status={status} ref={approval_reference}",
context_hint,
) )
return jsonify({"status": status or "unknown"}), 400 return jsonify({"status": status or "unknown"}), 400
@@ -525,6 +615,7 @@ def register(
log( log(
"server", "server",
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}", f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
context_hint,
) )
return jsonify({"error": "proof_replayed"}), 409 return jsonify({"error": "proof_replayed"}), 409
@@ -537,14 +628,40 @@ def register(
# Mark install code used # Mark install code used
if enrollment_code_id: if enrollment_code_id:
cur.execute(
"SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?",
(enrollment_code_id,),
)
usage_row = cur.fetchone()
try:
prior_count = int(usage_row[0]) if usage_row else 0
except Exception:
prior_count = 0
try:
allowed_uses = int(usage_row[1]) if usage_row else 1
except Exception:
allowed_uses = 1
if allowed_uses < 1:
allowed_uses = 1
new_count = prior_count + 1
consumed = new_count >= allowed_uses
cur.execute( cur.execute(
""" """
UPDATE enrollment_install_codes UPDATE enrollment_install_codes
SET used_at = ?, used_by_guid = ? SET use_count = ?,
used_by_guid = ?,
last_used_at = ?,
used_at = CASE WHEN ? THEN ? ELSE used_at END
WHERE id = ? WHERE id = ?
AND used_at IS NULL
""", """,
(now_iso, effective_guid, enrollment_code_id), (
new_count,
effective_guid,
now_iso,
1 if consumed else 0,
now_iso,
enrollment_code_id,
),
) )
# Update approval record with final state # Update approval record with final state
@@ -573,6 +690,7 @@ def register(
log( log(
"server", "server",
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}", f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
context_hint,
) )
return jsonify( return jsonify(
{ {

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Callable from typing import Callable, Optional
import eventlet import eventlet
from flask_socketio import SocketIO from flask_socketio import SocketIO
@@ -11,7 +11,7 @@ def start_prune_job(
socketio: SocketIO, socketio: SocketIO,
*, *,
db_conn_factory: Callable[[], any], db_conn_factory: Callable[[], any],
log: Callable[[str, str], None], log: Callable[[str, str, Optional[str]], None],
) -> None: ) -> None:
def _job_loop(): def _job_loop():
while True: while True:
@@ -24,7 +24,7 @@ def start_prune_job(
socketio.start_background_task(_job_loop) socketio.start_background_task(_job_loop)
def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None]) -> None: def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str, Optional[str]], None]) -> None:
now = datetime.now(tz=timezone.utc) now = datetime.now(tz=timezone.utc)
now_iso = now.isoformat() now_iso = now.isoformat()
stale_before = (now - timedelta(hours=24)).isoformat() stale_before = (now - timedelta(hours=24)).isoformat()
@@ -34,7 +34,7 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
cur.execute( cur.execute(
""" """
DELETE FROM enrollment_install_codes DELETE FROM enrollment_install_codes
WHERE used_at IS NULL WHERE use_count = 0
AND expires_at < ? AND expires_at < ?
""", """,
(now_iso,), (now_iso,),
@@ -52,7 +52,10 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
SELECT 1 SELECT 1
FROM enrollment_install_codes c FROM enrollment_install_codes c
WHERE c.id = device_approvals.enrollment_code_id WHERE c.id = device_approvals.enrollment_code_id
AND c.expires_at < ? AND (
c.expires_at < ?
OR c.use_count >= c.max_uses
)
) )
OR created_at < ? OR created_at < ?
) )

View File

@@ -93,7 +93,20 @@ def register(
except DPoPVerificationError: except DPoPVerificationError:
return jsonify({"error": "dpop_invalid"}), 400 return jsonify({"error": "dpop_invalid"}), 400
elif stored_jkt: elif stored_jkt:
return jsonify({"error": "dpop_required"}), 400 # The agent does not yet emit DPoP proofs; allow recovery by clearing
# the stored binding so refreshes can succeed. This preserves
# backward compatibility while the client gains full DPoP support.
try:
app.logger.warning(
"Clearing stored DPoP binding for guid=%s due to missing proof",
guid,
)
except Exception:
pass
cur.execute(
"UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?",
(record_id,),
)
new_access_token = jwt_service.issue_access_token( new_access_token = jwt_service.issue_access_token(
guid, guid,

View File

@@ -65,7 +65,9 @@ const formatDateTime = (value) => {
const determineStatus = (record) => { const determineStatus = (record) => {
if (!record) return "expired"; if (!record) return "expired";
if (record.used_at) return "used"; const maxUses = Number.isFinite(record?.max_uses) ? record.max_uses : 1;
const useCount = Number.isFinite(record?.use_count) ? record.use_count : 0;
if (useCount >= Math.max(1, maxUses || 1)) return "used";
if (!record.expires_at) return "expired"; if (!record.expires_at) return "expired";
const expires = new Date(record.expires_at); const expires = new Date(record.expires_at);
if (Number.isNaN(expires.getTime())) return "expired"; if (Number.isNaN(expires.getTime())) return "expired";
@@ -80,6 +82,7 @@ function EnrollmentCodes() {
const [statusFilter, setStatusFilter] = useState("all"); const [statusFilter, setStatusFilter] = useState("all");
const [ttlHours, setTtlHours] = useState(6); const [ttlHours, setTtlHours] = useState(6);
const [generating, setGenerating] = useState(false); const [generating, setGenerating] = useState(false);
const [maxUses, setMaxUses] = useState(2);
const filteredCodes = useMemo(() => { const filteredCodes = useMemo(() => {
if (statusFilter === "all") return codes; if (statusFilter === "all") return codes;
@@ -119,7 +122,7 @@ function EnrollmentCodes() {
method: "POST", method: "POST",
credentials: "include", credentials: "include",
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
body: JSON.stringify({ ttl_hours: ttlHours }), body: JSON.stringify({ ttl_hours: ttlHours, max_uses: maxUses }),
}); });
if (!resp.ok) { if (!resp.ok) {
const body = await resp.json().catch(() => ({})); const body = await resp.json().catch(() => ({}));
@@ -133,7 +136,7 @@ function EnrollmentCodes() {
} finally { } finally {
setGenerating(false); setGenerating(false);
} }
}, [fetchCodes, ttlHours]); }, [fetchCodes, ttlHours, maxUses]);
const handleDelete = useCallback( const handleDelete = useCallback(
async (id) => { async (id) => {
@@ -216,7 +219,7 @@ function EnrollmentCodes() {
labelId="ttl-select-label" labelId="ttl-select-label"
label="Duration" label="Duration"
value={ttlHours} value={ttlHours}
onChange={(event) => setTtlHours(event.target.value)} onChange={(event) => setTtlHours(Number(event.target.value))}
> >
{TTL_PRESETS.map((preset) => ( {TTL_PRESETS.map((preset) => (
<MenuItem key={preset.value} value={preset.value}> <MenuItem key={preset.value} value={preset.value}>
@@ -226,6 +229,22 @@ function EnrollmentCodes() {
</Select> </Select>
</FormControl> </FormControl>
<FormControl size="small" sx={{ minWidth: 160 }}>
<InputLabel id="uses-select-label">Allowed Uses</InputLabel>
<Select
labelId="uses-select-label"
label="Allowed Uses"
value={maxUses}
onChange={(event) => setMaxUses(Number(event.target.value))}
>
{[1, 2, 3, 5].map((uses) => (
<MenuItem key={uses} value={uses}>
{uses === 1 ? "Single use" : `${uses} uses`}
</MenuItem>
))}
</Select>
</FormControl>
<Button <Button
variant="contained" variant="contained"
color="primary" color="primary"
@@ -270,7 +289,9 @@ function EnrollmentCodes() {
<TableCell>Installer Code</TableCell> <TableCell>Installer Code</TableCell>
<TableCell>Expires At</TableCell> <TableCell>Expires At</TableCell>
<TableCell>Created By</TableCell> <TableCell>Created By</TableCell>
<TableCell>Used At</TableCell> <TableCell>Usage</TableCell>
<TableCell>Last Used</TableCell>
<TableCell>Consumed At</TableCell>
<TableCell>Used By GUID</TableCell> <TableCell>Used By GUID</TableCell>
<TableCell align="right">Actions</TableCell> <TableCell align="right">Actions</TableCell>
</TableRow> </TableRow>
@@ -296,13 +317,17 @@ function EnrollmentCodes() {
) : ( ) : (
filteredCodes.map((record) => { filteredCodes.map((record) => {
const status = determineStatus(record); const status = determineStatus(record);
const disableDelete = status !== "active"; const maxAllowed = Math.max(1, Number.isFinite(record?.max_uses) ? record.max_uses : 1);
const usageCount = Math.max(0, Number.isFinite(record?.use_count) ? record.use_count : 0);
const disableDelete = usageCount !== 0;
return ( return (
<TableRow hover key={record.id}> <TableRow hover key={record.id}>
<TableCell>{renderStatusChip(record)}</TableCell> <TableCell>{renderStatusChip(record)}</TableCell>
<TableCell sx={{ fontFamily: "monospace" }}>{maskCode(record.code)}</TableCell> <TableCell sx={{ fontFamily: "monospace" }}>{maskCode(record.code)}</TableCell>
<TableCell>{formatDateTime(record.expires_at)}</TableCell> <TableCell>{formatDateTime(record.expires_at)}</TableCell>
<TableCell>{record.created_by_user_id || "—"}</TableCell> <TableCell>{record.created_by_user_id || "—"}</TableCell>
<TableCell sx={{ fontFamily: "monospace" }}>{`${usageCount} / ${maxAllowed}`}</TableCell>
<TableCell>{formatDateTime(record.last_used_at)}</TableCell>
<TableCell>{formatDateTime(record.used_at)}</TableCell> <TableCell>{formatDateTime(record.used_at)}</TableCell>
<TableCell sx={{ fontFamily: "monospace" }}> <TableCell sx={{ fontFamily: "monospace" }}>
{record.used_by_guid || "—"} {record.used_by_guid || "—"}

View File

@@ -152,19 +152,98 @@ def _rotate_daily(path: str):
pass pass
def _write_service_log(service: str, msg: str): _SERVER_SCOPE_PATTERN = re.compile(r"\\b(?:scope|context|agent_context)=([A-Za-z0-9_-]+)", re.IGNORECASE)
_SERVER_AGENT_ID_PATTERN = re.compile(r"\\bagent_id=([^\s,]+)", re.IGNORECASE)
_AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_server_scope(raw: Optional[str]) -> Optional[str]:
if not raw:
return None
value = "".join(ch for ch in str(raw) if ch.isalnum() or ch in ("_", "-"))
if not value:
return None
return value.upper()
def _scope_from_agent_id(agent_id: Optional[str]) -> Optional[str]:
candidate = _canonical_server_scope(agent_id)
if not candidate:
return None
if candidate.endswith("_SYSTEM"):
return "SYSTEM"
if candidate.endswith("_CURRENTUSER"):
return "CURRENTUSER"
return candidate
def _infer_server_scope(message: str, explicit: Optional[str]) -> Optional[str]:
scope = _canonical_server_scope(explicit)
if scope:
return scope
match = _SERVER_SCOPE_PATTERN.search(message or "")
if match:
scope = _canonical_server_scope(match.group(1))
if scope:
return scope
agent_match = _SERVER_AGENT_ID_PATTERN.search(message or "")
if agent_match:
scope = _scope_from_agent_id(agent_match.group(1))
if scope:
return scope
return None
def _write_service_log(service: str, msg: str, scope: Optional[str] = None, *, level: str = "INFO"):
try: try:
base = _server_logs_root() base = _server_logs_root()
os.makedirs(base, exist_ok=True) os.makedirs(base, exist_ok=True)
path = os.path.join(base, f"{service}.log") path = os.path.join(base, f"{service}.log")
_rotate_daily(path) _rotate_daily(path)
ts = time.strftime('%Y-%m-%d %H:%M:%S') ts = time.strftime('%Y-%m-%d %H:%M:%S')
resolved_scope = _infer_server_scope(msg, scope)
prefix_parts = [f"[{level.upper()}]"]
if resolved_scope:
prefix_parts.append(f"[CONTEXT-{resolved_scope}]")
prefix = "".join(prefix_parts)
with open(path, 'a', encoding='utf-8') as fh: with open(path, 'a', encoding='utf-8') as fh:
fh.write(f'[{ts}] {msg}\n') fh.write(f'[{ts}] {prefix} {msg}\n')
except Exception: except Exception:
pass pass
def _mask_server_value(value: str, *, prefix: int = 4, suffix: int = 4) -> str:
try:
if not value:
return ''
stripped = value.strip()
if len(stripped) <= prefix + suffix:
return '*' * len(stripped)
return f"{stripped[:prefix]}***{stripped[-suffix:]}"
except Exception:
return '***'
def _summarize_socket_headers(headers) -> str:
try:
rendered = []
for key, value in headers.items():
lowered = key.lower()
display = value
if lowered == 'authorization':
if isinstance(value, str) and value.lower().startswith('bearer '):
token = value.split(' ', 1)[1]
display = f"Bearer {_mask_server_value(token)}"
else:
display = _mask_server_value(str(value))
elif lowered == 'cookie':
display = '<redacted>'
rendered.append(f"{key}={display}")
return ", ".join(rendered)
except Exception:
return '<header-summary-unavailable>'
# ============================================================================= # =============================================================================
# Section: Repository Hash Tracking # Section: Repository Hash Tracking
# ============================================================================= # =============================================================================
@@ -7968,6 +8047,55 @@ def screenshot_node_viewer(agent_id, node_id):
# ============================================================================= # =============================================================================
# Realtime channels for screenshots, macros, windows, and Ansible control. # Realtime channels for screenshots, macros, windows, and Ansible control.
@socketio.on('connect')
def socket_connect():
try:
sid = getattr(request, 'sid', '<unknown>')
except Exception:
sid = '<unknown>'
try:
remote_addr = request.remote_addr
except Exception:
remote_addr = None
try:
scope = _canonical_server_scope(request.headers.get(_AGENT_CONTEXT_HEADER))
except Exception:
scope = None
try:
query_pairs = [f"{k}={v}" for k, v in request.args.items()] # type: ignore[attr-defined]
query_summary = "&".join(query_pairs) if query_pairs else "<none>"
except Exception:
query_summary = "<unavailable>"
header_summary = _summarize_socket_headers(getattr(request, 'headers', {}))
transport = request.args.get('transport') if hasattr(request, 'args') else None # type: ignore[attr-defined]
_write_service_log(
'server',
f"socket.io connect sid={sid} ip={remote_addr} transport={transport!r} query={query_summary} headers={header_summary}",
scope=scope,
)
@socketio.on('disconnect')
def socket_disconnect():
try:
sid = getattr(request, 'sid', '<unknown>')
except Exception:
sid = '<unknown>'
try:
remote_addr = request.remote_addr
except Exception:
remote_addr = None
try:
scope = _canonical_server_scope(request.headers.get(_AGENT_CONTEXT_HEADER))
except Exception:
scope = None
_write_service_log(
'server',
f"socket.io disconnect sid={sid} ip={remote_addr}",
scope=scope,
)
@socketio.on("agent_screenshot_task") @socketio.on("agent_screenshot_task")
def receive_screenshot_task(data): def receive_screenshot_task(data):
agent_id = data.get("agent_id") agent_id = data.get("agent_id")
@@ -7997,6 +8125,19 @@ def connect_agent(data):
if not agent_id: if not agent_id:
return return
print(f"Agent connected: {agent_id}") print(f"Agent connected: {agent_id}")
try:
scope = _normalize_service_mode((data or {}).get("service_mode"), agent_id)
except Exception:
scope = None
try:
sid = getattr(request, 'sid', '<unknown>')
except Exception:
sid = '<unknown>'
_write_service_log(
'server',
f"socket.io connect_agent agent_id={agent_id} sid={sid} service_mode={scope}",
scope=scope,
)
# Join per-agent room so we can address this connection specifically # Join per-agent room so we can address this connection specifically
try: try:
@@ -8004,7 +8145,7 @@ def connect_agent(data):
except Exception: except Exception:
pass pass
service_mode = _normalize_service_mode((data or {}).get("service_mode"), agent_id) service_mode = scope if scope else _normalize_service_mode((data or {}).get("service_mode"), agent_id)
rec = registered_agents.setdefault(agent_id, {}) rec = registered_agents.setdefault(agent_id, {})
rec["agent_id"] = agent_id rec["agent_id"] = agent_id
rec["hostname"] = rec.get("hostname", "unknown") rec["hostname"] = rec.get("hostname", "unknown")

View File

@@ -0,0 +1,58 @@
import json
import sys
import pytest
@pytest.fixture
def agent_module(tmp_path, monkeypatch):
settings_dir = tmp_path / "Agent" / "Borealis" / "Settings"
settings_dir.mkdir(parents=True)
system_config = settings_dir / "agent_settings_SYSTEM.json"
system_config.write_text(json.dumps({
"config_file_watcher_interval": 2,
"agent_id": "",
"regions": {},
"installer_code": "",
}, indent=2))
current_config = settings_dir / "agent_settings_CURRENTUSER.json"
current_config.write_text(json.dumps({
"config_file_watcher_interval": 2,
"agent_id": "",
"regions": {},
"installer_code": "",
}, indent=2))
monkeypatch.setenv("BOREALIS_ROOT", str(tmp_path))
monkeypatch.setenv("BOREALIS_AGENT_MODE", "system")
monkeypatch.setenv("BOREALIS_AGENT_CONFIG", "")
monkeypatch.setitem(sys.modules, "PyQt5", None)
monkeypatch.setitem(sys.modules, "qasync", None)
monkeypatch.setattr(sys, "argv", ["agent.py", "--system-service", "--config", "SYSTEM"], raising=False)
agent = pytest.importorskip(
"Data.Agent.agent", reason="agent module requires optional dependencies"
)
return agent, system_config
def test_shared_installer_code_cache_allows_system_reuse(agent_module, tmp_path):
agent, system_config = agent_module
client = agent.AgentHttpClient()
shared_code = "SHARED-CODE-1234"
client.key_store.cache_installer_code(shared_code, consumer="CURRENTUSER")
# System agent should discover the cached code even though its config is empty.
resolved = client._resolve_installer_code()
assert resolved == shared_code
# Config should now persist the adopted code to avoid repeated lookups.
data = json.loads(system_config.read_text())
assert data.get("installer_code") == shared_code
# After enrollment completes, the cache should be cleared for future runs.
client._consume_installer_code()
assert client.key_store.load_cached_installer_code() is None
data = json.loads(system_config.read_text())
assert data.get("installer_code") == ""

View File

@@ -0,0 +1,213 @@
import base64
import os
import pathlib
import sqlite3
import sys
import uuid
import pytest
try: # pragma: no cover - optional dependency
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
_CRYPTO_IMPORT_ERROR: Exception | None = None
except Exception as exc: # pragma: no cover - dependency unavailable
serialization = None # type: ignore
ed25519 = None # type: ignore
_CRYPTO_IMPORT_ERROR = exc
try: # pragma: no cover - optional dependency
from flask import Flask
_FLASK_IMPORT_ERROR: Exception | None = None
except Exception as exc: # pragma: no cover - dependency unavailable
Flask = None # type: ignore
_FLASK_IMPORT_ERROR = exc
ROOT = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from Data.Server.Modules import db_migrations
from Data.Server.Modules.auth.rate_limit import SlidingWindowRateLimiter
from Data.Server.Modules.enrollment.nonce_store import NonceCache
if Flask is not None: # pragma: no cover - dependency unavailable
from Data.Server.Modules.enrollment import routes as enrollment_routes
else: # pragma: no cover - dependency unavailable
enrollment_routes = None # type: ignore
class _DummyJWTService:
def issue_access_token(self, guid: str, fingerprint: str, token_version: int, expires_in: int = 900, extra_claims=None):
return f"token-{guid}"
class _DummySigner:
def public_base64_spki(self) -> str:
return ""
def _make_app(db_path: str, tls_path: str):
if Flask is None or enrollment_routes is None: # pragma: no cover - dependency unavailable
pytest.skip(f"flask unavailable: {_FLASK_IMPORT_ERROR}")
app = Flask(__name__)
def _factory():
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
return conn
enrollment_routes.register(
app,
db_conn_factory=_factory,
log=lambda channel, message: None,
jwt_service=_DummyJWTService(),
tls_bundle_path=tls_path,
ip_rate_limiter=SlidingWindowRateLimiter(),
fp_rate_limiter=SlidingWindowRateLimiter(),
nonce_cache=NonceCache(ttl_seconds=30.0),
script_signer=_DummySigner(),
)
return app, _factory
def _create_install_code(conn: sqlite3.Connection, code: str, *, max_uses: int = 2):
cur = conn.cursor()
record_id = str(uuid.uuid4())
cur.execute(
"""
INSERT INTO enrollment_install_codes (
id, code, expires_at, created_by_user_id, max_uses, use_count
)
VALUES (?, ?, datetime('now', '+6 hours'), 'test-user', ?, 0)
""",
(record_id, code, max_uses),
)
conn.commit()
return record_id
def _perform_enrollment_cycle(app, factory, code: str, private_key):
client = app.test_client()
public_der = private_key.public_key().public_bytes(
serialization.Encoding.DER,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_b64 = base64.b64encode(public_der).decode("ascii")
client_nonce = os.urandom(32)
payload = {
"hostname": "unit-test-host",
"enrollment_code": code,
"agent_pubkey": public_b64,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
}
request_resp = client.post("/api/agent/enroll/request", json=payload)
assert request_resp.status_code == 200
request_data = request_resp.get_json()
approval_reference = request_data["approval_reference"]
with factory() as conn:
cur = conn.cursor()
cur.execute(
"""
UPDATE device_approvals
SET status = 'approved',
approved_by_user_id = 'tester'
WHERE approval_reference = ?
""",
(approval_reference,),
)
cur.execute(
"""
SELECT server_nonce, client_nonce
FROM device_approvals
WHERE approval_reference = ?
""",
(approval_reference,),
)
row = cur.fetchone()
assert row is not None
server_nonce_b64 = row["server_nonce"]
server_nonce = base64.b64decode(server_nonce_b64)
proof_message = server_nonce + approval_reference.encode("utf-8") + client_nonce
proof_sig = private_key.sign(proof_message)
poll_payload = {
"approval_reference": approval_reference,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
"proof_sig": base64.b64encode(proof_sig).decode("ascii"),
}
poll_resp = client.post("/api/agent/enroll/poll", json=poll_payload)
assert poll_resp.status_code == 200
return poll_resp.get_json()
@pytest.mark.parametrize("max_uses", [2])
@pytest.mark.skipif(ed25519 is None, reason=f"cryptography unavailable: {_CRYPTO_IMPORT_ERROR}")
@pytest.mark.skipif(Flask is None, reason=f"flask unavailable: {_FLASK_IMPORT_ERROR}")
def test_install_code_allows_multiple_and_reuse(tmp_path, max_uses):
db_path = tmp_path / "test.db"
conn = sqlite3.connect(db_path)
db_migrations.apply_all(conn)
_create_install_code(conn, "TEST-CODE-1234", max_uses=max_uses)
conn.close()
tls_path = tmp_path / "tls.pem"
tls_path.write_text("TEST CERT")
app, factory = _make_app(str(db_path), str(tls_path))
private_key = ed25519.Ed25519PrivateKey.generate()
# First enrollment consumes one use but keeps the code active.
first = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
assert first["status"] == "approved"
with factory() as conn:
cur = conn.cursor()
cur.execute(
"SELECT use_count, max_uses, used_at, last_used_at FROM enrollment_install_codes WHERE code = ?",
("TEST-CODE-1234",),
)
row = cur.fetchone()
assert row is not None
assert row["use_count"] == 1
assert row["max_uses"] == max_uses
assert row["used_at"] is None
assert row["last_used_at"] is not None
# Second enrollment hits the configured max uses.
second = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
assert second["status"] == "approved"
with factory() as conn:
cur = conn.cursor()
cur.execute(
"SELECT use_count, used_at, last_used_at, used_by_guid FROM enrollment_install_codes WHERE code = ?",
("TEST-CODE-1234",),
)
row = cur.fetchone()
assert row is not None
assert row["use_count"] == max_uses
assert row["used_at"] is not None
assert row["last_used_at"] is not None
consumed_guid = row["used_by_guid"]
assert consumed_guid
# Additional enrollments from the same identity reuse the stored GUID even after consumption.
third = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
assert third["status"] == "approved"
with factory() as conn:
cur = conn.cursor()
cur.execute(
"SELECT use_count, used_at, last_used_at, used_by_guid FROM enrollment_install_codes WHERE code = ?",
("TEST-CODE-1234",),
)
row = cur.fetchone()
assert row is not None
assert row["use_count"] == max_uses + 1
assert row["used_by_guid"] == consumed_guid
assert row["used_at"] is not None
assert row["last_used_at"] is not None
cur.execute("SELECT COUNT(*) FROM devices WHERE guid = ?", (consumed_guid,))
assert cur.fetchone()[0] == 1