mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 19:21:58 -06:00
Merge pull request #128 from bunny-lab-io:codex/implement-secure-agent-enrollment-features
Handle missing devices without re-enrollment loops
This commit is contained in:
@@ -23,7 +23,8 @@ import ssl
|
|||||||
import threading
|
import threading
|
||||||
import contextlib
|
import contextlib
|
||||||
import errno
|
import errno
|
||||||
from typing import Any, Dict, Optional, List, Callable
|
import re
|
||||||
|
from typing import Any, Dict, Optional, List, Callable, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
try:
|
try:
|
||||||
@@ -66,21 +67,120 @@ def _rotate_daily(path: str):
|
|||||||
|
|
||||||
|
|
||||||
# Early bootstrap logging (goes to agent.log)
|
# Early bootstrap logging (goes to agent.log)
|
||||||
def _bootstrap_log(msg: str):
|
_AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||||
|
_AGENT_SCOPE_PATTERN = re.compile(r"\\bscope=([A-Za-z0-9_-]+)", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def _canonical_scope_value(raw: Optional[str]) -> Optional[str]:
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
value = "".join(ch for ch in str(raw) if ch.isalnum() or ch in ("_", "-"))
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
return value.upper()
|
||||||
|
|
||||||
|
|
||||||
|
def _agent_context_default() -> Optional[str]:
|
||||||
|
suffix = globals().get("CONFIG_SUFFIX_CANONICAL")
|
||||||
|
context = _canonical_scope_value(suffix)
|
||||||
|
if context:
|
||||||
|
return context
|
||||||
|
service = globals().get("SERVICE_MODE_CANONICAL")
|
||||||
|
context = _canonical_scope_value(service)
|
||||||
|
if context:
|
||||||
|
return context
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_agent_scope(message: str, provided_scope: Optional[str] = None) -> Optional[str]:
|
||||||
|
scope = _canonical_scope_value(provided_scope)
|
||||||
|
if scope:
|
||||||
|
return scope
|
||||||
|
match = _AGENT_SCOPE_PATTERN.search(message or "")
|
||||||
|
if match:
|
||||||
|
scope = _canonical_scope_value(match.group(1))
|
||||||
|
if scope:
|
||||||
|
return scope
|
||||||
|
return _agent_context_default()
|
||||||
|
|
||||||
|
|
||||||
|
def _format_agent_log_message(message: str, fname: str, scope: Optional[str] = None) -> str:
|
||||||
|
context = _infer_agent_scope(message, scope)
|
||||||
|
if fname == "agent.error.log":
|
||||||
|
prefix = "[ERROR]"
|
||||||
|
if context:
|
||||||
|
prefix = f"{prefix}[CONTEXT-{context}]"
|
||||||
|
return f"{prefix} {message}"
|
||||||
|
if context:
|
||||||
|
return f"[CONTEXT-{context}] {message}"
|
||||||
|
return f"[INFO] {message}"
|
||||||
|
|
||||||
|
|
||||||
|
def _bootstrap_log(msg: str, *, scope: Optional[str] = None):
|
||||||
try:
|
try:
|
||||||
base = _agent_logs_root()
|
base = _agent_logs_root()
|
||||||
os.makedirs(base, exist_ok=True)
|
os.makedirs(base, exist_ok=True)
|
||||||
path = os.path.join(base, 'agent.log')
|
path = os.path.join(base, 'agent.log')
|
||||||
_rotate_daily(path)
|
_rotate_daily(path)
|
||||||
ts = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
ts = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
line = _format_agent_log_message(msg, 'agent.log', scope)
|
||||||
with open(path, 'a', encoding='utf-8') as fh:
|
with open(path, 'a', encoding='utf-8') as fh:
|
||||||
fh.write(f'[{ts}] {msg}\n')
|
fh.write(f'[{ts}] {line}\n')
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _describe_exception(exc: BaseException) -> str:
|
||||||
|
try:
|
||||||
|
primary = f"{exc.__class__.__name__}: {exc}"
|
||||||
|
except Exception:
|
||||||
|
primary = repr(exc)
|
||||||
|
parts = [primary]
|
||||||
|
try:
|
||||||
|
cause = getattr(exc, "__cause__", None)
|
||||||
|
if cause and cause is not exc:
|
||||||
|
parts.append(f"cause={cause.__class__.__name__}: {cause}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
context = getattr(exc, "__context__", None)
|
||||||
|
if context and context is not exc and context is not getattr(exc, "__cause__", None):
|
||||||
|
parts.append(f"context={context.__class__.__name__}: {context}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
args = getattr(exc, "args", None)
|
||||||
|
if isinstance(args, tuple) and len(args) > 1:
|
||||||
|
parts.append(f"args={args!r}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
details = getattr(exc, "__dict__", None)
|
||||||
|
if isinstance(details, dict):
|
||||||
|
# Capture noteworthy nested attributes such as os_error/errno to help diagnose
|
||||||
|
# connection failures that collapse into generic ConnectionError wrappers.
|
||||||
|
for key in ("os_error", "errno", "code", "status"):
|
||||||
|
if key in details and details[key]:
|
||||||
|
parts.append(f"{key}={details[key]!r}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return "; ".join(part for part in parts if part)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_exception_trace(prefix: str) -> None:
|
||||||
|
try:
|
||||||
|
tb = traceback.format_exc()
|
||||||
|
if not tb:
|
||||||
|
return
|
||||||
|
for line in tb.rstrip().splitlines():
|
||||||
|
_log_agent(f"{prefix} trace: {line}", fname="agent.error.log")
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Headless/service mode flag (skip Qt and interactive UI)
|
# Headless/service mode flag (skip Qt and interactive UI)
|
||||||
SYSTEM_SERVICE_MODE = ('--system-service' in sys.argv) or (os.environ.get('BOREALIS_AGENT_MODE') == 'system')
|
SYSTEM_SERVICE_MODE = ('--system-service' in sys.argv) or (os.environ.get('BOREALIS_AGENT_MODE') == 'system')
|
||||||
SERVICE_MODE = 'system' if SYSTEM_SERVICE_MODE else 'currentuser'
|
SERVICE_MODE = 'system' if SYSTEM_SERVICE_MODE else 'currentuser'
|
||||||
|
SERVICE_MODE_CANONICAL = SERVICE_MODE.upper()
|
||||||
_bootstrap_log(f'agent.py loaded; SYSTEM_SERVICE_MODE={SYSTEM_SERVICE_MODE}; argv={sys.argv!r}')
|
_bootstrap_log(f'agent.py loaded; SYSTEM_SERVICE_MODE={SYSTEM_SERVICE_MODE}; argv={sys.argv!r}')
|
||||||
def _argv_get(flag: str, default: str = None):
|
def _argv_get(flag: str, default: str = None):
|
||||||
try:
|
try:
|
||||||
@@ -359,15 +459,16 @@ def _find_project_root():
|
|||||||
return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
# Simple file logger under Logs/Agent
|
# Simple file logger under Logs/Agent
|
||||||
def _log_agent(message: str, fname: str = 'agent.log'):
|
def _log_agent(message: str, fname: str = 'agent.log', *, scope: Optional[str] = None):
|
||||||
try:
|
try:
|
||||||
log_dir = _agent_logs_root()
|
log_dir = _agent_logs_root()
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
ts = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
ts = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
path = os.path.join(log_dir, fname)
|
path = os.path.join(log_dir, fname)
|
||||||
_rotate_daily(path)
|
_rotate_daily(path)
|
||||||
|
line = _format_agent_log_message(message, fname, scope)
|
||||||
with open(path, 'a', encoding='utf-8') as fh:
|
with open(path, 'a', encoding='utf-8') as fh:
|
||||||
fh.write(f'[{ts}] {message}\n')
|
fh.write(f'[{ts}] {line}\n')
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -384,6 +485,31 @@ def _mask_sensitive(value: str, *, prefix: int = 4, suffix: int = 4) -> str:
|
|||||||
return '***'
|
return '***'
|
||||||
|
|
||||||
|
|
||||||
|
def _format_debug_pairs(pairs: Dict[str, Any]) -> str:
|
||||||
|
try:
|
||||||
|
parts = []
|
||||||
|
for key, value in pairs.items():
|
||||||
|
parts.append(f"{key}={value!r}")
|
||||||
|
return ", ".join(parts)
|
||||||
|
except Exception:
|
||||||
|
return repr(pairs)
|
||||||
|
|
||||||
|
|
||||||
|
def _summarize_headers(headers: Dict[str, str]) -> str:
|
||||||
|
try:
|
||||||
|
rendered: List[str] = []
|
||||||
|
for key, value in headers.items():
|
||||||
|
lowered = key.lower()
|
||||||
|
display = value
|
||||||
|
if lowered == 'authorization':
|
||||||
|
token = value.split()[-1] if value and ' ' in value else value
|
||||||
|
display = f"Bearer {_mask_sensitive(token)}"
|
||||||
|
rendered.append(f"{key}={display}")
|
||||||
|
return ", ".join(rendered)
|
||||||
|
except Exception:
|
||||||
|
return '<unavailable>'
|
||||||
|
|
||||||
|
|
||||||
def _decode_base64_text(value):
|
def _decode_base64_text(value):
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
return None
|
return None
|
||||||
@@ -571,6 +697,57 @@ DEFAULT_CONFIG = {
|
|||||||
"installer_code": ""
|
"installer_code": ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_installer_code_from_file(path: str) -> str:
|
||||||
|
try:
|
||||||
|
with open(path, "r", encoding="utf-8") as fh:
|
||||||
|
data = json.load(fh)
|
||||||
|
except Exception:
|
||||||
|
return ""
|
||||||
|
value = data.get("installer_code") if isinstance(data, dict) else ""
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.strip()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_installer_code(current_path: str) -> str:
|
||||||
|
settings_dir = os.path.dirname(current_path)
|
||||||
|
candidates: List[str] = []
|
||||||
|
suffix = CONFIG_SUFFIX_CANONICAL
|
||||||
|
sibling_map = {
|
||||||
|
"SYSTEM": "agent_settings_CURRENTUSER.json",
|
||||||
|
"CURRENTUSER": "agent_settings_SYSTEM.json",
|
||||||
|
}
|
||||||
|
sibling_name = sibling_map.get(suffix or "")
|
||||||
|
if sibling_name:
|
||||||
|
candidates.append(os.path.join(settings_dir, sibling_name))
|
||||||
|
# Prefer the shared/base config next
|
||||||
|
candidates.append(os.path.join(settings_dir, "agent_settings.json"))
|
||||||
|
# Legacy location fallback
|
||||||
|
try:
|
||||||
|
project_root = _find_project_root()
|
||||||
|
legacy_dir = os.path.join(project_root, "Agent", "Settings")
|
||||||
|
if sibling_name:
|
||||||
|
candidates.append(os.path.join(legacy_dir, sibling_name))
|
||||||
|
candidates.append(os.path.join(legacy_dir, "agent_settings.json"))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
current_abspath = os.path.abspath(current_path)
|
||||||
|
for candidate in candidates:
|
||||||
|
if not candidate:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
candidate_path = os.path.abspath(candidate)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if candidate_path == current_abspath or not os.path.isfile(candidate_path):
|
||||||
|
continue
|
||||||
|
code = _load_installer_code_from_file(candidate_path)
|
||||||
|
if code:
|
||||||
|
return code
|
||||||
|
return ""
|
||||||
|
|
||||||
class ConfigManager:
|
class ConfigManager:
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
self.path = path
|
self.path = path
|
||||||
@@ -631,6 +808,9 @@ class AgentHttpClient:
|
|||||||
self.key_store = _key_store()
|
self.key_store = _key_store()
|
||||||
self.identity = IDENTITY
|
self.identity = IDENTITY
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
|
context_label = _agent_context_default()
|
||||||
|
if context_label:
|
||||||
|
self.session.headers.setdefault(_AGENT_CONTEXT_HEADER, context_label)
|
||||||
self.base_url: Optional[str] = None
|
self.base_url: Optional[str] = None
|
||||||
self.guid: Optional[str] = None
|
self.guid: Optional[str] = None
|
||||||
self.access_token: Optional[str] = None
|
self.access_token: Optional[str] = None
|
||||||
@@ -697,42 +877,115 @@ class AgentHttpClient:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def auth_headers(self) -> Dict[str, str]:
|
def auth_headers(self) -> Dict[str, str]:
|
||||||
|
headers: Dict[str, str] = {}
|
||||||
if self.access_token:
|
if self.access_token:
|
||||||
return {"Authorization": f"Bearer {self.access_token}"}
|
headers["Authorization"] = f"Bearer {self.access_token}"
|
||||||
return {}
|
context_label = _agent_context_default()
|
||||||
|
if context_label:
|
||||||
|
headers[_AGENT_CONTEXT_HEADER] = context_label
|
||||||
|
return headers
|
||||||
|
|
||||||
def configure_socketio(self, client: "socketio.AsyncClient") -> None:
|
def configure_socketio(self, client: "socketio.AsyncClient") -> None:
|
||||||
"""Align the Socket.IO engine's TLS verification with the REST client."""
|
"""Align the Socket.IO engine's TLS verification with the REST client."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
verify = getattr(self.session, "verify", True)
|
verify = getattr(self.session, "verify", True)
|
||||||
engine = getattr(client, "eio", None)
|
engine = getattr(client, "eio", None)
|
||||||
if engine is None:
|
if engine is None:
|
||||||
|
_log_agent(
|
||||||
|
"SocketIO TLS alignment skipped; AsyncClient.eio missing",
|
||||||
|
fname="agent.error.log",
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# python-engineio accepts either a boolean or an ``ssl.SSLContext``
|
http_iface = getattr(engine, "http", None)
|
||||||
# for TLS verification. When we have a pinned certificate bundle
|
|
||||||
# on disk, prefer constructing a dedicated context that trusts that
|
debug_info = {
|
||||||
# bundle so WebSocket connections succeed even with private CAs.
|
"verify_type": type(verify).__name__,
|
||||||
|
"verify_value": verify,
|
||||||
|
"engine_type": type(engine).__name__,
|
||||||
|
"http_iface_present": http_iface is not None,
|
||||||
|
}
|
||||||
|
_log_agent(
|
||||||
|
f"SocketIO TLS alignment start: {_format_debug_pairs(debug_info)}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_attr(target: Any, name: str, value: Any) -> None:
|
||||||
|
if target is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
setattr(target, name, value)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _reset_cached_session() -> None:
|
||||||
|
if http_iface is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if hasattr(http_iface, "session"):
|
||||||
|
setattr(http_iface, "session", None)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
context = None
|
||||||
if isinstance(verify, str) and os.path.isfile(verify):
|
if isinstance(verify, str) and os.path.isfile(verify):
|
||||||
try:
|
try:
|
||||||
context = ssl.create_default_context(cafile=verify)
|
# Mirror Requests' certificate handling by starting from a
|
||||||
|
# default client context (which pre-loads the system
|
||||||
|
# certificate stores) and then layering the pinned
|
||||||
|
# certificate bundle on top. This matches the REST client
|
||||||
|
# behaviour and ensures self-signed leaf certificates work
|
||||||
|
# the same way for Socket.IO handshakes.
|
||||||
|
context = ssl.create_default_context()
|
||||||
context.check_hostname = False
|
context.check_hostname = False
|
||||||
|
context.load_verify_locations(cafile=verify)
|
||||||
|
_log_agent(
|
||||||
|
f"SocketIO TLS alignment created SSLContext from cafile={verify}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
context = None
|
context = None
|
||||||
if context is not None:
|
_log_agent(
|
||||||
engine.ssl_context = context
|
f"SocketIO TLS alignment failed to build context from cafile={verify}",
|
||||||
engine.ssl_verify = True
|
fname="agent.error.log",
|
||||||
else:
|
)
|
||||||
engine.ssl_context = None
|
|
||||||
engine.ssl_verify = verify
|
if context is not None:
|
||||||
elif verify is False:
|
_set_attr(engine, "ssl_context", context)
|
||||||
engine.ssl_context = None
|
_set_attr(engine, "ssl_verify", True)
|
||||||
engine.ssl_verify = False
|
_set_attr(engine, "verify_ssl", True)
|
||||||
else:
|
_set_attr(http_iface, "ssl_context", context)
|
||||||
engine.ssl_context = None
|
_set_attr(http_iface, "ssl_verify", True)
|
||||||
engine.ssl_verify = True
|
_set_attr(http_iface, "verify_ssl", True)
|
||||||
|
_reset_cached_session()
|
||||||
|
_log_agent(
|
||||||
|
"SocketIO TLS alignment applied dedicated SSLContext to engine/http",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Fall back to boolean verification flags when we either do not
|
||||||
|
# have a pinned certificate bundle or failed to build a dedicated
|
||||||
|
# context for it.
|
||||||
|
verify_flag = False if verify is False else True
|
||||||
|
_set_attr(engine, "ssl_context", None)
|
||||||
|
_set_attr(engine, "ssl_verify", verify_flag)
|
||||||
|
_set_attr(engine, "verify_ssl", verify_flag)
|
||||||
|
_set_attr(http_iface, "ssl_context", None)
|
||||||
|
_set_attr(http_iface, "ssl_verify", verify_flag)
|
||||||
|
_set_attr(http_iface, "verify_ssl", verify_flag)
|
||||||
|
_reset_cached_session()
|
||||||
|
_log_agent(
|
||||||
|
f"SocketIO TLS alignment fallback verify_flag={verify_flag}",
|
||||||
|
fname="agent.log",
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
_log_agent(
|
||||||
|
"SocketIO TLS alignment encountered unexpected error",
|
||||||
|
fname="agent.error.log",
|
||||||
|
)
|
||||||
|
_log_exception_trace("configure_socketio")
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Enrollment & token management
|
# Enrollment & token management
|
||||||
@@ -1007,10 +1260,22 @@ class AgentHttpClient:
|
|||||||
timeout=20,
|
timeout=20,
|
||||||
)
|
)
|
||||||
if resp.status_code in (401, 403):
|
if resp.status_code in (401, 403):
|
||||||
_log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log")
|
error_code, snippet = self._error_details(resp)
|
||||||
self._clear_tokens_locked()
|
if resp.status_code == 401 and self._should_retry_auth(resp.status_code, error_code):
|
||||||
self._perform_enrollment_locked()
|
_log_agent(
|
||||||
return
|
"Refresh token rejected; attempting re-enrollment"
|
||||||
|
f" error={error_code or '<unknown>'}",
|
||||||
|
fname="agent.error.log",
|
||||||
|
)
|
||||||
|
self._clear_tokens_locked()
|
||||||
|
self._perform_enrollment_locked()
|
||||||
|
return
|
||||||
|
_log_agent(
|
||||||
|
"Refresh token request forbidden "
|
||||||
|
f"status={resp.status_code} error={error_code or '<unknown>'}"
|
||||||
|
f" body_snippet={snippet}",
|
||||||
|
fname="agent.error.log",
|
||||||
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
access_token = data.get("access_token")
|
access_token = data.get("access_token")
|
||||||
@@ -1036,14 +1301,79 @@ class AgentHttpClient:
|
|||||||
self.guid = self.key_store.load_guid()
|
self.guid = self.key_store.load_guid()
|
||||||
self.session.headers.pop("Authorization", None)
|
self.session.headers.pop("Authorization", None)
|
||||||
|
|
||||||
|
def _error_details(self, response: requests.Response) -> Tuple[Optional[str], str]:
|
||||||
|
error_code: Optional[str] = None
|
||||||
|
snippet = ""
|
||||||
|
try:
|
||||||
|
snippet = response.text[:256]
|
||||||
|
except Exception:
|
||||||
|
snippet = "<unavailable>"
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except Exception:
|
||||||
|
data = None
|
||||||
|
if isinstance(data, dict):
|
||||||
|
for key in ("error", "code", "status"):
|
||||||
|
value = data.get(key)
|
||||||
|
if isinstance(value, str) and value.strip():
|
||||||
|
error_code = value.strip()
|
||||||
|
break
|
||||||
|
return error_code, snippet
|
||||||
|
|
||||||
|
def _should_retry_auth(self, status_code: int, error_code: Optional[str]) -> bool:
|
||||||
|
if status_code == 401:
|
||||||
|
return True
|
||||||
|
retryable_forbidden = {"fingerprint_mismatch"}
|
||||||
|
if status_code == 403 and error_code in retryable_forbidden:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def _resolve_installer_code(self) -> str:
|
def _resolve_installer_code(self) -> str:
|
||||||
if INSTALLER_CODE_OVERRIDE:
|
if INSTALLER_CODE_OVERRIDE:
|
||||||
return INSTALLER_CODE_OVERRIDE
|
code = INSTALLER_CODE_OVERRIDE.strip()
|
||||||
|
if code:
|
||||||
|
try:
|
||||||
|
self.key_store.cache_installer_code(code, consumer=SERVICE_MODE_CANONICAL)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return code
|
||||||
|
code = ""
|
||||||
try:
|
try:
|
||||||
code = (CONFIG.data.get("installer_code") or "").strip()
|
code = (CONFIG.data.get("installer_code") or "").strip()
|
||||||
return code
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
code = ""
|
||||||
|
if code:
|
||||||
|
try:
|
||||||
|
self.key_store.cache_installer_code(code, consumer=SERVICE_MODE_CANONICAL)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return code
|
||||||
|
try:
|
||||||
|
cached = self.key_store.load_cached_installer_code()
|
||||||
|
except Exception:
|
||||||
|
cached = None
|
||||||
|
if cached:
|
||||||
|
try:
|
||||||
|
self.key_store.cache_installer_code(cached, consumer=SERVICE_MODE_CANONICAL)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return cached
|
||||||
|
fallback = _fallback_installer_code(CONFIG.path)
|
||||||
|
if fallback:
|
||||||
|
try:
|
||||||
|
CONFIG.data["installer_code"] = fallback
|
||||||
|
CONFIG._write()
|
||||||
|
_log_agent(
|
||||||
|
"Adopted installer code from sibling configuration", fname="agent.log"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self.key_store.cache_installer_code(fallback, consumer=SERVICE_MODE_CANONICAL)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return fallback
|
||||||
|
return ""
|
||||||
|
|
||||||
def _consume_installer_code(self) -> None:
|
def _consume_installer_code(self) -> None:
|
||||||
# Avoid clearing explicit CLI/env overrides; only mutate persisted config.
|
# Avoid clearing explicit CLI/env overrides; only mutate persisted config.
|
||||||
@@ -1057,6 +1387,13 @@ class AgentHttpClient:
|
|||||||
_log_agent("Cleared persisted installer code after successful enrollment", fname="agent.log")
|
_log_agent("Cleared persisted installer code after successful enrollment", fname="agent.log")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
_log_agent(f"Failed to clear installer code after enrollment: {exc}", fname="agent.error.log")
|
_log_agent(f"Failed to clear installer code after enrollment: {exc}", fname="agent.error.log")
|
||||||
|
try:
|
||||||
|
self.key_store.mark_installer_code_consumed(SERVICE_MODE_CANONICAL)
|
||||||
|
except Exception as exc:
|
||||||
|
_log_agent(
|
||||||
|
f"Failed to update shared installer code cache: {exc}",
|
||||||
|
fname="agent.error.log",
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# HTTP helpers
|
# HTTP helpers
|
||||||
@@ -1068,20 +1405,19 @@ class AgentHttpClient:
|
|||||||
headers = self.auth_headers()
|
headers = self.auth_headers()
|
||||||
response = self.session.post(url, json=payload, headers=headers, timeout=30)
|
response = self.session.post(url, json=payload, headers=headers, timeout=30)
|
||||||
if response.status_code in (401, 403) and require_auth:
|
if response.status_code in (401, 403) and require_auth:
|
||||||
snippet = ""
|
error_code, snippet = self._error_details(response)
|
||||||
try:
|
if self._should_retry_auth(response.status_code, error_code):
|
||||||
snippet = response.text[:256]
|
self.clear_tokens()
|
||||||
except Exception:
|
self.ensure_authenticated()
|
||||||
snippet = "<unavailable>"
|
headers = self.auth_headers()
|
||||||
_log_agent(
|
response = self.session.post(url, json=payload, headers=headers, timeout=30)
|
||||||
"Authenticated request rejected "
|
else:
|
||||||
f"path={path} status={response.status_code} body_snippet={snippet}",
|
_log_agent(
|
||||||
fname="agent.error.log",
|
"Authenticated request rejected "
|
||||||
)
|
f"path={path} status={response.status_code} error={error_code or '<unknown>'}"
|
||||||
self.clear_tokens()
|
f" body_snippet={snippet}",
|
||||||
self.ensure_authenticated()
|
fname="agent.error.log",
|
||||||
headers = self.auth_headers()
|
)
|
||||||
response = self.session.post(url, json=payload, headers=headers, timeout=30)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
if response.headers.get("Content-Type", "").lower().startswith("application/json"):
|
if response.headers.get("Content-Type", "").lower().startswith("application/json"):
|
||||||
return response.json()
|
return response.json()
|
||||||
@@ -2107,6 +2443,15 @@ async def send_agent_details_once():
|
|||||||
async def connect():
|
async def connect():
|
||||||
print(f"[INFO] Successfully Connected to Borealis Server!")
|
print(f"[INFO] Successfully Connected to Borealis Server!")
|
||||||
_log_agent('Connected to server.')
|
_log_agent('Connected to server.')
|
||||||
|
try:
|
||||||
|
sid = getattr(sio, 'sid', None)
|
||||||
|
transport = getattr(sio, 'transport', None)
|
||||||
|
_log_agent(
|
||||||
|
f'WebSocket handshake established sid={sid!r} transport={transport!r}',
|
||||||
|
fname='agent.log',
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
await sio.emit('connect_agent', {"agent_id": AGENT_ID, "service_mode": SERVICE_MODE})
|
await sio.emit('connect_agent', {"agent_id": AGENT_ID, "service_mode": SERVICE_MODE})
|
||||||
|
|
||||||
# Send an immediate heartbeat via authenticated REST call.
|
# Send an immediate heartbeat via authenticated REST call.
|
||||||
@@ -2143,6 +2488,17 @@ async def connect():
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@sio.event
|
||||||
|
async def connect_error(data):
|
||||||
|
try:
|
||||||
|
setattr(sio, "connection_error", data)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
_log_agent(f'Socket connect_error event: {data!r}', fname='agent.error.log')
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
@sio.event
|
@sio.event
|
||||||
async def disconnect():
|
async def disconnect():
|
||||||
print("[WebSocket] Disconnected from Borealis server.")
|
print("[WebSocket] Disconnected from Borealis server.")
|
||||||
@@ -2390,22 +2746,64 @@ if not SYSTEM_SERVICE_MODE:
|
|||||||
async def connect_loop():
|
async def connect_loop():
|
||||||
retry = 5
|
retry = 5
|
||||||
client = http_client()
|
client = http_client()
|
||||||
|
attempt = 0
|
||||||
while True:
|
while True:
|
||||||
|
attempt += 1
|
||||||
try:
|
try:
|
||||||
|
_log_agent(
|
||||||
|
f'connect_loop attempt={attempt} starting authentication phase',
|
||||||
|
fname='agent.log',
|
||||||
|
)
|
||||||
client.ensure_authenticated()
|
client.ensure_authenticated()
|
||||||
|
auth_snapshot = {
|
||||||
|
'guid_present': bool(client.guid),
|
||||||
|
'access_token': bool(client.access_token),
|
||||||
|
'refresh_token': bool(client.refresh_token),
|
||||||
|
'access_expiry': client.access_expires_at,
|
||||||
|
}
|
||||||
|
_log_agent(
|
||||||
|
f"connect_loop attempt={attempt} auth snapshot: {_format_debug_pairs(auth_snapshot)}",
|
||||||
|
fname='agent.log',
|
||||||
|
)
|
||||||
client.configure_socketio(sio)
|
client.configure_socketio(sio)
|
||||||
|
try:
|
||||||
|
setattr(sio, "connection_error", None)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
url = client.websocket_base_url()
|
url = client.websocket_base_url()
|
||||||
|
headers = client.auth_headers()
|
||||||
|
header_summary = _summarize_headers(headers)
|
||||||
|
verify_value = getattr(client.session, 'verify', None)
|
||||||
|
_log_agent(
|
||||||
|
f"connect_loop attempt={attempt} dialing websocket url={url} transports=['websocket'] verify={verify_value!r} headers={header_summary}",
|
||||||
|
fname='agent.log',
|
||||||
|
)
|
||||||
print(f"[INFO] Connecting Agent to {url}...")
|
print(f"[INFO] Connecting Agent to {url}...")
|
||||||
_log_agent(f'Connecting to {url}...')
|
|
||||||
await sio.connect(
|
await sio.connect(
|
||||||
url,
|
url,
|
||||||
transports=['websocket'],
|
transports=['websocket'],
|
||||||
headers=client.auth_headers(),
|
headers=headers,
|
||||||
|
)
|
||||||
|
_log_agent(
|
||||||
|
f'connect_loop attempt={attempt} sio.connect completed successfully',
|
||||||
|
fname='agent.log',
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[WebSocket] Server unavailable: {e}. Retrying in {retry}s...")
|
detail = _describe_exception(e)
|
||||||
_log_agent(f'Server unavailable: {e}', fname='agent.error.log')
|
try:
|
||||||
|
conn_err = getattr(sio, "connection_error", None)
|
||||||
|
except Exception:
|
||||||
|
conn_err = None
|
||||||
|
if conn_err:
|
||||||
|
detail = f"{detail}; connection_error={conn_err!r}"
|
||||||
|
message = (
|
||||||
|
f"connect_loop attempt={attempt} server unavailable: {detail}. "
|
||||||
|
f"Retrying in {retry}s..."
|
||||||
|
)
|
||||||
|
print(f"[WebSocket] {message}")
|
||||||
|
_log_agent(message, fname='agent.error.log')
|
||||||
|
_log_exception_trace(f'connect_loop attempt={attempt}')
|
||||||
await asyncio.sleep(retry)
|
await asyncio.sleep(retry)
|
||||||
|
|
||||||
if __name__=='__main__':
|
if __name__=='__main__':
|
||||||
|
|||||||
@@ -230,6 +230,7 @@ class AgentKeyStore:
|
|||||||
self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem")
|
self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem")
|
||||||
self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub")
|
self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub")
|
||||||
self._identity_lock_path = os.path.join(self.settings_dir, "identity.lock")
|
self._identity_lock_path = os.path.join(self.settings_dir, "identity.lock")
|
||||||
|
self._installer_cache_path = os.path.join(self.settings_dir, "installer_code.shared.json")
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Identity management
|
# Identity management
|
||||||
@@ -455,3 +456,107 @@ class AgentKeyStore:
|
|||||||
if isinstance(value, str) and value.strip():
|
if isinstance(value, str) and value.strip():
|
||||||
return value.strip()
|
return value.strip()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Installer code sharing helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _load_installer_cache(self) -> dict:
|
||||||
|
if not os.path.isfile(self._installer_cache_path):
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
with open(self._installer_cache_path, "r", encoding="utf-8") as fh:
|
||||||
|
data = json.load(fh)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _store_installer_cache(self, payload: dict) -> None:
|
||||||
|
try:
|
||||||
|
with open(self._installer_cache_path, "w", encoding="utf-8") as fh:
|
||||||
|
json.dump(payload, fh, indent=2)
|
||||||
|
_restrict_permissions(self._installer_cache_path)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def cache_installer_code(self, code: str, consumer: Optional[str] = None) -> None:
|
||||||
|
normalized = (code or "").strip()
|
||||||
|
if not normalized:
|
||||||
|
return
|
||||||
|
payload = self._load_installer_cache()
|
||||||
|
payload["code"] = normalized
|
||||||
|
consumers = set()
|
||||||
|
existing = payload.get("consumed")
|
||||||
|
if isinstance(existing, list):
|
||||||
|
consumers = {str(item).upper() for item in existing if isinstance(item, str)}
|
||||||
|
if consumer:
|
||||||
|
consumers.add(str(consumer).upper())
|
||||||
|
payload["consumed"] = sorted(consumers)
|
||||||
|
payload["updated_at"] = int(time.time())
|
||||||
|
self._store_installer_cache(payload)
|
||||||
|
|
||||||
|
def load_cached_installer_code(self) -> Optional[str]:
|
||||||
|
payload = self._load_installer_cache()
|
||||||
|
code = payload.get("code")
|
||||||
|
if isinstance(code, str):
|
||||||
|
stripped = code.strip()
|
||||||
|
if stripped:
|
||||||
|
return stripped
|
||||||
|
return None
|
||||||
|
|
||||||
|
def mark_installer_code_consumed(self, consumer: Optional[str] = None) -> None:
|
||||||
|
payload = self._load_installer_cache()
|
||||||
|
if not payload:
|
||||||
|
return
|
||||||
|
consumers = set()
|
||||||
|
existing = payload.get("consumed")
|
||||||
|
if isinstance(existing, list):
|
||||||
|
consumers = {str(item).upper() for item in existing if isinstance(item, str)}
|
||||||
|
if consumer:
|
||||||
|
consumers.add(str(consumer).upper())
|
||||||
|
payload["consumed"] = sorted(consumers)
|
||||||
|
payload["updated_at"] = int(time.time())
|
||||||
|
|
||||||
|
code_present = isinstance(payload.get("code"), str) and payload["code"].strip()
|
||||||
|
should_clear = False
|
||||||
|
if not code_present:
|
||||||
|
should_clear = True
|
||||||
|
else:
|
||||||
|
required_consumers = {"SYSTEM", "CURRENTUSER"}
|
||||||
|
if required_consumers.issubset(consumers):
|
||||||
|
should_clear = True
|
||||||
|
else:
|
||||||
|
remaining = required_consumers - consumers
|
||||||
|
if not remaining:
|
||||||
|
should_clear = True
|
||||||
|
else:
|
||||||
|
exists_other = False
|
||||||
|
for other in remaining:
|
||||||
|
if other == "SYSTEM":
|
||||||
|
cfg_name = "agent_settings_SYSTEM.json"
|
||||||
|
elif other == "CURRENTUSER":
|
||||||
|
cfg_name = "agent_settings_CURRENTUSER.json"
|
||||||
|
else:
|
||||||
|
cfg_name = None
|
||||||
|
if not cfg_name:
|
||||||
|
continue
|
||||||
|
path = os.path.join(self.settings_dir, cfg_name)
|
||||||
|
if os.path.isfile(path):
|
||||||
|
exists_other = True
|
||||||
|
break
|
||||||
|
if not exists_other:
|
||||||
|
should_clear = True
|
||||||
|
|
||||||
|
if should_clear:
|
||||||
|
payload.pop("code", None)
|
||||||
|
payload["consumed"] = []
|
||||||
|
|
||||||
|
if payload.get("code") or payload.get("consumed"):
|
||||||
|
self._store_installer_cache(payload)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
if os.path.isfile(self._installer_cache_path):
|
||||||
|
os.remove(self._installer_cache_path)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ def register(
|
|||||||
db_conn_factory: Callable[[], sqlite3.Connection],
|
db_conn_factory: Callable[[], sqlite3.Connection],
|
||||||
require_admin: Callable[[], Optional[Any]],
|
require_admin: Callable[[], Optional[Any]],
|
||||||
current_user: Callable[[], Optional[Dict[str, str]]],
|
current_user: Callable[[], Optional[Dict[str, str]]],
|
||||||
log: Callable[[str, str], None],
|
log: Callable[[str, str, Optional[str]], None],
|
||||||
) -> None:
|
) -> None:
|
||||||
blueprint = Blueprint("admin", __name__)
|
blueprint = Blueprint("admin", __name__)
|
||||||
|
|
||||||
@@ -54,18 +54,27 @@ def register(
|
|||||||
try:
|
try:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
sql = """
|
sql = """
|
||||||
SELECT id, code, expires_at, created_by_user_id, used_at, used_by_guid
|
SELECT id,
|
||||||
|
code,
|
||||||
|
expires_at,
|
||||||
|
created_by_user_id,
|
||||||
|
used_at,
|
||||||
|
used_by_guid,
|
||||||
|
max_uses,
|
||||||
|
use_count,
|
||||||
|
last_used_at
|
||||||
FROM enrollment_install_codes
|
FROM enrollment_install_codes
|
||||||
"""
|
"""
|
||||||
params: List[str] = []
|
params: List[str] = []
|
||||||
|
now_iso = _iso(_now())
|
||||||
if status_filter == "active":
|
if status_filter == "active":
|
||||||
sql += " WHERE used_at IS NULL AND expires_at > ?"
|
sql += " WHERE use_count < max_uses AND expires_at > ?"
|
||||||
params.append(_iso(_now()))
|
params.append(now_iso)
|
||||||
elif status_filter == "expired":
|
elif status_filter == "expired":
|
||||||
sql += " WHERE used_at IS NULL AND expires_at <= ?"
|
sql += " WHERE use_count < max_uses AND expires_at <= ?"
|
||||||
params.append(_iso(_now()))
|
params.append(now_iso)
|
||||||
elif status_filter == "used":
|
elif status_filter == "used":
|
||||||
sql += " WHERE used_at IS NOT NULL"
|
sql += " WHERE use_count >= max_uses"
|
||||||
sql += " ORDER BY expires_at ASC"
|
sql += " ORDER BY expires_at ASC"
|
||||||
cur.execute(sql, params)
|
cur.execute(sql, params)
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
@@ -82,6 +91,9 @@ def register(
|
|||||||
"created_by_user_id": row[3],
|
"created_by_user_id": row[3],
|
||||||
"used_at": row[4],
|
"used_at": row[4],
|
||||||
"used_by_guid": row[5],
|
"used_by_guid": row[5],
|
||||||
|
"max_uses": row[6],
|
||||||
|
"use_count": row[7],
|
||||||
|
"last_used_at": row[8],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return jsonify({"codes": records})
|
return jsonify({"codes": records})
|
||||||
@@ -93,6 +105,18 @@ def register(
|
|||||||
if ttl_hours not in VALID_TTL_HOURS:
|
if ttl_hours not in VALID_TTL_HOURS:
|
||||||
return jsonify({"error": "invalid_ttl"}), 400
|
return jsonify({"error": "invalid_ttl"}), 400
|
||||||
|
|
||||||
|
max_uses_value = payload.get("max_uses")
|
||||||
|
if max_uses_value is None:
|
||||||
|
max_uses_value = payload.get("allowed_uses")
|
||||||
|
try:
|
||||||
|
max_uses = int(max_uses_value)
|
||||||
|
except Exception:
|
||||||
|
max_uses = 2
|
||||||
|
if max_uses < 1:
|
||||||
|
max_uses = 1
|
||||||
|
if max_uses > 10:
|
||||||
|
max_uses = 10
|
||||||
|
|
||||||
user = current_user() or {}
|
user = current_user() or {}
|
||||||
username = user.get("username") or ""
|
username = user.get("username") or ""
|
||||||
|
|
||||||
@@ -106,22 +130,28 @@ def register(
|
|||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO enrollment_install_codes (
|
INSERT INTO enrollment_install_codes (
|
||||||
id, code, expires_at, created_by_user_id
|
id, code, expires_at, created_by_user_id, max_uses, use_count
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, 0)
|
||||||
""",
|
""",
|
||||||
(record_id, code_value, _iso(expires_at), created_by),
|
(record_id, code_value, _iso(expires_at), created_by, max_uses),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
log("server", f"installer code created id={record_id} by={username} ttl={ttl_hours}h")
|
log(
|
||||||
|
"server",
|
||||||
|
f"installer code created id={record_id} by={username} ttl={ttl_hours}h max_uses={max_uses}",
|
||||||
|
)
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"id": record_id,
|
"id": record_id,
|
||||||
"code": code_value,
|
"code": code_value,
|
||||||
"expires_at": _iso(expires_at),
|
"expires_at": _iso(expires_at),
|
||||||
|
"max_uses": max_uses,
|
||||||
|
"use_count": 0,
|
||||||
|
"last_used_at": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -131,7 +161,7 @@ def register(
|
|||||||
try:
|
try:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"DELETE FROM enrollment_install_codes WHERE id = ? AND used_at IS NULL",
|
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
|
||||||
(code_id,),
|
(code_id,),
|
||||||
)
|
)
|
||||||
deleted = cur.rowcount
|
deleted = cur.rowcount
|
||||||
|
|||||||
@@ -10,13 +10,24 @@ from flask import Blueprint, jsonify, request, g
|
|||||||
from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
|
from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
|
||||||
from Modules.crypto.signing import ScriptSigner
|
from Modules.crypto.signing import ScriptSigner
|
||||||
|
|
||||||
|
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||||
|
|
||||||
|
|
||||||
|
def _canonical_context(value: Optional[str]) -> Optional[str]:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
|
||||||
|
if not cleaned:
|
||||||
|
return None
|
||||||
|
return cleaned.upper()
|
||||||
|
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
app,
|
app,
|
||||||
*,
|
*,
|
||||||
db_conn_factory: Callable[[], Any],
|
db_conn_factory: Callable[[], Any],
|
||||||
auth_manager: DeviceAuthManager,
|
auth_manager: DeviceAuthManager,
|
||||||
log: Callable[[str, str], None],
|
log: Callable[[str, str, Optional[str]], None],
|
||||||
script_signer: ScriptSigner,
|
script_signer: ScriptSigner,
|
||||||
) -> None:
|
) -> None:
|
||||||
blueprint = Blueprint("agents", __name__)
|
blueprint = Blueprint("agents", __name__)
|
||||||
@@ -29,10 +40,15 @@ def register(
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _context_hint(ctx=None) -> Optional[str]:
|
||||||
|
if ctx is not None and getattr(ctx, "service_mode", None):
|
||||||
|
return _canonical_context(getattr(ctx, "service_mode", None))
|
||||||
|
return _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||||
|
|
||||||
def _auth_context():
|
def _auth_context():
|
||||||
ctx = getattr(g, "device_auth", None)
|
ctx = getattr(g, "device_auth", None)
|
||||||
if ctx is None:
|
if ctx is None:
|
||||||
log("server", f"device auth context missing for {request.path}")
|
log("server", f"device auth context missing for {request.path}", _context_hint())
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
|
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
|
||||||
@@ -42,6 +58,7 @@ def register(
|
|||||||
if ctx is None:
|
if ctx is None:
|
||||||
return jsonify({"error": "auth_context_missing"}), 500
|
return jsonify({"error": "auth_context_missing"}), 500
|
||||||
payload = request.get_json(force=True, silent=True) or {}
|
payload = request.get_json(force=True, silent=True) or {}
|
||||||
|
context_label = _context_hint(ctx)
|
||||||
|
|
||||||
now_ts = int(time.time())
|
now_ts = int(time.time())
|
||||||
updates: Dict[str, Optional[str]] = {"last_seen": now_ts}
|
updates: Dict[str, Optional[str]] = {"last_seen": now_ts}
|
||||||
@@ -111,12 +128,13 @@ def register(
|
|||||||
"server",
|
"server",
|
||||||
"heartbeat hostname collision ignored for guid="
|
"heartbeat hostname collision ignored for guid="
|
||||||
f"{ctx.guid}",
|
f"{ctx.guid}",
|
||||||
|
context_label,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if rowcount == 0:
|
if rowcount == 0:
|
||||||
log("server", f"heartbeat missing device record guid={ctx.guid}")
|
log("server", f"heartbeat missing device record guid={ctx.guid}", context_label)
|
||||||
return jsonify({"error": "device_not_registered"}), 404
|
return jsonify({"error": "device_not_registered"}), 404
|
||||||
conn.commit()
|
conn.commit()
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
from contextlib import closing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
@@ -10,6 +14,17 @@ from flask import g, jsonify, request
|
|||||||
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
||||||
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||||
|
|
||||||
|
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||||
|
|
||||||
|
|
||||||
|
def _canonical_context(value: Optional[str]) -> Optional[str]:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
|
||||||
|
if not cleaned:
|
||||||
|
return None
|
||||||
|
return cleaned.upper()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DeviceAuthContext:
|
class DeviceAuthContext:
|
||||||
@@ -20,6 +35,7 @@ class DeviceAuthContext:
|
|||||||
claims: Dict[str, Any]
|
claims: Dict[str, Any]
|
||||||
dpop_jkt: Optional[str]
|
dpop_jkt: Optional[str]
|
||||||
status: str
|
status: str
|
||||||
|
service_mode: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class DeviceAuthError(Exception):
|
class DeviceAuthError(Exception):
|
||||||
@@ -47,7 +63,7 @@ class DeviceAuthManager:
|
|||||||
db_conn_factory: Callable[[], Any],
|
db_conn_factory: Callable[[], Any],
|
||||||
jwt_service,
|
jwt_service,
|
||||||
dpop_validator: Optional[DPoPValidator],
|
dpop_validator: Optional[DPoPValidator],
|
||||||
log: Callable[[str, str], None],
|
log: Callable[[str, str, Optional[str]], None],
|
||||||
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
|
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._db_conn_factory = db_conn_factory
|
self._db_conn_factory = db_conn_factory
|
||||||
@@ -86,8 +102,9 @@ class DeviceAuthManager:
|
|||||||
retry_after=decision.retry_after,
|
retry_after=decision.retry_after,
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = self._db_conn_factory()
|
context_label = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||||
try:
|
|
||||||
|
with closing(self._db_conn_factory()) as conn:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
@@ -98,8 +115,11 @@ class DeviceAuthManager:
|
|||||||
(guid,),
|
(guid,),
|
||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
finally:
|
|
||||||
conn.close()
|
if not row:
|
||||||
|
row = self._recover_device_record(
|
||||||
|
conn, guid, fingerprint, token_version, context_label
|
||||||
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
raise DeviceAuthError("device_not_found", status_code=403)
|
raise DeviceAuthError("device_not_found", status_code=403)
|
||||||
@@ -121,7 +141,11 @@ class DeviceAuthManager:
|
|||||||
if status_normalized not in allowed_statuses:
|
if status_normalized not in allowed_statuses:
|
||||||
raise DeviceAuthError("device_revoked", status_code=403)
|
raise DeviceAuthError("device_revoked", status_code=403)
|
||||||
if status_normalized == "quarantined":
|
if status_normalized == "quarantined":
|
||||||
self._log("server", f"device {guid} is quarantined; limited access for {request.path}")
|
self._log(
|
||||||
|
"server",
|
||||||
|
f"device {guid} is quarantined; limited access for {request.path}",
|
||||||
|
context_label,
|
||||||
|
)
|
||||||
|
|
||||||
dpop_jkt: Optional[str] = None
|
dpop_jkt: Optional[str] = None
|
||||||
dpop_proof = request.headers.get("DPoP")
|
dpop_proof = request.headers.get("DPoP")
|
||||||
@@ -144,9 +168,111 @@ class DeviceAuthManager:
|
|||||||
claims=claims,
|
claims=claims,
|
||||||
dpop_jkt=dpop_jkt,
|
dpop_jkt=dpop_jkt,
|
||||||
status=status_normalized,
|
status=status_normalized,
|
||||||
|
service_mode=context_label,
|
||||||
)
|
)
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
|
def _recover_device_record(
|
||||||
|
self,
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
guid: str,
|
||||||
|
fingerprint: str,
|
||||||
|
token_version: int,
|
||||||
|
context_label: Optional[str],
|
||||||
|
) -> Optional[tuple]:
|
||||||
|
"""Attempt to recreate a missing device row for an authenticated token."""
|
||||||
|
|
||||||
|
guid = (guid or "").strip()
|
||||||
|
fingerprint = (fingerprint or "").strip()
|
||||||
|
if not guid or not fingerprint:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cur = conn.cursor()
|
||||||
|
now_ts = int(time.time())
|
||||||
|
try:
|
||||||
|
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||||
|
except Exception:
|
||||||
|
now_iso = datetime.utcnow().isoformat() # pragma: no cover
|
||||||
|
|
||||||
|
base_hostname = f"RECOVERED-{guid[:12].upper()}" if guid else "RECOVERED"
|
||||||
|
|
||||||
|
for attempt in range(6):
|
||||||
|
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO devices (
|
||||||
|
guid,
|
||||||
|
hostname,
|
||||||
|
created_at,
|
||||||
|
last_seen,
|
||||||
|
ssl_key_fingerprint,
|
||||||
|
token_version,
|
||||||
|
status,
|
||||||
|
key_added_at
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
guid,
|
||||||
|
hostname,
|
||||||
|
now_ts,
|
||||||
|
now_ts,
|
||||||
|
fingerprint,
|
||||||
|
max(token_version or 1, 1),
|
||||||
|
now_iso,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except sqlite3.IntegrityError as exc:
|
||||||
|
# Hostname collision – try again with a suffixed placeholder.
|
||||||
|
message = str(exc).lower()
|
||||||
|
if "hostname" in message and "unique" in message:
|
||||||
|
continue
|
||||||
|
self._log(
|
||||||
|
"server",
|
||||||
|
f"device auth failed to recover guid={guid} due to integrity error: {exc}",
|
||||||
|
context_label,
|
||||||
|
)
|
||||||
|
conn.rollback()
|
||||||
|
return None
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
self._log(
|
||||||
|
"server",
|
||||||
|
f"device auth unexpected error recovering guid={guid}: {exc}",
|
||||||
|
context_label,
|
||||||
|
)
|
||||||
|
conn.rollback()
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
conn.commit()
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Exhausted attempts because of hostname collisions.
|
||||||
|
self._log(
|
||||||
|
"server",
|
||||||
|
f"device auth could not recover guid={guid}; hostname collisions persisted",
|
||||||
|
context_label,
|
||||||
|
)
|
||||||
|
conn.rollback()
|
||||||
|
return None
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||||
|
FROM devices
|
||||||
|
WHERE guid = ?
|
||||||
|
""",
|
||||||
|
(guid,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if not row:
|
||||||
|
self._log(
|
||||||
|
"server",
|
||||||
|
f"device auth recovery for guid={guid} committed but row still missing",
|
||||||
|
context_label,
|
||||||
|
)
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def require_device_auth(manager: DeviceAuthManager):
|
def require_device_auth(manager: DeviceAuthManager):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
|
|||||||
@@ -152,7 +152,10 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
|
|||||||
expires_at TEXT NOT NULL,
|
expires_at TEXT NOT NULL,
|
||||||
created_by_user_id TEXT,
|
created_by_user_id TEXT,
|
||||||
used_at TEXT,
|
used_at TEXT,
|
||||||
used_by_guid TEXT
|
used_by_guid TEXT,
|
||||||
|
max_uses INTEGER NOT NULL DEFAULT 1,
|
||||||
|
use_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
last_used_at TEXT
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@@ -163,6 +166,29 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")}
|
||||||
|
if "max_uses" not in columns:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE enrollment_install_codes
|
||||||
|
ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
if "use_count" not in columns:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE enrollment_install_codes
|
||||||
|
ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
if "last_used_at" not in columns:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE enrollment_install_codes
|
||||||
|
ADD COLUMN last_used_at TEXT
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
|
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
|
|||||||
@@ -6,7 +6,18 @@ import sqlite3
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||||
|
|
||||||
|
|
||||||
|
def _canonical_context(value: Optional[str]) -> Optional[str]:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
|
||||||
|
if not cleaned:
|
||||||
|
return None
|
||||||
|
return cleaned.upper()
|
||||||
|
|
||||||
from flask import Blueprint, jsonify, request
|
from flask import Blueprint, jsonify, request
|
||||||
|
|
||||||
@@ -20,7 +31,7 @@ def register(
|
|||||||
app,
|
app,
|
||||||
*,
|
*,
|
||||||
db_conn_factory: Callable[[], sqlite3.Connection],
|
db_conn_factory: Callable[[], sqlite3.Connection],
|
||||||
log: Callable[[str, str], None],
|
log: Callable[[str, str, Optional[str]], None],
|
||||||
jwt_service,
|
jwt_service,
|
||||||
tls_bundle_path: str,
|
tls_bundle_path: str,
|
||||||
ip_rate_limiter: SlidingWindowRateLimiter,
|
ip_rate_limiter: SlidingWindowRateLimiter,
|
||||||
@@ -51,12 +62,19 @@ def register(
|
|||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
|
def _rate_limited(
|
||||||
|
key: str,
|
||||||
|
limiter: SlidingWindowRateLimiter,
|
||||||
|
limit: int,
|
||||||
|
window_s: float,
|
||||||
|
context_hint: Optional[str],
|
||||||
|
):
|
||||||
decision = limiter.check(key, limit, window_s)
|
decision = limiter.check(key, limit, window_s)
|
||||||
if not decision.allowed:
|
if not decision.allowed:
|
||||||
log(
|
log(
|
||||||
"server",
|
"server",
|
||||||
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
|
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
|
||||||
|
context_hint,
|
||||||
)
|
)
|
||||||
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
|
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
|
||||||
response.status_code = 429
|
response.status_code = 429
|
||||||
@@ -66,31 +84,79 @@ def register(
|
|||||||
|
|
||||||
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
|
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?",
|
"""
|
||||||
|
SELECT id,
|
||||||
|
code,
|
||||||
|
expires_at,
|
||||||
|
used_at,
|
||||||
|
used_by_guid,
|
||||||
|
max_uses,
|
||||||
|
use_count,
|
||||||
|
last_used_at
|
||||||
|
FROM enrollment_install_codes
|
||||||
|
WHERE code = ?
|
||||||
|
""",
|
||||||
(code_value,),
|
(code_value,),
|
||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
keys = ["id", "code", "expires_at", "used_at"]
|
keys = [
|
||||||
|
"id",
|
||||||
|
"code",
|
||||||
|
"expires_at",
|
||||||
|
"used_at",
|
||||||
|
"used_by_guid",
|
||||||
|
"max_uses",
|
||||||
|
"use_count",
|
||||||
|
"last_used_at",
|
||||||
|
]
|
||||||
record = dict(zip(keys, row))
|
record = dict(zip(keys, row))
|
||||||
return record
|
return record
|
||||||
|
|
||||||
def _install_code_valid(record: Dict[str, Any]) -> bool:
|
def _install_code_valid(
|
||||||
|
record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor
|
||||||
|
) -> Tuple[bool, Optional[str]]:
|
||||||
if not record:
|
if not record:
|
||||||
return False
|
return False, None
|
||||||
expires_at = record.get("expires_at")
|
expires_at = record.get("expires_at")
|
||||||
if not isinstance(expires_at, str):
|
if not isinstance(expires_at, str):
|
||||||
return False
|
return False, None
|
||||||
try:
|
try:
|
||||||
expiry = datetime.fromisoformat(expires_at)
|
expiry = datetime.fromisoformat(expires_at)
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False, None
|
||||||
if expiry <= _now():
|
if expiry <= _now():
|
||||||
return False
|
return False, None
|
||||||
if record.get("used_at"):
|
try:
|
||||||
return False
|
max_uses = int(record.get("max_uses") or 1)
|
||||||
return True
|
except Exception:
|
||||||
|
max_uses = 1
|
||||||
|
if max_uses < 1:
|
||||||
|
max_uses = 1
|
||||||
|
try:
|
||||||
|
use_count = int(record.get("use_count") or 0)
|
||||||
|
except Exception:
|
||||||
|
use_count = 0
|
||||||
|
if use_count < max_uses:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
guid = str(record.get("used_by_guid") or "").strip()
|
||||||
|
if not guid:
|
||||||
|
return False, None
|
||||||
|
cur.execute(
|
||||||
|
"SELECT ssl_key_fingerprint FROM devices WHERE guid = ?",
|
||||||
|
(guid,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if not row:
|
||||||
|
return False, None
|
||||||
|
stored_fp = (row[0] or "").strip().lower()
|
||||||
|
if not stored_fp:
|
||||||
|
return False, None
|
||||||
|
if stored_fp == (fingerprint or "").strip().lower():
|
||||||
|
return True, guid
|
||||||
|
return False, None
|
||||||
|
|
||||||
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
|
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
|
||||||
base = (hostname or "").strip() or guid
|
base = (hostname or "").strip() or guid
|
||||||
@@ -247,7 +313,9 @@ def register(
|
|||||||
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
|
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
|
||||||
def enrollment_request():
|
def enrollment_request():
|
||||||
remote = _remote_addr()
|
remote = _remote_addr()
|
||||||
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0)
|
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||||
|
|
||||||
|
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0, context_hint)
|
||||||
if rate_error:
|
if rate_error:
|
||||||
return rate_error
|
return rate_error
|
||||||
|
|
||||||
@@ -262,42 +330,43 @@ def register(
|
|||||||
"enrollment request received "
|
"enrollment request received "
|
||||||
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
|
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
|
||||||
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
|
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
|
||||||
|
context_hint,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not hostname:
|
if not hostname:
|
||||||
log("server", f"enrollment rejected missing_hostname ip={remote}")
|
log("server", f"enrollment rejected missing_hostname ip={remote}", context_hint)
|
||||||
return jsonify({"error": "hostname_required"}), 400
|
return jsonify({"error": "hostname_required"}), 400
|
||||||
if not enrollment_code:
|
if not enrollment_code:
|
||||||
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}")
|
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
|
||||||
return jsonify({"error": "enrollment_code_required"}), 400
|
return jsonify({"error": "enrollment_code_required"}), 400
|
||||||
if not isinstance(agent_pubkey_b64, str):
|
if not isinstance(agent_pubkey_b64, str):
|
||||||
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}")
|
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
|
||||||
return jsonify({"error": "agent_pubkey_required"}), 400
|
return jsonify({"error": "agent_pubkey_required"}), 400
|
||||||
if not isinstance(client_nonce_b64, str):
|
if not isinstance(client_nonce_b64, str):
|
||||||
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}")
|
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
|
||||||
return jsonify({"error": "client_nonce_required"}), 400
|
return jsonify({"error": "client_nonce_required"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
|
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
|
||||||
except Exception:
|
except Exception:
|
||||||
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}")
|
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
|
||||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||||
|
|
||||||
if len(agent_pubkey_der) < 10:
|
if len(agent_pubkey_der) < 10:
|
||||||
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}")
|
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
|
||||||
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
return jsonify({"error": "invalid_agent_pubkey"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}")
|
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
|
||||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||||
if len(client_nonce_bytes) < 16:
|
if len(client_nonce_bytes) < 16:
|
||||||
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}")
|
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
|
||||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||||
|
|
||||||
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
|
||||||
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0)
|
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0, context_hint)
|
||||||
if rate_error:
|
if rate_error:
|
||||||
return rate_error
|
return rate_error
|
||||||
|
|
||||||
@@ -305,7 +374,14 @@ def register(
|
|||||||
try:
|
try:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
install_code = _load_install_code(cur, enrollment_code)
|
install_code = _load_install_code(cur, enrollment_code)
|
||||||
if not _install_code_valid(install_code):
|
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
|
||||||
|
if not valid_code:
|
||||||
|
log(
|
||||||
|
"server",
|
||||||
|
"enrollment request invalid_code "
|
||||||
|
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
|
||||||
|
context_hint,
|
||||||
|
)
|
||||||
return jsonify({"error": "invalid_enrollment_code"}), 400
|
return jsonify({"error": "invalid_enrollment_code"}), 400
|
||||||
|
|
||||||
approval_reference: str
|
approval_reference: str
|
||||||
@@ -331,6 +407,7 @@ def register(
|
|||||||
"""
|
"""
|
||||||
UPDATE device_approvals
|
UPDATE device_approvals
|
||||||
SET hostname_claimed = ?,
|
SET hostname_claimed = ?,
|
||||||
|
guid = ?,
|
||||||
enrollment_code_id = ?,
|
enrollment_code_id = ?,
|
||||||
client_nonce = ?,
|
client_nonce = ?,
|
||||||
server_nonce = ?,
|
server_nonce = ?,
|
||||||
@@ -340,6 +417,7 @@ def register(
|
|||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
hostname,
|
hostname,
|
||||||
|
reuse_guid,
|
||||||
install_code["id"],
|
install_code["id"],
|
||||||
client_nonce_b64,
|
client_nonce_b64,
|
||||||
server_nonce_b64,
|
server_nonce_b64,
|
||||||
@@ -359,11 +437,12 @@ def register(
|
|||||||
status, client_nonce, server_nonce, agent_pubkey_der,
|
status, client_nonce, server_nonce, agent_pubkey_der,
|
||||||
created_at, updated_at
|
created_at, updated_at
|
||||||
)
|
)
|
||||||
VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
record_id,
|
record_id,
|
||||||
approval_reference,
|
approval_reference,
|
||||||
|
reuse_guid,
|
||||||
hostname,
|
hostname,
|
||||||
fingerprint,
|
fingerprint,
|
||||||
install_code["id"],
|
install_code["id"],
|
||||||
@@ -387,7 +466,11 @@ def register(
|
|||||||
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
||||||
"signing_key": _signing_key_b64(),
|
"signing_key": _signing_key_b64(),
|
||||||
}
|
}
|
||||||
log("server", f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}")
|
log(
|
||||||
|
"server",
|
||||||
|
f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}",
|
||||||
|
context_hint,
|
||||||
|
)
|
||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
|
||||||
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
|
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
|
||||||
@@ -396,34 +479,36 @@ def register(
|
|||||||
approval_reference = payload.get("approval_reference")
|
approval_reference = payload.get("approval_reference")
|
||||||
client_nonce_b64 = payload.get("client_nonce")
|
client_nonce_b64 = payload.get("client_nonce")
|
||||||
proof_sig_b64 = payload.get("proof_sig")
|
proof_sig_b64 = payload.get("proof_sig")
|
||||||
|
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
|
||||||
|
|
||||||
log(
|
log(
|
||||||
"server",
|
"server",
|
||||||
"enrollment poll received "
|
"enrollment poll received "
|
||||||
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
|
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
|
||||||
f" proof_sig_len={len(proof_sig_b64 or '')}",
|
f" proof_sig_len={len(proof_sig_b64 or '')}",
|
||||||
|
context_hint,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(approval_reference, str) or not approval_reference:
|
if not isinstance(approval_reference, str) or not approval_reference:
|
||||||
log("server", "enrollment poll rejected missing_reference")
|
log("server", "enrollment poll rejected missing_reference", context_hint)
|
||||||
return jsonify({"error": "approval_reference_required"}), 400
|
return jsonify({"error": "approval_reference_required"}), 400
|
||||||
if not isinstance(client_nonce_b64, str):
|
if not isinstance(client_nonce_b64, str):
|
||||||
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}")
|
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
|
||||||
return jsonify({"error": "client_nonce_required"}), 400
|
return jsonify({"error": "client_nonce_required"}), 400
|
||||||
if not isinstance(proof_sig_b64, str):
|
if not isinstance(proof_sig_b64, str):
|
||||||
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}")
|
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
|
||||||
return jsonify({"error": "proof_sig_required"}), 400
|
return jsonify({"error": "proof_sig_required"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}")
|
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
|
||||||
return jsonify({"error": "invalid_client_nonce"}), 400
|
return jsonify({"error": "invalid_client_nonce"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
|
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
log("server", f"enrollment poll invalid_sig ref={approval_reference}")
|
log("server", f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
|
||||||
return jsonify({"error": "invalid_proof_sig"}), 400
|
return jsonify({"error": "invalid_proof_sig"}), 400
|
||||||
|
|
||||||
conn = db_conn_factory()
|
conn = db_conn_factory()
|
||||||
@@ -441,7 +526,7 @@ def register(
|
|||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
log("server", f"enrollment poll unknown_reference ref={approval_reference}")
|
log("server", f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
|
||||||
return jsonify({"status": "unknown"}), 404
|
return jsonify({"status": "unknown"}), 404
|
||||||
|
|
||||||
(
|
(
|
||||||
@@ -460,13 +545,13 @@ def register(
|
|||||||
) = row
|
) = row
|
||||||
|
|
||||||
if client_nonce_stored != client_nonce_b64:
|
if client_nonce_stored != client_nonce_b64:
|
||||||
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}")
|
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
|
||||||
return jsonify({"error": "nonce_mismatch"}), 400
|
return jsonify({"error": "nonce_mismatch"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
|
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}")
|
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
|
||||||
return jsonify({"error": "server_nonce_invalid"}), 400
|
return jsonify({"error": "server_nonce_invalid"}), 400
|
||||||
|
|
||||||
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
|
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
|
||||||
@@ -474,17 +559,17 @@ def register(
|
|||||||
try:
|
try:
|
||||||
public_key = serialization.load_der_public_key(agent_pubkey_der)
|
public_key = serialization.load_der_public_key(agent_pubkey_der)
|
||||||
except Exception:
|
except Exception:
|
||||||
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}")
|
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
|
||||||
public_key = None
|
public_key = None
|
||||||
|
|
||||||
if public_key is None:
|
if public_key is None:
|
||||||
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}")
|
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
|
||||||
return jsonify({"error": "agent_pubkey_invalid"}), 400
|
return jsonify({"error": "agent_pubkey_invalid"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
public_key.verify(proof_sig, message)
|
public_key.verify(proof_sig, message)
|
||||||
except Exception:
|
except Exception:
|
||||||
log("server", f"enrollment poll invalid_proof ref={approval_reference}")
|
log("server", f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
|
||||||
return jsonify({"error": "invalid_proof"}), 400
|
return jsonify({"error": "invalid_proof"}), 400
|
||||||
|
|
||||||
if status == "pending":
|
if status == "pending":
|
||||||
@@ -492,24 +577,28 @@ def register(
|
|||||||
"server",
|
"server",
|
||||||
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
|
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
|
||||||
f" fingerprint={fingerprint[:12]}",
|
f" fingerprint={fingerprint[:12]}",
|
||||||
|
context_hint,
|
||||||
)
|
)
|
||||||
return jsonify({"status": "pending", "poll_after_ms": 5000})
|
return jsonify({"status": "pending", "poll_after_ms": 5000})
|
||||||
if status == "denied":
|
if status == "denied":
|
||||||
log(
|
log(
|
||||||
"server",
|
"server",
|
||||||
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
|
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
|
||||||
|
context_hint,
|
||||||
)
|
)
|
||||||
return jsonify({"status": "denied", "reason": "operator_denied"})
|
return jsonify({"status": "denied", "reason": "operator_denied"})
|
||||||
if status == "expired":
|
if status == "expired":
|
||||||
log(
|
log(
|
||||||
"server",
|
"server",
|
||||||
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
|
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
|
||||||
|
context_hint,
|
||||||
)
|
)
|
||||||
return jsonify({"status": "expired"})
|
return jsonify({"status": "expired"})
|
||||||
if status == "completed":
|
if status == "completed":
|
||||||
log(
|
log(
|
||||||
"server",
|
"server",
|
||||||
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
|
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
|
||||||
|
context_hint,
|
||||||
)
|
)
|
||||||
return jsonify({"status": "approved", "detail": "finalized"})
|
return jsonify({"status": "approved", "detail": "finalized"})
|
||||||
|
|
||||||
@@ -517,6 +606,7 @@ def register(
|
|||||||
log(
|
log(
|
||||||
"server",
|
"server",
|
||||||
f"enrollment poll unexpected_status={status} ref={approval_reference}",
|
f"enrollment poll unexpected_status={status} ref={approval_reference}",
|
||||||
|
context_hint,
|
||||||
)
|
)
|
||||||
return jsonify({"status": status or "unknown"}), 400
|
return jsonify({"status": status or "unknown"}), 400
|
||||||
|
|
||||||
@@ -525,6 +615,7 @@ def register(
|
|||||||
log(
|
log(
|
||||||
"server",
|
"server",
|
||||||
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
|
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
|
||||||
|
context_hint,
|
||||||
)
|
)
|
||||||
return jsonify({"error": "proof_replayed"}), 409
|
return jsonify({"error": "proof_replayed"}), 409
|
||||||
|
|
||||||
@@ -537,14 +628,40 @@ def register(
|
|||||||
|
|
||||||
# Mark install code used
|
# Mark install code used
|
||||||
if enrollment_code_id:
|
if enrollment_code_id:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?",
|
||||||
|
(enrollment_code_id,),
|
||||||
|
)
|
||||||
|
usage_row = cur.fetchone()
|
||||||
|
try:
|
||||||
|
prior_count = int(usage_row[0]) if usage_row else 0
|
||||||
|
except Exception:
|
||||||
|
prior_count = 0
|
||||||
|
try:
|
||||||
|
allowed_uses = int(usage_row[1]) if usage_row else 1
|
||||||
|
except Exception:
|
||||||
|
allowed_uses = 1
|
||||||
|
if allowed_uses < 1:
|
||||||
|
allowed_uses = 1
|
||||||
|
new_count = prior_count + 1
|
||||||
|
consumed = new_count >= allowed_uses
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE enrollment_install_codes
|
UPDATE enrollment_install_codes
|
||||||
SET used_at = ?, used_by_guid = ?
|
SET use_count = ?,
|
||||||
|
used_by_guid = ?,
|
||||||
|
last_used_at = ?,
|
||||||
|
used_at = CASE WHEN ? THEN ? ELSE used_at END
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
AND used_at IS NULL
|
|
||||||
""",
|
""",
|
||||||
(now_iso, effective_guid, enrollment_code_id),
|
(
|
||||||
|
new_count,
|
||||||
|
effective_guid,
|
||||||
|
now_iso,
|
||||||
|
1 if consumed else 0,
|
||||||
|
now_iso,
|
||||||
|
enrollment_code_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update approval record with final state
|
# Update approval record with final state
|
||||||
@@ -573,6 +690,7 @@ def register(
|
|||||||
log(
|
log(
|
||||||
"server",
|
"server",
|
||||||
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
|
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
|
||||||
|
context_hint,
|
||||||
)
|
)
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Callable
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import eventlet
|
import eventlet
|
||||||
from flask_socketio import SocketIO
|
from flask_socketio import SocketIO
|
||||||
@@ -11,7 +11,7 @@ def start_prune_job(
|
|||||||
socketio: SocketIO,
|
socketio: SocketIO,
|
||||||
*,
|
*,
|
||||||
db_conn_factory: Callable[[], any],
|
db_conn_factory: Callable[[], any],
|
||||||
log: Callable[[str, str], None],
|
log: Callable[[str, str, Optional[str]], None],
|
||||||
) -> None:
|
) -> None:
|
||||||
def _job_loop():
|
def _job_loop():
|
||||||
while True:
|
while True:
|
||||||
@@ -24,7 +24,7 @@ def start_prune_job(
|
|||||||
socketio.start_background_task(_job_loop)
|
socketio.start_background_task(_job_loop)
|
||||||
|
|
||||||
|
|
||||||
def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None]) -> None:
|
def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str, Optional[str]], None]) -> None:
|
||||||
now = datetime.now(tz=timezone.utc)
|
now = datetime.now(tz=timezone.utc)
|
||||||
now_iso = now.isoformat()
|
now_iso = now.isoformat()
|
||||||
stale_before = (now - timedelta(hours=24)).isoformat()
|
stale_before = (now - timedelta(hours=24)).isoformat()
|
||||||
@@ -34,7 +34,7 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
|
|||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
DELETE FROM enrollment_install_codes
|
DELETE FROM enrollment_install_codes
|
||||||
WHERE used_at IS NULL
|
WHERE use_count = 0
|
||||||
AND expires_at < ?
|
AND expires_at < ?
|
||||||
""",
|
""",
|
||||||
(now_iso,),
|
(now_iso,),
|
||||||
@@ -52,7 +52,10 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
|
|||||||
SELECT 1
|
SELECT 1
|
||||||
FROM enrollment_install_codes c
|
FROM enrollment_install_codes c
|
||||||
WHERE c.id = device_approvals.enrollment_code_id
|
WHERE c.id = device_approvals.enrollment_code_id
|
||||||
AND c.expires_at < ?
|
AND (
|
||||||
|
c.expires_at < ?
|
||||||
|
OR c.use_count >= c.max_uses
|
||||||
|
)
|
||||||
)
|
)
|
||||||
OR created_at < ?
|
OR created_at < ?
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -93,7 +93,20 @@ def register(
|
|||||||
except DPoPVerificationError:
|
except DPoPVerificationError:
|
||||||
return jsonify({"error": "dpop_invalid"}), 400
|
return jsonify({"error": "dpop_invalid"}), 400
|
||||||
elif stored_jkt:
|
elif stored_jkt:
|
||||||
return jsonify({"error": "dpop_required"}), 400
|
# The agent does not yet emit DPoP proofs; allow recovery by clearing
|
||||||
|
# the stored binding so refreshes can succeed. This preserves
|
||||||
|
# backward compatibility while the client gains full DPoP support.
|
||||||
|
try:
|
||||||
|
app.logger.warning(
|
||||||
|
"Clearing stored DPoP binding for guid=%s due to missing proof",
|
||||||
|
guid,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
cur.execute(
|
||||||
|
"UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?",
|
||||||
|
(record_id,),
|
||||||
|
)
|
||||||
|
|
||||||
new_access_token = jwt_service.issue_access_token(
|
new_access_token = jwt_service.issue_access_token(
|
||||||
guid,
|
guid,
|
||||||
|
|||||||
@@ -65,7 +65,9 @@ const formatDateTime = (value) => {
|
|||||||
|
|
||||||
const determineStatus = (record) => {
|
const determineStatus = (record) => {
|
||||||
if (!record) return "expired";
|
if (!record) return "expired";
|
||||||
if (record.used_at) return "used";
|
const maxUses = Number.isFinite(record?.max_uses) ? record.max_uses : 1;
|
||||||
|
const useCount = Number.isFinite(record?.use_count) ? record.use_count : 0;
|
||||||
|
if (useCount >= Math.max(1, maxUses || 1)) return "used";
|
||||||
if (!record.expires_at) return "expired";
|
if (!record.expires_at) return "expired";
|
||||||
const expires = new Date(record.expires_at);
|
const expires = new Date(record.expires_at);
|
||||||
if (Number.isNaN(expires.getTime())) return "expired";
|
if (Number.isNaN(expires.getTime())) return "expired";
|
||||||
@@ -80,6 +82,7 @@ function EnrollmentCodes() {
|
|||||||
const [statusFilter, setStatusFilter] = useState("all");
|
const [statusFilter, setStatusFilter] = useState("all");
|
||||||
const [ttlHours, setTtlHours] = useState(6);
|
const [ttlHours, setTtlHours] = useState(6);
|
||||||
const [generating, setGenerating] = useState(false);
|
const [generating, setGenerating] = useState(false);
|
||||||
|
const [maxUses, setMaxUses] = useState(2);
|
||||||
|
|
||||||
const filteredCodes = useMemo(() => {
|
const filteredCodes = useMemo(() => {
|
||||||
if (statusFilter === "all") return codes;
|
if (statusFilter === "all") return codes;
|
||||||
@@ -119,7 +122,7 @@ function EnrollmentCodes() {
|
|||||||
method: "POST",
|
method: "POST",
|
||||||
credentials: "include",
|
credentials: "include",
|
||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
body: JSON.stringify({ ttl_hours: ttlHours }),
|
body: JSON.stringify({ ttl_hours: ttlHours, max_uses: maxUses }),
|
||||||
});
|
});
|
||||||
if (!resp.ok) {
|
if (!resp.ok) {
|
||||||
const body = await resp.json().catch(() => ({}));
|
const body = await resp.json().catch(() => ({}));
|
||||||
@@ -133,7 +136,7 @@ function EnrollmentCodes() {
|
|||||||
} finally {
|
} finally {
|
||||||
setGenerating(false);
|
setGenerating(false);
|
||||||
}
|
}
|
||||||
}, [fetchCodes, ttlHours]);
|
}, [fetchCodes, ttlHours, maxUses]);
|
||||||
|
|
||||||
const handleDelete = useCallback(
|
const handleDelete = useCallback(
|
||||||
async (id) => {
|
async (id) => {
|
||||||
@@ -216,7 +219,7 @@ function EnrollmentCodes() {
|
|||||||
labelId="ttl-select-label"
|
labelId="ttl-select-label"
|
||||||
label="Duration"
|
label="Duration"
|
||||||
value={ttlHours}
|
value={ttlHours}
|
||||||
onChange={(event) => setTtlHours(event.target.value)}
|
onChange={(event) => setTtlHours(Number(event.target.value))}
|
||||||
>
|
>
|
||||||
{TTL_PRESETS.map((preset) => (
|
{TTL_PRESETS.map((preset) => (
|
||||||
<MenuItem key={preset.value} value={preset.value}>
|
<MenuItem key={preset.value} value={preset.value}>
|
||||||
@@ -226,6 +229,22 @@ function EnrollmentCodes() {
|
|||||||
</Select>
|
</Select>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
|
||||||
|
<FormControl size="small" sx={{ minWidth: 160 }}>
|
||||||
|
<InputLabel id="uses-select-label">Allowed Uses</InputLabel>
|
||||||
|
<Select
|
||||||
|
labelId="uses-select-label"
|
||||||
|
label="Allowed Uses"
|
||||||
|
value={maxUses}
|
||||||
|
onChange={(event) => setMaxUses(Number(event.target.value))}
|
||||||
|
>
|
||||||
|
{[1, 2, 3, 5].map((uses) => (
|
||||||
|
<MenuItem key={uses} value={uses}>
|
||||||
|
{uses === 1 ? "Single use" : `${uses} uses`}
|
||||||
|
</MenuItem>
|
||||||
|
))}
|
||||||
|
</Select>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
<Button
|
<Button
|
||||||
variant="contained"
|
variant="contained"
|
||||||
color="primary"
|
color="primary"
|
||||||
@@ -270,7 +289,9 @@ function EnrollmentCodes() {
|
|||||||
<TableCell>Installer Code</TableCell>
|
<TableCell>Installer Code</TableCell>
|
||||||
<TableCell>Expires At</TableCell>
|
<TableCell>Expires At</TableCell>
|
||||||
<TableCell>Created By</TableCell>
|
<TableCell>Created By</TableCell>
|
||||||
<TableCell>Used At</TableCell>
|
<TableCell>Usage</TableCell>
|
||||||
|
<TableCell>Last Used</TableCell>
|
||||||
|
<TableCell>Consumed At</TableCell>
|
||||||
<TableCell>Used By GUID</TableCell>
|
<TableCell>Used By GUID</TableCell>
|
||||||
<TableCell align="right">Actions</TableCell>
|
<TableCell align="right">Actions</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
@@ -296,13 +317,17 @@ function EnrollmentCodes() {
|
|||||||
) : (
|
) : (
|
||||||
filteredCodes.map((record) => {
|
filteredCodes.map((record) => {
|
||||||
const status = determineStatus(record);
|
const status = determineStatus(record);
|
||||||
const disableDelete = status !== "active";
|
const maxAllowed = Math.max(1, Number.isFinite(record?.max_uses) ? record.max_uses : 1);
|
||||||
|
const usageCount = Math.max(0, Number.isFinite(record?.use_count) ? record.use_count : 0);
|
||||||
|
const disableDelete = usageCount !== 0;
|
||||||
return (
|
return (
|
||||||
<TableRow hover key={record.id}>
|
<TableRow hover key={record.id}>
|
||||||
<TableCell>{renderStatusChip(record)}</TableCell>
|
<TableCell>{renderStatusChip(record)}</TableCell>
|
||||||
<TableCell sx={{ fontFamily: "monospace" }}>{maskCode(record.code)}</TableCell>
|
<TableCell sx={{ fontFamily: "monospace" }}>{maskCode(record.code)}</TableCell>
|
||||||
<TableCell>{formatDateTime(record.expires_at)}</TableCell>
|
<TableCell>{formatDateTime(record.expires_at)}</TableCell>
|
||||||
<TableCell>{record.created_by_user_id || "—"}</TableCell>
|
<TableCell>{record.created_by_user_id || "—"}</TableCell>
|
||||||
|
<TableCell sx={{ fontFamily: "monospace" }}>{`${usageCount} / ${maxAllowed}`}</TableCell>
|
||||||
|
<TableCell>{formatDateTime(record.last_used_at)}</TableCell>
|
||||||
<TableCell>{formatDateTime(record.used_at)}</TableCell>
|
<TableCell>{formatDateTime(record.used_at)}</TableCell>
|
||||||
<TableCell sx={{ fontFamily: "monospace" }}>
|
<TableCell sx={{ fontFamily: "monospace" }}>
|
||||||
{record.used_by_guid || "—"}
|
{record.used_by_guid || "—"}
|
||||||
|
|||||||
@@ -152,19 +152,98 @@ def _rotate_daily(path: str):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _write_service_log(service: str, msg: str):
|
_SERVER_SCOPE_PATTERN = re.compile(r"\\b(?:scope|context|agent_context)=([A-Za-z0-9_-]+)", re.IGNORECASE)
|
||||||
|
_SERVER_AGENT_ID_PATTERN = re.compile(r"\\bagent_id=([^\s,]+)", re.IGNORECASE)
|
||||||
|
_AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||||
|
|
||||||
|
|
||||||
|
def _canonical_server_scope(raw: Optional[str]) -> Optional[str]:
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
value = "".join(ch for ch in str(raw) if ch.isalnum() or ch in ("_", "-"))
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
return value.upper()
|
||||||
|
|
||||||
|
|
||||||
|
def _scope_from_agent_id(agent_id: Optional[str]) -> Optional[str]:
|
||||||
|
candidate = _canonical_server_scope(agent_id)
|
||||||
|
if not candidate:
|
||||||
|
return None
|
||||||
|
if candidate.endswith("_SYSTEM"):
|
||||||
|
return "SYSTEM"
|
||||||
|
if candidate.endswith("_CURRENTUSER"):
|
||||||
|
return "CURRENTUSER"
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_server_scope(message: str, explicit: Optional[str]) -> Optional[str]:
|
||||||
|
scope = _canonical_server_scope(explicit)
|
||||||
|
if scope:
|
||||||
|
return scope
|
||||||
|
match = _SERVER_SCOPE_PATTERN.search(message or "")
|
||||||
|
if match:
|
||||||
|
scope = _canonical_server_scope(match.group(1))
|
||||||
|
if scope:
|
||||||
|
return scope
|
||||||
|
agent_match = _SERVER_AGENT_ID_PATTERN.search(message or "")
|
||||||
|
if agent_match:
|
||||||
|
scope = _scope_from_agent_id(agent_match.group(1))
|
||||||
|
if scope:
|
||||||
|
return scope
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _write_service_log(service: str, msg: str, scope: Optional[str] = None, *, level: str = "INFO"):
|
||||||
try:
|
try:
|
||||||
base = _server_logs_root()
|
base = _server_logs_root()
|
||||||
os.makedirs(base, exist_ok=True)
|
os.makedirs(base, exist_ok=True)
|
||||||
path = os.path.join(base, f"{service}.log")
|
path = os.path.join(base, f"{service}.log")
|
||||||
_rotate_daily(path)
|
_rotate_daily(path)
|
||||||
ts = time.strftime('%Y-%m-%d %H:%M:%S')
|
ts = time.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
resolved_scope = _infer_server_scope(msg, scope)
|
||||||
|
prefix_parts = [f"[{level.upper()}]"]
|
||||||
|
if resolved_scope:
|
||||||
|
prefix_parts.append(f"[CONTEXT-{resolved_scope}]")
|
||||||
|
prefix = "".join(prefix_parts)
|
||||||
with open(path, 'a', encoding='utf-8') as fh:
|
with open(path, 'a', encoding='utf-8') as fh:
|
||||||
fh.write(f'[{ts}] {msg}\n')
|
fh.write(f'[{ts}] {prefix} {msg}\n')
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_server_value(value: str, *, prefix: int = 4, suffix: int = 4) -> str:
|
||||||
|
try:
|
||||||
|
if not value:
|
||||||
|
return ''
|
||||||
|
stripped = value.strip()
|
||||||
|
if len(stripped) <= prefix + suffix:
|
||||||
|
return '*' * len(stripped)
|
||||||
|
return f"{stripped[:prefix]}***{stripped[-suffix:]}"
|
||||||
|
except Exception:
|
||||||
|
return '***'
|
||||||
|
|
||||||
|
|
||||||
|
def _summarize_socket_headers(headers) -> str:
|
||||||
|
try:
|
||||||
|
rendered = []
|
||||||
|
for key, value in headers.items():
|
||||||
|
lowered = key.lower()
|
||||||
|
display = value
|
||||||
|
if lowered == 'authorization':
|
||||||
|
if isinstance(value, str) and value.lower().startswith('bearer '):
|
||||||
|
token = value.split(' ', 1)[1]
|
||||||
|
display = f"Bearer {_mask_server_value(token)}"
|
||||||
|
else:
|
||||||
|
display = _mask_server_value(str(value))
|
||||||
|
elif lowered == 'cookie':
|
||||||
|
display = '<redacted>'
|
||||||
|
rendered.append(f"{key}={display}")
|
||||||
|
return ", ".join(rendered)
|
||||||
|
except Exception:
|
||||||
|
return '<header-summary-unavailable>'
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Section: Repository Hash Tracking
|
# Section: Repository Hash Tracking
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -7968,6 +8047,55 @@ def screenshot_node_viewer(agent_id, node_id):
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Realtime channels for screenshots, macros, windows, and Ansible control.
|
# Realtime channels for screenshots, macros, windows, and Ansible control.
|
||||||
|
|
||||||
|
@socketio.on('connect')
|
||||||
|
def socket_connect():
|
||||||
|
try:
|
||||||
|
sid = getattr(request, 'sid', '<unknown>')
|
||||||
|
except Exception:
|
||||||
|
sid = '<unknown>'
|
||||||
|
try:
|
||||||
|
remote_addr = request.remote_addr
|
||||||
|
except Exception:
|
||||||
|
remote_addr = None
|
||||||
|
try:
|
||||||
|
scope = _canonical_server_scope(request.headers.get(_AGENT_CONTEXT_HEADER))
|
||||||
|
except Exception:
|
||||||
|
scope = None
|
||||||
|
try:
|
||||||
|
query_pairs = [f"{k}={v}" for k, v in request.args.items()] # type: ignore[attr-defined]
|
||||||
|
query_summary = "&".join(query_pairs) if query_pairs else "<none>"
|
||||||
|
except Exception:
|
||||||
|
query_summary = "<unavailable>"
|
||||||
|
header_summary = _summarize_socket_headers(getattr(request, 'headers', {}))
|
||||||
|
transport = request.args.get('transport') if hasattr(request, 'args') else None # type: ignore[attr-defined]
|
||||||
|
_write_service_log(
|
||||||
|
'server',
|
||||||
|
f"socket.io connect sid={sid} ip={remote_addr} transport={transport!r} query={query_summary} headers={header_summary}",
|
||||||
|
scope=scope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@socketio.on('disconnect')
|
||||||
|
def socket_disconnect():
|
||||||
|
try:
|
||||||
|
sid = getattr(request, 'sid', '<unknown>')
|
||||||
|
except Exception:
|
||||||
|
sid = '<unknown>'
|
||||||
|
try:
|
||||||
|
remote_addr = request.remote_addr
|
||||||
|
except Exception:
|
||||||
|
remote_addr = None
|
||||||
|
try:
|
||||||
|
scope = _canonical_server_scope(request.headers.get(_AGENT_CONTEXT_HEADER))
|
||||||
|
except Exception:
|
||||||
|
scope = None
|
||||||
|
_write_service_log(
|
||||||
|
'server',
|
||||||
|
f"socket.io disconnect sid={sid} ip={remote_addr}",
|
||||||
|
scope=scope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@socketio.on("agent_screenshot_task")
|
@socketio.on("agent_screenshot_task")
|
||||||
def receive_screenshot_task(data):
|
def receive_screenshot_task(data):
|
||||||
agent_id = data.get("agent_id")
|
agent_id = data.get("agent_id")
|
||||||
@@ -7997,6 +8125,19 @@ def connect_agent(data):
|
|||||||
if not agent_id:
|
if not agent_id:
|
||||||
return
|
return
|
||||||
print(f"Agent connected: {agent_id}")
|
print(f"Agent connected: {agent_id}")
|
||||||
|
try:
|
||||||
|
scope = _normalize_service_mode((data or {}).get("service_mode"), agent_id)
|
||||||
|
except Exception:
|
||||||
|
scope = None
|
||||||
|
try:
|
||||||
|
sid = getattr(request, 'sid', '<unknown>')
|
||||||
|
except Exception:
|
||||||
|
sid = '<unknown>'
|
||||||
|
_write_service_log(
|
||||||
|
'server',
|
||||||
|
f"socket.io connect_agent agent_id={agent_id} sid={sid} service_mode={scope}",
|
||||||
|
scope=scope,
|
||||||
|
)
|
||||||
|
|
||||||
# Join per-agent room so we can address this connection specifically
|
# Join per-agent room so we can address this connection specifically
|
||||||
try:
|
try:
|
||||||
@@ -8004,7 +8145,7 @@ def connect_agent(data):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
service_mode = _normalize_service_mode((data or {}).get("service_mode"), agent_id)
|
service_mode = scope if scope else _normalize_service_mode((data or {}).get("service_mode"), agent_id)
|
||||||
rec = registered_agents.setdefault(agent_id, {})
|
rec = registered_agents.setdefault(agent_id, {})
|
||||||
rec["agent_id"] = agent_id
|
rec["agent_id"] = agent_id
|
||||||
rec["hostname"] = rec.get("hostname", "unknown")
|
rec["hostname"] = rec.get("hostname", "unknown")
|
||||||
|
|||||||
58
tests/test_agent_installer_code.py
Normal file
58
tests/test_agent_installer_code.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent_module(tmp_path, monkeypatch):
|
||||||
|
settings_dir = tmp_path / "Agent" / "Borealis" / "Settings"
|
||||||
|
settings_dir.mkdir(parents=True)
|
||||||
|
system_config = settings_dir / "agent_settings_SYSTEM.json"
|
||||||
|
system_config.write_text(json.dumps({
|
||||||
|
"config_file_watcher_interval": 2,
|
||||||
|
"agent_id": "",
|
||||||
|
"regions": {},
|
||||||
|
"installer_code": "",
|
||||||
|
}, indent=2))
|
||||||
|
current_config = settings_dir / "agent_settings_CURRENTUSER.json"
|
||||||
|
current_config.write_text(json.dumps({
|
||||||
|
"config_file_watcher_interval": 2,
|
||||||
|
"agent_id": "",
|
||||||
|
"regions": {},
|
||||||
|
"installer_code": "",
|
||||||
|
}, indent=2))
|
||||||
|
|
||||||
|
monkeypatch.setenv("BOREALIS_ROOT", str(tmp_path))
|
||||||
|
monkeypatch.setenv("BOREALIS_AGENT_MODE", "system")
|
||||||
|
monkeypatch.setenv("BOREALIS_AGENT_CONFIG", "")
|
||||||
|
monkeypatch.setitem(sys.modules, "PyQt5", None)
|
||||||
|
monkeypatch.setitem(sys.modules, "qasync", None)
|
||||||
|
monkeypatch.setattr(sys, "argv", ["agent.py", "--system-service", "--config", "SYSTEM"], raising=False)
|
||||||
|
|
||||||
|
agent = pytest.importorskip(
|
||||||
|
"Data.Agent.agent", reason="agent module requires optional dependencies"
|
||||||
|
)
|
||||||
|
return agent, system_config
|
||||||
|
|
||||||
|
|
||||||
|
def test_shared_installer_code_cache_allows_system_reuse(agent_module, tmp_path):
|
||||||
|
agent, system_config = agent_module
|
||||||
|
|
||||||
|
client = agent.AgentHttpClient()
|
||||||
|
shared_code = "SHARED-CODE-1234"
|
||||||
|
client.key_store.cache_installer_code(shared_code, consumer="CURRENTUSER")
|
||||||
|
|
||||||
|
# System agent should discover the cached code even though its config is empty.
|
||||||
|
resolved = client._resolve_installer_code()
|
||||||
|
assert resolved == shared_code
|
||||||
|
|
||||||
|
# Config should now persist the adopted code to avoid repeated lookups.
|
||||||
|
data = json.loads(system_config.read_text())
|
||||||
|
assert data.get("installer_code") == shared_code
|
||||||
|
|
||||||
|
# After enrollment completes, the cache should be cleared for future runs.
|
||||||
|
client._consume_installer_code()
|
||||||
|
assert client.key_store.load_cached_installer_code() is None
|
||||||
|
data = json.loads(system_config.read_text())
|
||||||
|
assert data.get("installer_code") == ""
|
||||||
213
tests/test_enrollment_install_codes.py
Normal file
213
tests/test_enrollment_install_codes.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import sqlite3
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
try: # pragma: no cover - optional dependency
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||||
|
_CRYPTO_IMPORT_ERROR: Exception | None = None
|
||||||
|
except Exception as exc: # pragma: no cover - dependency unavailable
|
||||||
|
serialization = None # type: ignore
|
||||||
|
ed25519 = None # type: ignore
|
||||||
|
_CRYPTO_IMPORT_ERROR = exc
|
||||||
|
try: # pragma: no cover - optional dependency
|
||||||
|
from flask import Flask
|
||||||
|
_FLASK_IMPORT_ERROR: Exception | None = None
|
||||||
|
except Exception as exc: # pragma: no cover - dependency unavailable
|
||||||
|
Flask = None # type: ignore
|
||||||
|
_FLASK_IMPORT_ERROR = exc
|
||||||
|
|
||||||
|
ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
|
from Data.Server.Modules import db_migrations
|
||||||
|
from Data.Server.Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||||
|
from Data.Server.Modules.enrollment.nonce_store import NonceCache
|
||||||
|
|
||||||
|
if Flask is not None: # pragma: no cover - dependency unavailable
|
||||||
|
from Data.Server.Modules.enrollment import routes as enrollment_routes
|
||||||
|
else: # pragma: no cover - dependency unavailable
|
||||||
|
enrollment_routes = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyJWTService:
|
||||||
|
def issue_access_token(self, guid: str, fingerprint: str, token_version: int, expires_in: int = 900, extra_claims=None):
|
||||||
|
return f"token-{guid}"
|
||||||
|
|
||||||
|
|
||||||
|
class _DummySigner:
|
||||||
|
def public_base64_spki(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app(db_path: str, tls_path: str):
|
||||||
|
if Flask is None or enrollment_routes is None: # pragma: no cover - dependency unavailable
|
||||||
|
pytest.skip(f"flask unavailable: {_FLASK_IMPORT_ERROR}")
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
def _factory():
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
enrollment_routes.register(
|
||||||
|
app,
|
||||||
|
db_conn_factory=_factory,
|
||||||
|
log=lambda channel, message: None,
|
||||||
|
jwt_service=_DummyJWTService(),
|
||||||
|
tls_bundle_path=tls_path,
|
||||||
|
ip_rate_limiter=SlidingWindowRateLimiter(),
|
||||||
|
fp_rate_limiter=SlidingWindowRateLimiter(),
|
||||||
|
nonce_cache=NonceCache(ttl_seconds=30.0),
|
||||||
|
script_signer=_DummySigner(),
|
||||||
|
)
|
||||||
|
return app, _factory
|
||||||
|
|
||||||
|
|
||||||
|
def _create_install_code(conn: sqlite3.Connection, code: str, *, max_uses: int = 2):
|
||||||
|
cur = conn.cursor()
|
||||||
|
record_id = str(uuid.uuid4())
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO enrollment_install_codes (
|
||||||
|
id, code, expires_at, created_by_user_id, max_uses, use_count
|
||||||
|
)
|
||||||
|
VALUES (?, ?, datetime('now', '+6 hours'), 'test-user', ?, 0)
|
||||||
|
""",
|
||||||
|
(record_id, code, max_uses),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return record_id
|
||||||
|
|
||||||
|
|
||||||
|
def _perform_enrollment_cycle(app, factory, code: str, private_key):
|
||||||
|
client = app.test_client()
|
||||||
|
public_der = private_key.public_key().public_bytes(
|
||||||
|
serialization.Encoding.DER,
|
||||||
|
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
public_b64 = base64.b64encode(public_der).decode("ascii")
|
||||||
|
client_nonce = os.urandom(32)
|
||||||
|
payload = {
|
||||||
|
"hostname": "unit-test-host",
|
||||||
|
"enrollment_code": code,
|
||||||
|
"agent_pubkey": public_b64,
|
||||||
|
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
|
||||||
|
}
|
||||||
|
request_resp = client.post("/api/agent/enroll/request", json=payload)
|
||||||
|
assert request_resp.status_code == 200
|
||||||
|
request_data = request_resp.get_json()
|
||||||
|
approval_reference = request_data["approval_reference"]
|
||||||
|
|
||||||
|
with factory() as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
UPDATE device_approvals
|
||||||
|
SET status = 'approved',
|
||||||
|
approved_by_user_id = 'tester'
|
||||||
|
WHERE approval_reference = ?
|
||||||
|
""",
|
||||||
|
(approval_reference,),
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT server_nonce, client_nonce
|
||||||
|
FROM device_approvals
|
||||||
|
WHERE approval_reference = ?
|
||||||
|
""",
|
||||||
|
(approval_reference,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
server_nonce_b64 = row["server_nonce"]
|
||||||
|
|
||||||
|
server_nonce = base64.b64decode(server_nonce_b64)
|
||||||
|
proof_message = server_nonce + approval_reference.encode("utf-8") + client_nonce
|
||||||
|
proof_sig = private_key.sign(proof_message)
|
||||||
|
poll_payload = {
|
||||||
|
"approval_reference": approval_reference,
|
||||||
|
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
|
||||||
|
"proof_sig": base64.b64encode(proof_sig).decode("ascii"),
|
||||||
|
}
|
||||||
|
poll_resp = client.post("/api/agent/enroll/poll", json=poll_payload)
|
||||||
|
assert poll_resp.status_code == 200
|
||||||
|
return poll_resp.get_json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("max_uses", [2])
|
||||||
|
@pytest.mark.skipif(ed25519 is None, reason=f"cryptography unavailable: {_CRYPTO_IMPORT_ERROR}")
|
||||||
|
@pytest.mark.skipif(Flask is None, reason=f"flask unavailable: {_FLASK_IMPORT_ERROR}")
|
||||||
|
def test_install_code_allows_multiple_and_reuse(tmp_path, max_uses):
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
db_migrations.apply_all(conn)
|
||||||
|
_create_install_code(conn, "TEST-CODE-1234", max_uses=max_uses)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
tls_path = tmp_path / "tls.pem"
|
||||||
|
tls_path.write_text("TEST CERT")
|
||||||
|
|
||||||
|
app, factory = _make_app(str(db_path), str(tls_path))
|
||||||
|
|
||||||
|
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||||
|
|
||||||
|
# First enrollment consumes one use but keeps the code active.
|
||||||
|
first = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
|
||||||
|
assert first["status"] == "approved"
|
||||||
|
|
||||||
|
with factory() as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"SELECT use_count, max_uses, used_at, last_used_at FROM enrollment_install_codes WHERE code = ?",
|
||||||
|
("TEST-CODE-1234",),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["use_count"] == 1
|
||||||
|
assert row["max_uses"] == max_uses
|
||||||
|
assert row["used_at"] is None
|
||||||
|
assert row["last_used_at"] is not None
|
||||||
|
|
||||||
|
# Second enrollment hits the configured max uses.
|
||||||
|
second = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
|
||||||
|
assert second["status"] == "approved"
|
||||||
|
|
||||||
|
with factory() as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"SELECT use_count, used_at, last_used_at, used_by_guid FROM enrollment_install_codes WHERE code = ?",
|
||||||
|
("TEST-CODE-1234",),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["use_count"] == max_uses
|
||||||
|
assert row["used_at"] is not None
|
||||||
|
assert row["last_used_at"] is not None
|
||||||
|
consumed_guid = row["used_by_guid"]
|
||||||
|
assert consumed_guid
|
||||||
|
|
||||||
|
# Additional enrollments from the same identity reuse the stored GUID even after consumption.
|
||||||
|
third = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
|
||||||
|
assert third["status"] == "approved"
|
||||||
|
|
||||||
|
with factory() as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"SELECT use_count, used_at, last_used_at, used_by_guid FROM enrollment_install_codes WHERE code = ?",
|
||||||
|
("TEST-CODE-1234",),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["use_count"] == max_uses + 1
|
||||||
|
assert row["used_by_guid"] == consumed_guid
|
||||||
|
assert row["used_at"] is not None
|
||||||
|
assert row["last_used_at"] is not None
|
||||||
|
|
||||||
|
cur.execute("SELECT COUNT(*) FROM devices WHERE guid = ?", (consumed_guid,))
|
||||||
|
assert cur.fetchone()[0] == 1
|
||||||
Reference in New Issue
Block a user