Merge pull request #128 from bunny-lab-io:codex/implement-secure-agent-enrollment-features

Handle missing devices without re-enrollment loops
This commit is contained in:
2025-10-18 05:46:33 -06:00
committed by GitHub
13 changed files with 1401 additions and 127 deletions

View File

@@ -23,7 +23,8 @@ import ssl
import threading
import contextlib
import errno
from typing import Any, Dict, Optional, List, Callable
import re
from typing import Any, Dict, Optional, List, Callable, Tuple
import requests
try:
@@ -66,21 +67,120 @@ def _rotate_daily(path: str):
# Early bootstrap logging (goes to agent.log)
def _bootstrap_log(msg: str):
_AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
_AGENT_SCOPE_PATTERN = re.compile(r"\\bscope=([A-Za-z0-9_-]+)", re.IGNORECASE)
def _canonical_scope_value(raw: Optional[str]) -> Optional[str]:
if not raw:
return None
value = "".join(ch for ch in str(raw) if ch.isalnum() or ch in ("_", "-"))
if not value:
return None
return value.upper()
def _agent_context_default() -> Optional[str]:
suffix = globals().get("CONFIG_SUFFIX_CANONICAL")
context = _canonical_scope_value(suffix)
if context:
return context
service = globals().get("SERVICE_MODE_CANONICAL")
context = _canonical_scope_value(service)
if context:
return context
return None
def _infer_agent_scope(message: str, provided_scope: Optional[str] = None) -> Optional[str]:
scope = _canonical_scope_value(provided_scope)
if scope:
return scope
match = _AGENT_SCOPE_PATTERN.search(message or "")
if match:
scope = _canonical_scope_value(match.group(1))
if scope:
return scope
return _agent_context_default()
def _format_agent_log_message(message: str, fname: str, scope: Optional[str] = None) -> str:
context = _infer_agent_scope(message, scope)
if fname == "agent.error.log":
prefix = "[ERROR]"
if context:
prefix = f"{prefix}[CONTEXT-{context}]"
return f"{prefix} {message}"
if context:
return f"[CONTEXT-{context}] {message}"
return f"[INFO] {message}"
def _bootstrap_log(msg: str, *, scope: Optional[str] = None):
try:
base = _agent_logs_root()
os.makedirs(base, exist_ok=True)
path = os.path.join(base, 'agent.log')
_rotate_daily(path)
ts = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
line = _format_agent_log_message(msg, 'agent.log', scope)
with open(path, 'a', encoding='utf-8') as fh:
fh.write(f'[{ts}] {msg}\n')
fh.write(f'[{ts}] {line}\n')
except Exception:
pass
def _describe_exception(exc: BaseException) -> str:
try:
primary = f"{exc.__class__.__name__}: {exc}"
except Exception:
primary = repr(exc)
parts = [primary]
try:
cause = getattr(exc, "__cause__", None)
if cause and cause is not exc:
parts.append(f"cause={cause.__class__.__name__}: {cause}")
except Exception:
pass
try:
context = getattr(exc, "__context__", None)
if context and context is not exc and context is not getattr(exc, "__cause__", None):
parts.append(f"context={context.__class__.__name__}: {context}")
except Exception:
pass
try:
args = getattr(exc, "args", None)
if isinstance(args, tuple) and len(args) > 1:
parts.append(f"args={args!r}")
except Exception:
pass
try:
details = getattr(exc, "__dict__", None)
if isinstance(details, dict):
# Capture noteworthy nested attributes such as os_error/errno to help diagnose
# connection failures that collapse into generic ConnectionError wrappers.
for key in ("os_error", "errno", "code", "status"):
if key in details and details[key]:
parts.append(f"{key}={details[key]!r}")
except Exception:
pass
return "; ".join(part for part in parts if part)
def _log_exception_trace(prefix: str) -> None:
try:
tb = traceback.format_exc()
if not tb:
return
for line in tb.rstrip().splitlines():
_log_agent(f"{prefix} trace: {line}", fname="agent.error.log")
except Exception:
pass
# Headless/service mode flag (skip Qt and interactive UI)
SYSTEM_SERVICE_MODE = ('--system-service' in sys.argv) or (os.environ.get('BOREALIS_AGENT_MODE') == 'system')
SERVICE_MODE = 'system' if SYSTEM_SERVICE_MODE else 'currentuser'
SERVICE_MODE_CANONICAL = SERVICE_MODE.upper()
_bootstrap_log(f'agent.py loaded; SYSTEM_SERVICE_MODE={SYSTEM_SERVICE_MODE}; argv={sys.argv!r}')
def _argv_get(flag: str, default: str = None):
try:
@@ -359,15 +459,16 @@ def _find_project_root():
return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
# Simple file logger under Logs/Agent
def _log_agent(message: str, fname: str = 'agent.log'):
def _log_agent(message: str, fname: str = 'agent.log', *, scope: Optional[str] = None):
try:
log_dir = _agent_logs_root()
os.makedirs(log_dir, exist_ok=True)
ts = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
path = os.path.join(log_dir, fname)
_rotate_daily(path)
line = _format_agent_log_message(message, fname, scope)
with open(path, 'a', encoding='utf-8') as fh:
fh.write(f'[{ts}] {message}\n')
fh.write(f'[{ts}] {line}\n')
except Exception:
pass
@@ -384,6 +485,31 @@ def _mask_sensitive(value: str, *, prefix: int = 4, suffix: int = 4) -> str:
return '***'
def _format_debug_pairs(pairs: Dict[str, Any]) -> str:
try:
parts = []
for key, value in pairs.items():
parts.append(f"{key}={value!r}")
return ", ".join(parts)
except Exception:
return repr(pairs)
def _summarize_headers(headers: Dict[str, str]) -> str:
try:
rendered: List[str] = []
for key, value in headers.items():
lowered = key.lower()
display = value
if lowered == 'authorization':
token = value.split()[-1] if value and ' ' in value else value
display = f"Bearer {_mask_sensitive(token)}"
rendered.append(f"{key}={display}")
return ", ".join(rendered)
except Exception:
return '<unavailable>'
def _decode_base64_text(value):
if not isinstance(value, str):
return None
@@ -571,6 +697,57 @@ DEFAULT_CONFIG = {
"installer_code": ""
}
def _load_installer_code_from_file(path: str) -> str:
try:
with open(path, "r", encoding="utf-8") as fh:
data = json.load(fh)
except Exception:
return ""
value = data.get("installer_code") if isinstance(data, dict) else ""
if isinstance(value, str):
return value.strip()
return ""
def _fallback_installer_code(current_path: str) -> str:
settings_dir = os.path.dirname(current_path)
candidates: List[str] = []
suffix = CONFIG_SUFFIX_CANONICAL
sibling_map = {
"SYSTEM": "agent_settings_CURRENTUSER.json",
"CURRENTUSER": "agent_settings_SYSTEM.json",
}
sibling_name = sibling_map.get(suffix or "")
if sibling_name:
candidates.append(os.path.join(settings_dir, sibling_name))
# Prefer the shared/base config next
candidates.append(os.path.join(settings_dir, "agent_settings.json"))
# Legacy location fallback
try:
project_root = _find_project_root()
legacy_dir = os.path.join(project_root, "Agent", "Settings")
if sibling_name:
candidates.append(os.path.join(legacy_dir, sibling_name))
candidates.append(os.path.join(legacy_dir, "agent_settings.json"))
except Exception:
pass
current_abspath = os.path.abspath(current_path)
for candidate in candidates:
if not candidate:
continue
try:
candidate_path = os.path.abspath(candidate)
except Exception:
continue
if candidate_path == current_abspath or not os.path.isfile(candidate_path):
continue
code = _load_installer_code_from_file(candidate_path)
if code:
return code
return ""
class ConfigManager:
def __init__(self, path):
self.path = path
@@ -631,6 +808,9 @@ class AgentHttpClient:
self.key_store = _key_store()
self.identity = IDENTITY
self.session = requests.Session()
context_label = _agent_context_default()
if context_label:
self.session.headers.setdefault(_AGENT_CONTEXT_HEADER, context_label)
self.base_url: Optional[str] = None
self.guid: Optional[str] = None
self.access_token: Optional[str] = None
@@ -697,42 +877,115 @@ class AgentHttpClient:
pass
def auth_headers(self) -> Dict[str, str]:
headers: Dict[str, str] = {}
if self.access_token:
return {"Authorization": f"Bearer {self.access_token}"}
return {}
headers["Authorization"] = f"Bearer {self.access_token}"
context_label = _agent_context_default()
if context_label:
headers[_AGENT_CONTEXT_HEADER] = context_label
return headers
def configure_socketio(self, client: "socketio.AsyncClient") -> None:
"""Align the Socket.IO engine's TLS verification with the REST client."""
try:
verify = getattr(self.session, "verify", True)
engine = getattr(client, "eio", None)
if engine is None:
_log_agent(
"SocketIO TLS alignment skipped; AsyncClient.eio missing",
fname="agent.error.log",
)
return
# python-engineio accepts either a boolean or an ``ssl.SSLContext``
# for TLS verification. When we have a pinned certificate bundle
# on disk, prefer constructing a dedicated context that trusts that
# bundle so WebSocket connections succeed even with private CAs.
http_iface = getattr(engine, "http", None)
debug_info = {
"verify_type": type(verify).__name__,
"verify_value": verify,
"engine_type": type(engine).__name__,
"http_iface_present": http_iface is not None,
}
_log_agent(
f"SocketIO TLS alignment start: {_format_debug_pairs(debug_info)}",
fname="agent.log",
)
def _set_attr(target: Any, name: str, value: Any) -> None:
if target is None:
return
try:
setattr(target, name, value)
except Exception:
pass
def _reset_cached_session() -> None:
if http_iface is None:
return
try:
if hasattr(http_iface, "session"):
setattr(http_iface, "session", None)
except Exception:
pass
context = None
if isinstance(verify, str) and os.path.isfile(verify):
try:
context = ssl.create_default_context(cafile=verify)
# Mirror Requests' certificate handling by starting from a
# default client context (which pre-loads the system
# certificate stores) and then layering the pinned
# certificate bundle on top. This matches the REST client
# behaviour and ensures self-signed leaf certificates work
# the same way for Socket.IO handshakes.
context = ssl.create_default_context()
context.check_hostname = False
context.load_verify_locations(cafile=verify)
_log_agent(
f"SocketIO TLS alignment created SSLContext from cafile={verify}",
fname="agent.log",
)
except Exception:
context = None
if context is not None:
engine.ssl_context = context
engine.ssl_verify = True
else:
engine.ssl_context = None
engine.ssl_verify = verify
elif verify is False:
engine.ssl_context = None
engine.ssl_verify = False
else:
engine.ssl_context = None
engine.ssl_verify = True
_log_agent(
f"SocketIO TLS alignment failed to build context from cafile={verify}",
fname="agent.error.log",
)
if context is not None:
_set_attr(engine, "ssl_context", context)
_set_attr(engine, "ssl_verify", True)
_set_attr(engine, "verify_ssl", True)
_set_attr(http_iface, "ssl_context", context)
_set_attr(http_iface, "ssl_verify", True)
_set_attr(http_iface, "verify_ssl", True)
_reset_cached_session()
_log_agent(
"SocketIO TLS alignment applied dedicated SSLContext to engine/http",
fname="agent.log",
)
return
# Fall back to boolean verification flags when we either do not
# have a pinned certificate bundle or failed to build a dedicated
# context for it.
verify_flag = False if verify is False else True
_set_attr(engine, "ssl_context", None)
_set_attr(engine, "ssl_verify", verify_flag)
_set_attr(engine, "verify_ssl", verify_flag)
_set_attr(http_iface, "ssl_context", None)
_set_attr(http_iface, "ssl_verify", verify_flag)
_set_attr(http_iface, "verify_ssl", verify_flag)
_reset_cached_session()
_log_agent(
f"SocketIO TLS alignment fallback verify_flag={verify_flag}",
fname="agent.log",
)
except Exception:
pass
_log_agent(
"SocketIO TLS alignment encountered unexpected error",
fname="agent.error.log",
)
_log_exception_trace("configure_socketio")
# ------------------------------------------------------------------
# Enrollment & token management
@@ -1007,10 +1260,22 @@ class AgentHttpClient:
timeout=20,
)
if resp.status_code in (401, 403):
_log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log")
self._clear_tokens_locked()
self._perform_enrollment_locked()
return
error_code, snippet = self._error_details(resp)
if resp.status_code == 401 and self._should_retry_auth(resp.status_code, error_code):
_log_agent(
"Refresh token rejected; attempting re-enrollment"
f" error={error_code or '<unknown>'}",
fname="agent.error.log",
)
self._clear_tokens_locked()
self._perform_enrollment_locked()
return
_log_agent(
"Refresh token request forbidden "
f"status={resp.status_code} error={error_code or '<unknown>'}"
f" body_snippet={snippet}",
fname="agent.error.log",
)
resp.raise_for_status()
data = resp.json()
access_token = data.get("access_token")
@@ -1036,14 +1301,79 @@ class AgentHttpClient:
self.guid = self.key_store.load_guid()
self.session.headers.pop("Authorization", None)
def _error_details(self, response: requests.Response) -> Tuple[Optional[str], str]:
error_code: Optional[str] = None
snippet = ""
try:
snippet = response.text[:256]
except Exception:
snippet = "<unavailable>"
try:
data = response.json()
except Exception:
data = None
if isinstance(data, dict):
for key in ("error", "code", "status"):
value = data.get(key)
if isinstance(value, str) and value.strip():
error_code = value.strip()
break
return error_code, snippet
def _should_retry_auth(self, status_code: int, error_code: Optional[str]) -> bool:
if status_code == 401:
return True
retryable_forbidden = {"fingerprint_mismatch"}
if status_code == 403 and error_code in retryable_forbidden:
return True
return False
def _resolve_installer_code(self) -> str:
if INSTALLER_CODE_OVERRIDE:
return INSTALLER_CODE_OVERRIDE
code = INSTALLER_CODE_OVERRIDE.strip()
if code:
try:
self.key_store.cache_installer_code(code, consumer=SERVICE_MODE_CANONICAL)
except Exception:
pass
return code
code = ""
try:
code = (CONFIG.data.get("installer_code") or "").strip()
return code
except Exception:
return ""
code = ""
if code:
try:
self.key_store.cache_installer_code(code, consumer=SERVICE_MODE_CANONICAL)
except Exception:
pass
return code
try:
cached = self.key_store.load_cached_installer_code()
except Exception:
cached = None
if cached:
try:
self.key_store.cache_installer_code(cached, consumer=SERVICE_MODE_CANONICAL)
except Exception:
pass
return cached
fallback = _fallback_installer_code(CONFIG.path)
if fallback:
try:
CONFIG.data["installer_code"] = fallback
CONFIG._write()
_log_agent(
"Adopted installer code from sibling configuration", fname="agent.log"
)
except Exception:
pass
try:
self.key_store.cache_installer_code(fallback, consumer=SERVICE_MODE_CANONICAL)
except Exception:
pass
return fallback
return ""
def _consume_installer_code(self) -> None:
# Avoid clearing explicit CLI/env overrides; only mutate persisted config.
@@ -1057,6 +1387,13 @@ class AgentHttpClient:
_log_agent("Cleared persisted installer code after successful enrollment", fname="agent.log")
except Exception as exc:
_log_agent(f"Failed to clear installer code after enrollment: {exc}", fname="agent.error.log")
try:
self.key_store.mark_installer_code_consumed(SERVICE_MODE_CANONICAL)
except Exception as exc:
_log_agent(
f"Failed to update shared installer code cache: {exc}",
fname="agent.error.log",
)
# ------------------------------------------------------------------
# HTTP helpers
@@ -1068,20 +1405,19 @@ class AgentHttpClient:
headers = self.auth_headers()
response = self.session.post(url, json=payload, headers=headers, timeout=30)
if response.status_code in (401, 403) and require_auth:
snippet = ""
try:
snippet = response.text[:256]
except Exception:
snippet = "<unavailable>"
_log_agent(
"Authenticated request rejected "
f"path={path} status={response.status_code} body_snippet={snippet}",
fname="agent.error.log",
)
self.clear_tokens()
self.ensure_authenticated()
headers = self.auth_headers()
response = self.session.post(url, json=payload, headers=headers, timeout=30)
error_code, snippet = self._error_details(response)
if self._should_retry_auth(response.status_code, error_code):
self.clear_tokens()
self.ensure_authenticated()
headers = self.auth_headers()
response = self.session.post(url, json=payload, headers=headers, timeout=30)
else:
_log_agent(
"Authenticated request rejected "
f"path={path} status={response.status_code} error={error_code or '<unknown>'}"
f" body_snippet={snippet}",
fname="agent.error.log",
)
response.raise_for_status()
if response.headers.get("Content-Type", "").lower().startswith("application/json"):
return response.json()
@@ -2107,6 +2443,15 @@ async def send_agent_details_once():
async def connect():
print(f"[INFO] Successfully Connected to Borealis Server!")
_log_agent('Connected to server.')
try:
sid = getattr(sio, 'sid', None)
transport = getattr(sio, 'transport', None)
_log_agent(
f'WebSocket handshake established sid={sid!r} transport={transport!r}',
fname='agent.log',
)
except Exception:
pass
await sio.emit('connect_agent', {"agent_id": AGENT_ID, "service_mode": SERVICE_MODE})
# Send an immediate heartbeat via authenticated REST call.
@@ -2143,6 +2488,17 @@ async def connect():
except Exception:
pass
@sio.event
async def connect_error(data):
try:
setattr(sio, "connection_error", data)
except Exception:
pass
try:
_log_agent(f'Socket connect_error event: {data!r}', fname='agent.error.log')
except Exception:
pass
@sio.event
async def disconnect():
print("[WebSocket] Disconnected from Borealis server.")
@@ -2390,22 +2746,64 @@ if not SYSTEM_SERVICE_MODE:
async def connect_loop():
retry = 5
client = http_client()
attempt = 0
while True:
attempt += 1
try:
_log_agent(
f'connect_loop attempt={attempt} starting authentication phase',
fname='agent.log',
)
client.ensure_authenticated()
auth_snapshot = {
'guid_present': bool(client.guid),
'access_token': bool(client.access_token),
'refresh_token': bool(client.refresh_token),
'access_expiry': client.access_expires_at,
}
_log_agent(
f"connect_loop attempt={attempt} auth snapshot: {_format_debug_pairs(auth_snapshot)}",
fname='agent.log',
)
client.configure_socketio(sio)
try:
setattr(sio, "connection_error", None)
except Exception:
pass
url = client.websocket_base_url()
headers = client.auth_headers()
header_summary = _summarize_headers(headers)
verify_value = getattr(client.session, 'verify', None)
_log_agent(
f"connect_loop attempt={attempt} dialing websocket url={url} transports=['websocket'] verify={verify_value!r} headers={header_summary}",
fname='agent.log',
)
print(f"[INFO] Connecting Agent to {url}...")
_log_agent(f'Connecting to {url}...')
await sio.connect(
url,
transports=['websocket'],
headers=client.auth_headers(),
headers=headers,
)
_log_agent(
f'connect_loop attempt={attempt} sio.connect completed successfully',
fname='agent.log',
)
break
except Exception as e:
print(f"[WebSocket] Server unavailable: {e}. Retrying in {retry}s...")
_log_agent(f'Server unavailable: {e}', fname='agent.error.log')
detail = _describe_exception(e)
try:
conn_err = getattr(sio, "connection_error", None)
except Exception:
conn_err = None
if conn_err:
detail = f"{detail}; connection_error={conn_err!r}"
message = (
f"connect_loop attempt={attempt} server unavailable: {detail}. "
f"Retrying in {retry}s..."
)
print(f"[WebSocket] {message}")
_log_agent(message, fname='agent.error.log')
_log_exception_trace(f'connect_loop attempt={attempt}')
await asyncio.sleep(retry)
if __name__=='__main__':

