From 174cea55494735b03fd3b98d9527687e001af6b9 Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Fri, 17 Oct 2025 18:47:15 -0600 Subject: [PATCH] additional changes. --- AGENTS.md | 13 +- Data/Agent/agent.py | 477 ++++++++++++++++++++++++++++------- Data/Agent/security.py | 50 ++++ tests/test_agent_security.py | 57 +++++ 4 files changed, 507 insertions(+), 90 deletions(-) create mode 100644 tests/test_agent_security.py diff --git a/AGENTS.md b/AGENTS.md index e406e78..fcbf654 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -48,7 +48,17 @@ Today the stable core focuses on workflow-driven API and automation scenarios. R ## Agent Responsibilities ### Communication Channels -Agents establish REST calls to the Flask backend on port 5000 and keep a WebSocket session for interactive features such as screenshot capture. Future plans include WebRTC for higher-performance remote desktop. No authentication or enrollment handshake exists yet, so agents are implicitly trusted once launched. This will be secured in future updates to Borealis. +Agents establish TLS-secured REST calls to the Flask backend on port 5000 and keep an authenticated WebSocket session for interactive features such as screenshot capture. Future plans include WebRTC for higher-performance remote desktop. Every agent now performs an enrollment handshake (see **Secure Enrollment & Tokens** below) prior to opening either channel; all API access is bound to short-lived Ed25519-signed JWTs. + +### Secure Enrollment & Tokens +- On first launch the agent generates an Ed25519 identity and stores the private key under `Agent/Borealis/Settings/agent_key.ed25519` (protected with DPAPI on Windows or chmod 600 elsewhere). The public key is retained as SPKI DER and fingerprinted with SHA-256. +- Enrollment starts with an installer code (minted in the Web UI) and proves key possession by signing the server nonce. Upon operator approval the server issues: + - The canonical device GUID (persisted to `guid.txt` alongside the key material). + - A short-lived access token (EdDSA/JWT) and a long-lived refresh token (stored encrypted via DPAPI and hashed server-side). + - The server TLS certificate and script-signing public key so the agent can pin both for future sessions. +- Access tokens are automatically refreshed before expiry. Refresh failures trigger a re-enrollment. +- All REST calls (heartbeat, script polling, device details, service check-in) use these tokens; WebSocket connections include the `Authorization` header as well. +- Specify the installer code via `--installer-code `, `BOREALIS_INSTALLER_CODE`, or by adding `"installer_code": ""` to `Agent/Borealis/Settings/agent_settings.json`. ### Execution Contexts The agent runs in the interactive user session. SYSTEM-level script execution is provided by the ScriptExec SYSTEM role using ephemeral scheduled tasks; no separate supervisor or watchdog is required. @@ -195,4 +205,3 @@ This section summarizes what is considered usable vs. experimental today. - diff --git a/Data/Agent/agent.py b/Data/Agent/agent.py index 619435a..15ce11c 100644 --- a/Data/Agent/agent.py +++ b/Data/Agent/agent.py @@ -19,6 +19,8 @@ import getpass import datetime import shutil import string +import ssl +from typing import Any, Dict, Optional import requests try: @@ -107,6 +109,11 @@ def _canonical_config_suffix(raw_suffix: str) -> str: CONFIG_SUFFIX_CANONICAL = _canonical_config_suffix(CONFIG_NAME_SUFFIX) +INSTALLER_CODE_OVERRIDE = ( + (_argv_get('--installer-code') or os.environ.get('BOREALIS_INSTALLER_CODE') or '') + .strip() +) + def _agent_guid_path() -> str: try: @@ -383,7 +390,8 @@ CONFIG_PATH = _resolve_config_path() DEFAULT_CONFIG = { "config_file_watcher_interval": 2, "agent_id": "", - "regions": {} + "regions": {}, + "installer_code": "" } class ConfigManager: @@ -440,6 +448,277 @@ class ConfigManager: CONFIG = ConfigManager(CONFIG_PATH) CONFIG.load() + +class AgentHttpClient: + def __init__(self): + self.key_store = _key_store() + self.identity = IDENTITY + self.session = requests.Session() + self.base_url: Optional[str] = None + self.guid: Optional[str] = self.key_store.load_guid() + self.access_token: Optional[str] = self.key_store.load_access_token() + self.refresh_token: Optional[str] = self.key_store.load_refresh_token() + self.access_expires_at: Optional[int] = self.key_store.get_access_expiry() + self.refresh_base_url() + self._configure_verify() + if self.access_token: + self.session.headers.update({"Authorization": f"Bearer {self.access_token}"}) + self.session.headers.setdefault("User-Agent", "Borealis-Agent/secure") + + # ------------------------------------------------------------------ + # Session helpers + # ------------------------------------------------------------------ + def refresh_base_url(self) -> None: + try: + url = (get_server_url() or "").strip() + except Exception: + url = "" + if not url: + url = "https://localhost:5000" + if url.endswith("/"): + url = url[:-1] + if url != self.base_url: + self.base_url = url + + def _configure_verify(self) -> None: + cert_path = self.key_store.server_certificate_path() + if cert_path and os.path.isfile(cert_path): + self.session.verify = cert_path + else: + self.session.verify = False + try: + import urllib3 # type: ignore + + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # type: ignore[attr-defined] + except Exception: + pass + + def auth_headers(self) -> Dict[str, str]: + if self.access_token: + return {"Authorization": f"Bearer {self.access_token}"} + return {} + + def websocket_kwargs(self) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {} + verify = getattr(self.session, "verify", True) + if isinstance(verify, str) and os.path.isfile(verify): + try: + ctx = ssl.create_default_context(cafile=verify) + kwargs["ssl"] = ctx + except Exception: + pass + elif verify is False: + try: + kwargs["ssl"] = ssl._create_unverified_context() + except Exception: + pass + return kwargs + + # ------------------------------------------------------------------ + # Enrollment & token management + # ------------------------------------------------------------------ + def ensure_authenticated(self) -> None: + self.refresh_base_url() + if not self.guid or not self.refresh_token: + self.perform_enrollment() + if not self.access_token or self._token_expiring_soon(): + self.refresh_access_token() + + def _token_expiring_soon(self) -> bool: + if not self.access_token: + return True + if not self.access_expires_at: + return True + return (self.access_expires_at - time.time()) < 60 + + def perform_enrollment(self) -> None: + code = self._resolve_installer_code() + if not code: + raise RuntimeError( + "Installer code is required for enrollment. " + "Set BOREALIS_INSTALLER_CODE, pass --installer-code, or update agent_settings.json." + ) + self.refresh_base_url() + client_nonce = os.urandom(32) + payload = { + "hostname": socket.gethostname(), + "enrollment_code": code, + "agent_pubkey": PUBLIC_KEY_B64, + "client_nonce": base64.b64encode(client_nonce).decode("ascii"), + } + request_url = f"{self.base_url}/api/agent/enroll/request" + _log_agent("Starting enrollment request...", fname="agent.log") + resp = self.session.post(request_url, json=payload, timeout=30) + resp.raise_for_status() + data = resp.json() + if data.get("server_certificate"): + self.key_store.save_server_certificate(data["server_certificate"]) + self._configure_verify() + if data.get("status") != "pending": + raise RuntimeError(f"Unexpected enrollment status: {data}") + approval_reference = data.get("approval_reference") + server_nonce_b64 = data.get("server_nonce") + if not approval_reference or not server_nonce_b64: + raise RuntimeError("Enrollment response missing approval_reference or server_nonce") + server_nonce = base64.b64decode(server_nonce_b64) + poll_delay = max(int(data.get("poll_after_ms", 3000)) / 1000, 1) + while True: + time.sleep(min(poll_delay, 15)) + signature = self.identity.sign(server_nonce + approval_reference.encode("utf-8") + client_nonce) + poll_payload = { + "approval_reference": approval_reference, + "client_nonce": base64.b64encode(client_nonce).decode("ascii"), + "proof_sig": base64.b64encode(signature).decode("ascii"), + } + poll_resp = self.session.post( + f"{self.base_url}/api/agent/enroll/poll", + json=poll_payload, + timeout=30, + ) + poll_resp.raise_for_status() + poll_data = poll_resp.json() + status = poll_data.get("status") + if status == "pending": + poll_delay = max(int(poll_data.get("poll_after_ms", 5000)) / 1000, 1) + continue + if status == "denied": + raise RuntimeError("Enrollment denied by operator") + if status in ("expired", "unknown"): + raise RuntimeError(f"Enrollment failed with status={status}") + if status in ("approved", "completed"): + self._finalize_enrollment(poll_data) + break + raise RuntimeError(f"Unexpected enrollment poll response: {poll_data}") + + def _finalize_enrollment(self, payload: Dict[str, Any]) -> None: + server_cert = payload.get("server_certificate") + if server_cert: + self.key_store.save_server_certificate(server_cert) + self._configure_verify() + guid = payload.get("guid") + access_token = payload.get("access_token") + refresh_token = payload.get("refresh_token") + expires_in = int(payload.get("expires_in") or 900) + if not (guid and access_token and refresh_token): + raise RuntimeError("Enrollment approval response missing tokens or guid") + self.guid = str(guid).strip() + self.access_token = access_token.strip() + self.refresh_token = refresh_token.strip() + expiry = int(time.time()) + max(expires_in - 5, 0) + self.access_expires_at = expiry + self.key_store.save_guid(self.guid) + self.key_store.save_refresh_token(self.refresh_token) + self.key_store.save_access_token(self.access_token, expires_at=expiry) + self.key_store.set_access_binding(SSL_KEY_FINGERPRINT) + self.session.headers.update({"Authorization": f"Bearer {self.access_token}"}) + try: + _update_agent_id_for_guid(self.guid) + except Exception as exc: + _log_agent(f"Failed to update agent id after enrollment: {exc}", fname="agent.error.log") + _log_agent(f"Enrollment finalized for guid={self.guid}", fname="agent.log") + + def refresh_access_token(self) -> None: + if not self.refresh_token or not self.guid: + self.clear_tokens() + self.perform_enrollment() + return + payload = {"guid": self.guid, "refresh_token": self.refresh_token} + resp = self.session.post( + f"{self.base_url}/api/agent/token/refresh", + json=payload, + headers=self.auth_headers(), + timeout=20, + ) + if resp.status_code in (401, 403): + _log_agent("Refresh token rejected; re-enrolling", fname="agent.error.log") + self.clear_tokens() + self.perform_enrollment() + return + resp.raise_for_status() + data = resp.json() + access_token = data.get("access_token") + expires_in = int(data.get("expires_in") or 900) + if not access_token: + raise RuntimeError("Token refresh response missing access_token") + self.access_token = access_token.strip() + expiry = int(time.time()) + max(expires_in - 5, 0) + self.access_expires_at = expiry + self.key_store.save_access_token(self.access_token, expires_at=expiry) + self.key_store.set_access_binding(SSL_KEY_FINGERPRINT) + self.session.headers.update({"Authorization": f"Bearer {self.access_token}"}) + + def clear_tokens(self) -> None: + self.key_store.clear_tokens() + self.access_token = None + self.refresh_token = None + self.access_expires_at = None + self.guid = self.key_store.load_guid() + self.session.headers.pop("Authorization", None) + + def _resolve_installer_code(self) -> str: + if INSTALLER_CODE_OVERRIDE: + return INSTALLER_CODE_OVERRIDE + try: + code = (CONFIG.data.get("installer_code") or "").strip() + return code + except Exception: + return "" + + # ------------------------------------------------------------------ + # HTTP helpers + # ------------------------------------------------------------------ + def post_json(self, path: str, payload: Optional[Dict[str, Any]] = None, *, require_auth: bool = True) -> Any: + if require_auth: + self.ensure_authenticated() + url = f"{self.base_url}{path}" + headers = self.auth_headers() + response = self.session.post(url, json=payload, headers=headers, timeout=30) + if response.status_code in (401, 403) and require_auth: + self.clear_tokens() + self.ensure_authenticated() + headers = self.auth_headers() + response = self.session.post(url, json=payload, headers=headers, timeout=30) + response.raise_for_status() + if response.headers.get("Content-Type", "").lower().startswith("application/json"): + return response.json() + return response.text + + async def async_post_json( + self, + path: str, + payload: Optional[Dict[str, Any]] = None, + *, + require_auth: bool = True, + ) -> Any: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self.post_json, path, payload, require_auth) + + def websocket_base_url(self) -> str: + self.refresh_base_url() + return self.base_url or "https://localhost:5000" + + def store_server_signing_key(self, value: str) -> None: + try: + self.key_store.save_server_signing_key(value) + except Exception as exc: + _log_agent(f"Unable to store server signing key: {exc}", fname="agent.error.log") + + def load_server_signing_key(self) -> Optional[str]: + try: + return self.key_store.load_server_signing_key() + except Exception: + return None + + +HTTP_CLIENT: Optional[AgentHttpClient] = None + + +def http_client() -> AgentHttpClient: + global HTTP_CLIENT + if HTTP_CLIENT is None: + HTTP_CLIENT = AgentHttpClient() + return HTTP_CLIENT + def _get_context_label() -> str: return 'SYSTEM' if SYSTEM_SERVICE_MODE else 'CURRENTUSER' @@ -679,6 +958,46 @@ def detect_agent_os(): print(f"[WARN] OS detection failed: {e}") return "Unknown" + +def _system_uptime_seconds() -> Optional[int]: + try: + if psutil and hasattr(psutil, "boot_time"): + return int(time.time() - psutil.boot_time()) + except Exception: + pass + return None + + +def _collect_heartbeat_metrics() -> Dict[str, Any]: + metrics: Dict[str, Any] = { + "operating_system": detect_agent_os(), + "service_mode": SERVICE_MODE, + } + uptime = _system_uptime_seconds() + if uptime is not None: + metrics["uptime"] = uptime + try: + metrics["hostname"] = socket.gethostname() + except Exception: + pass + try: + metrics["username"] = getpass.getuser() + except Exception: + pass + if psutil: + try: + cpu = psutil.cpu_percent(interval=None) + metrics["cpu_percent"] = cpu + except Exception: + pass + try: + mem = psutil.virtual_memory() + metrics["memory_percent"] = getattr(mem, "percent", None) + except Exception: + pass + return metrics + + def _settings_dir(): try: return os.path.join(_find_project_root(), 'Agent', 'Borealis', 'Settings') @@ -697,6 +1016,9 @@ def _key_store() -> AgentKeyStore: return _KEY_STORE_INSTANCE +SERVER_CERT_PATH = _key_store().server_certificate_path() + + IDENTITY = _key_store().load_or_create_identity() SSL_KEY_FINGERPRINT = IDENTITY.fingerprint PUBLIC_KEY_B64 = IDENTITY.public_key_b64 @@ -985,46 +1307,40 @@ async def send_heartbeat(): Periodically send agent heartbeat to the server so the Devices page can show hostname, OS, and last_seen. """ - # Initial heartbeat is sent in the WebSocket 'connect' handler. - # Delay the loop start so we don't double-send immediately. - await asyncio.sleep(60) + await asyncio.sleep(15) + client = http_client() while True: try: + client.ensure_authenticated() payload = { - "agent_id": AGENT_ID, + "guid": client.guid or _read_agent_guid_from_disk(), "hostname": socket.gethostname(), - "agent_operating_system": detect_agent_os(), - "last_seen": int(time.time()), - "service_mode": SERVICE_MODE, + "inventory": {}, + "metrics": _collect_heartbeat_metrics(), } - await sio.emit("agent_heartbeat", payload) - # Also report collector status alive ping. - # To avoid clobbering last_user with SYSTEM/machine accounts, - # only include last_user from the interactive agent. - try: - if not SYSTEM_SERVICE_MODE: - import getpass - await sio.emit('collector_status', { - 'agent_id': AGENT_ID, - 'hostname': socket.gethostname(), - 'active': True, - 'service_mode': SERVICE_MODE, - 'last_user': f"{os.environ.get('USERDOMAIN') or socket.gethostname()}\\{getpass.getuser()}" - }) - else: - await sio.emit('collector_status', { - 'agent_id': AGENT_ID, - 'hostname': socket.gethostname(), - 'active': True, - 'service_mode': SERVICE_MODE, - }) - except Exception: - pass - except Exception as e: - print(f"[WARN] heartbeat emit failed: {e}") - # Send periodic heartbeats every 60 seconds + await client.async_post_json("/api/agent/heartbeat", payload, require_auth=True) + except Exception as exc: + _log_agent(f'Heartbeat post failed: {exc}', fname='agent.error.log') await asyncio.sleep(60) + +async def poll_script_requests(): + await asyncio.sleep(20) + client = http_client() + while True: + try: + client.ensure_authenticated() + payload = {"guid": client.guid or _read_agent_guid_from_disk()} + response = await client.async_post_json("/api/agent/script/request", payload, require_auth=True) + if isinstance(response, dict): + signing_key = response.get("signing_key") + if signing_key: + client.store_server_signing_key(signing_key) + # Placeholder: future script execution handling lives here. + except Exception as exc: + _log_agent(f'script request poll failed: {exc}', fname='agent.error.log') + await asyncio.sleep(30) + # ---------------- Detailed Agent Data ---------------- ## Moved to agent_info module @@ -1302,14 +1618,13 @@ async def send_agent_details(): "storage": collect_storage(), "network": collect_network(), } - url = get_server_url().rstrip('/') + "/api/agent/details" payload = { "agent_id": AGENT_ID, "hostname": details.get("summary", {}).get("hostname", socket.gethostname()), "details": details, } - async with aiohttp.ClientSession() as session: - await session.post(url, json=payload, timeout=10) + client = http_client() + await client.async_post_json("/api/agent/details", payload, require_auth=True) _log_agent('Posted agent details to server.') except Exception as e: print(f"[WARN] Failed to send agent details: {e}") @@ -1325,14 +1640,13 @@ async def send_agent_details_once(): "storage": collect_storage(), "network": collect_network(), } - url = get_server_url().rstrip('/') + "/api/agent/details" payload = { "agent_id": AGENT_ID, "hostname": details.get("summary", {}).get("hostname", socket.gethostname()), "details": details, } - async with aiohttp.ClientSession() as session: - await session.post(url, json=payload, timeout=10) + client = http_client() + await client.async_post_json("/api/agent/details", payload, require_auth=True) _log_agent('Posted agent details (once) to server.') except Exception as e: _log_agent(f'Failed to post agent details once: {e}', fname='agent.error.log') @@ -1343,39 +1657,19 @@ async def connect(): _log_agent('Connected to server.') await sio.emit('connect_agent', {"agent_id": AGENT_ID, "service_mode": SERVICE_MODE}) - # Send an immediate heartbeat so the UI can populate instantly. + # Send an immediate heartbeat via authenticated REST call. try: - await sio.emit("agent_heartbeat", { - "agent_id": AGENT_ID, + client = http_client() + client.ensure_authenticated() + payload = { + "guid": client.guid or _read_agent_guid_from_disk(), "hostname": socket.gethostname(), - "agent_operating_system": detect_agent_os(), - "last_seen": int(time.time()), - "service_mode": SERVICE_MODE, - }) - except Exception as e: - print(f"[WARN] initial heartbeat failed: {e}") - _log_agent(f'Initial heartbeat failed: {e}', fname='agent.error.log') - - # Let server know collector is active; send last_user only from interactive agent - try: - if not SYSTEM_SERVICE_MODE: - import getpass - await sio.emit('collector_status', { - 'agent_id': AGENT_ID, - 'hostname': socket.gethostname(), - 'active': True, - 'service_mode': SERVICE_MODE, - 'last_user': f"{os.environ.get('USERDOMAIN') or socket.gethostname()}\\{getpass.getuser()}" - }) - else: - await sio.emit('collector_status', { - 'agent_id': AGENT_ID, - 'hostname': socket.gethostname(), - 'active': True, - 'service_mode': SERVICE_MODE, - }) - except Exception: - pass + "inventory": {}, + "metrics": _collect_heartbeat_metrics(), + } + await client.async_post_json("/api/agent/heartbeat", payload, require_auth=True) + except Exception as exc: + _log_agent(f'Initial REST heartbeat failed: {exc}', fname='agent.error.log') await sio.emit('request_config', {"agent_id": AGENT_ID}) # Inventory details posting is managed by the DeviceAudit role (SYSTEM). No one-shot post here. @@ -1383,21 +1677,14 @@ async def connect(): try: async def _svc_checkin_once(): try: - url = get_server_url().rstrip('/') + "/api/agent/checkin" payload = {"agent_id": AGENT_ID, "hostname": socket.gethostname(), "username": ".\\svcBorealis"} - timeout = aiohttp.ClientTimeout(total=10) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post(url, json=payload) as resp: - if resp.status == 200: - try: - data = await resp.json(content_type=None) - except Exception: - data = None - if isinstance(data, dict): - guid_value = (data.get('agent_guid') or '').strip() - if guid_value: - _persist_agent_guid_local(guid_value) - _update_agent_id_for_guid(guid_value) + client = http_client() + response = await client.async_post_json("/api/agent/checkin", payload, require_auth=True) + if isinstance(response, dict): + guid_value = (response.get('agent_guid') or '').strip() + if guid_value: + _persist_agent_guid_local(guid_value) + _update_agent_id_for_guid(guid_value) except Exception: pass asyncio.create_task(_svc_checkin_once()) @@ -1649,13 +1936,20 @@ if not SYSTEM_SERVICE_MODE: # MAIN & EVENT LOOP # ////////////////////////////////////////////////////////////////////////// async def connect_loop(): - retry=5 + retry = 5 + client = http_client() while True: try: - url=get_server_url() + client.ensure_authenticated() + url = client.websocket_base_url() print(f"[INFO] Connecting Agent to {url}...") _log_agent(f'Connecting to {url}...') - await sio.connect(url,transports=['websocket']) + await sio.connect( + url, + transports=['websocket'], + headers=client.auth_headers(), + ssl_verify=client.session.verify, + ) break except Exception as e: print(f"[WebSocket] Server unavailable: {e}. Retrying in {retry}s...") @@ -1683,6 +1977,12 @@ if __name__=='__main__': dummy_window=PersistentWindow(); dummy_window.show() # Initialize roles context for role tasks # Initialize role manager and hot-load roles from Roles/ + client = http_client() + try: + client.ensure_authenticated() + except Exception as exc: + _log_agent(f'Authentication bootstrap failed: {exc}', fname='agent.error.log') + print(f"[WARN] Authentication bootstrap failed: {exc}") try: base_hooks = {'send_service_control': send_service_control, 'get_server_url': get_server_url} if not SYSTEM_SERVICE_MODE: @@ -1723,6 +2023,7 @@ if __name__=='__main__': background_tasks.append(loop.create_task(idle_task())) # Start periodic heartbeats background_tasks.append(loop.create_task(send_heartbeat())) + background_tasks.append(loop.create_task(poll_script_requests())) # Inventory upload is handled by the DeviceAudit role running in SYSTEM context. # Do not schedule the legacy agent-level details poster to avoid duplicates. diff --git a/Data/Agent/security.py b/Data/Agent/security.py index 7f37f15..e3feb9d 100644 --- a/Data/Agent/security.py +++ b/Data/Agent/security.py @@ -79,6 +79,8 @@ class AgentKeyStore: self._access_token_path = os.path.join(self.settings_dir, "access.jwt") self._refresh_token_path = os.path.join(self.settings_dir, "refresh.token") self._token_meta_path = os.path.join(self.settings_dir, "access.meta.json") + self._server_certificate_path = os.path.join(self.settings_dir, "server_certificate.pem") + self._server_signing_key_path = os.path.join(self.settings_dir, "server_signing_key.pub") # ------------------------------------------------------------------ # Identity management @@ -198,6 +200,54 @@ class AgentKeyStore: os.remove(path) except Exception: pass + # ------------------------------------------------------------------ + # Server certificate & signing key helpers + # ------------------------------------------------------------------ + def server_certificate_path(self) -> str: + return self._server_certificate_path + + def save_server_certificate(self, pem_text: str) -> None: + if not pem_text: + return + normalized = pem_text.strip() + if not normalized: + return + if not normalized.endswith("\n"): + normalized += "\n" + with open(self._server_certificate_path, "w", encoding="utf-8") as fh: + fh.write(normalized) + _restrict_permissions(self._server_certificate_path) + + def load_server_certificate(self) -> Optional[str]: + try: + if os.path.isfile(self._server_certificate_path): + with open(self._server_certificate_path, "r", encoding="utf-8") as fh: + return fh.read() + except Exception: + return None + return None + + def save_server_signing_key(self, value: str) -> None: + if not value: + return + normalized = value.strip() + if not normalized: + return + with open(self._server_signing_key_path, "w", encoding="utf-8") as fh: + fh.write(normalized) + fh.write("\n") + _restrict_permissions(self._server_signing_key_path) + + def load_server_signing_key(self) -> Optional[str]: + try: + if os.path.isfile(self._server_signing_key_path): + with open(self._server_signing_key_path, "r", encoding="utf-8") as fh: + value = fh.read().strip() + return value or None + except Exception: + return None + return None + # ------------------------------------------------------------------ # Token metadata (e.g., expiry, fingerprint binding) diff --git a/tests/test_agent_security.py b/tests/test_agent_security.py new file mode 100644 index 0000000..db2dd69 --- /dev/null +++ b/tests/test_agent_security.py @@ -0,0 +1,57 @@ + +from __future__ import annotations + +import shutil +import tempfile +import unittest + +import pathlib +import sys + +ROOT = pathlib.Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +try: + from Data.Agent.security import AgentKeyStore # type: ignore + _IMPORT_ERROR: Exception | None = None +except Exception as exc: # pragma: no cover - handled via skip + AgentKeyStore = None # type: ignore + _IMPORT_ERROR = exc + + +@unittest.skipIf(AgentKeyStore is None, f"security module unavailable: {_IMPORT_ERROR}") +class AgentKeyStoreTests(unittest.TestCase): + def test_roundtrip(self): + tmp_dir = tempfile.mkdtemp(prefix="akstest-") + try: + store = AgentKeyStore(tmp_dir, scope="CURRENTUSER") + identity = store.load_or_create_identity() + + self.assertTrue(identity.public_key_b64) + self.assertEqual(len(identity.fingerprint), 64) + + store.save_guid("ABC-123") + self.assertEqual(store.load_guid(), "ABC-123") + + store.save_access_token("access-token", expires_at=12345) + self.assertEqual(store.load_access_token(), "access-token") + self.assertEqual(store.get_access_expiry(), 12345) + + store.save_refresh_token("refresh-token") + self.assertEqual(store.load_refresh_token(), "refresh-token") + + store.set_access_binding(identity.fingerprint) + self.assertEqual(store.get_access_binding(), identity.fingerprint) + + store.save_server_certificate("-----BEGIN CERT-----\nABC\n-----END CERT-----") + self.assertIn("BEGIN CERT", store.load_server_certificate() or "") + + store.save_server_signing_key("PUBKEYDATA") + self.assertEqual(store.load_server_signing_key(), "PUBKEYDATA") + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + +if __name__ == "__main__": + unittest.main()