View File

@@ -230,6 +230,7 @@ class AgentKeyStore:
self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem")
self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub")
self._identity_lock_path = os.path.join(self.settings_dir, "identity.lock")
self._installer_cache_path = os.path.join(self.settings_dir, "installer_code.shared.json")
# ------------------------------------------------------------------
# Identity management
@@ -455,3 +456,107 @@ class AgentKeyStore:
if isinstance(value, str) and value.strip():
return value.strip()
return None
# ------------------------------------------------------------------
# Installer code sharing helpers
# ------------------------------------------------------------------
def _load_installer_cache(self) -> dict:
if not os.path.isfile(self._installer_cache_path):
return {}
try:
with open(self._installer_cache_path, "r", encoding="utf-8") as fh:
data = json.load(fh)
if isinstance(data, dict):
return data
except Exception:
pass
return {}
def _store_installer_cache(self, payload: dict) -> None:
try:
with open(self._installer_cache_path, "w", encoding="utf-8") as fh:
json.dump(payload, fh, indent=2)
_restrict_permissions(self._installer_cache_path)
except Exception:
pass
def cache_installer_code(self, code: str, consumer: Optional[str] = None) -> None:
normalized = (code or "").strip()
if not normalized:
return
payload = self._load_installer_cache()
payload["code"] = normalized
consumers = set()
existing = payload.get("consumed")
if isinstance(existing, list):
consumers = {str(item).upper() for item in existing if isinstance(item, str)}
if consumer:
consumers.add(str(consumer).upper())
payload["consumed"] = sorted(consumers)
payload["updated_at"] = int(time.time())
self._store_installer_cache(payload)
def load_cached_installer_code(self) -> Optional[str]:
payload = self._load_installer_cache()
code = payload.get("code")
if isinstance(code, str):
stripped = code.strip()
if stripped:
return stripped
return None
def mark_installer_code_consumed(self, consumer: Optional[str] = None) -> None:
payload = self._load_installer_cache()
if not payload:
return
consumers = set()
existing = payload.get("consumed")
if isinstance(existing, list):
consumers = {str(item).upper() for item in existing if isinstance(item, str)}
if consumer:
consumers.add(str(consumer).upper())
payload["consumed"] = sorted(consumers)
payload["updated_at"] = int(time.time())
code_present = isinstance(payload.get("code"), str) and payload["code"].strip()
should_clear = False
if not code_present:
should_clear = True
else:
required_consumers = {"SYSTEM", "CURRENTUSER"}
if required_consumers.issubset(consumers):
should_clear = True
else:
remaining = required_consumers - consumers
if not remaining:
should_clear = True
else:
exists_other = False
for other in remaining:
if other == "SYSTEM":
cfg_name = "agent_settings_SYSTEM.json"
elif other == "CURRENTUSER":
cfg_name = "agent_settings_CURRENTUSER.json"
else:
cfg_name = None
if not cfg_name:
continue
path = os.path.join(self.settings_dir, cfg_name)
if os.path.isfile(path):
exists_other = True
break
if not exists_other:
should_clear = True
if should_clear:
payload.pop("code", None)
payload["consumed"] = []
if payload.get("code") or payload.get("consumed"):
self._store_installer_cache(payload)
else:
try:
if os.path.isfile(self._installer_cache_path):
os.remove(self._installer_cache_path)
except Exception:
pass

View File

@@ -18,7 +18,7 @@ def register(
db_conn_factory: Callable[[], sqlite3.Connection],
require_admin: Callable[[], Optional[Any]],
current_user: Callable[[], Optional[Dict[str, str]]],
log: Callable[[str, str], None],
log: Callable[[str, str, Optional[str]], None],
) -> None:
blueprint = Blueprint("admin", __name__)
@@ -54,18 +54,27 @@ def register(
try:
cur = conn.cursor()
sql = """
SELECT id, code, expires_at, created_by_user_id, used_at, used_by_guid
SELECT id,
code,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
"""
params: List[str] = []
now_iso = _iso(_now())
if status_filter == "active":
sql += " WHERE used_at IS NULL AND expires_at > ?"
params.append(_iso(_now()))
sql += " WHERE use_count < max_uses AND expires_at > ?"
params.append(now_iso)
elif status_filter == "expired":
sql += " WHERE used_at IS NULL AND expires_at <= ?"
params.append(_iso(_now()))
sql += " WHERE use_count < max_uses AND expires_at <= ?"
params.append(now_iso)
elif status_filter == "used":
sql += " WHERE used_at IS NOT NULL"
sql += " WHERE use_count >= max_uses"
sql += " ORDER BY expires_at ASC"
cur.execute(sql, params)
rows = cur.fetchall()
@@ -82,6 +91,9 @@ def register(
"created_by_user_id": row[3],
"used_at": row[4],
"used_by_guid": row[5],
"max_uses": row[6],
"use_count": row[7],
"last_used_at": row[8],
}
)
return jsonify({"codes": records})
@@ -93,6 +105,18 @@ def register(
if ttl_hours not in VALID_TTL_HOURS:
return jsonify({"error": "invalid_ttl"}), 400
max_uses_value = payload.get("max_uses")
if max_uses_value is None:
max_uses_value = payload.get("allowed_uses")
try:
max_uses = int(max_uses_value)
except Exception:
max_uses = 2
if max_uses < 1:
max_uses = 1
if max_uses > 10:
max_uses = 10
user = current_user() or {}
username = user.get("username") or ""
@@ -106,22 +130,28 @@ def register(
cur.execute(
"""
INSERT INTO enrollment_install_codes (
id, code, expires_at, created_by_user_id
id, code, expires_at, created_by_user_id, max_uses, use_count
)
VALUES (?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, 0)
""",
(record_id, code_value, _iso(expires_at), created_by),
(record_id, code_value, _iso(expires_at), created_by, max_uses),
)
conn.commit()
finally:
conn.close()
log("server", f"installer code created id={record_id} by={username} ttl={ttl_hours}h")
log(
"server",
f"installer code created id={record_id} by={username} ttl={ttl_hours}h max_uses={max_uses}",
)
return jsonify(
{
"id": record_id,
"code": code_value,
"expires_at": _iso(expires_at),
"max_uses": max_uses,
"use_count": 0,
"last_used_at": None,
}
)
@@ -131,7 +161,7 @@ def register(
try:
cur = conn.cursor()
cur.execute(
"DELETE FROM enrollment_install_codes WHERE id = ? AND used_at IS NULL",
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
(code_id,),
)
deleted = cur.rowcount

View File

@@ -10,13 +10,24 @@ from flask import Blueprint, jsonify, request, g
from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
from Modules.crypto.signing import ScriptSigner
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
def register(
app,
*,
db_conn_factory: Callable[[], Any],
auth_manager: DeviceAuthManager,
log: Callable[[str, str], None],
log: Callable[[str, str, Optional[str]], None],
script_signer: ScriptSigner,
) -> None:
blueprint = Blueprint("agents", __name__)
@@ -29,10 +40,15 @@ def register(
except Exception:
return None
def _context_hint(ctx=None) -> Optional[str]:
if ctx is not None and getattr(ctx, "service_mode", None):
return _canonical_context(getattr(ctx, "service_mode", None))
return _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
def _auth_context():
ctx = getattr(g, "device_auth", None)
if ctx is None:
log("server", f"device auth context missing for {request.path}")
log("server", f"device auth context missing for {request.path}", _context_hint())
return ctx
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
@@ -42,6 +58,7 @@ def register(
if ctx is None:
return jsonify({"error": "auth_context_missing"}), 500
payload = request.get_json(force=True, silent=True) or {}
context_label = _context_hint(ctx)
now_ts = int(time.time())
updates: Dict[str, Optional[str]] = {"last_seen": now_ts}
@@ -111,12 +128,13 @@ def register(
"server",
"heartbeat hostname collision ignored for guid="
f"{ctx.guid}",
context_label,
)
else:
raise
if rowcount == 0:
log("server", f"heartbeat missing device record guid={ctx.guid}")
log("server", f"heartbeat missing device record guid={ctx.guid}", context_label)
return jsonify({"error": "device_not_registered"}), 404
conn.commit()
finally:

View File

@@ -1,7 +1,11 @@
from __future__ import annotations
import functools
import sqlite3
import time
from contextlib import closing
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Optional
import jwt
@@ -10,6 +14,17 @@ from flask import g, jsonify, request
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
from Modules.auth.rate_limit import SlidingWindowRateLimiter
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
@dataclass
class DeviceAuthContext:
@@ -20,6 +35,7 @@ class DeviceAuthContext:
claims: Dict[str, Any]
dpop_jkt: Optional[str]
status: str
service_mode: Optional[str]
class DeviceAuthError(Exception):
@@ -47,7 +63,7 @@ class DeviceAuthManager:
db_conn_factory: Callable[[], Any],
jwt_service,
dpop_validator: Optional[DPoPValidator],
log: Callable[[str, str], None],
log: Callable[[str, str, Optional[str]], None],
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
) -> None:
self._db_conn_factory = db_conn_factory
@@ -86,8 +102,9 @@ class DeviceAuthManager:
retry_after=decision.retry_after,
)
conn = self._db_conn_factory()
try:
context_label = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
with closing(self._db_conn_factory()) as conn:
cur = conn.cursor()
cur.execute(
"""
@@ -98,8 +115,11 @@ class DeviceAuthManager:
(guid,),
)
row = cur.fetchone()
finally:
conn.close()
if not row:
row = self._recover_device_record(
conn, guid, fingerprint, token_version, context_label
)
if not row:
raise DeviceAuthError("device_not_found", status_code=403)
@@ -121,7 +141,11 @@ class DeviceAuthManager:
if status_normalized not in allowed_statuses:
raise DeviceAuthError("device_revoked", status_code=403)
if status_normalized == "quarantined":
self._log("server", f"device {guid} is quarantined; limited access for {request.path}")
self._log(
"server",
f"device {guid} is quarantined; limited access for {request.path}",
context_label,
)
dpop_jkt: Optional[str] = None
dpop_proof = request.headers.get("DPoP")
@@ -144,9 +168,111 @@ class DeviceAuthManager:
claims=claims,
dpop_jkt=dpop_jkt,
status=status_normalized,
service_mode=context_label,
)
return ctx
def _recover_device_record(
self,
conn: sqlite3.Connection,
guid: str,
fingerprint: str,
token_version: int,
context_label: Optional[str],
) -> Optional[tuple]:
"""Attempt to recreate a missing device row for an authenticated token."""
guid = (guid or "").strip()
fingerprint = (fingerprint or "").strip()
if not guid or not fingerprint:
return None
cur = conn.cursor()
now_ts = int(time.time())
try:
now_iso = datetime.now(tz=timezone.utc).isoformat()
except Exception:
now_iso = datetime.utcnow().isoformat() # pragma: no cover
base_hostname = f"RECOVERED-{guid[:12].upper()}" if guid else "RECOVERED"
for attempt in range(6):
hostname = base_hostname if attempt == 0 else f"{base_hostname}-{attempt}"
try:
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
token_version,
status,
key_added_at
)
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
""",
(
guid,
hostname,
now_ts,
now_ts,
fingerprint,
max(token_version or 1, 1),
now_iso,
),
)
except sqlite3.IntegrityError as exc:
# Hostname collision try again with a suffixed placeholder.
message = str(exc).lower()
if "hostname" in message and "unique" in message:
continue
self._log(
"server",
f"device auth failed to recover guid={guid} due to integrity error: {exc}",
context_label,
)
conn.rollback()
return None
except Exception as exc: # pragma: no cover - defensive logging
self._log(
"server",
f"device auth unexpected error recovering guid={guid}: {exc}",
context_label,
)
conn.rollback()
return None
else:
conn.commit()
break
else:
# Exhausted attempts because of hostname collisions.
self._log(
"server",
f"device auth could not recover guid={guid}; hostname collisions persisted",
context_label,
)
conn.rollback()
return None
cur.execute(
"""
SELECT guid, ssl_key_fingerprint, token_version, status
FROM devices
WHERE guid = ?
""",
(guid,),
)
row = cur.fetchone()
if not row:
self._log(
"server",
f"device auth recovery for guid={guid} committed but row still missing",
context_label,
)
return row
def require_device_auth(manager: DeviceAuthManager):
def decorator(func):

View File

@@ -152,7 +152,10 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
expires_at TEXT NOT NULL,
created_by_user_id TEXT,
used_at TEXT,
used_by_guid TEXT
used_by_guid TEXT,
max_uses INTEGER NOT NULL DEFAULT 1,
use_count INTEGER NOT NULL DEFAULT 0,
last_used_at TEXT
)
"""
)
@@ -163,6 +166,29 @@ def _ensure_install_code_table(conn: sqlite3.Connection) -> None:
"""
)
columns = {row[1] for row in _table_info(cur, "enrollment_install_codes")}
if "max_uses" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 1
"""
)
if "use_count" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0
"""
)
if "last_used_at" not in columns:
cur.execute(
"""
ALTER TABLE enrollment_install_codes
ADD COLUMN last_used_at TEXT
"""
)
def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()

View File

@@ -6,7 +6,18 @@ import sqlite3
import uuid
from datetime import datetime, timezone, timedelta
import time
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Tuple
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_context(value: Optional[str]) -> Optional[str]:
if not value:
return None
cleaned = "".join(ch for ch in str(value) if ch.isalnum() or ch in ("_", "-"))
if not cleaned:
return None
return cleaned.upper()
from flask import Blueprint, jsonify, request
@@ -20,7 +31,7 @@ def register(
app,
*,
db_conn_factory: Callable[[], sqlite3.Connection],
log: Callable[[str, str], None],
log: Callable[[str, str, Optional[str]], None],
jwt_service,
tls_bundle_path: str,
ip_rate_limiter: SlidingWindowRateLimiter,
@@ -51,12 +62,19 @@ def register(
except Exception:
return ""
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
def _rate_limited(
key: str,
limiter: SlidingWindowRateLimiter,
limit: int,
window_s: float,
context_hint: Optional[str],
):
decision = limiter.check(key, limit, window_s)
if not decision.allowed:
log(
"server",
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
context_hint,
)
response = jsonify({"error": "rate_limited", "retry_after": decision.retry_after})
response.status_code = 429
@@ -66,31 +84,79 @@ def register(
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
cur.execute(
"SELECT id, code, expires_at, used_at FROM enrollment_install_codes WHERE code = ?",
"""
SELECT id,
code,
expires_at,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
WHERE code = ?
""",
(code_value,),
)
row = cur.fetchone()
if not row:
return None
keys = ["id", "code", "expires_at", "used_at"]
keys = [
"id",
"code",
"expires_at",
"used_at",
"used_by_guid",
"max_uses",
"use_count",
"last_used_at",
]
record = dict(zip(keys, row))
return record
def _install_code_valid(record: Dict[str, Any]) -> bool:
def _install_code_valid(
record: Dict[str, Any], fingerprint: str, cur: sqlite3.Cursor
) -> Tuple[bool, Optional[str]]:
if not record:
return False
return False, None
expires_at = record.get("expires_at")
if not isinstance(expires_at, str):
return False
return False, None
try:
expiry = datetime.fromisoformat(expires_at)
except Exception:
return False
return False, None
if expiry <= _now():
return False
if record.get("used_at"):
return False
return True
return False, None
try:
max_uses = int(record.get("max_uses") or 1)
except Exception:
max_uses = 1
if max_uses < 1:
max_uses = 1
try:
use_count = int(record.get("use_count") or 0)
except Exception:
use_count = 0
if use_count < max_uses:
return True, None
guid = str(record.get("used_by_guid") or "").strip()
if not guid:
return False, None
cur.execute(
"SELECT ssl_key_fingerprint FROM devices WHERE guid = ?",
(guid,),
)
row = cur.fetchone()
if not row:
return False, None
stored_fp = (row[0] or "").strip().lower()
if not stored_fp:
return False, None
if stored_fp == (fingerprint or "").strip().lower():
return True, guid
return False, None
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
base = (hostname or "").strip() or guid
@@ -247,7 +313,9 @@ def register(
@blueprint.route("/api/agent/enroll/request", methods=["POST"])
def enrollment_request():
remote = _remote_addr()
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0)
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
rate_error = _rate_limited(f"ip:{remote}", ip_rate_limiter, 40, 60.0, context_hint)
if rate_error:
return rate_error
@@ -262,42 +330,43 @@ def register(
"enrollment request received "
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
context_hint,
)
if not hostname:
log("server", f"enrollment rejected missing_hostname ip={remote}")
log("server", f"enrollment rejected missing_hostname ip={remote}", context_hint)
return jsonify({"error": "hostname_required"}), 400
if not enrollment_code:
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}")
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
return jsonify({"error": "enrollment_code_required"}), 400
if not isinstance(agent_pubkey_b64, str):
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}")
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "agent_pubkey_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}")
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400
try:
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
except Exception:
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}")
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400
if len(agent_pubkey_der) < 10:
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}")
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}")
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
if len(client_nonce_bytes) < 16:
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}")
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0)
rate_error = _rate_limited(f"fp:{fingerprint}", fp_rate_limiter, 12, 60.0, context_hint)
if rate_error:
return rate_error
@@ -305,7 +374,14 @@ def register(
try:
cur = conn.cursor()
install_code = _load_install_code(cur, enrollment_code)
if not _install_code_valid(install_code):
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
if not valid_code:
log(
"server",
"enrollment request invalid_code "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
)
return jsonify({"error": "invalid_enrollment_code"}), 400
approval_reference: str
@@ -331,6 +407,7 @@ def register(
"""
UPDATE device_approvals
SET hostname_claimed = ?,
guid = ?,
enrollment_code_id = ?,
client_nonce = ?,
server_nonce = ?,
@@ -340,6 +417,7 @@ def register(
""",
(
hostname,
reuse_guid,
install_code["id"],
client_nonce_b64,
server_nonce_b64,
@@ -359,11 +437,12 @@ def register(
status, client_nonce, server_nonce, agent_pubkey_der,
created_at, updated_at
)
VALUES (?, ?, NULL, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?, ?, ?)
""",
(
record_id,
approval_reference,
reuse_guid,
hostname,
fingerprint,
install_code["id"],
@@ -387,7 +466,11 @@ def register(
"server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(),
}
log("server", f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}")
log(
"server",
f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}",
context_hint,
)
return jsonify(response)
@blueprint.route("/api/agent/enroll/poll", methods=["POST"])
@@ -396,34 +479,36 @@ def register(
approval_reference = payload.get("approval_reference")
client_nonce_b64 = payload.get("client_nonce")
proof_sig_b64 = payload.get("proof_sig")
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
log(
"server",
"enrollment poll received "
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
f" proof_sig_len={len(proof_sig_b64 or '')}",
context_hint,
)
if not isinstance(approval_reference, str) or not approval_reference:
log("server", "enrollment poll rejected missing_reference")
log("server", "enrollment poll rejected missing_reference", context_hint)
return jsonify({"error": "approval_reference_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}")
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400
if not isinstance(proof_sig_b64, str):
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}")
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
return jsonify({"error": "proof_sig_required"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}")
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
try:
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_sig ref={approval_reference}")
log("server", f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof_sig"}), 400
conn = db_conn_factory()
@@ -441,7 +526,7 @@ def register(
)
row = cur.fetchone()
if not row:
log("server", f"enrollment poll unknown_reference ref={approval_reference}")
log("server", f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
return jsonify({"status": "unknown"}), 404
(
@@ -460,13 +545,13 @@ def register(
) = row
if client_nonce_stored != client_nonce_b64:
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}")
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
return jsonify({"error": "nonce_mismatch"}), 400
try:
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}")
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "server_nonce_invalid"}), 400
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
@@ -474,17 +559,17 @@ def register(
try:
public_key = serialization.load_der_public_key(agent_pubkey_der)
except Exception:
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}")
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
public_key = None
if public_key is None:
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}")
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
return jsonify({"error": "agent_pubkey_invalid"}), 400
try:
public_key.verify(proof_sig, message)
except Exception:
log("server", f"enrollment poll invalid_proof ref={approval_reference}")
log("server", f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof"}), 400
if status == "pending":
@@ -492,24 +577,28 @@ def register(
"server",
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
f" fingerprint={fingerprint[:12]}",
context_hint,
)
return jsonify({"status": "pending", "poll_after_ms": 5000})
if status == "denied":
log(
"server",
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "denied", "reason": "operator_denied"})
if status == "expired":
log(
"server",
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "expired"})
if status == "completed":
log(
"server",
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "approved", "detail": "finalized"})
@@ -517,6 +606,7 @@ def register(
log(
"server",
f"enrollment poll unexpected_status={status} ref={approval_reference}",
context_hint,
)
return jsonify({"status": status or "unknown"}), 400
@@ -525,6 +615,7 @@ def register(
log(
"server",
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
context_hint,
)
return jsonify({"error": "proof_replayed"}), 409
@@ -537,14 +628,40 @@ def register(
# Mark install code used
if enrollment_code_id:
cur.execute(
"SELECT use_count, max_uses FROM enrollment_install_codes WHERE id = ?",
(enrollment_code_id,),
)
usage_row = cur.fetchone()
try:
prior_count = int(usage_row[0]) if usage_row else 0
except Exception:
prior_count = 0
try:
allowed_uses = int(usage_row[1]) if usage_row else 1
except Exception:
allowed_uses = 1
if allowed_uses < 1:
allowed_uses = 1
new_count = prior_count + 1
consumed = new_count >= allowed_uses
cur.execute(
"""
UPDATE enrollment_install_codes
SET used_at = ?, used_by_guid = ?
SET use_count = ?,
used_by_guid = ?,
last_used_at = ?,
used_at = CASE WHEN ? THEN ? ELSE used_at END
WHERE id = ?
AND used_at IS NULL
""",
(now_iso, effective_guid, enrollment_code_id),
(
new_count,
effective_guid,
now_iso,
1 if consumed else 0,
now_iso,
enrollment_code_id,
),
)
# Update approval record with final state
@@ -573,6 +690,7 @@ def register(
log(
"server",
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
context_hint,
)
return jsonify(
{

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Callable
from typing import Callable, Optional
import eventlet
from flask_socketio import SocketIO
@@ -11,7 +11,7 @@ def start_prune_job(
socketio: SocketIO,
*,
db_conn_factory: Callable[[], any],
log: Callable[[str, str], None],
log: Callable[[str, str, Optional[str]], None],
) -> None:
def _job_loop():
while True:
@@ -24,7 +24,7 @@ def start_prune_job(
socketio.start_background_task(_job_loop)
def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None]) -> None:
def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str, Optional[str]], None]) -> None:
now = datetime.now(tz=timezone.utc)
now_iso = now.isoformat()
stale_before = (now - timedelta(hours=24)).isoformat()
@@ -34,7 +34,7 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
cur.execute(
"""
DELETE FROM enrollment_install_codes
WHERE used_at IS NULL
WHERE use_count = 0
AND expires_at < ?
""",
(now_iso,),
@@ -52,7 +52,10 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
SELECT 1
FROM enrollment_install_codes c
WHERE c.id = device_approvals.enrollment_code_id
AND c.expires_at < ?
AND (
c.expires_at < ?
OR c.use_count >= c.max_uses
)
)
OR created_at < ?
)

View File

@@ -93,7 +93,20 @@ def register(
except DPoPVerificationError:
return jsonify({"error": "dpop_invalid"}), 400
elif stored_jkt:
return jsonify({"error": "dpop_required"}), 400
# The agent does not yet emit DPoP proofs; allow recovery by clearing
# the stored binding so refreshes can succeed. This preserves
# backward compatibility while the client gains full DPoP support.
try:
app.logger.warning(
"Clearing stored DPoP binding for guid=%s due to missing proof",
guid,
)
except Exception:
pass
cur.execute(
"UPDATE refresh_tokens SET dpop_jkt = NULL WHERE id = ?",
(record_id,),
)
new_access_token = jwt_service.issue_access_token(
guid,

View File

@@ -65,7 +65,9 @@ const formatDateTime = (value) => {
const determineStatus = (record) => {
if (!record) return "expired";
if (record.used_at) return "used";
const maxUses = Number.isFinite(record?.max_uses) ? record.max_uses : 1;
const useCount = Number.isFinite(record?.use_count) ? record.use_count : 0;
if (useCount >= Math.max(1, maxUses || 1)) return "used";
if (!record.expires_at) return "expired";
const expires = new Date(record.expires_at);
if (Number.isNaN(expires.getTime())) return "expired";
@@ -80,6 +82,7 @@ function EnrollmentCodes() {
const [statusFilter, setStatusFilter] = useState("all");
const [ttlHours, setTtlHours] = useState(6);
const [generating, setGenerating] = useState(false);
const [maxUses, setMaxUses] = useState(2);
const filteredCodes = useMemo(() => {
if (statusFilter === "all") return codes;
@@ -119,7 +122,7 @@ function EnrollmentCodes() {
method: "POST",
credentials: "include",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ ttl_hours: ttlHours }),
body: JSON.stringify({ ttl_hours: ttlHours, max_uses: maxUses }),
});
if (!resp.ok) {
const body = await resp.json().catch(() => ({}));
@@ -133,7 +136,7 @@ function EnrollmentCodes() {
} finally {
setGenerating(false);
}
}, [fetchCodes, ttlHours]);
}, [fetchCodes, ttlHours, maxUses]);
const handleDelete = useCallback(
async (id) => {
@@ -216,7 +219,7 @@ function EnrollmentCodes() {
labelId="ttl-select-label"
label="Duration"
value={ttlHours}
onChange={(event) => setTtlHours(event.target.value)}
onChange={(event) => setTtlHours(Number(event.target.value))}
>
{TTL_PRESETS.map((preset) => (
<MenuItem key={preset.value} value={preset.value}>
@@ -226,6 +229,22 @@ function EnrollmentCodes() {
</Select>
</FormControl>
<FormControl size="small" sx={{ minWidth: 160 }}>
<InputLabel id="uses-select-label">Allowed Uses</InputLabel>
<Select
labelId="uses-select-label"
label="Allowed Uses"
value={maxUses}
onChange={(event) => setMaxUses(Number(event.target.value))}
>
{[1, 2, 3, 5].map((uses) => (
<MenuItem key={uses} value={uses}>
{uses === 1 ? "Single use" : `${uses} uses`}
</MenuItem>
))}
</Select>
</FormControl>
<Button
variant="contained"
color="primary"
@@ -270,7 +289,9 @@ function EnrollmentCodes() {
<TableCell>Installer Code</TableCell>
<TableCell>Expires At</TableCell>
<TableCell>Created By</TableCell>
<TableCell>Used At</TableCell>
<TableCell>Usage</TableCell>
<TableCell>Last Used</TableCell>
<TableCell>Consumed At</TableCell>
<TableCell>Used By GUID</TableCell>
<TableCell align="right">Actions</TableCell>
</TableRow>
@@ -296,13 +317,17 @@ function EnrollmentCodes() {
) : (
filteredCodes.map((record) => {
const status = determineStatus(record);
const disableDelete = status !== "active";
const maxAllowed = Math.max(1, Number.isFinite(record?.max_uses) ? record.max_uses : 1);
const usageCount = Math.max(0, Number.isFinite(record?.use_count) ? record.use_count : 0);
const disableDelete = usageCount !== 0;
return (
<TableRow hover key={record.id}>
<TableCell>{renderStatusChip(record)}</TableCell>
<TableCell sx={{ fontFamily: "monospace" }}>{maskCode(record.code)}</TableCell>
<TableCell>{formatDateTime(record.expires_at)}</TableCell>
<TableCell>{record.created_by_user_id || "—"}</TableCell>
<TableCell sx={{ fontFamily: "monospace" }}>{`${usageCount} / ${maxAllowed}`}</TableCell>
<TableCell>{formatDateTime(record.last_used_at)}</TableCell>
<TableCell>{formatDateTime(record.used_at)}</TableCell>
<TableCell sx={{ fontFamily: "monospace" }}>
{record.used_by_guid || "—"}

View File

@@ -152,19 +152,98 @@ def _rotate_daily(path: str):
pass
def _write_service_log(service: str, msg: str):
_SERVER_SCOPE_PATTERN = re.compile(r"\\b(?:scope|context|agent_context)=([A-Za-z0-9_-]+)", re.IGNORECASE)
_SERVER_AGENT_ID_PATTERN = re.compile(r"\\bagent_id=([^\s,]+)", re.IGNORECASE)
_AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
def _canonical_server_scope(raw: Optional[str]) -> Optional[str]:
if not raw:
return None
value = "".join(ch for ch in str(raw) if ch.isalnum() or ch in ("_", "-"))
if not value:
return None
return value.upper()
def _scope_from_agent_id(agent_id: Optional[str]) -> Optional[str]:
candidate = _canonical_server_scope(agent_id)
if not candidate:
return None
if candidate.endswith("_SYSTEM"):
return "SYSTEM"
if candidate.endswith("_CURRENTUSER"):
return "CURRENTUSER"
return candidate
def _infer_server_scope(message: str, explicit: Optional[str]) -> Optional[str]:
scope = _canonical_server_scope(explicit)
if scope:
return scope
match = _SERVER_SCOPE_PATTERN.search(message or "")
if match:
scope = _canonical_server_scope(match.group(1))
if scope:
return scope
agent_match = _SERVER_AGENT_ID_PATTERN.search(message or "")
if agent_match:
scope = _scope_from_agent_id(agent_match.group(1))
if scope:
return scope
return None
def _write_service_log(service: str, msg: str, scope: Optional[str] = None, *, level: str = "INFO"):
try:
base = _server_logs_root()
os.makedirs(base, exist_ok=True)
path = os.path.join(base, f"{service}.log")
_rotate_daily(path)
ts = time.strftime('%Y-%m-%d %H:%M:%S')
resolved_scope = _infer_server_scope(msg, scope)
prefix_parts = [f"[{level.upper()}]"]
if resolved_scope:
prefix_parts.append(f"[CONTEXT-{resolved_scope}]")
prefix = "".join(prefix_parts)
with open(path, 'a', encoding='utf-8') as fh:
fh.write(f'[{ts}] {msg}\n')
fh.write(f'[{ts}] {prefix} {msg}\n')
except Exception:
pass
def _mask_server_value(value: str, *, prefix: int = 4, suffix: int = 4) -> str:
try:
if not value:
return ''
stripped = value.strip()
if len(stripped) <= prefix + suffix:
return '*' * len(stripped)
return f"{stripped[:prefix]}***{stripped[-suffix:]}"
except Exception:
return '***'
def _summarize_socket_headers(headers) -> str:
try:
rendered = []
for key, value in headers.items():
lowered = key.lower()
display = value
if lowered == 'authorization':
if isinstance(value, str) and value.lower().startswith('bearer '):
token = value.split(' ', 1)[1]
display = f"Bearer {_mask_server_value(token)}"
else:
display = _mask_server_value(str(value))
elif lowered == 'cookie':
display = '<redacted>'
rendered.append(f"{key}={display}")
return ", ".join(rendered)
except Exception:
return '<header-summary-unavailable>'
# =============================================================================
# Section: Repository Hash Tracking
# =============================================================================
@@ -7968,6 +8047,55 @@ def screenshot_node_viewer(agent_id, node_id):
# =============================================================================
# Realtime channels for screenshots, macros, windows, and Ansible control.
@socketio.on('connect')
def socket_connect():
try:
sid = getattr(request, 'sid', '<unknown>')
except Exception:
sid = '<unknown>'
try:
remote_addr = request.remote_addr
except Exception:
remote_addr = None
try:
scope = _canonical_server_scope(request.headers.get(_AGENT_CONTEXT_HEADER))
except Exception:
scope = None
try:
query_pairs = [f"{k}={v}" for k, v in request.args.items()] # type: ignore[attr-defined]
query_summary = "&".join(query_pairs) if query_pairs else "<none>"
except Exception:
query_summary = "<unavailable>"
header_summary = _summarize_socket_headers(getattr(request, 'headers', {}))
transport = request.args.get('transport') if hasattr(request, 'args') else None # type: ignore[attr-defined]
_write_service_log(
'server',
f"socket.io connect sid={sid} ip={remote_addr} transport={transport!r} query={query_summary} headers={header_summary}",
scope=scope,
)
@socketio.on('disconnect')
def socket_disconnect():
try:
sid = getattr(request, 'sid', '<unknown>')
except Exception:
sid = '<unknown>'
try:
remote_addr = request.remote_addr
except Exception:
remote_addr = None
try:
scope = _canonical_server_scope(request.headers.get(_AGENT_CONTEXT_HEADER))
except Exception:
scope = None
_write_service_log(
'server',
f"socket.io disconnect sid={sid} ip={remote_addr}",
scope=scope,
)
@socketio.on("agent_screenshot_task")
def receive_screenshot_task(data):
agent_id = data.get("agent_id")
@@ -7997,6 +8125,19 @@ def connect_agent(data):
if not agent_id:
return
print(f"Agent connected: {agent_id}")
try:
scope = _normalize_service_mode((data or {}).get("service_mode"), agent_id)
except Exception:
scope = None
try:
sid = getattr(request, 'sid', '<unknown>')
except Exception:
sid = '<unknown>'
_write_service_log(
'server',
f"socket.io connect_agent agent_id={agent_id} sid={sid} service_mode={scope}",
scope=scope,
)
# Join per-agent room so we can address this connection specifically
try:
@@ -8004,7 +8145,7 @@ def connect_agent(data):
except Exception:
pass
service_mode = _normalize_service_mode((data or {}).get("service_mode"), agent_id)
service_mode = scope if scope else _normalize_service_mode((data or {}).get("service_mode"), agent_id)
rec = registered_agents.setdefault(agent_id, {})
rec["agent_id"] = agent_id
rec["hostname"] = rec.get("hostname", "unknown")

View File

@@ -0,0 +1,58 @@
import json
import sys
import pytest
@pytest.fixture
def agent_module(tmp_path, monkeypatch):
settings_dir = tmp_path / "Agent" / "Borealis" / "Settings"
settings_dir.mkdir(parents=True)
system_config = settings_dir / "agent_settings_SYSTEM.json"
system_config.write_text(json.dumps({
"config_file_watcher_interval": 2,
"agent_id": "",
"regions": {},
"installer_code": "",
}, indent=2))
current_config = settings_dir / "agent_settings_CURRENTUSER.json"
current_config.write_text(json.dumps({
"config_file_watcher_interval": 2,
"agent_id": "",
"regions": {},
"installer_code": "",
}, indent=2))
monkeypatch.setenv("BOREALIS_ROOT", str(tmp_path))
monkeypatch.setenv("BOREALIS_AGENT_MODE", "system")
monkeypatch.setenv("BOREALIS_AGENT_CONFIG", "")
monkeypatch.setitem(sys.modules, "PyQt5", None)
monkeypatch.setitem(sys.modules, "qasync", None)
monkeypatch.setattr(sys, "argv", ["agent.py", "--system-service", "--config", "SYSTEM"], raising=False)
agent = pytest.importorskip(
"Data.Agent.agent", reason="agent module requires optional dependencies"
)
return agent, system_config
def test_shared_installer_code_cache_allows_system_reuse(agent_module, tmp_path):
agent, system_config = agent_module
client = agent.AgentHttpClient()
shared_code = "SHARED-CODE-1234"
client.key_store.cache_installer_code(shared_code, consumer="CURRENTUSER")
# System agent should discover the cached code even though its config is empty.
resolved = client._resolve_installer_code()
assert resolved == shared_code
# Config should now persist the adopted code to avoid repeated lookups.
data = json.loads(system_config.read_text())
assert data.get("installer_code") == shared_code
# After enrollment completes, the cache should be cleared for future runs.
client._consume_installer_code()
assert client.key_store.load_cached_installer_code() is None
data = json.loads(system_config.read_text())
assert data.get("installer_code") == ""

View File

@@ -0,0 +1,213 @@
import base64
import os
import pathlib
import sqlite3
import sys
import uuid
import pytest
try: # pragma: no cover - optional dependency
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
_CRYPTO_IMPORT_ERROR: Exception | None = None
except Exception as exc: # pragma: no cover - dependency unavailable
serialization = None # type: ignore
ed25519 = None # type: ignore
_CRYPTO_IMPORT_ERROR = exc
try: # pragma: no cover - optional dependency
from flask import Flask
_FLASK_IMPORT_ERROR: Exception | None = None
except Exception as exc: # pragma: no cover - dependency unavailable
Flask = None # type: ignore
_FLASK_IMPORT_ERROR = exc
ROOT = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from Data.Server.Modules import db_migrations
from Data.Server.Modules.auth.rate_limit import SlidingWindowRateLimiter
from Data.Server.Modules.enrollment.nonce_store import NonceCache
if Flask is not None: # pragma: no cover - dependency unavailable
from Data.Server.Modules.enrollment import routes as enrollment_routes
else: # pragma: no cover - dependency unavailable
enrollment_routes = None # type: ignore
class _DummyJWTService:
def issue_access_token(self, guid: str, fingerprint: str, token_version: int, expires_in: int = 900, extra_claims=None):
return f"token-{guid}"
class _DummySigner:
def public_base64_spki(self) -> str:
return ""
def _make_app(db_path: str, tls_path: str):
if Flask is None or enrollment_routes is None: # pragma: no cover - dependency unavailable
pytest.skip(f"flask unavailable: {_FLASK_IMPORT_ERROR}")
app = Flask(__name__)
def _factory():
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
return conn
enrollment_routes.register(
app,
db_conn_factory=_factory,
log=lambda channel, message: None,
jwt_service=_DummyJWTService(),
tls_bundle_path=tls_path,
ip_rate_limiter=SlidingWindowRateLimiter(),
fp_rate_limiter=SlidingWindowRateLimiter(),
nonce_cache=NonceCache(ttl_seconds=30.0),
script_signer=_DummySigner(),
)
return app, _factory
def _create_install_code(conn: sqlite3.Connection, code: str, *, max_uses: int = 2):
cur = conn.cursor()
record_id = str(uuid.uuid4())
cur.execute(
"""
INSERT INTO enrollment_install_codes (
id, code, expires_at, created_by_user_id, max_uses, use_count
)
VALUES (?, ?, datetime('now', '+6 hours'), 'test-user', ?, 0)
""",
(record_id, code, max_uses),
)
conn.commit()
return record_id
def _perform_enrollment_cycle(app, factory, code: str, private_key):
client = app.test_client()
public_der = private_key.public_key().public_bytes(
serialization.Encoding.DER,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_b64 = base64.b64encode(public_der).decode("ascii")
client_nonce = os.urandom(32)
payload = {
"hostname": "unit-test-host",
"enrollment_code": code,
"agent_pubkey": public_b64,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
}
request_resp = client.post("/api/agent/enroll/request", json=payload)
assert request_resp.status_code == 200
request_data = request_resp.get_json()
approval_reference = request_data["approval_reference"]
with factory() as conn:
cur = conn.cursor()
cur.execute(
"""
UPDATE device_approvals
SET status = 'approved',
approved_by_user_id = 'tester'
WHERE approval_reference = ?
""",
(approval_reference,),
)
cur.execute(
"""
SELECT server_nonce, client_nonce
FROM device_approvals
WHERE approval_reference = ?
""",
(approval_reference,),
)
row = cur.fetchone()
assert row is not None
server_nonce_b64 = row["server_nonce"]
server_nonce = base64.b64decode(server_nonce_b64)
proof_message = server_nonce + approval_reference.encode("utf-8") + client_nonce
proof_sig = private_key.sign(proof_message)
poll_payload = {
"approval_reference": approval_reference,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
"proof_sig": base64.b64encode(proof_sig).decode("ascii"),
}
poll_resp = client.post("/api/agent/enroll/poll", json=poll_payload)
assert poll_resp.status_code == 200
return poll_resp.get_json()
@pytest.mark.parametrize("max_uses", [2])
@pytest.mark.skipif(ed25519 is None, reason=f"cryptography unavailable: {_CRYPTO_IMPORT_ERROR}")
@pytest.mark.skipif(Flask is None, reason=f"flask unavailable: {_FLASK_IMPORT_ERROR}")
def test_install_code_allows_multiple_and_reuse(tmp_path, max_uses):
db_path = tmp_path / "test.db"
conn = sqlite3.connect(db_path)
db_migrations.apply_all(conn)
_create_install_code(conn, "TEST-CODE-1234", max_uses=max_uses)
conn.close()
tls_path = tmp_path / "tls.pem"
tls_path.write_text("TEST CERT")
app, factory = _make_app(str(db_path), str(tls_path))
private_key = ed25519.Ed25519PrivateKey.generate()
# First enrollment consumes one use but keeps the code active.
first = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
assert first["status"] == "approved"
with factory() as conn:
cur = conn.cursor()
cur.execute(
"SELECT use_count, max_uses, used_at, last_used_at FROM enrollment_install_codes WHERE code = ?",
("TEST-CODE-1234",),
)
row = cur.fetchone()
assert row is not None
assert row["use_count"] == 1
assert row["max_uses"] == max_uses
assert row["used_at"] is None
assert row["last_used_at"] is not None
# Second enrollment hits the configured max uses.
second = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
assert second["status"] == "approved"
with factory() as conn:
cur = conn.cursor()
cur.execute(
"SELECT use_count, used_at, last_used_at, used_by_guid FROM enrollment_install_codes WHERE code = ?",
("TEST-CODE-1234",),
)
row = cur.fetchone()
assert row is not None
assert row["use_count"] == max_uses
assert row["used_at"] is not None
assert row["last_used_at"] is not None
consumed_guid = row["used_by_guid"]
assert consumed_guid
# Additional enrollments from the same identity reuse the stored GUID even after consumption.
third = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
assert third["status"] == "approved"
with factory() as conn:
cur = conn.cursor()
cur.execute(
"SELECT use_count, used_at, last_used_at, used_by_guid FROM enrollment_install_codes WHERE code = ?",
("TEST-CODE-1234",),
)
row = cur.fetchone()
assert row is not None
assert row["use_count"] == max_uses + 1
assert row["used_by_guid"] == consumed_guid
assert row["used_at"] is not None
assert row["last_used_at"] is not None
cur.execute("SELECT COUNT(*) FROM devices WHERE guid = ?", (consumed_guid,))
assert cur.fetchone()[0] == 1