mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 22:01:59 -06:00
Removed Experimental Engine
This commit is contained in:
@@ -1,104 +0,0 @@
|
||||
"""Application services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
"DeviceRecord",
|
||||
"RefreshTokenRecord",
|
||||
"TokenRefreshError",
|
||||
"TokenRefreshErrorCode",
|
||||
"TokenService",
|
||||
"EnrollmentService",
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"EnrollmentValidationError",
|
||||
"PollingResult",
|
||||
"AgentRealtimeService",
|
||||
"AgentRecord",
|
||||
"SchedulerService",
|
||||
"GitHubService",
|
||||
"GitHubTokenPayload",
|
||||
"EnrollmentAdminService",
|
||||
"SiteService",
|
||||
"DeviceInventoryService",
|
||||
"DeviceViewService",
|
||||
"CredentialService",
|
||||
"AssemblyService",
|
||||
"AssemblyListing",
|
||||
"AssemblyLoadResult",
|
||||
"AssemblyMutationResult",
|
||||
]
|
||||
|
||||
_LAZY_TARGETS: Dict[str, Tuple[str, str]] = {
|
||||
"DeviceAuthService": ("Data.Engine.services.auth.device_auth_service", "DeviceAuthService"),
|
||||
"DeviceRecord": ("Data.Engine.services.auth.device_auth_service", "DeviceRecord"),
|
||||
"RefreshTokenRecord": ("Data.Engine.services.auth.device_auth_service", "RefreshTokenRecord"),
|
||||
"TokenService": ("Data.Engine.services.auth.token_service", "TokenService"),
|
||||
"TokenRefreshError": ("Data.Engine.services.auth.token_service", "TokenRefreshError"),
|
||||
"TokenRefreshErrorCode": ("Data.Engine.services.auth.token_service", "TokenRefreshErrorCode"),
|
||||
"EnrollmentService": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentService"),
|
||||
"EnrollmentRequestResult": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentRequestResult"),
|
||||
"EnrollmentStatus": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentStatus"),
|
||||
"EnrollmentTokenBundle": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentTokenBundle"),
|
||||
"PollingResult": ("Data.Engine.services.enrollment.enrollment_service", "PollingResult"),
|
||||
"EnrollmentValidationError": ("Data.Engine.domain.device_enrollment", "EnrollmentValidationError"),
|
||||
"AgentRealtimeService": ("Data.Engine.services.realtime.agent_registry", "AgentRealtimeService"),
|
||||
"AgentRecord": ("Data.Engine.services.realtime.agent_registry", "AgentRecord"),
|
||||
"SchedulerService": ("Data.Engine.services.jobs.scheduler_service", "SchedulerService"),
|
||||
"GitHubService": ("Data.Engine.services.github.github_service", "GitHubService"),
|
||||
"GitHubTokenPayload": ("Data.Engine.services.github.github_service", "GitHubTokenPayload"),
|
||||
"EnrollmentAdminService": (
|
||||
"Data.Engine.services.enrollment.admin_service",
|
||||
"EnrollmentAdminService",
|
||||
),
|
||||
"SiteService": ("Data.Engine.services.sites.site_service", "SiteService"),
|
||||
"DeviceInventoryService": (
|
||||
"Data.Engine.services.devices.device_inventory_service",
|
||||
"DeviceInventoryService",
|
||||
),
|
||||
"DeviceViewService": (
|
||||
"Data.Engine.services.devices.device_view_service",
|
||||
"DeviceViewService",
|
||||
),
|
||||
"CredentialService": (
|
||||
"Data.Engine.services.credentials.credential_service",
|
||||
"CredentialService",
|
||||
),
|
||||
"AssemblyService": (
|
||||
"Data.Engine.services.assemblies.assembly_service",
|
||||
"AssemblyService",
|
||||
),
|
||||
"AssemblyListing": (
|
||||
"Data.Engine.services.assemblies.assembly_service",
|
||||
"AssemblyListing",
|
||||
),
|
||||
"AssemblyLoadResult": (
|
||||
"Data.Engine.services.assemblies.assembly_service",
|
||||
"AssemblyLoadResult",
|
||||
),
|
||||
"AssemblyMutationResult": (
|
||||
"Data.Engine.services.assemblies.assembly_service",
|
||||
"AssemblyMutationResult",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
try:
|
||||
module_name, attribute = _LAZY_TARGETS[name]
|
||||
except KeyError as exc:
|
||||
raise AttributeError(name) from exc
|
||||
|
||||
module = import_module(module_name)
|
||||
value = getattr(module, attribute)
|
||||
globals()[name] = value
|
||||
return value
|
||||
|
||||
|
||||
def __dir__() -> Any: # pragma: no cover - interactive helper
|
||||
return sorted(set(__all__))
|
||||
@@ -1,10 +0,0 @@
|
||||
"""Assembly management services."""
|
||||
|
||||
from .assembly_service import AssemblyService, AssemblyMutationResult, AssemblyLoadResult, AssemblyListing
|
||||
|
||||
__all__ = [
|
||||
"AssemblyService",
|
||||
"AssemblyMutationResult",
|
||||
"AssemblyLoadResult",
|
||||
"AssemblyListing",
|
||||
]
|
||||
@@ -1,715 +0,0 @@
|
||||
"""Filesystem-backed assembly management service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
__all__ = [
|
||||
"AssemblyService",
|
||||
"AssemblyListing",
|
||||
"AssemblyLoadResult",
|
||||
"AssemblyMutationResult",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AssemblyListing:
|
||||
"""Listing payload for an assembly island."""
|
||||
|
||||
root: Path
|
||||
items: List[Dict[str, Any]]
|
||||
folders: List[str]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"root": str(self.root),
|
||||
"items": self.items,
|
||||
"folders": self.folders,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AssemblyLoadResult:
|
||||
"""Container describing a loaded assembly artifact."""
|
||||
|
||||
payload: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return dict(self.payload)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AssemblyMutationResult:
|
||||
"""Mutation acknowledgement for create/edit/rename operations."""
|
||||
|
||||
status: str = "ok"
|
||||
rel_path: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {"status": self.status}
|
||||
if self.rel_path:
|
||||
payload["rel_path"] = self.rel_path
|
||||
return payload
|
||||
|
||||
|
||||
class AssemblyService:
|
||||
"""Provide CRUD helpers for workflow/script/ansible assemblies."""
|
||||
|
||||
_ISLAND_DIR_MAP = {
|
||||
"workflows": "Workflows",
|
||||
"workflow": "Workflows",
|
||||
"scripts": "Scripts",
|
||||
"script": "Scripts",
|
||||
"ansible": "Ansible_Playbooks",
|
||||
"ansible_playbooks": "Ansible_Playbooks",
|
||||
"ansible-playbooks": "Ansible_Playbooks",
|
||||
"playbooks": "Ansible_Playbooks",
|
||||
}
|
||||
|
||||
_SCRIPT_EXTENSIONS = (".json", ".ps1", ".bat", ".sh")
|
||||
_ANSIBLE_EXTENSIONS = (".json", ".yml")
|
||||
|
||||
def __init__(self, *, root: Path, logger: Optional[logging.Logger] = None) -> None:
|
||||
self._root = root.resolve()
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.assemblies")
|
||||
try:
|
||||
self._root.mkdir(parents=True, exist_ok=True)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.warning("failed to ensure assemblies root %s: %s", self._root, exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
def list_items(self, island: str) -> AssemblyListing:
|
||||
root = self._resolve_island_root(island)
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
items: List[Dict[str, Any]] = []
|
||||
folders: List[str] = []
|
||||
|
||||
isl = (island or "").strip().lower()
|
||||
if isl in {"workflows", "workflow"}:
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
rel_root = os.path.relpath(dirpath, root)
|
||||
if rel_root != ".":
|
||||
folders.append(rel_root.replace(os.sep, "/"))
|
||||
for fname in filenames:
|
||||
if not fname.lower().endswith(".json"):
|
||||
continue
|
||||
abs_path = Path(dirpath) / fname
|
||||
rel_path = abs_path.relative_to(root).as_posix()
|
||||
try:
|
||||
mtime = abs_path.stat().st_mtime
|
||||
except OSError:
|
||||
mtime = 0.0
|
||||
obj = self._safe_read_json(abs_path)
|
||||
tab = self._extract_tab_name(obj)
|
||||
items.append(
|
||||
{
|
||||
"file_name": fname,
|
||||
"rel_path": rel_path,
|
||||
"type": "workflow",
|
||||
"tab_name": tab,
|
||||
"last_edited": time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%S", time.localtime(mtime)
|
||||
),
|
||||
"last_edited_epoch": mtime,
|
||||
}
|
||||
)
|
||||
elif isl in {"scripts", "script"}:
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
rel_root = os.path.relpath(dirpath, root)
|
||||
if rel_root != ".":
|
||||
folders.append(rel_root.replace(os.sep, "/"))
|
||||
for fname in filenames:
|
||||
if not fname.lower().endswith(self._SCRIPT_EXTENSIONS):
|
||||
continue
|
||||
abs_path = Path(dirpath) / fname
|
||||
rel_path = abs_path.relative_to(root).as_posix()
|
||||
try:
|
||||
mtime = abs_path.stat().st_mtime
|
||||
except OSError:
|
||||
mtime = 0.0
|
||||
script_type = self._detect_script_type(abs_path)
|
||||
doc = self._load_assembly_document(abs_path, "scripts", script_type)
|
||||
items.append(
|
||||
{
|
||||
"file_name": fname,
|
||||
"rel_path": rel_path,
|
||||
"type": doc.get("type", script_type),
|
||||
"name": doc.get("name"),
|
||||
"category": doc.get("category"),
|
||||
"description": doc.get("description"),
|
||||
"last_edited": time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%S", time.localtime(mtime)
|
||||
),
|
||||
"last_edited_epoch": mtime,
|
||||
}
|
||||
)
|
||||
elif isl in {
|
||||
"ansible",
|
||||
"ansible_playbooks",
|
||||
"ansible-playbooks",
|
||||
"playbooks",
|
||||
}:
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
rel_root = os.path.relpath(dirpath, root)
|
||||
if rel_root != ".":
|
||||
folders.append(rel_root.replace(os.sep, "/"))
|
||||
for fname in filenames:
|
||||
if not fname.lower().endswith(self._ANSIBLE_EXTENSIONS):
|
||||
continue
|
||||
abs_path = Path(dirpath) / fname
|
||||
rel_path = abs_path.relative_to(root).as_posix()
|
||||
try:
|
||||
mtime = abs_path.stat().st_mtime
|
||||
except OSError:
|
||||
mtime = 0.0
|
||||
script_type = self._detect_script_type(abs_path)
|
||||
doc = self._load_assembly_document(abs_path, "ansible", script_type)
|
||||
items.append(
|
||||
{
|
||||
"file_name": fname,
|
||||
"rel_path": rel_path,
|
||||
"type": doc.get("type", "ansible"),
|
||||
"name": doc.get("name"),
|
||||
"category": doc.get("category"),
|
||||
"description": doc.get("description"),
|
||||
"last_edited": time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%S", time.localtime(mtime)
|
||||
),
|
||||
"last_edited_epoch": mtime,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid_island")
|
||||
|
||||
items.sort(key=lambda entry: entry.get("last_edited_epoch", 0.0), reverse=True)
|
||||
return AssemblyListing(root=root, items=items, folders=folders)
|
||||
|
||||
def load_item(self, island: str, rel_path: str) -> AssemblyLoadResult:
|
||||
root, abs_path, _ = self._resolve_assembly_path(island, rel_path)
|
||||
if not abs_path.is_file():
|
||||
raise FileNotFoundError("file_not_found")
|
||||
|
||||
isl = (island or "").strip().lower()
|
||||
if isl in {"workflows", "workflow"}:
|
||||
payload = self._safe_read_json(abs_path)
|
||||
return AssemblyLoadResult(payload=payload)
|
||||
|
||||
doc = self._load_assembly_document(abs_path, island)
|
||||
rel = abs_path.relative_to(root).as_posix()
|
||||
payload = {
|
||||
"file_name": abs_path.name,
|
||||
"rel_path": rel,
|
||||
"type": doc.get("type"),
|
||||
"assembly": doc,
|
||||
"content": doc.get("script"),
|
||||
}
|
||||
return AssemblyLoadResult(payload=payload)
|
||||
|
||||
def create_item(
|
||||
self,
|
||||
island: str,
|
||||
*,
|
||||
kind: str,
|
||||
rel_path: str,
|
||||
content: Any,
|
||||
item_type: Optional[str] = None,
|
||||
) -> AssemblyMutationResult:
|
||||
root, abs_path, rel_norm = self._resolve_assembly_path(island, rel_path)
|
||||
if not rel_norm:
|
||||
raise ValueError("path_required")
|
||||
|
||||
normalized_kind = (kind or "").strip().lower()
|
||||
if normalized_kind == "folder":
|
||||
abs_path.mkdir(parents=True, exist_ok=True)
|
||||
return AssemblyMutationResult()
|
||||
if normalized_kind != "file":
|
||||
raise ValueError("invalid_kind")
|
||||
|
||||
target_path = abs_path
|
||||
if not target_path.suffix:
|
||||
target_path = target_path.with_suffix(
|
||||
self._default_ext_for_island(island, item_type or "")
|
||||
)
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
isl = (island or "").strip().lower()
|
||||
if isl in {"workflows", "workflow"}:
|
||||
payload = self._ensure_workflow_document(content)
|
||||
base_name = target_path.stem
|
||||
payload.setdefault("tab_name", base_name)
|
||||
self._write_json(target_path, payload)
|
||||
else:
|
||||
document = self._normalize_assembly_document(
|
||||
content,
|
||||
self._default_type_for_island(island, item_type or ""),
|
||||
target_path.stem,
|
||||
)
|
||||
self._write_json(target_path, self._prepare_assembly_storage(document))
|
||||
|
||||
rel_new = target_path.relative_to(root).as_posix()
|
||||
return AssemblyMutationResult(rel_path=rel_new)
|
||||
|
||||
def edit_item(
|
||||
self,
|
||||
island: str,
|
||||
*,
|
||||
rel_path: str,
|
||||
content: Any,
|
||||
item_type: Optional[str] = None,
|
||||
) -> AssemblyMutationResult:
|
||||
root, abs_path, _ = self._resolve_assembly_path(island, rel_path)
|
||||
if not abs_path.exists():
|
||||
raise FileNotFoundError("file_not_found")
|
||||
|
||||
target_path = abs_path
|
||||
if not target_path.suffix:
|
||||
target_path = target_path.with_suffix(
|
||||
self._default_ext_for_island(island, item_type or "")
|
||||
)
|
||||
|
||||
isl = (island or "").strip().lower()
|
||||
if isl in {"workflows", "workflow"}:
|
||||
payload = self._ensure_workflow_document(content)
|
||||
self._write_json(target_path, payload)
|
||||
else:
|
||||
document = self._normalize_assembly_document(
|
||||
content,
|
||||
self._default_type_for_island(island, item_type or ""),
|
||||
target_path.stem,
|
||||
)
|
||||
self._write_json(target_path, self._prepare_assembly_storage(document))
|
||||
|
||||
if target_path != abs_path and abs_path.exists():
|
||||
try:
|
||||
abs_path.unlink()
|
||||
except OSError: # pragma: no cover - best effort cleanup
|
||||
pass
|
||||
|
||||
rel_new = target_path.relative_to(root).as_posix()
|
||||
return AssemblyMutationResult(rel_path=rel_new)
|
||||
|
||||
def rename_item(
|
||||
self,
|
||||
island: str,
|
||||
*,
|
||||
kind: str,
|
||||
rel_path: str,
|
||||
new_name: str,
|
||||
item_type: Optional[str] = None,
|
||||
) -> AssemblyMutationResult:
|
||||
root, old_path, _ = self._resolve_assembly_path(island, rel_path)
|
||||
|
||||
normalized_kind = (kind or "").strip().lower()
|
||||
if normalized_kind not in {"file", "folder"}:
|
||||
raise ValueError("invalid_kind")
|
||||
|
||||
if normalized_kind == "folder":
|
||||
if not old_path.is_dir():
|
||||
raise FileNotFoundError("folder_not_found")
|
||||
destination = old_path.parent / new_name
|
||||
else:
|
||||
if not old_path.is_file():
|
||||
raise FileNotFoundError("file_not_found")
|
||||
candidate = Path(new_name)
|
||||
if not candidate.suffix:
|
||||
candidate = candidate.with_suffix(
|
||||
self._default_ext_for_island(island, item_type or "")
|
||||
)
|
||||
destination = old_path.parent / candidate.name
|
||||
|
||||
destination = destination.resolve()
|
||||
if not str(destination).startswith(str(root)):
|
||||
raise ValueError("invalid_destination")
|
||||
|
||||
old_path.rename(destination)
|
||||
|
||||
isl = (island or "").strip().lower()
|
||||
if normalized_kind == "file" and isl in {"workflows", "workflow"}:
|
||||
try:
|
||||
obj = self._safe_read_json(destination)
|
||||
base_name = destination.stem
|
||||
for key in ["tabName", "tab_name", "name", "title"]:
|
||||
if key in obj:
|
||||
obj[key] = base_name
|
||||
obj.setdefault("tab_name", base_name)
|
||||
self._write_json(destination, obj)
|
||||
except Exception: # pragma: no cover - best effort update
|
||||
self._log.debug("failed to normalize workflow metadata for %s", destination)
|
||||
|
||||
rel_new = destination.relative_to(root).as_posix()
|
||||
return AssemblyMutationResult(rel_path=rel_new)
|
||||
|
||||
def move_item(
|
||||
self,
|
||||
island: str,
|
||||
*,
|
||||
rel_path: str,
|
||||
new_path: str,
|
||||
kind: Optional[str] = None,
|
||||
) -> AssemblyMutationResult:
|
||||
root, old_path, _ = self._resolve_assembly_path(island, rel_path)
|
||||
_, dest_path, _ = self._resolve_assembly_path(island, new_path)
|
||||
|
||||
normalized_kind = (kind or "").strip().lower()
|
||||
if normalized_kind == "folder":
|
||||
if not old_path.is_dir():
|
||||
raise FileNotFoundError("folder_not_found")
|
||||
else:
|
||||
if not old_path.exists():
|
||||
raise FileNotFoundError("file_not_found")
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(old_path), str(dest_path))
|
||||
return AssemblyMutationResult()
|
||||
|
||||
def delete_item(
|
||||
self,
|
||||
island: str,
|
||||
*,
|
||||
rel_path: str,
|
||||
kind: str,
|
||||
) -> AssemblyMutationResult:
|
||||
_, abs_path, rel_norm = self._resolve_assembly_path(island, rel_path)
|
||||
if not rel_norm:
|
||||
raise ValueError("cannot_delete_root")
|
||||
|
||||
normalized_kind = (kind or "").strip().lower()
|
||||
if normalized_kind == "folder":
|
||||
if not abs_path.is_dir():
|
||||
raise FileNotFoundError("folder_not_found")
|
||||
shutil.rmtree(abs_path)
|
||||
elif normalized_kind == "file":
|
||||
if not abs_path.is_file():
|
||||
raise FileNotFoundError("file_not_found")
|
||||
abs_path.unlink()
|
||||
else:
|
||||
raise ValueError("invalid_kind")
|
||||
|
||||
return AssemblyMutationResult()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _resolve_island_root(self, island: str) -> Path:
|
||||
key = (island or "").strip().lower()
|
||||
subdir = self._ISLAND_DIR_MAP.get(key)
|
||||
if not subdir:
|
||||
raise ValueError("invalid_island")
|
||||
root = (self._root / subdir).resolve()
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
def _resolve_assembly_path(self, island: str, rel_path: str) -> Tuple[Path, Path, str]:
|
||||
root = self._resolve_island_root(island)
|
||||
rel_norm = self._normalize_relpath(rel_path)
|
||||
abs_path = (root / rel_norm).resolve()
|
||||
if not str(abs_path).startswith(str(root)):
|
||||
raise ValueError("invalid_path")
|
||||
return root, abs_path, rel_norm
|
||||
|
||||
@staticmethod
|
||||
def _normalize_relpath(value: str) -> str:
|
||||
return (value or "").replace("\\", "/").strip("/")
|
||||
|
||||
@staticmethod
|
||||
def _default_ext_for_island(island: str, item_type: str) -> str:
|
||||
isl = (island or "").strip().lower()
|
||||
if isl in {"workflows", "workflow"}:
|
||||
return ".json"
|
||||
if isl in {"ansible", "ansible_playbooks", "ansible-playbooks", "playbooks"}:
|
||||
return ".json"
|
||||
if isl in {"scripts", "script"}:
|
||||
return ".json"
|
||||
typ = (item_type or "").strip().lower()
|
||||
if typ in {"bash", "batch", "powershell"}:
|
||||
return ".json"
|
||||
return ".json"
|
||||
|
||||
@staticmethod
|
||||
def _default_type_for_island(island: str, item_type: str) -> str:
|
||||
isl = (island or "").strip().lower()
|
||||
if isl in {"ansible", "ansible_playbooks", "ansible-playbooks", "playbooks"}:
|
||||
return "ansible"
|
||||
typ = (item_type or "").strip().lower()
|
||||
if typ in {"powershell", "batch", "bash", "ansible"}:
|
||||
return typ
|
||||
return "powershell"
|
||||
|
||||
@staticmethod
|
||||
def _empty_assembly_document(default_type: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"version": 1,
|
||||
"name": "",
|
||||
"description": "",
|
||||
"category": "application" if default_type.lower() == "ansible" else "script",
|
||||
"type": default_type or "powershell",
|
||||
"script": "",
|
||||
"timeout_seconds": 3600,
|
||||
"sites": {"mode": "all", "values": []},
|
||||
"variables": [],
|
||||
"files": [],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64_text(value: Any) -> Optional[str]:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return ""
|
||||
try:
|
||||
cleaned = re.sub(r"\s+", "", stripped)
|
||||
except Exception:
|
||||
cleaned = stripped
|
||||
try:
|
||||
decoded = base64.b64decode(cleaned, validate=True)
|
||||
except Exception:
|
||||
return None
|
||||
try:
|
||||
return decoded.decode("utf-8")
|
||||
except Exception:
|
||||
return decoded.decode("utf-8", errors="replace")
|
||||
|
||||
def _decode_script_content(self, value: Any, encoding_hint: str = "") -> str:
|
||||
encoding = (encoding_hint or "").strip().lower()
|
||||
if isinstance(value, str):
|
||||
if encoding in {"base64", "b64", "base-64"}:
|
||||
decoded = self._decode_base64_text(value)
|
||||
if decoded is not None:
|
||||
return decoded.replace("\r\n", "\n")
|
||||
decoded = self._decode_base64_text(value)
|
||||
if decoded is not None:
|
||||
return decoded.replace("\r\n", "\n")
|
||||
return value.replace("\r\n", "\n")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _encode_script_content(script_text: Any) -> str:
|
||||
if not isinstance(script_text, str):
|
||||
if script_text is None:
|
||||
script_text = ""
|
||||
else:
|
||||
script_text = str(script_text)
|
||||
normalized = script_text.replace("\r\n", "\n")
|
||||
if not normalized:
|
||||
return ""
|
||||
encoded = base64.b64encode(normalized.encode("utf-8"))
|
||||
return encoded.decode("ascii")
|
||||
|
||||
def _prepare_assembly_storage(self, document: Dict[str, Any]) -> Dict[str, Any]:
|
||||
stored: Dict[str, Any] = {}
|
||||
for key, value in (document or {}).items():
|
||||
if key == "script":
|
||||
stored[key] = self._encode_script_content(value)
|
||||
else:
|
||||
stored[key] = value
|
||||
stored["script_encoding"] = "base64"
|
||||
return stored
|
||||
|
||||
def _normalize_assembly_document(
|
||||
self,
|
||||
obj: Any,
|
||||
default_type: str,
|
||||
base_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
doc = self._empty_assembly_document(default_type)
|
||||
if not isinstance(obj, dict):
|
||||
obj = {}
|
||||
base = (base_name or "assembly").strip()
|
||||
doc["name"] = str(obj.get("name") or obj.get("display_name") or base)
|
||||
doc["description"] = str(obj.get("description") or "")
|
||||
category = str(obj.get("category") or doc["category"]).strip().lower()
|
||||
if category in {"script", "application"}:
|
||||
doc["category"] = category
|
||||
typ = str(obj.get("type") or obj.get("script_type") or default_type or "powershell").strip().lower()
|
||||
if typ in {"powershell", "batch", "bash", "ansible"}:
|
||||
doc["type"] = typ
|
||||
script_val = obj.get("script")
|
||||
content_val = obj.get("content")
|
||||
script_lines = obj.get("script_lines")
|
||||
if isinstance(script_lines, list):
|
||||
try:
|
||||
doc["script"] = "\n".join(str(line) for line in script_lines)
|
||||
except Exception:
|
||||
doc["script"] = ""
|
||||
elif isinstance(script_val, str):
|
||||
doc["script"] = script_val
|
||||
elif isinstance(content_val, str):
|
||||
doc["script"] = content_val
|
||||
encoding_hint = str(
|
||||
obj.get("script_encoding") or obj.get("scriptEncoding") or ""
|
||||
).strip().lower()
|
||||
doc["script"] = self._decode_script_content(doc.get("script"), encoding_hint)
|
||||
if encoding_hint in {"base64", "b64", "base-64"}:
|
||||
doc["script_encoding"] = "base64"
|
||||
else:
|
||||
probe_source = ""
|
||||
if isinstance(script_val, str) and script_val:
|
||||
probe_source = script_val
|
||||
elif isinstance(content_val, str) and content_val:
|
||||
probe_source = content_val
|
||||
decoded_probe = self._decode_base64_text(probe_source) if probe_source else None
|
||||
if decoded_probe is not None:
|
||||
doc["script_encoding"] = "base64"
|
||||
doc["script"] = decoded_probe.replace("\r\n", "\n")
|
||||
else:
|
||||
doc["script_encoding"] = "plain"
|
||||
timeout_val = obj.get("timeout_seconds", obj.get("timeout"))
|
||||
if timeout_val is not None:
|
||||
try:
|
||||
doc["timeout_seconds"] = max(0, int(timeout_val))
|
||||
except Exception:
|
||||
pass
|
||||
sites = obj.get("sites") if isinstance(obj.get("sites"), dict) else {}
|
||||
values = sites.get("values") if isinstance(sites.get("values"), list) else []
|
||||
mode = str(sites.get("mode") or ("specific" if values else "all")).strip().lower()
|
||||
if mode not in {"all", "specific"}:
|
||||
mode = "all"
|
||||
doc["sites"] = {
|
||||
"mode": mode,
|
||||
"values": [
|
||||
str(v).strip()
|
||||
for v in values
|
||||
if isinstance(v, (str, int, float)) and str(v).strip()
|
||||
],
|
||||
}
|
||||
vars_in = obj.get("variables") if isinstance(obj.get("variables"), list) else []
|
||||
doc_vars: List[Dict[str, Any]] = []
|
||||
for entry in vars_in:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
name = str(entry.get("name") or entry.get("key") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
vtype = str(entry.get("type") or "string").strip().lower()
|
||||
if vtype not in {"string", "number", "boolean", "credential"}:
|
||||
vtype = "string"
|
||||
default_val = entry.get("default", entry.get("default_value"))
|
||||
doc_vars.append(
|
||||
{
|
||||
"name": name,
|
||||
"label": str(entry.get("label") or ""),
|
||||
"type": vtype,
|
||||
"default": default_val,
|
||||
"required": bool(entry.get("required")),
|
||||
"description": str(entry.get("description") or ""),
|
||||
}
|
||||
)
|
||||
doc["variables"] = doc_vars
|
||||
files_in = obj.get("files") if isinstance(obj.get("files"), list) else []
|
||||
doc_files: List[Dict[str, Any]] = []
|
||||
for record in files_in:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
fname = record.get("file_name") or record.get("name")
|
||||
data = record.get("data")
|
||||
if not fname or not isinstance(data, str):
|
||||
continue
|
||||
size_val = record.get("size")
|
||||
try:
|
||||
size_int = int(size_val)
|
||||
except Exception:
|
||||
size_int = 0
|
||||
doc_files.append(
|
||||
{
|
||||
"file_name": str(fname),
|
||||
"size": size_int,
|
||||
"mime_type": str(record.get("mime_type") or record.get("mimeType") or ""),
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
doc["files"] = doc_files
|
||||
try:
|
||||
doc["version"] = int(obj.get("version") or doc["version"])
|
||||
except Exception:
|
||||
pass
|
||||
return doc
|
||||
|
||||
def _load_assembly_document(
|
||||
self,
|
||||
abs_path: Path,
|
||||
island: str,
|
||||
type_hint: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
base_name = abs_path.stem
|
||||
default_type = self._default_type_for_island(island, type_hint)
|
||||
if abs_path.suffix.lower() == ".json":
|
||||
data = self._safe_read_json(abs_path)
|
||||
return self._normalize_assembly_document(data, default_type, base_name)
|
||||
try:
|
||||
content = abs_path.read_text(encoding="utf-8", errors="replace")
|
||||
except Exception:
|
||||
content = ""
|
||||
document = self._empty_assembly_document(default_type)
|
||||
document["name"] = base_name
|
||||
document["script"] = (content or "").replace("\r\n", "\n")
|
||||
if default_type == "ansible":
|
||||
document["category"] = "application"
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def _safe_read_json(path: Path) -> Dict[str, Any]:
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _extract_tab_name(obj: Dict[str, Any]) -> str:
|
||||
if not isinstance(obj, dict):
|
||||
return ""
|
||||
for key in ["tabName", "tab_name", "name", "title"]:
|
||||
value = obj.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return ""
|
||||
|
||||
def _detect_script_type(self, path: Path) -> str:
|
||||
lower = path.name.lower()
|
||||
if lower.endswith(".json") and path.is_file():
|
||||
obj = self._safe_read_json(path)
|
||||
if isinstance(obj, dict):
|
||||
typ = str(
|
||||
obj.get("type") or obj.get("script_type") or ""
|
||||
).strip().lower()
|
||||
if typ in {"powershell", "batch", "bash", "ansible"}:
|
||||
return typ
|
||||
return "powershell"
|
||||
if lower.endswith(".yml"):
|
||||
return "ansible"
|
||||
if lower.endswith(".ps1"):
|
||||
return "powershell"
|
||||
if lower.endswith(".bat"):
|
||||
return "batch"
|
||||
if lower.endswith(".sh"):
|
||||
return "bash"
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def _ensure_workflow_document(content: Any) -> Dict[str, Any]:
|
||||
payload = content
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
payload = json.loads(payload)
|
||||
except Exception:
|
||||
payload = {}
|
||||
if not isinstance(payload, dict):
|
||||
payload = {}
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _write_json(path: Path, payload: Dict[str, Any]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||
@@ -1,63 +0,0 @@
|
||||
"""Authentication services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .device_auth_service import DeviceAuthService, DeviceRecord
|
||||
from .dpop import DPoPReplayError, DPoPVerificationError, DPoPValidator
|
||||
from .jwt_service import JWTService, load_service as load_jwt_service
|
||||
from .token_service import (
|
||||
RefreshTokenRecord,
|
||||
TokenRefreshError,
|
||||
TokenRefreshErrorCode,
|
||||
TokenService,
|
||||
)
|
||||
from .operator_account_service import (
|
||||
AccountNotFoundError,
|
||||
CannotModifySelfError,
|
||||
InvalidPasswordHashError,
|
||||
InvalidRoleError,
|
||||
LastAdminError,
|
||||
LastUserError,
|
||||
OperatorAccountError,
|
||||
OperatorAccountRecord,
|
||||
OperatorAccountService,
|
||||
UsernameAlreadyExistsError,
|
||||
)
|
||||
from .operator_auth_service import (
|
||||
InvalidCredentialsError,
|
||||
InvalidMFACodeError,
|
||||
MFAUnavailableError,
|
||||
MFASessionError,
|
||||
OperatorAuthError,
|
||||
OperatorAuthService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
"DeviceRecord",
|
||||
"DPoPReplayError",
|
||||
"DPoPVerificationError",
|
||||
"DPoPValidator",
|
||||
"JWTService",
|
||||
"load_jwt_service",
|
||||
"RefreshTokenRecord",
|
||||
"TokenRefreshError",
|
||||
"TokenRefreshErrorCode",
|
||||
"TokenService",
|
||||
"OperatorAccountService",
|
||||
"OperatorAccountError",
|
||||
"OperatorAccountRecord",
|
||||
"UsernameAlreadyExistsError",
|
||||
"AccountNotFoundError",
|
||||
"LastAdminError",
|
||||
"LastUserError",
|
||||
"CannotModifySelfError",
|
||||
"InvalidRoleError",
|
||||
"InvalidPasswordHashError",
|
||||
"OperatorAuthService",
|
||||
"OperatorAuthError",
|
||||
"InvalidCredentialsError",
|
||||
"InvalidMFACodeError",
|
||||
"MFAUnavailableError",
|
||||
"MFASessionError",
|
||||
]
|
||||
@@ -1,237 +0,0 @@
|
||||
"""Device authentication service copied from the legacy server stack."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Mapping, Optional, Protocol
|
||||
import logging
|
||||
|
||||
from Data.Engine.builders.device_auth import DeviceAuthRequest
|
||||
from Data.Engine.domain.device_auth import (
|
||||
AccessTokenClaims,
|
||||
DeviceAuthContext,
|
||||
DeviceAuthErrorCode,
|
||||
DeviceAuthFailure,
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
DeviceIdentity,
|
||||
DeviceStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeviceAuthService",
|
||||
"DeviceRecord",
|
||||
"DPoPValidator",
|
||||
"DPoPVerificationError",
|
||||
"DPoPReplayError",
|
||||
"RateLimiter",
|
||||
"RateLimitDecision",
|
||||
"DeviceRepository",
|
||||
]
|
||||
|
||||
|
||||
class RateLimitDecision(Protocol):
|
||||
allowed: bool
|
||||
retry_after: Optional[float]
|
||||
|
||||
|
||||
class RateLimiter(Protocol):
|
||||
def check(self, key: str, max_requests: int, window_seconds: float) -> RateLimitDecision: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class JWTDecoder(Protocol):
|
||||
def decode(self, token: str) -> Mapping[str, object]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class DPoPValidator(Protocol):
|
||||
def verify(
|
||||
self,
|
||||
method: str,
|
||||
htu: str,
|
||||
proof: str,
|
||||
access_token: Optional[str] = None,
|
||||
) -> str: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class DPoPVerificationError(Exception):
|
||||
"""Raised when a DPoP proof fails validation."""
|
||||
|
||||
|
||||
class DPoPReplayError(DPoPVerificationError):
|
||||
"""Raised when a DPoP proof is replayed."""
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceRecord:
|
||||
"""Snapshot of a device record required for authentication."""
|
||||
|
||||
identity: DeviceIdentity
|
||||
token_version: int
|
||||
status: DeviceStatus
|
||||
|
||||
|
||||
class DeviceRepository(Protocol):
|
||||
"""Port that exposes the minimal device persistence operations."""
|
||||
|
||||
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def recover_missing(
|
||||
self,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
token_version: int,
|
||||
service_context: Optional[str],
|
||||
) -> Optional[DeviceRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class DeviceAuthService:
|
||||
"""Authenticate devices using access tokens, repositories, and DPoP proofs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_repository: DeviceRepository,
|
||||
jwt_service: JWTDecoder,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
rate_limiter: Optional[RateLimiter] = None,
|
||||
dpop_validator: Optional[DPoPValidator] = None,
|
||||
) -> None:
|
||||
self._repository = device_repository
|
||||
self._jwt = jwt_service
|
||||
self._log = logger or logging.getLogger("borealis.engine.auth")
|
||||
self._rate_limiter = rate_limiter
|
||||
self._dpop_validator = dpop_validator
|
||||
|
||||
def authenticate(self, request: DeviceAuthRequest, *, path: str) -> DeviceAuthContext:
|
||||
"""Authenticate an access token and return the resulting context."""
|
||||
|
||||
claims = self._decode_claims(request.access_token)
|
||||
rate_limit_key = f"fp:{claims.fingerprint.value}"
|
||||
if self._rate_limiter is not None:
|
||||
decision = self._rate_limiter.check(rate_limit_key, 60, 60.0)
|
||||
if not decision.allowed:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.RATE_LIMITED,
|
||||
http_status=429,
|
||||
retry_after=decision.retry_after,
|
||||
)
|
||||
|
||||
record = self._repository.fetch_by_guid(claims.guid)
|
||||
if record is None:
|
||||
record = self._repository.recover_missing(
|
||||
claims.guid,
|
||||
claims.fingerprint,
|
||||
claims.token_version,
|
||||
request.service_context,
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DEVICE_NOT_FOUND,
|
||||
http_status=403,
|
||||
)
|
||||
|
||||
self._validate_identity(record, claims)
|
||||
|
||||
dpop_jkt = self._validate_dpop(request, record, claims)
|
||||
|
||||
context = DeviceAuthContext(
|
||||
identity=record.identity,
|
||||
access_token=request.access_token,
|
||||
claims=claims,
|
||||
status=record.status,
|
||||
service_context=request.service_context,
|
||||
dpop_jkt=dpop_jkt,
|
||||
)
|
||||
|
||||
if context.is_quarantined:
|
||||
self._log.warning(
|
||||
"device %s is quarantined; limited access for %s",
|
||||
record.identity.guid,
|
||||
path,
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
def _decode_claims(self, token: str) -> AccessTokenClaims:
|
||||
try:
|
||||
raw_claims = self._jwt.decode(token)
|
||||
except Exception as exc: # pragma: no cover - defensive fallback
|
||||
if self._is_expired_signature(exc):
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.TOKEN_EXPIRED) from exc
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_TOKEN) from exc
|
||||
|
||||
try:
|
||||
return AccessTokenClaims.from_mapping(raw_claims)
|
||||
except Exception as exc:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.INVALID_CLAIMS) from exc
|
||||
|
||||
@staticmethod
|
||||
def _is_expired_signature(exc: Exception) -> bool:
|
||||
name = exc.__class__.__name__
|
||||
return name == "ExpiredSignatureError"
|
||||
|
||||
def _validate_identity(
|
||||
self,
|
||||
record: DeviceRecord,
|
||||
claims: AccessTokenClaims,
|
||||
) -> None:
|
||||
if record.identity.guid.value != claims.guid.value:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DEVICE_GUID_MISMATCH,
|
||||
http_status=403,
|
||||
)
|
||||
|
||||
if record.identity.fingerprint.value:
|
||||
if record.identity.fingerprint.value != claims.fingerprint.value:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.FINGERPRINT_MISMATCH,
|
||||
http_status=403,
|
||||
)
|
||||
|
||||
if record.token_version > claims.token_version:
|
||||
raise DeviceAuthFailure(DeviceAuthErrorCode.TOKEN_VERSION_REVOKED)
|
||||
|
||||
if not record.status.allows_access:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DEVICE_REVOKED,
|
||||
http_status=403,
|
||||
)
|
||||
|
||||
def _validate_dpop(
|
||||
self,
|
||||
request: DeviceAuthRequest,
|
||||
record: DeviceRecord,
|
||||
claims: AccessTokenClaims,
|
||||
) -> Optional[str]:
|
||||
if not request.dpop_proof:
|
||||
return None
|
||||
|
||||
if self._dpop_validator is None:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DPOP_NOT_SUPPORTED,
|
||||
http_status=400,
|
||||
)
|
||||
|
||||
try:
|
||||
return self._dpop_validator.verify(
|
||||
request.http_method,
|
||||
request.htu,
|
||||
request.dpop_proof,
|
||||
request.access_token,
|
||||
)
|
||||
except DPoPReplayError as exc:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DPOP_REPLAYED,
|
||||
http_status=400,
|
||||
) from exc
|
||||
except DPoPVerificationError as exc:
|
||||
raise DeviceAuthFailure(
|
||||
DeviceAuthErrorCode.DPOP_INVALID,
|
||||
http_status=400,
|
||||
) from exc
|
||||
@@ -1,105 +0,0 @@
|
||||
"""DPoP proof validation for Engine services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict, Optional
|
||||
|
||||
import jwt
|
||||
|
||||
__all__ = ["DPoPValidator", "DPoPVerificationError", "DPoPReplayError"]
|
||||
|
||||
|
||||
_DP0P_MAX_SKEW = 300.0
|
||||
|
||||
|
||||
class DPoPVerificationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DPoPReplayError(DPoPVerificationError):
|
||||
pass
|
||||
|
||||
|
||||
class DPoPValidator:
|
||||
def __init__(self) -> None:
|
||||
self._observed_jti: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def verify(
|
||||
self,
|
||||
method: str,
|
||||
htu: str,
|
||||
proof: str,
|
||||
access_token: Optional[str] = None,
|
||||
) -> str:
|
||||
if not proof:
|
||||
raise DPoPVerificationError("DPoP proof missing")
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(proof)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP header") from exc
|
||||
|
||||
jwk = header.get("jwk")
|
||||
alg = header.get("alg")
|
||||
if not jwk or not isinstance(jwk, dict):
|
||||
raise DPoPVerificationError("missing jwk in DPoP header")
|
||||
if alg not in ("EdDSA", "ES256", "ES384", "ES512"):
|
||||
raise DPoPVerificationError(f"unsupported DPoP alg {alg}")
|
||||
|
||||
try:
|
||||
key = jwt.PyJWK(jwk)
|
||||
public_key = key.key
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid jwk in DPoP header") from exc
|
||||
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
proof,
|
||||
public_key,
|
||||
algorithms=[alg],
|
||||
options={"require": ["htm", "htu", "jti", "iat"]},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise DPoPVerificationError("invalid DPoP signature") from exc
|
||||
|
||||
htm = claims.get("htm")
|
||||
proof_htu = claims.get("htu")
|
||||
jti = claims.get("jti")
|
||||
iat = claims.get("iat")
|
||||
ath = claims.get("ath")
|
||||
|
||||
if not isinstance(htm, str) or htm.lower() != method.lower():
|
||||
raise DPoPVerificationError("DPoP htm mismatch")
|
||||
if not isinstance(proof_htu, str) or proof_htu != htu:
|
||||
raise DPoPVerificationError("DPoP htu mismatch")
|
||||
if not isinstance(jti, str):
|
||||
raise DPoPVerificationError("DPoP jti missing")
|
||||
if not isinstance(iat, (int, float)):
|
||||
raise DPoPVerificationError("DPoP iat missing")
|
||||
|
||||
now = time.time()
|
||||
if abs(now - float(iat)) > _DP0P_MAX_SKEW:
|
||||
raise DPoPVerificationError("DPoP proof outside allowed skew")
|
||||
|
||||
if ath and access_token:
|
||||
expected_ath = jwt.utils.base64url_encode(
|
||||
hashlib.sha256(access_token.encode("utf-8")).digest()
|
||||
).decode("ascii")
|
||||
if expected_ath != ath:
|
||||
raise DPoPVerificationError("DPoP ath mismatch")
|
||||
|
||||
with self._lock:
|
||||
expiry = self._observed_jti.get(jti)
|
||||
if expiry and expiry > now:
|
||||
raise DPoPReplayError("DPoP proof replay detected")
|
||||
self._observed_jti[jti] = now + _DP0P_MAX_SKEW
|
||||
stale = [key for key, exp in self._observed_jti.items() if exp <= now]
|
||||
for key in stale:
|
||||
self._observed_jti.pop(key, None)
|
||||
|
||||
thumbprint = jwt.PyJWK(jwk).thumbprint()
|
||||
return thumbprint.decode("ascii")
|
||||
@@ -1,124 +0,0 @@
|
||||
"""JWT issuance utilities for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from Data.Engine.runtime import ensure_runtime_dir, runtime_path
|
||||
|
||||
__all__ = ["JWTService", "load_service"]
|
||||
|
||||
|
||||
_KEY_DIR = runtime_path("auth_keys")
|
||||
_KEY_FILE = _KEY_DIR / "engine-jwt-ed25519.key"
|
||||
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-jwt-ed25519.key"
|
||||
|
||||
|
||||
class JWTService:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey, key_id: str) -> None:
|
||||
self._private_key = private_key
|
||||
self._public_key = private_key.public_key()
|
||||
self._key_id = key_id
|
||||
|
||||
@property
|
||||
def key_id(self) -> str:
|
||||
return self._key_id
|
||||
|
||||
def issue_access_token(
|
||||
self,
|
||||
guid: str,
|
||||
ssl_key_fingerprint: str,
|
||||
token_version: int,
|
||||
expires_in: int = 900,
|
||||
extra_claims: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
now = int(time.time())
|
||||
payload: Dict[str, Any] = {
|
||||
"sub": f"device:{guid}",
|
||||
"guid": guid,
|
||||
"ssl_key_fingerprint": ssl_key_fingerprint,
|
||||
"token_version": int(token_version),
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": now + int(expires_in),
|
||||
}
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
self._private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
),
|
||||
algorithm="EdDSA",
|
||||
headers={"kid": self._key_id},
|
||||
)
|
||||
return token
|
||||
|
||||
def decode(self, token: str, *, audience: Optional[str] = None) -> Dict[str, Any]:
|
||||
options = {"require": ["exp", "iat", "sub"]}
|
||||
public_pem = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=["EdDSA"],
|
||||
audience=audience,
|
||||
options=options,
|
||||
)
|
||||
|
||||
def public_jwk(self) -> Dict[str, Any]:
|
||||
public_bytes = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
jwk_x = jwt.utils.base64url_encode(public_bytes).decode("ascii")
|
||||
return {"kty": "OKP", "crv": "Ed25519", "kid": self._key_id, "alg": "EdDSA", "use": "sig", "x": jwk_x}
|
||||
|
||||
|
||||
def load_service() -> JWTService:
|
||||
private_key = _load_or_create_private_key()
|
||||
public_bytes = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
key_id = hashlib.sha256(public_bytes).hexdigest()[:16]
|
||||
return JWTService(private_key, key_id)
|
||||
|
||||
|
||||
def _load_or_create_private_key() -> ed25519.Ed25519PrivateKey:
|
||||
ensure_runtime_dir("auth_keys")
|
||||
|
||||
if _KEY_FILE.exists():
|
||||
with _KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
if _LEGACY_KEY_FILE.exists():
|
||||
with _LEGACY_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
_KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with _KEY_FILE.open("wb") as fh:
|
||||
fh.write(pem)
|
||||
try:
|
||||
if hasattr(_KEY_FILE, "chmod"):
|
||||
_KEY_FILE.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
return private_key
|
||||
@@ -1,211 +0,0 @@
|
||||
"""Operator account management service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from Data.Engine.domain import OperatorAccount
|
||||
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
|
||||
|
||||
|
||||
class OperatorAccountError(Exception):
|
||||
"""Base class for operator account management failures."""
|
||||
|
||||
|
||||
class UsernameAlreadyExistsError(OperatorAccountError):
|
||||
"""Raised when attempting to create an operator with a duplicate username."""
|
||||
|
||||
|
||||
class AccountNotFoundError(OperatorAccountError):
|
||||
"""Raised when the requested operator account cannot be located."""
|
||||
|
||||
|
||||
class LastAdminError(OperatorAccountError):
|
||||
"""Raised when attempting to demote or delete the last remaining admin."""
|
||||
|
||||
|
||||
class LastUserError(OperatorAccountError):
|
||||
"""Raised when attempting to delete the final operator account."""
|
||||
|
||||
|
||||
class CannotModifySelfError(OperatorAccountError):
|
||||
"""Raised when the caller attempts to delete themselves."""
|
||||
|
||||
|
||||
class InvalidRoleError(OperatorAccountError):
|
||||
"""Raised when a role value is invalid."""
|
||||
|
||||
|
||||
class InvalidPasswordHashError(OperatorAccountError):
|
||||
"""Raised when a password hash is malformed."""
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OperatorAccountRecord:
|
||||
username: str
|
||||
display_name: str
|
||||
role: str
|
||||
last_login: int
|
||||
created_at: int
|
||||
updated_at: int
|
||||
mfa_enabled: bool
|
||||
|
||||
|
||||
class OperatorAccountService:
|
||||
"""High-level operations for managing operator accounts."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: SQLiteUserRepository,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.operator_accounts")
|
||||
|
||||
def list_accounts(self) -> list[OperatorAccountRecord]:
|
||||
return [_to_record(account) for account in self._repository.list_accounts()]
|
||||
|
||||
def create_account(
|
||||
self,
|
||||
*,
|
||||
username: str,
|
||||
password_sha512: str,
|
||||
role: str,
|
||||
display_name: Optional[str] = None,
|
||||
) -> OperatorAccountRecord:
|
||||
normalized_role = self._normalize_role(role)
|
||||
username = (username or "").strip()
|
||||
password_sha512 = (password_sha512 or "").strip().lower()
|
||||
display_name = (display_name or username or "").strip()
|
||||
|
||||
if not username or not password_sha512:
|
||||
raise InvalidPasswordHashError("username and password are required")
|
||||
if len(password_sha512) != 128:
|
||||
raise InvalidPasswordHashError("password hash must be 128 hex characters")
|
||||
|
||||
now = int(time.time())
|
||||
try:
|
||||
self._repository.create_account(
|
||||
username=username,
|
||||
display_name=display_name or username,
|
||||
password_sha512=password_sha512,
|
||||
role=normalized_role,
|
||||
timestamp=now,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - sqlite integrity errors are deterministic
|
||||
import sqlite3
|
||||
|
||||
if isinstance(exc, sqlite3.IntegrityError):
|
||||
raise UsernameAlreadyExistsError("username already exists") from exc
|
||||
raise
|
||||
|
||||
account = self._repository.fetch_by_username(username)
|
||||
if not account: # pragma: no cover - sanity guard
|
||||
raise AccountNotFoundError("account creation failed")
|
||||
return _to_record(account)
|
||||
|
||||
def delete_account(self, username: str, *, actor: Optional[str] = None) -> None:
|
||||
username = (username or "").strip()
|
||||
if not username:
|
||||
raise AccountNotFoundError("invalid username")
|
||||
|
||||
if actor and actor.strip().lower() == username.lower():
|
||||
raise CannotModifySelfError("cannot delete yourself")
|
||||
|
||||
total_accounts = self._repository.count_accounts()
|
||||
if total_accounts <= 1:
|
||||
raise LastUserError("cannot delete the last user")
|
||||
|
||||
target = self._repository.fetch_by_username(username)
|
||||
if not target:
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
if target.role.lower() == "admin" and self._repository.count_admins() <= 1:
|
||||
raise LastAdminError("cannot delete the last admin")
|
||||
|
||||
if not self._repository.delete_account(username):
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
def reset_password(self, username: str, password_sha512: str) -> None:
|
||||
username = (username or "").strip()
|
||||
password_sha512 = (password_sha512 or "").strip().lower()
|
||||
if len(password_sha512) != 128:
|
||||
raise InvalidPasswordHashError("invalid password hash")
|
||||
|
||||
now = int(time.time())
|
||||
if not self._repository.update_password(username, password_sha512, timestamp=now):
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
def change_role(self, username: str, role: str, *, actor: Optional[str] = None) -> OperatorAccountRecord:
|
||||
username = (username or "").strip()
|
||||
normalized_role = self._normalize_role(role)
|
||||
|
||||
account = self._repository.fetch_by_username(username)
|
||||
if not account:
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
if account.role.lower() == "admin" and normalized_role.lower() != "admin":
|
||||
if self._repository.count_admins() <= 1:
|
||||
raise LastAdminError("cannot demote the last admin")
|
||||
|
||||
now = int(time.time())
|
||||
if not self._repository.update_role(username, normalized_role, timestamp=now):
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
updated = self._repository.fetch_by_username(username)
|
||||
if not updated: # pragma: no cover - guard
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
record = _to_record(updated)
|
||||
if actor and actor.strip().lower() == username.lower():
|
||||
self._log.info("actor-role-updated", extra={"username": username, "role": record.role})
|
||||
return record
|
||||
|
||||
def update_mfa(self, username: str, *, enabled: bool, reset_secret: bool) -> None:
|
||||
username = (username or "").strip()
|
||||
if not username:
|
||||
raise AccountNotFoundError("invalid username")
|
||||
|
||||
now = int(time.time())
|
||||
if not self._repository.update_mfa(username, enabled=enabled, reset_secret=reset_secret, timestamp=now):
|
||||
raise AccountNotFoundError("user not found")
|
||||
|
||||
def fetch_account(self, username: str) -> Optional[OperatorAccountRecord]:
|
||||
account = self._repository.fetch_by_username(username)
|
||||
return _to_record(account) if account else None
|
||||
|
||||
def _normalize_role(self, role: str) -> str:
|
||||
normalized = (role or "").strip().title() or "User"
|
||||
if normalized not in {"User", "Admin"}:
|
||||
raise InvalidRoleError("invalid role")
|
||||
return normalized
|
||||
|
||||
|
||||
def _to_record(account: OperatorAccount) -> OperatorAccountRecord:
|
||||
return OperatorAccountRecord(
|
||||
username=account.username,
|
||||
display_name=account.display_name or account.username,
|
||||
role=account.role or "User",
|
||||
last_login=int(account.last_login or 0),
|
||||
created_at=int(account.created_at or 0),
|
||||
updated_at=int(account.updated_at or 0),
|
||||
mfa_enabled=bool(account.mfa_enabled),
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OperatorAccountService",
|
||||
"OperatorAccountError",
|
||||
"UsernameAlreadyExistsError",
|
||||
"AccountNotFoundError",
|
||||
"LastAdminError",
|
||||
"LastUserError",
|
||||
"CannotModifySelfError",
|
||||
"InvalidRoleError",
|
||||
"InvalidPasswordHashError",
|
||||
"OperatorAccountRecord",
|
||||
]
|
||||
@@ -1,236 +0,0 @@
|
||||
"""Operator authentication service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
try: # pragma: no cover - optional dependencies mirror legacy server behaviour
|
||||
import pyotp # type: ignore
|
||||
except Exception: # pragma: no cover - gracefully degrade when unavailable
|
||||
pyotp = None # type: ignore
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import qrcode # type: ignore
|
||||
except Exception: # pragma: no cover - gracefully degrade when unavailable
|
||||
qrcode = None # type: ignore
|
||||
|
||||
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
|
||||
|
||||
from Data.Engine.builders.operator_auth import (
|
||||
OperatorLoginRequest,
|
||||
OperatorMFAVerificationRequest,
|
||||
)
|
||||
from Data.Engine.domain import (
|
||||
OperatorAccount,
|
||||
OperatorLoginSuccess,
|
||||
OperatorMFAChallenge,
|
||||
)
|
||||
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
|
||||
|
||||
|
||||
class OperatorAuthError(Exception):
|
||||
"""Base class for operator authentication errors."""
|
||||
|
||||
|
||||
class InvalidCredentialsError(OperatorAuthError):
|
||||
"""Raised when username/password verification fails."""
|
||||
|
||||
|
||||
class MFAUnavailableError(OperatorAuthError):
|
||||
"""Raised when MFA functionality is requested but dependencies are missing."""
|
||||
|
||||
|
||||
class InvalidMFACodeError(OperatorAuthError):
|
||||
"""Raised when the submitted MFA code is invalid."""
|
||||
|
||||
|
||||
class MFASessionError(OperatorAuthError):
|
||||
"""Raised when the MFA session state cannot be validated."""
|
||||
|
||||
|
||||
class OperatorAuthService:
|
||||
"""Authenticate operator accounts and manage MFA challenges."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: SQLiteUserRepository,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.operator_auth")
|
||||
|
||||
def authenticate(
|
||||
self, request: OperatorLoginRequest
|
||||
) -> OperatorLoginSuccess | OperatorMFAChallenge:
|
||||
account = self._repository.fetch_by_username(request.username)
|
||||
if not account:
|
||||
raise InvalidCredentialsError("invalid username or password")
|
||||
|
||||
if not self._password_matches(account, request.password_sha512):
|
||||
raise InvalidCredentialsError("invalid username or password")
|
||||
|
||||
if not account.mfa_enabled:
|
||||
return self._finalize_login(account)
|
||||
|
||||
stage = "verify" if account.mfa_secret else "setup"
|
||||
return self._build_mfa_challenge(account, stage)
|
||||
|
||||
def verify_mfa(
|
||||
self,
|
||||
challenge: OperatorMFAChallenge,
|
||||
request: OperatorMFAVerificationRequest,
|
||||
) -> OperatorLoginSuccess:
|
||||
now = int(time.time())
|
||||
if challenge.pending_token != request.pending_token:
|
||||
raise MFASessionError("invalid_session")
|
||||
if challenge.expires_at < now:
|
||||
raise MFASessionError("expired")
|
||||
|
||||
if challenge.stage == "setup":
|
||||
secret = (challenge.secret or "").strip()
|
||||
if not secret:
|
||||
raise MFASessionError("mfa_not_configured")
|
||||
totp = self._totp_for_secret(secret)
|
||||
if not totp.verify(request.code, valid_window=1):
|
||||
raise InvalidMFACodeError("invalid_code")
|
||||
self._repository.store_mfa_secret(challenge.username, secret, timestamp=now)
|
||||
else:
|
||||
account = self._repository.fetch_by_username(challenge.username)
|
||||
if not account or not account.mfa_secret:
|
||||
raise MFASessionError("mfa_not_configured")
|
||||
totp = self._totp_for_secret(account.mfa_secret)
|
||||
if not totp.verify(request.code, valid_window=1):
|
||||
raise InvalidMFACodeError("invalid_code")
|
||||
|
||||
account = self._repository.fetch_by_username(challenge.username)
|
||||
if not account:
|
||||
raise InvalidCredentialsError("invalid username or password")
|
||||
return self._finalize_login(account)
|
||||
|
||||
def issue_token(self, username: str, role: str) -> str:
|
||||
serializer = self._token_serializer()
|
||||
payload = {"u": username, "r": role or "User", "ts": int(time.time())}
|
||||
return serializer.dumps(payload)
|
||||
|
||||
def resolve_token(self, token: str, *, max_age: int = 30 * 24 * 3600) -> Optional[OperatorAccount]:
|
||||
"""Return the account associated with *token* if it is valid."""
|
||||
|
||||
token = (token or "").strip()
|
||||
if not token:
|
||||
return None
|
||||
|
||||
serializer = self._token_serializer()
|
||||
try:
|
||||
payload = serializer.loads(token, max_age=max_age)
|
||||
except (BadSignature, SignatureExpired):
|
||||
return None
|
||||
|
||||
username = str(payload.get("u") or "").strip()
|
||||
if not username:
|
||||
return None
|
||||
|
||||
return self._repository.fetch_by_username(username)
|
||||
|
||||
def fetch_account(self, username: str) -> Optional[OperatorAccount]:
|
||||
"""Return the operator account for *username* if it exists."""
|
||||
|
||||
username = (username or "").strip()
|
||||
if not username:
|
||||
return None
|
||||
return self._repository.fetch_by_username(username)
|
||||
|
||||
def _finalize_login(self, account: OperatorAccount) -> OperatorLoginSuccess:
|
||||
now = int(time.time())
|
||||
self._repository.update_last_login(account.username, now)
|
||||
token = self.issue_token(account.username, account.role)
|
||||
return OperatorLoginSuccess(username=account.username, role=account.role, token=token)
|
||||
|
||||
def _password_matches(self, account: OperatorAccount, provided_hash: str) -> bool:
|
||||
expected = (account.password_sha512 or "").strip().lower()
|
||||
candidate = (provided_hash or "").strip().lower()
|
||||
return bool(expected and candidate and expected == candidate)
|
||||
|
||||
def _build_mfa_challenge(
|
||||
self,
|
||||
account: OperatorAccount,
|
||||
stage: str,
|
||||
) -> OperatorMFAChallenge:
|
||||
now = int(time.time())
|
||||
pending_token = uuid.uuid4().hex
|
||||
secret = None
|
||||
otpauth_url = None
|
||||
qr_image = None
|
||||
|
||||
if stage == "setup":
|
||||
secret = self._generate_totp_secret()
|
||||
otpauth_url = self._totp_provisioning_uri(secret, account.username)
|
||||
qr_image = self._totp_qr_data_uri(otpauth_url) if otpauth_url else None
|
||||
|
||||
return OperatorMFAChallenge(
|
||||
username=account.username,
|
||||
role=account.role,
|
||||
stage="verify" if stage == "verify" else "setup",
|
||||
pending_token=pending_token,
|
||||
expires_at=now + 300,
|
||||
secret=secret,
|
||||
otpauth_url=otpauth_url,
|
||||
qr_image=qr_image,
|
||||
)
|
||||
|
||||
def _token_serializer(self) -> URLSafeTimedSerializer:
|
||||
secret = os.getenv("BOREALIS_FLASK_SECRET_KEY") or "change-me"
|
||||
return URLSafeTimedSerializer(secret, salt="borealis-auth")
|
||||
|
||||
def _generate_totp_secret(self) -> str:
|
||||
if not pyotp:
|
||||
raise MFAUnavailableError("pyotp is not installed; MFA unavailable")
|
||||
return pyotp.random_base32() # type: ignore[no-any-return]
|
||||
|
||||
def _totp_for_secret(self, secret: str):
|
||||
if not pyotp:
|
||||
raise MFAUnavailableError("pyotp is not installed; MFA unavailable")
|
||||
normalized = secret.replace(" ", "").strip().upper()
|
||||
if not normalized:
|
||||
raise MFASessionError("mfa_not_configured")
|
||||
return pyotp.TOTP(normalized, digits=6, interval=30)
|
||||
|
||||
def _totp_provisioning_uri(self, secret: str, username: str) -> Optional[str]:
|
||||
try:
|
||||
totp = self._totp_for_secret(secret)
|
||||
except OperatorAuthError:
|
||||
return None
|
||||
issuer = os.getenv("BOREALIS_MFA_ISSUER", "Borealis")
|
||||
try:
|
||||
return totp.provisioning_uri(name=username, issuer_name=issuer)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return None
|
||||
|
||||
def _totp_qr_data_uri(self, payload: str) -> Optional[str]:
|
||||
if not payload or qrcode is None:
|
||||
return None
|
||||
try:
|
||||
img = qrcode.make(payload, box_size=6, border=4)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
encoded = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{encoded}"
|
||||
except Exception: # pragma: no cover - defensive
|
||||
self._log.warning("failed to generate MFA QR code", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OperatorAuthService",
|
||||
"OperatorAuthError",
|
||||
"InvalidCredentialsError",
|
||||
"MFAUnavailableError",
|
||||
"InvalidMFACodeError",
|
||||
"MFASessionError",
|
||||
]
|
||||
@@ -1,190 +0,0 @@
|
||||
"""Token refresh service extracted from the legacy blueprint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Protocol
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from Data.Engine.builders.device_auth import RefreshTokenRequest
|
||||
from Data.Engine.domain.device_auth import DeviceGuid
|
||||
|
||||
from .device_auth_service import (
|
||||
DeviceRecord,
|
||||
DeviceRepository,
|
||||
DPoPReplayError,
|
||||
DPoPVerificationError,
|
||||
DPoPValidator,
|
||||
)
|
||||
|
||||
__all__ = ["RefreshTokenRecord", "TokenService", "TokenRefreshError", "TokenRefreshErrorCode"]
|
||||
|
||||
|
||||
class JWTIssuer(Protocol):
|
||||
def issue_access_token(self, guid: str, fingerprint: str, token_version: int) -> str: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class TokenRefreshErrorCode(str):
|
||||
INVALID_REFRESH_TOKEN = "invalid_refresh_token"
|
||||
REFRESH_TOKEN_REVOKED = "refresh_token_revoked"
|
||||
REFRESH_TOKEN_EXPIRED = "refresh_token_expired"
|
||||
DEVICE_NOT_FOUND = "device_not_found"
|
||||
DEVICE_REVOKED = "device_revoked"
|
||||
DPOP_REPLAYED = "dpop_replayed"
|
||||
DPOP_INVALID = "dpop_invalid"
|
||||
|
||||
|
||||
class TokenRefreshError(Exception):
|
||||
def __init__(self, code: str, *, http_status: int = 400) -> None:
|
||||
self.code = code
|
||||
self.http_status = http_status
|
||||
super().__init__(code)
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
return {"error": self.code}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RefreshTokenRecord:
|
||||
record_id: str
|
||||
guid: DeviceGuid
|
||||
token_hash: str
|
||||
dpop_jkt: Optional[str]
|
||||
created_at: datetime
|
||||
expires_at: Optional[datetime]
|
||||
revoked_at: Optional[datetime]
|
||||
|
||||
@classmethod
|
||||
def from_row(
|
||||
cls,
|
||||
*,
|
||||
record_id: str,
|
||||
guid: DeviceGuid,
|
||||
token_hash: str,
|
||||
dpop_jkt: Optional[str],
|
||||
created_at: datetime,
|
||||
expires_at: Optional[datetime],
|
||||
revoked_at: Optional[datetime],
|
||||
) -> "RefreshTokenRecord":
|
||||
return cls(
|
||||
record_id=record_id,
|
||||
guid=guid,
|
||||
token_hash=token_hash,
|
||||
dpop_jkt=dpop_jkt,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
revoked_at=revoked_at,
|
||||
)
|
||||
|
||||
|
||||
class RefreshTokenRepository(Protocol):
|
||||
def fetch(self, guid: DeviceGuid, token_hash: str) -> Optional[RefreshTokenRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def clear_dpop_binding(self, record_id: str) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def touch(self, record_id: str, *, last_used_at: datetime, dpop_jkt: Optional[str]) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AccessTokenResponse:
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
refresh_token_repository: RefreshTokenRepository,
|
||||
device_repository: DeviceRepository,
|
||||
jwt_service: JWTIssuer,
|
||||
dpop_validator: Optional[DPoPValidator] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._refresh_tokens = refresh_token_repository
|
||||
self._devices = device_repository
|
||||
self._jwt = jwt_service
|
||||
self._dpop_validator = dpop_validator
|
||||
self._log = logger or logging.getLogger("borealis.engine.auth")
|
||||
|
||||
def refresh_access_token(
|
||||
self,
|
||||
request: RefreshTokenRequest,
|
||||
) -> AccessTokenResponse:
|
||||
record = self._refresh_tokens.fetch(
|
||||
request.guid,
|
||||
self._hash_token(request.refresh_token),
|
||||
)
|
||||
if record is None:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.INVALID_REFRESH_TOKEN, http_status=401)
|
||||
|
||||
if record.guid.value != request.guid.value:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.INVALID_REFRESH_TOKEN, http_status=401)
|
||||
|
||||
if record.revoked_at is not None:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.REFRESH_TOKEN_REVOKED, http_status=401)
|
||||
|
||||
if record.expires_at is not None and record.expires_at <= self._now():
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.REFRESH_TOKEN_EXPIRED, http_status=401)
|
||||
|
||||
device = self._devices.fetch_by_guid(request.guid)
|
||||
if device is None:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DEVICE_NOT_FOUND, http_status=404)
|
||||
|
||||
if not device.status.allows_access:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DEVICE_REVOKED, http_status=403)
|
||||
|
||||
dpop_jkt = record.dpop_jkt or ""
|
||||
if request.dpop_proof:
|
||||
if self._dpop_validator is None:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DPOP_INVALID)
|
||||
try:
|
||||
dpop_jkt = self._dpop_validator.verify(
|
||||
request.http_method,
|
||||
request.htu,
|
||||
request.dpop_proof,
|
||||
None,
|
||||
)
|
||||
except DPoPReplayError as exc:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DPOP_REPLAYED) from exc
|
||||
except DPoPVerificationError as exc:
|
||||
raise TokenRefreshError(TokenRefreshErrorCode.DPOP_INVALID) from exc
|
||||
elif record.dpop_jkt:
|
||||
self._log.warning(
|
||||
"Clearing stored DPoP binding for guid=%s due to missing proof",
|
||||
request.guid.value,
|
||||
)
|
||||
self._refresh_tokens.clear_dpop_binding(record.record_id)
|
||||
|
||||
access_token = self._jwt.issue_access_token(
|
||||
request.guid.value,
|
||||
device.identity.fingerprint.value,
|
||||
max(device.token_version, 1),
|
||||
)
|
||||
|
||||
self._refresh_tokens.touch(
|
||||
record.record_id,
|
||||
last_used_at=self._now(),
|
||||
dpop_jkt=dpop_jkt or None,
|
||||
)
|
||||
|
||||
return AccessTokenResponse(
|
||||
access_token=access_token,
|
||||
expires_in=900,
|
||||
token_type="Bearer",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
@@ -1,238 +0,0 @@
|
||||
"""Service container assembly for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from Data.Engine.config import EngineSettings
|
||||
from Data.Engine.integrations.github import GitHubArtifactProvider
|
||||
from Data.Engine.repositories.sqlite import (
|
||||
SQLiteConnectionFactory,
|
||||
SQLiteDeviceRepository,
|
||||
SQLiteDeviceInventoryRepository,
|
||||
SQLiteDeviceViewRepository,
|
||||
SQLiteCredentialRepository,
|
||||
SQLiteEnrollmentRepository,
|
||||
SQLiteGitHubRepository,
|
||||
SQLiteJobRepository,
|
||||
SQLiteRefreshTokenRepository,
|
||||
SQLiteSiteRepository,
|
||||
SQLiteUserRepository,
|
||||
)
|
||||
from Data.Engine.services.auth import (
|
||||
DeviceAuthService,
|
||||
DPoPValidator,
|
||||
OperatorAccountService,
|
||||
OperatorAuthService,
|
||||
JWTService,
|
||||
TokenService,
|
||||
load_jwt_service,
|
||||
)
|
||||
from Data.Engine.services.crypto.signing import ScriptSigner, load_signer
|
||||
from Data.Engine.services.enrollment import EnrollmentService
|
||||
from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService
|
||||
from Data.Engine.services.enrollment.nonce_cache import NonceCache
|
||||
from Data.Engine.services.devices import DeviceInventoryService
|
||||
from Data.Engine.services.devices import DeviceViewService
|
||||
from Data.Engine.services.credentials import CredentialService
|
||||
from Data.Engine.services.github import GitHubService
|
||||
from Data.Engine.services.jobs import SchedulerService
|
||||
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
|
||||
from Data.Engine.services.realtime import AgentRealtimeService
|
||||
from Data.Engine.services.sites import SiteService
|
||||
from Data.Engine.services.assemblies import AssemblyService
|
||||
|
||||
__all__ = ["EngineServiceContainer", "build_service_container"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EngineServiceContainer:
|
||||
device_auth: DeviceAuthService
|
||||
device_inventory: DeviceInventoryService
|
||||
device_view_service: DeviceViewService
|
||||
credential_service: CredentialService
|
||||
token_service: TokenService
|
||||
enrollment_service: EnrollmentService
|
||||
enrollment_admin_service: EnrollmentAdminService
|
||||
site_service: SiteService
|
||||
jwt_service: JWTService
|
||||
dpop_validator: DPoPValidator
|
||||
agent_realtime: AgentRealtimeService
|
||||
scheduler_service: SchedulerService
|
||||
github_service: GitHubService
|
||||
operator_auth_service: OperatorAuthService
|
||||
operator_account_service: OperatorAccountService
|
||||
assembly_service: AssemblyService
|
||||
script_signer: Optional[ScriptSigner]
|
||||
|
||||
|
||||
def build_service_container(
|
||||
settings: EngineSettings,
|
||||
*,
|
||||
db_factory: SQLiteConnectionFactory,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> EngineServiceContainer:
|
||||
log = logger or logging.getLogger("borealis.engine.services")
|
||||
|
||||
device_repo = SQLiteDeviceRepository(db_factory, logger=log.getChild("devices"))
|
||||
device_inventory_repo = SQLiteDeviceInventoryRepository(
|
||||
db_factory, logger=log.getChild("devices.inventory")
|
||||
)
|
||||
device_view_repo = SQLiteDeviceViewRepository(
|
||||
db_factory, logger=log.getChild("devices.views")
|
||||
)
|
||||
credential_repo = SQLiteCredentialRepository(
|
||||
db_factory, logger=log.getChild("credentials.repo")
|
||||
)
|
||||
token_repo = SQLiteRefreshTokenRepository(db_factory, logger=log.getChild("tokens"))
|
||||
enrollment_repo = SQLiteEnrollmentRepository(db_factory, logger=log.getChild("enrollment"))
|
||||
job_repo = SQLiteJobRepository(db_factory, logger=log.getChild("jobs"))
|
||||
github_repo = SQLiteGitHubRepository(db_factory, logger=log.getChild("github_repo"))
|
||||
site_repo = SQLiteSiteRepository(db_factory, logger=log.getChild("sites.repo"))
|
||||
user_repo = SQLiteUserRepository(db_factory, logger=log.getChild("users"))
|
||||
|
||||
jwt_service = load_jwt_service()
|
||||
dpop_validator = DPoPValidator()
|
||||
rate_limiter = SlidingWindowRateLimiter()
|
||||
|
||||
token_service = TokenService(
|
||||
refresh_token_repository=token_repo,
|
||||
device_repository=device_repo,
|
||||
jwt_service=jwt_service,
|
||||
dpop_validator=dpop_validator,
|
||||
logger=log.getChild("token_service"),
|
||||
)
|
||||
|
||||
script_signer = _load_script_signer(log)
|
||||
|
||||
enrollment_service = EnrollmentService(
|
||||
device_repository=device_repo,
|
||||
enrollment_repository=enrollment_repo,
|
||||
token_repository=token_repo,
|
||||
jwt_service=jwt_service,
|
||||
tls_bundle_loader=_tls_bundle_loader(settings),
|
||||
ip_rate_limiter=SlidingWindowRateLimiter(),
|
||||
fingerprint_rate_limiter=SlidingWindowRateLimiter(),
|
||||
nonce_cache=NonceCache(),
|
||||
script_signer=script_signer,
|
||||
logger=log.getChild("enrollment"),
|
||||
)
|
||||
|
||||
enrollment_admin_service = EnrollmentAdminService(
|
||||
repository=enrollment_repo,
|
||||
user_repository=user_repo,
|
||||
logger=log.getChild("enrollment_admin"),
|
||||
)
|
||||
|
||||
device_auth = DeviceAuthService(
|
||||
device_repository=device_repo,
|
||||
jwt_service=jwt_service,
|
||||
logger=log.getChild("device_auth"),
|
||||
rate_limiter=rate_limiter,
|
||||
dpop_validator=dpop_validator,
|
||||
)
|
||||
|
||||
agent_realtime = AgentRealtimeService(
|
||||
device_repository=device_repo,
|
||||
logger=log.getChild("agent_realtime"),
|
||||
)
|
||||
|
||||
scheduler_service = SchedulerService(
|
||||
job_repository=job_repo,
|
||||
assemblies_root=settings.project_root / "Assemblies",
|
||||
logger=log.getChild("scheduler"),
|
||||
)
|
||||
|
||||
operator_auth_service = OperatorAuthService(
|
||||
repository=user_repo,
|
||||
logger=log.getChild("operator_auth"),
|
||||
)
|
||||
operator_account_service = OperatorAccountService(
|
||||
repository=user_repo,
|
||||
logger=log.getChild("operator_accounts"),
|
||||
)
|
||||
device_inventory = DeviceInventoryService(
|
||||
repository=device_inventory_repo,
|
||||
logger=log.getChild("device_inventory"),
|
||||
)
|
||||
device_view_service = DeviceViewService(
|
||||
repository=device_view_repo,
|
||||
logger=log.getChild("device_views"),
|
||||
)
|
||||
credential_service = CredentialService(
|
||||
repository=credential_repo,
|
||||
logger=log.getChild("credentials"),
|
||||
)
|
||||
site_service = SiteService(
|
||||
repository=site_repo,
|
||||
logger=log.getChild("sites"),
|
||||
)
|
||||
|
||||
assembly_service = AssemblyService(
|
||||
root=settings.project_root / "Assemblies",
|
||||
logger=log.getChild("assemblies"),
|
||||
)
|
||||
|
||||
github_provider = GitHubArtifactProvider(
|
||||
cache_file=settings.github.cache_file,
|
||||
default_repo=settings.github.default_repo,
|
||||
default_branch=settings.github.default_branch,
|
||||
refresh_interval=settings.github.refresh_interval_seconds,
|
||||
logger=log.getChild("github.provider"),
|
||||
)
|
||||
github_service = GitHubService(
|
||||
repository=github_repo,
|
||||
provider=github_provider,
|
||||
logger=log.getChild("github"),
|
||||
)
|
||||
github_service.start_background_refresh()
|
||||
|
||||
return EngineServiceContainer(
|
||||
device_auth=device_auth,
|
||||
token_service=token_service,
|
||||
enrollment_service=enrollment_service,
|
||||
enrollment_admin_service=enrollment_admin_service,
|
||||
jwt_service=jwt_service,
|
||||
dpop_validator=dpop_validator,
|
||||
agent_realtime=agent_realtime,
|
||||
scheduler_service=scheduler_service,
|
||||
github_service=github_service,
|
||||
operator_auth_service=operator_auth_service,
|
||||
operator_account_service=operator_account_service,
|
||||
device_inventory=device_inventory,
|
||||
device_view_service=device_view_service,
|
||||
credential_service=credential_service,
|
||||
site_service=site_service,
|
||||
assembly_service=assembly_service,
|
||||
script_signer=script_signer,
|
||||
)
|
||||
|
||||
|
||||
def _tls_bundle_loader(settings: EngineSettings) -> Callable[[], str]:
|
||||
candidates = [
|
||||
Path(os.getenv("BOREALIS_TLS_BUNDLE", "")),
|
||||
settings.project_root / "Certificates" / "Server" / "borealis-server-bundle.pem",
|
||||
]
|
||||
|
||||
def loader() -> str:
|
||||
for candidate in candidates:
|
||||
if candidate and candidate.is_file():
|
||||
try:
|
||||
return candidate.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
return ""
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def _load_script_signer(logger: logging.Logger) -> Optional[ScriptSigner]:
|
||||
try:
|
||||
return load_signer()
|
||||
except Exception as exc:
|
||||
logger.warning("script signer unavailable: %s", exc)
|
||||
return None
|
||||
@@ -1,3 +0,0 @@
|
||||
from .credential_service import CredentialService
|
||||
|
||||
__all__ = ["CredentialService"]
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Expose read access to stored credentials."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from Data.Engine.repositories.sqlite.credential_repository import SQLiteCredentialRepository
|
||||
|
||||
__all__ = ["CredentialService"]
|
||||
|
||||
|
||||
class CredentialService:
|
||||
def __init__(
|
||||
self,
|
||||
repository: SQLiteCredentialRepository,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._repo = repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.credentials")
|
||||
|
||||
def list_credentials(
|
||||
self,
|
||||
*,
|
||||
site_id: Optional[int] = None,
|
||||
connection_type: Optional[str] = None,
|
||||
) -> List[dict]:
|
||||
return self._repo.list_credentials(site_id=site_id, connection_type=connection_type)
|
||||
@@ -1,366 +0,0 @@
|
||||
"""Server TLS certificate management for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import ssl
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
|
||||
|
||||
from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path
|
||||
|
||||
__all__ = [
|
||||
"build_ssl_context",
|
||||
"certificate_paths",
|
||||
"ensure_certificate",
|
||||
]
|
||||
|
||||
_CERT_DIR = server_certificates_path()
|
||||
_CERT_FILE = _CERT_DIR / "borealis-server-cert.pem"
|
||||
_KEY_FILE = _CERT_DIR / "borealis-server-key.pem"
|
||||
_BUNDLE_FILE = _CERT_DIR / "borealis-server-bundle.pem"
|
||||
_CA_KEY_FILE = _CERT_DIR / "borealis-root-ca-key.pem"
|
||||
_CA_CERT_FILE = _CERT_DIR / "borealis-root-ca.pem"
|
||||
|
||||
_LEGACY_CERT_DIR = runtime_path("certs")
|
||||
_LEGACY_CERT_FILE = _LEGACY_CERT_DIR / "borealis-server-cert.pem"
|
||||
_LEGACY_KEY_FILE = _LEGACY_CERT_DIR / "borealis-server-key.pem"
|
||||
_LEGACY_BUNDLE_FILE = _LEGACY_CERT_DIR / "borealis-server-bundle.pem"
|
||||
|
||||
_ROOT_COMMON_NAME = "Borealis Root CA"
|
||||
_ORG_NAME = "Borealis"
|
||||
_ROOT_VALIDITY = timedelta(days=365 * 100)
|
||||
_SERVER_VALIDITY = timedelta(days=365 * 5)
|
||||
|
||||
|
||||
def ensure_certificate(common_name: str = "Borealis Engine") -> Tuple[Path, Path, Path]:
|
||||
"""Ensure the root CA, server certificate, and bundle exist on disk."""
|
||||
|
||||
ensure_server_certificates_dir()
|
||||
_migrate_legacy_material_if_present()
|
||||
|
||||
ca_key, ca_cert, ca_regenerated = _ensure_root_ca()
|
||||
|
||||
server_cert = _load_certificate(_CERT_FILE)
|
||||
needs_regen = ca_regenerated or _server_certificate_needs_regeneration(server_cert, ca_cert)
|
||||
if needs_regen:
|
||||
server_cert = _generate_server_certificate(common_name, ca_key, ca_cert)
|
||||
|
||||
if server_cert is None:
|
||||
server_cert = _generate_server_certificate(common_name, ca_key, ca_cert)
|
||||
|
||||
_write_bundle(server_cert, ca_cert)
|
||||
|
||||
return _CERT_FILE, _KEY_FILE, _BUNDLE_FILE
|
||||
|
||||
|
||||
def _migrate_legacy_material_if_present() -> None:
|
||||
if not _CERT_FILE.exists() or not _KEY_FILE.exists():
|
||||
legacy_cert = _LEGACY_CERT_FILE
|
||||
legacy_key = _LEGACY_KEY_FILE
|
||||
if legacy_cert.exists() and legacy_key.exists():
|
||||
try:
|
||||
ensure_server_certificates_dir()
|
||||
if not _CERT_FILE.exists():
|
||||
_safe_copy(legacy_cert, _CERT_FILE)
|
||||
if not _KEY_FILE.exists():
|
||||
_safe_copy(legacy_key, _KEY_FILE)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_root_ca() -> Tuple[ec.EllipticCurvePrivateKey, x509.Certificate, bool]:
|
||||
regenerated = False
|
||||
|
||||
ca_key: Optional[ec.EllipticCurvePrivateKey] = None
|
||||
ca_cert: Optional[x509.Certificate] = None
|
||||
|
||||
if _CA_KEY_FILE.exists() and _CA_CERT_FILE.exists():
|
||||
try:
|
||||
ca_key = _load_private_key(_CA_KEY_FILE)
|
||||
ca_cert = _load_certificate(_CA_CERT_FILE)
|
||||
if ca_cert is not None and ca_key is not None:
|
||||
expiry = _cert_not_after(ca_cert)
|
||||
subject = ca_cert.subject
|
||||
subject_cn = ""
|
||||
try:
|
||||
subject_cn = subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value # type: ignore[index]
|
||||
except Exception:
|
||||
subject_cn = ""
|
||||
try:
|
||||
basic = ca_cert.extensions.get_extension_for_class(x509.BasicConstraints).value # type: ignore[attr-defined]
|
||||
is_ca = bool(basic.ca)
|
||||
except Exception:
|
||||
is_ca = False
|
||||
if (
|
||||
expiry <= datetime.now(tz=timezone.utc)
|
||||
or not is_ca
|
||||
or subject_cn != _ROOT_COMMON_NAME
|
||||
):
|
||||
regenerated = True
|
||||
else:
|
||||
regenerated = True
|
||||
except Exception:
|
||||
regenerated = True
|
||||
else:
|
||||
regenerated = True
|
||||
|
||||
if regenerated or ca_key is None or ca_cert is None:
|
||||
ca_key = ec.generate_private_key(ec.SECP384R1())
|
||||
public_key = ca_key.public_key()
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
builder = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, _ROOT_COMMON_NAME),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
|
||||
]
|
||||
)
|
||||
)
|
||||
.issuer_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, _ROOT_COMMON_NAME),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
|
||||
]
|
||||
)
|
||||
)
|
||||
.public_key(public_key)
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now - timedelta(minutes=5))
|
||||
.not_valid_after(now + _ROOT_VALIDITY)
|
||||
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=True,
|
||||
crl_sign=True,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
)
|
||||
|
||||
builder = builder.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
|
||||
ca_cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA384())
|
||||
|
||||
_CA_KEY_FILE.write_bytes(
|
||||
ca_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
_CA_CERT_FILE.write_bytes(ca_cert.public_bytes(serialization.Encoding.PEM))
|
||||
|
||||
_tighten_permissions(_CA_KEY_FILE)
|
||||
_tighten_permissions(_CA_CERT_FILE)
|
||||
else:
|
||||
regenerated = False
|
||||
|
||||
return ca_key, ca_cert, regenerated
|
||||
|
||||
|
||||
def _server_certificate_needs_regeneration(
|
||||
server_cert: Optional[x509.Certificate],
|
||||
ca_cert: x509.Certificate,
|
||||
) -> bool:
|
||||
if server_cert is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
if server_cert.issuer != ca_cert.subject:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
try:
|
||||
expiry = _cert_not_after(server_cert)
|
||||
if expiry <= datetime.now(tz=timezone.utc):
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
try:
|
||||
basic = server_cert.extensions.get_extension_for_class(x509.BasicConstraints).value # type: ignore[attr-defined]
|
||||
if basic.ca:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
try:
|
||||
eku = server_cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value # type: ignore[attr-defined]
|
||||
if ExtendedKeyUsageOID.SERVER_AUTH not in eku:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _generate_server_certificate(
|
||||
common_name: str,
|
||||
ca_key: ec.EllipticCurvePrivateKey,
|
||||
ca_cert: x509.Certificate,
|
||||
) -> x509.Certificate:
|
||||
private_key = ec.generate_private_key(ec.SECP384R1())
|
||||
public_key = private_key.public_key()
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
ca_expiry = _cert_not_after(ca_cert)
|
||||
candidate_expiry = now + _SERVER_VALIDITY
|
||||
not_after = min(ca_expiry - timedelta(days=1), candidate_expiry)
|
||||
|
||||
builder = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, _ORG_NAME),
|
||||
]
|
||||
)
|
||||
)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(public_key)
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now - timedelta(minutes=5))
|
||||
.not_valid_after(not_after)
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName(
|
||||
[
|
||||
x509.DNSName("localhost"),
|
||||
x509.DNSName("127.0.0.1"),
|
||||
x509.DNSName("::1"),
|
||||
]
|
||||
),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=False,
|
||||
crl_sign=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_key.public_key()),
|
||||
critical=False,
|
||||
)
|
||||
)
|
||||
|
||||
certificate = builder.sign(private_key=ca_key, algorithm=hashes.SHA384())
|
||||
|
||||
_KEY_FILE.write_bytes(
|
||||
private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
_CERT_FILE.write_bytes(certificate.public_bytes(serialization.Encoding.PEM))
|
||||
|
||||
_tighten_permissions(_KEY_FILE)
|
||||
_tighten_permissions(_CERT_FILE)
|
||||
|
||||
return certificate
|
||||
|
||||
|
||||
def _write_bundle(server_cert: x509.Certificate, ca_cert: x509.Certificate) -> None:
|
||||
try:
|
||||
server_pem = server_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip()
|
||||
ca_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
bundle = f"{server_pem}\n{ca_pem}\n"
|
||||
_BUNDLE_FILE.write_text(bundle, encoding="utf-8")
|
||||
_tighten_permissions(_BUNDLE_FILE)
|
||||
|
||||
|
||||
def _safe_copy(src: Path, dst: Path) -> None:
|
||||
try:
|
||||
dst.write_bytes(src.read_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _tighten_permissions(path: Path) -> None:
|
||||
try:
|
||||
if os.name == "posix":
|
||||
path.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _load_private_key(path: Path) -> ec.EllipticCurvePrivateKey:
|
||||
with path.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
|
||||
def _load_certificate(path: Path) -> Optional[x509.Certificate]:
|
||||
try:
|
||||
return x509.load_pem_x509_certificate(path.read_bytes())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _cert_not_after(cert: x509.Certificate) -> datetime:
|
||||
try:
|
||||
return cert.not_valid_after_utc # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
value = cert.not_valid_after
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value
|
||||
|
||||
|
||||
def build_ssl_context() -> ssl.SSLContext:
|
||||
cert_path, key_path, bundle_path = ensure_certificate()
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||
context.load_cert_chain(certfile=str(bundle_path), keyfile=str(key_path))
|
||||
return context
|
||||
|
||||
|
||||
def certificate_paths() -> Tuple[str, str, str]:
|
||||
cert_path, key_path, bundle_path = ensure_certificate()
|
||||
return str(cert_path), str(key_path), str(bundle_path)
|
||||
@@ -1,75 +0,0 @@
|
||||
"""Script signing utilities for the Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from Data.Engine.integrations.crypto.keys import base64_from_spki_der
|
||||
from Data.Engine.runtime import ensure_server_certificates_dir, runtime_path, server_certificates_path
|
||||
|
||||
__all__ = ["ScriptSigner", "load_signer"]
|
||||
|
||||
|
||||
_KEY_DIR = server_certificates_path("Code-Signing")
|
||||
_SIGNING_KEY_FILE = _KEY_DIR / "engine-script-ed25519.key"
|
||||
_SIGNING_PUB_FILE = _KEY_DIR / "engine-script-ed25519.pub"
|
||||
_LEGACY_KEY_FILE = runtime_path("keys") / "borealis-script-ed25519.key"
|
||||
_LEGACY_PUB_FILE = runtime_path("keys") / "borealis-script-ed25519.pub"
|
||||
|
||||
|
||||
class ScriptSigner:
|
||||
def __init__(self, private_key: ed25519.Ed25519PrivateKey) -> None:
|
||||
self._private = private_key
|
||||
self._public = private_key.public_key()
|
||||
|
||||
def sign(self, payload: bytes) -> bytes:
|
||||
return self._private.sign(payload)
|
||||
|
||||
def public_spki_der(self) -> bytes:
|
||||
return self._public.public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def public_base64_spki(self) -> str:
|
||||
return base64_from_spki_der(self.public_spki_der())
|
||||
|
||||
|
||||
def load_signer() -> ScriptSigner:
|
||||
private_key = _load_or_create()
|
||||
return ScriptSigner(private_key)
|
||||
|
||||
|
||||
def _load_or_create() -> ed25519.Ed25519PrivateKey:
|
||||
ensure_server_certificates_dir("Code-Signing")
|
||||
|
||||
if _SIGNING_KEY_FILE.exists():
|
||||
with _SIGNING_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
if _LEGACY_KEY_FILE.exists():
|
||||
with _LEGACY_KEY_FILE.open("rb") as fh:
|
||||
return serialization.load_pem_private_key(fh.read(), password=None)
|
||||
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
_KEY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_SIGNING_KEY_FILE.write_bytes(pem)
|
||||
try:
|
||||
if hasattr(_SIGNING_KEY_FILE, "chmod"):
|
||||
_SIGNING_KEY_FILE.chmod(0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
pub_der = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
_SIGNING_PUB_FILE.write_bytes(pub_der)
|
||||
|
||||
return private_key
|
||||
@@ -1,15 +0,0 @@
|
||||
from .device_inventory_service import (
|
||||
DeviceDescriptionError,
|
||||
DeviceDetailsError,
|
||||
DeviceInventoryService,
|
||||
RemoteDeviceError,
|
||||
)
|
||||
from .device_view_service import DeviceViewService
|
||||
|
||||
__all__ = [
|
||||
"DeviceInventoryService",
|
||||
"RemoteDeviceError",
|
||||
"DeviceViewService",
|
||||
"DeviceDetailsError",
|
||||
"DeviceDescriptionError",
|
||||
]
|
||||
@@ -1,575 +0,0 @@
|
||||
"""Mirrors the legacy device inventory HTTP behaviour."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from Data.Engine.repositories.sqlite.device_inventory_repository import (
|
||||
SQLiteDeviceInventoryRepository,
|
||||
)
|
||||
from Data.Engine.domain.device_auth import DeviceAuthContext, normalize_guid
|
||||
from Data.Engine.domain.devices import clean_device_str, coerce_int, ts_to_human
|
||||
|
||||
__all__ = [
|
||||
"DeviceInventoryService",
|
||||
"RemoteDeviceError",
|
||||
"DeviceHeartbeatError",
|
||||
"DeviceDetailsError",
|
||||
"DeviceDescriptionError",
|
||||
]
|
||||
|
||||
|
||||
class RemoteDeviceError(Exception):
|
||||
def __init__(self, code: str, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or code)
|
||||
self.code = code
|
||||
|
||||
|
||||
class DeviceHeartbeatError(Exception):
|
||||
def __init__(self, code: str, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or code)
|
||||
self.code = code
|
||||
|
||||
|
||||
class DeviceDetailsError(Exception):
|
||||
def __init__(self, code: str, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or code)
|
||||
self.code = code
|
||||
|
||||
|
||||
class DeviceDescriptionError(Exception):
|
||||
def __init__(self, code: str, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or code)
|
||||
self.code = code
|
||||
|
||||
|
||||
class DeviceInventoryService:
|
||||
def __init__(
|
||||
self,
|
||||
repository: SQLiteDeviceInventoryRepository,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._repo = repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.devices")
|
||||
|
||||
def list_devices(self) -> List[Dict[str, object]]:
|
||||
return self._repo.fetch_devices()
|
||||
|
||||
def list_agent_devices(self) -> List[Dict[str, object]]:
|
||||
return self._repo.fetch_devices(only_agents=True)
|
||||
|
||||
def list_remote_devices(self, connection_type: str) -> List[Dict[str, object]]:
|
||||
return self._repo.fetch_devices(connection_type=connection_type)
|
||||
|
||||
def get_device_by_guid(self, guid: str) -> Optional[Dict[str, object]]:
|
||||
snapshot = self._repo.load_snapshot(guid=guid)
|
||||
if not snapshot:
|
||||
return None
|
||||
devices = self._repo.fetch_devices(hostname=snapshot.get("hostname"))
|
||||
return devices[0] if devices else None
|
||||
|
||||
def get_device_details(self, hostname: str) -> Dict[str, object]:
|
||||
normalized_host = clean_device_str(hostname)
|
||||
if not normalized_host:
|
||||
return {}
|
||||
|
||||
snapshot = self._repo.load_snapshot(hostname=normalized_host)
|
||||
if not snapshot:
|
||||
return {}
|
||||
|
||||
summary = dict(snapshot.get("summary") or {})
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"details": snapshot.get("details", {}),
|
||||
"summary": summary,
|
||||
"description": snapshot.get("description")
|
||||
or summary.get("description")
|
||||
or "",
|
||||
"created_at": snapshot.get("created_at") or 0,
|
||||
"agent_hash": snapshot.get("agent_hash")
|
||||
or summary.get("agent_hash")
|
||||
or "",
|
||||
"agent_guid": snapshot.get("agent_guid")
|
||||
or summary.get("agent_guid")
|
||||
or "",
|
||||
"memory": snapshot.get("memory", []),
|
||||
"network": snapshot.get("network", []),
|
||||
"software": snapshot.get("software", []),
|
||||
"storage": snapshot.get("storage", []),
|
||||
"cpu": snapshot.get("cpu", {}),
|
||||
"device_type": snapshot.get("device_type")
|
||||
or summary.get("device_type")
|
||||
or "",
|
||||
"domain": snapshot.get("domain")
|
||||
or summary.get("domain")
|
||||
or "",
|
||||
"external_ip": snapshot.get("external_ip")
|
||||
or summary.get("external_ip")
|
||||
or "",
|
||||
"internal_ip": snapshot.get("internal_ip")
|
||||
or summary.get("internal_ip")
|
||||
or "",
|
||||
"last_reboot": snapshot.get("last_reboot")
|
||||
or summary.get("last_reboot")
|
||||
or "",
|
||||
"last_seen": snapshot.get("last_seen")
|
||||
or summary.get("last_seen")
|
||||
or 0,
|
||||
"last_user": snapshot.get("last_user")
|
||||
or summary.get("last_user")
|
||||
or "",
|
||||
"operating_system": snapshot.get("operating_system")
|
||||
or summary.get("operating_system")
|
||||
or summary.get("agent_operating_system")
|
||||
or "",
|
||||
"uptime": snapshot.get("uptime")
|
||||
or summary.get("uptime")
|
||||
or 0,
|
||||
"agent_id": snapshot.get("agent_id")
|
||||
or summary.get("agent_id")
|
||||
or "",
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
def collect_agent_hash_records(self) -> List[Dict[str, object]]:
|
||||
records: List[Dict[str, object]] = []
|
||||
key_to_index: Dict[str, int] = {}
|
||||
|
||||
for device in self._repo.fetch_devices():
|
||||
summary = device.get("summary", {}) if isinstance(device, dict) else {}
|
||||
agent_id = (summary.get("agent_id") or "").strip()
|
||||
agent_guid = (summary.get("agent_guid") or "").strip()
|
||||
hostname = (summary.get("hostname") or device.get("hostname") or "").strip()
|
||||
agent_hash = (summary.get("agent_hash") or device.get("agent_hash") or "").strip()
|
||||
|
||||
keys: List[str] = []
|
||||
if agent_id:
|
||||
keys.append(f"id:{agent_id.lower()}")
|
||||
if agent_guid:
|
||||
keys.append(f"guid:{agent_guid.lower()}")
|
||||
if hostname:
|
||||
keys.append(f"host:{hostname.lower()}")
|
||||
|
||||
payload = {
|
||||
"agent_id": agent_id or None,
|
||||
"agent_guid": agent_guid or None,
|
||||
"hostname": hostname or None,
|
||||
"agent_hash": agent_hash or None,
|
||||
"source": "database",
|
||||
}
|
||||
|
||||
if not keys:
|
||||
records.append(payload)
|
||||
continue
|
||||
|
||||
existing_index = None
|
||||
for key in keys:
|
||||
if key in key_to_index:
|
||||
existing_index = key_to_index[key]
|
||||
break
|
||||
|
||||
if existing_index is None:
|
||||
existing_index = len(records)
|
||||
records.append(payload)
|
||||
for key in keys:
|
||||
key_to_index[key] = existing_index
|
||||
continue
|
||||
|
||||
merged = records[existing_index]
|
||||
for key in ("agent_id", "agent_guid", "hostname", "agent_hash"):
|
||||
if not merged.get(key) and payload.get(key):
|
||||
merged[key] = payload[key]
|
||||
|
||||
return records
|
||||
|
||||
def upsert_remote_device(
|
||||
self,
|
||||
connection_type: str,
|
||||
hostname: str,
|
||||
address: Optional[str],
|
||||
description: Optional[str],
|
||||
os_hint: Optional[str],
|
||||
*,
|
||||
ensure_existing_type: Optional[str],
|
||||
) -> Dict[str, object]:
|
||||
normalized_type = (connection_type or "").strip().lower()
|
||||
if not normalized_type:
|
||||
raise RemoteDeviceError("invalid_type", "connection type required")
|
||||
normalized_host = (hostname or "").strip()
|
||||
if not normalized_host:
|
||||
raise RemoteDeviceError("invalid_hostname", "hostname is required")
|
||||
|
||||
existing = self._repo.load_snapshot(hostname=normalized_host)
|
||||
existing_type = (existing or {}).get("summary", {}).get("connection_type") or ""
|
||||
existing_type = existing_type.strip().lower()
|
||||
|
||||
if ensure_existing_type and existing_type != ensure_existing_type.lower():
|
||||
raise RemoteDeviceError("not_found", "device not found")
|
||||
if ensure_existing_type is None and existing_type and existing_type != normalized_type:
|
||||
raise RemoteDeviceError("conflict", "device already exists with different connection type")
|
||||
|
||||
created_ts = None
|
||||
if existing:
|
||||
created_ts = existing.get("summary", {}).get("created_at")
|
||||
|
||||
endpoint = (address or "").strip() or (existing or {}).get("summary", {}).get("connection_endpoint") or ""
|
||||
if not endpoint:
|
||||
raise RemoteDeviceError("address_required", "address is required")
|
||||
|
||||
description_val = description if description is not None else (existing or {}).get("summary", {}).get("description")
|
||||
os_value = os_hint or (existing or {}).get("summary", {}).get("operating_system")
|
||||
os_value = (os_value or "").strip()
|
||||
|
||||
device_type_label = "SSH Remote" if normalized_type == "ssh" else "WinRM Remote"
|
||||
|
||||
summary_payload = {
|
||||
"connection_type": normalized_type,
|
||||
"connection_endpoint": endpoint,
|
||||
"internal_ip": endpoint,
|
||||
"external_ip": endpoint,
|
||||
"device_type": device_type_label,
|
||||
"operating_system": os_value or "",
|
||||
"last_seen": 0,
|
||||
"description": (description_val or ""),
|
||||
}
|
||||
|
||||
try:
|
||||
self._repo.upsert_device(
|
||||
normalized_host,
|
||||
description_val,
|
||||
{"summary": summary_payload},
|
||||
created_ts,
|
||||
)
|
||||
except sqlite3.DatabaseError as exc: # type: ignore[name-defined]
|
||||
raise RemoteDeviceError("storage_error", str(exc)) from exc
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
raise RemoteDeviceError("storage_error", str(exc)) from exc
|
||||
|
||||
devices = self._repo.fetch_devices(hostname=normalized_host)
|
||||
if not devices:
|
||||
raise RemoteDeviceError("reload_failed", "failed to load device after upsert")
|
||||
return devices[0]
|
||||
|
||||
def delete_remote_device(self, connection_type: str, hostname: str) -> None:
|
||||
normalized_host = (hostname or "").strip()
|
||||
if not normalized_host:
|
||||
raise RemoteDeviceError("invalid_hostname", "invalid hostname")
|
||||
existing = self._repo.load_snapshot(hostname=normalized_host)
|
||||
if not existing:
|
||||
raise RemoteDeviceError("not_found", "device not found")
|
||||
existing_type = (existing.get("summary", {}) or {}).get("connection_type") or ""
|
||||
if (existing_type or "").strip().lower() != (connection_type or "").strip().lower():
|
||||
raise RemoteDeviceError("not_found", "device not found")
|
||||
self._repo.delete_device_by_hostname(normalized_host)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent heartbeats
|
||||
# ------------------------------------------------------------------
|
||||
def record_heartbeat(
|
||||
self,
|
||||
*,
|
||||
context: DeviceAuthContext,
|
||||
payload: Mapping[str, Any],
|
||||
) -> None:
|
||||
guid = context.identity.guid.value
|
||||
snapshot = self._repo.load_snapshot(guid=guid)
|
||||
if not snapshot:
|
||||
raise DeviceHeartbeatError("device_not_registered", "device not registered")
|
||||
|
||||
summary = dict(snapshot.get("summary") or {})
|
||||
details = dict(snapshot.get("details") or {})
|
||||
|
||||
now_ts = int(time.time())
|
||||
summary["last_seen"] = now_ts
|
||||
summary["agent_guid"] = guid
|
||||
|
||||
existing_hostname = clean_device_str(summary.get("hostname")) or clean_device_str(
|
||||
snapshot.get("hostname")
|
||||
)
|
||||
incoming_hostname = clean_device_str(payload.get("hostname"))
|
||||
raw_metrics = payload.get("metrics")
|
||||
metrics = raw_metrics if isinstance(raw_metrics, Mapping) else {}
|
||||
metrics_hostname = clean_device_str(metrics.get("hostname")) if metrics else None
|
||||
hostname = incoming_hostname or metrics_hostname or existing_hostname
|
||||
if not hostname:
|
||||
hostname = f"RECOVERED-{guid[:12]}"
|
||||
summary["hostname"] = hostname
|
||||
|
||||
if metrics:
|
||||
last_user = metrics.get("last_user") or metrics.get("username") or metrics.get("user")
|
||||
if last_user:
|
||||
cleaned_user = clean_device_str(last_user)
|
||||
if cleaned_user:
|
||||
summary["last_user"] = cleaned_user
|
||||
operating_system = metrics.get("operating_system")
|
||||
if operating_system:
|
||||
cleaned_os = clean_device_str(operating_system)
|
||||
if cleaned_os:
|
||||
summary["operating_system"] = cleaned_os
|
||||
uptime = metrics.get("uptime")
|
||||
if uptime is not None:
|
||||
coerced = coerce_int(uptime)
|
||||
if coerced is not None:
|
||||
summary["uptime"] = coerced
|
||||
agent_id = metrics.get("agent_id")
|
||||
if agent_id:
|
||||
cleaned_agent = clean_device_str(agent_id)
|
||||
if cleaned_agent:
|
||||
summary["agent_id"] = cleaned_agent
|
||||
|
||||
for field in ("external_ip", "internal_ip", "device_type"):
|
||||
value = payload.get(field)
|
||||
cleaned = clean_device_str(value)
|
||||
if cleaned:
|
||||
summary[field] = cleaned
|
||||
|
||||
summary.setdefault("description", summary.get("description") or "")
|
||||
created_at = coerce_int(summary.get("created_at"))
|
||||
if created_at is None:
|
||||
created_at = coerce_int(snapshot.get("created_at"))
|
||||
if created_at is None:
|
||||
created_at = now_ts
|
||||
summary["created_at"] = created_at
|
||||
|
||||
raw_inventory = payload.get("inventory")
|
||||
inventory = raw_inventory if isinstance(raw_inventory, Mapping) else {}
|
||||
memory = inventory.get("memory") if isinstance(inventory.get("memory"), list) else details.get("memory")
|
||||
network = inventory.get("network") if isinstance(inventory.get("network"), list) else details.get("network")
|
||||
software = (
|
||||
inventory.get("software") if isinstance(inventory.get("software"), list) else details.get("software")
|
||||
)
|
||||
storage = inventory.get("storage") if isinstance(inventory.get("storage"), list) else details.get("storage")
|
||||
cpu = inventory.get("cpu") if isinstance(inventory.get("cpu"), Mapping) else details.get("cpu")
|
||||
|
||||
merged_details: Dict[str, Any] = {
|
||||
"summary": summary,
|
||||
"memory": memory,
|
||||
"network": network,
|
||||
"software": software,
|
||||
"storage": storage,
|
||||
"cpu": cpu,
|
||||
}
|
||||
|
||||
try:
|
||||
self._repo.upsert_device(
|
||||
summary["hostname"],
|
||||
summary.get("description"),
|
||||
merged_details,
|
||||
summary.get("created_at"),
|
||||
agent_hash=clean_device_str(summary.get("agent_hash")),
|
||||
guid=guid,
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
self._log.warning(
|
||||
"device-heartbeat-conflict guid=%s hostname=%s error=%s",
|
||||
guid,
|
||||
summary["hostname"],
|
||||
exc,
|
||||
)
|
||||
raise DeviceHeartbeatError("storage_conflict", str(exc)) from exc
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
self._log.exception(
|
||||
"device-heartbeat-failure guid=%s hostname=%s",
|
||||
guid,
|
||||
summary["hostname"],
|
||||
exc_info=exc,
|
||||
)
|
||||
raise DeviceHeartbeatError("storage_error", "failed to persist heartbeat") from exc
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent details
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _is_empty(value: Any) -> bool:
|
||||
return value in (None, "", [], {})
|
||||
|
||||
@classmethod
|
||||
def _deep_merge_preserve(cls, prev: Dict[str, Any], incoming: Dict[str, Any]) -> Dict[str, Any]:
|
||||
merged: Dict[str, Any] = dict(prev or {})
|
||||
for key, value in (incoming or {}).items():
|
||||
if isinstance(value, Mapping):
|
||||
existing = merged.get(key)
|
||||
if not isinstance(existing, Mapping):
|
||||
existing = {}
|
||||
merged[key] = cls._deep_merge_preserve(dict(existing), dict(value))
|
||||
elif isinstance(value, list):
|
||||
if value:
|
||||
merged[key] = value
|
||||
else:
|
||||
if cls._is_empty(value):
|
||||
continue
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
def save_agent_details(
|
||||
self,
|
||||
*,
|
||||
context: DeviceAuthContext,
|
||||
payload: Mapping[str, Any],
|
||||
) -> None:
|
||||
hostname = clean_device_str(payload.get("hostname"))
|
||||
details_raw = payload.get("details")
|
||||
agent_id = clean_device_str(payload.get("agent_id"))
|
||||
agent_hash = clean_device_str(payload.get("agent_hash"))
|
||||
|
||||
if not isinstance(details_raw, Mapping):
|
||||
raise DeviceDetailsError("invalid_payload", "details object required")
|
||||
|
||||
details_dict: Dict[str, Any]
|
||||
try:
|
||||
details_dict = json.loads(json.dumps(details_raw))
|
||||
except Exception:
|
||||
details_dict = dict(details_raw)
|
||||
|
||||
incoming_summary = dict(details_dict.get("summary") or {})
|
||||
if not hostname:
|
||||
hostname = clean_device_str(incoming_summary.get("hostname"))
|
||||
if not hostname:
|
||||
raise DeviceDetailsError("invalid_payload", "hostname required")
|
||||
|
||||
snapshot = self._repo.load_snapshot(hostname=hostname)
|
||||
if not snapshot:
|
||||
snapshot = {}
|
||||
|
||||
previous_details = snapshot.get("details")
|
||||
if isinstance(previous_details, Mapping):
|
||||
try:
|
||||
prev_details = json.loads(json.dumps(previous_details))
|
||||
except Exception:
|
||||
prev_details = dict(previous_details)
|
||||
else:
|
||||
prev_details = {}
|
||||
|
||||
prev_summary = dict(prev_details.get("summary") or {})
|
||||
|
||||
existing_guid = clean_device_str(snapshot.get("guid") or snapshot.get("summary", {}).get("agent_guid"))
|
||||
normalized_existing_guid = normalize_guid(existing_guid)
|
||||
auth_guid = context.identity.guid.value
|
||||
|
||||
if normalized_existing_guid and normalized_existing_guid != auth_guid:
|
||||
raise DeviceDetailsError("guid_mismatch", "device guid mismatch")
|
||||
|
||||
fingerprint = context.identity.fingerprint.value.lower()
|
||||
stored_fp = clean_device_str(snapshot.get("summary", {}).get("ssl_key_fingerprint"))
|
||||
if stored_fp and stored_fp.lower() != fingerprint:
|
||||
raise DeviceDetailsError("fingerprint_mismatch", "device fingerprint mismatch")
|
||||
|
||||
incoming_summary.setdefault("hostname", hostname)
|
||||
if agent_id and not incoming_summary.get("agent_id"):
|
||||
incoming_summary["agent_id"] = agent_id
|
||||
if agent_hash:
|
||||
incoming_summary["agent_hash"] = agent_hash
|
||||
incoming_summary["agent_guid"] = auth_guid
|
||||
if fingerprint:
|
||||
incoming_summary["ssl_key_fingerprint"] = fingerprint
|
||||
if not incoming_summary.get("last_seen") and prev_summary.get("last_seen"):
|
||||
incoming_summary["last_seen"] = prev_summary.get("last_seen")
|
||||
|
||||
details_dict["summary"] = incoming_summary
|
||||
merged_details = self._deep_merge_preserve(prev_details, details_dict)
|
||||
merged_summary = merged_details.setdefault("summary", {})
|
||||
|
||||
if not merged_summary.get("last_user") and prev_summary.get("last_user"):
|
||||
merged_summary["last_user"] = prev_summary.get("last_user")
|
||||
|
||||
created_at = coerce_int(merged_summary.get("created_at"))
|
||||
if created_at is None:
|
||||
created_at = coerce_int(snapshot.get("created_at"))
|
||||
if created_at is None:
|
||||
created_at = int(time.time())
|
||||
merged_summary["created_at"] = created_at
|
||||
if not merged_summary.get("created"):
|
||||
merged_summary["created"] = ts_to_human(created_at)
|
||||
|
||||
if fingerprint:
|
||||
merged_summary["ssl_key_fingerprint"] = fingerprint
|
||||
if not merged_summary.get("key_added_at"):
|
||||
merged_summary["key_added_at"] = datetime.now(timezone.utc).isoformat()
|
||||
if merged_summary.get("token_version") is None:
|
||||
merged_summary["token_version"] = 1
|
||||
if not merged_summary.get("status") and snapshot.get("summary", {}).get("status"):
|
||||
merged_summary["status"] = snapshot.get("summary", {}).get("status")
|
||||
uptime_val = merged_summary.get("uptime")
|
||||
if merged_summary.get("uptime_sec") is None and uptime_val is not None:
|
||||
coerced = coerce_int(uptime_val)
|
||||
if coerced is not None:
|
||||
merged_summary["uptime_sec"] = coerced
|
||||
merged_summary.setdefault("uptime_seconds", coerced)
|
||||
if merged_summary.get("uptime_seconds") is None and merged_summary.get("uptime_sec") is not None:
|
||||
merged_summary["uptime_seconds"] = merged_summary.get("uptime_sec")
|
||||
|
||||
description = clean_device_str(merged_summary.get("description"))
|
||||
existing_description = snapshot.get("description") if snapshot else ""
|
||||
description_to_store = description if description is not None else (existing_description or "")
|
||||
|
||||
existing_hash = clean_device_str(snapshot.get("agent_hash") or snapshot.get("summary", {}).get("agent_hash"))
|
||||
effective_hash = agent_hash or existing_hash
|
||||
|
||||
try:
|
||||
self._repo.upsert_device(
|
||||
hostname,
|
||||
description_to_store,
|
||||
merged_details,
|
||||
created_at,
|
||||
agent_hash=effective_hash,
|
||||
guid=auth_guid,
|
||||
)
|
||||
except sqlite3.DatabaseError as exc:
|
||||
raise DeviceDetailsError("storage_error", str(exc)) from exc
|
||||
|
||||
added_at = merged_summary.get("key_added_at") or datetime.now(timezone.utc).isoformat()
|
||||
self._repo.record_device_fingerprint(auth_guid, fingerprint, added_at)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Description management
|
||||
# ------------------------------------------------------------------
|
||||
def update_device_description(self, hostname: str, description: Optional[str]) -> None:
|
||||
normalized_host = clean_device_str(hostname)
|
||||
if not normalized_host:
|
||||
raise DeviceDescriptionError("invalid_hostname", "invalid hostname")
|
||||
|
||||
snapshot = self._repo.load_snapshot(hostname=normalized_host)
|
||||
if not snapshot:
|
||||
raise DeviceDescriptionError("not_found", "device not found")
|
||||
|
||||
details = snapshot.get("details")
|
||||
if isinstance(details, Mapping):
|
||||
try:
|
||||
existing = json.loads(json.dumps(details))
|
||||
except Exception:
|
||||
existing = dict(details)
|
||||
else:
|
||||
existing = {}
|
||||
|
||||
summary = dict(existing.get("summary") or {})
|
||||
summary["description"] = description or ""
|
||||
existing["summary"] = summary
|
||||
|
||||
created_at = coerce_int(summary.get("created_at"))
|
||||
if created_at is None:
|
||||
created_at = coerce_int(snapshot.get("created_at"))
|
||||
if created_at is None:
|
||||
created_at = int(time.time())
|
||||
|
||||
agent_hash = clean_device_str(summary.get("agent_hash") or snapshot.get("agent_hash"))
|
||||
guid = clean_device_str(summary.get("agent_guid") or snapshot.get("guid"))
|
||||
|
||||
try:
|
||||
self._repo.upsert_device(
|
||||
normalized_host,
|
||||
description or (snapshot.get("description") or ""),
|
||||
existing,
|
||||
created_at,
|
||||
agent_hash=agent_hash,
|
||||
guid=guid,
|
||||
)
|
||||
except sqlite3.DatabaseError as exc:
|
||||
raise DeviceDescriptionError("storage_error", str(exc)) from exc
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Service exposing CRUD for saved device list views."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from Data.Engine.domain.device_views import DeviceListView
|
||||
from Data.Engine.repositories.sqlite.device_view_repository import SQLiteDeviceViewRepository
|
||||
|
||||
__all__ = ["DeviceViewService"]
|
||||
|
||||
|
||||
class DeviceViewService:
|
||||
def __init__(
|
||||
self,
|
||||
repository: SQLiteDeviceViewRepository,
|
||||
*,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._repo = repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.device_views")
|
||||
|
||||
def list_views(self) -> List[DeviceListView]:
|
||||
return self._repo.list_views()
|
||||
|
||||
def get_view(self, view_id: int) -> Optional[DeviceListView]:
|
||||
return self._repo.get_view(view_id)
|
||||
|
||||
def create_view(self, name: str, columns: List[str], filters: dict) -> DeviceListView:
|
||||
normalized_name = (name or "").strip()
|
||||
if not normalized_name:
|
||||
raise ValueError("missing_name")
|
||||
if normalized_name.lower() == "default view":
|
||||
raise ValueError("reserved")
|
||||
return self._repo.create_view(normalized_name, list(columns), dict(filters))
|
||||
|
||||
def update_view(
|
||||
self,
|
||||
view_id: int,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
filters: Optional[dict] = None,
|
||||
) -> DeviceListView:
|
||||
updates: dict = {}
|
||||
if name is not None:
|
||||
normalized = (name or "").strip()
|
||||
if not normalized:
|
||||
raise ValueError("missing_name")
|
||||
if normalized.lower() == "default view":
|
||||
raise ValueError("reserved")
|
||||
updates["name"] = normalized
|
||||
if columns is not None:
|
||||
if not isinstance(columns, list) or not all(isinstance(col, str) for col in columns):
|
||||
raise ValueError("invalid_columns")
|
||||
updates["columns"] = list(columns)
|
||||
if filters is not None:
|
||||
if not isinstance(filters, dict):
|
||||
raise ValueError("invalid_filters")
|
||||
updates["filters"] = dict(filters)
|
||||
if not updates:
|
||||
raise ValueError("no_fields")
|
||||
return self._repo.update_view(
|
||||
view_id,
|
||||
name=updates.get("name"),
|
||||
columns=updates.get("columns"),
|
||||
filters=updates.get("filters"),
|
||||
)
|
||||
|
||||
def delete_view(self, view_id: int) -> bool:
|
||||
return self._repo.delete_view(view_id)
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
"""Enrollment services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentService",
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"PollingResult",
|
||||
"EnrollmentValidationError",
|
||||
"EnrollmentAdminService",
|
||||
]
|
||||
|
||||
_LAZY: dict[str, tuple[str, str]] = {
|
||||
"EnrollmentService": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentService"),
|
||||
"EnrollmentRequestResult": (
|
||||
"Data.Engine.services.enrollment.enrollment_service",
|
||||
"EnrollmentRequestResult",
|
||||
),
|
||||
"EnrollmentStatus": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentStatus"),
|
||||
"EnrollmentTokenBundle": (
|
||||
"Data.Engine.services.enrollment.enrollment_service",
|
||||
"EnrollmentTokenBundle",
|
||||
),
|
||||
"PollingResult": ("Data.Engine.services.enrollment.enrollment_service", "PollingResult"),
|
||||
"EnrollmentValidationError": (
|
||||
"Data.Engine.domain.device_enrollment",
|
||||
"EnrollmentValidationError",
|
||||
),
|
||||
"EnrollmentAdminService": (
|
||||
"Data.Engine.services.enrollment.admin_service",
|
||||
"EnrollmentAdminService",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
try:
|
||||
module_name, attribute = _LAZY[name]
|
||||
except KeyError as exc: # pragma: no cover - defensive
|
||||
raise AttributeError(name) from exc
|
||||
|
||||
module = import_module(module_name)
|
||||
value = getattr(module, attribute)
|
||||
globals()[name] = value
|
||||
return value
|
||||
|
||||
|
||||
def __dir__() -> list[str]: # pragma: no cover - interactive helper
|
||||
return sorted(set(__all__))
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
"""Administrative helpers for enrollment workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from Data.Engine.domain.device_auth import DeviceGuid, normalize_guid
|
||||
from Data.Engine.domain.device_enrollment import EnrollmentApprovalStatus
|
||||
from Data.Engine.domain.enrollment_admin import DeviceApprovalRecord, EnrollmentCodeRecord
|
||||
from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository
|
||||
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
|
||||
|
||||
__all__ = ["EnrollmentAdminService", "DeviceApprovalActionResult"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DeviceApprovalActionResult:
|
||||
"""Outcome metadata returned after mutating an approval."""
|
||||
|
||||
status: str
|
||||
conflict_resolution: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
payload = {"status": self.status}
|
||||
if self.conflict_resolution:
|
||||
payload["conflict_resolution"] = self.conflict_resolution
|
||||
return payload
|
||||
|
||||
|
||||
class EnrollmentAdminService:
|
||||
"""Expose administrative enrollment operations."""
|
||||
|
||||
_VALID_TTL_HOURS = {1, 3, 6, 12, 24}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: SQLiteEnrollmentRepository,
|
||||
user_repository: SQLiteUserRepository,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
clock: Optional[Callable[[], datetime]] = None,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._users = user_repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.enrollment_admin")
|
||||
self._clock = clock or (lambda: datetime.now(tz=timezone.utc))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Enrollment install codes
|
||||
# ------------------------------------------------------------------
|
||||
def list_install_codes(self, *, status: Optional[str] = None) -> List[EnrollmentCodeRecord]:
|
||||
return self._repository.list_install_codes(status=status, now=self._clock())
|
||||
|
||||
def create_install_code(
|
||||
self,
|
||||
*,
|
||||
ttl_hours: int,
|
||||
max_uses: int,
|
||||
created_by: Optional[str],
|
||||
) -> EnrollmentCodeRecord:
|
||||
if ttl_hours not in self._VALID_TTL_HOURS:
|
||||
raise ValueError("invalid_ttl")
|
||||
|
||||
normalized_max = self._normalize_max_uses(max_uses)
|
||||
|
||||
now = self._clock()
|
||||
expires_at = now + timedelta(hours=ttl_hours)
|
||||
record_id = str(uuid.uuid4())
|
||||
code = self._generate_install_code()
|
||||
|
||||
created_by_identifier = None
|
||||
if created_by:
|
||||
created_by_identifier = self._users.resolve_identifier(created_by)
|
||||
if not created_by_identifier:
|
||||
created_by_identifier = created_by.strip() or None
|
||||
|
||||
record = self._repository.insert_install_code(
|
||||
record_id=record_id,
|
||||
code=code,
|
||||
expires_at=expires_at,
|
||||
created_by=created_by_identifier,
|
||||
max_uses=normalized_max,
|
||||
)
|
||||
|
||||
self._log.info(
|
||||
"install code created id=%s ttl=%sh max_uses=%s",
|
||||
record.record_id,
|
||||
ttl_hours,
|
||||
normalized_max,
|
||||
)
|
||||
|
||||
return record
|
||||
|
||||
def delete_install_code(self, record_id: str) -> bool:
|
||||
deleted = self._repository.delete_install_code_if_unused(record_id)
|
||||
if deleted:
|
||||
self._log.info("install code deleted id=%s", record_id)
|
||||
return deleted
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Device approvals
|
||||
# ------------------------------------------------------------------
|
||||
def list_device_approvals(self, *, status: Optional[str] = None) -> List[DeviceApprovalRecord]:
|
||||
return self._repository.list_device_approvals(status=status)
|
||||
|
||||
def approve_device_approval(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
actor: Optional[str],
|
||||
guid: Optional[str] = None,
|
||||
conflict_resolution: Optional[str] = None,
|
||||
) -> DeviceApprovalActionResult:
|
||||
return self._set_device_approval_status(
|
||||
record_id,
|
||||
EnrollmentApprovalStatus.APPROVED,
|
||||
actor=actor,
|
||||
guid=guid,
|
||||
conflict_resolution=conflict_resolution,
|
||||
)
|
||||
|
||||
def deny_device_approval(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
actor: Optional[str],
|
||||
) -> DeviceApprovalActionResult:
|
||||
return self._set_device_approval_status(
|
||||
record_id,
|
||||
EnrollmentApprovalStatus.DENIED,
|
||||
actor=actor,
|
||||
guid=None,
|
||||
conflict_resolution=None,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _generate_install_code() -> str:
|
||||
raw = secrets.token_hex(16).upper()
|
||||
return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))
|
||||
|
||||
@staticmethod
|
||||
def _normalize_max_uses(value: int) -> int:
|
||||
try:
|
||||
count = int(value)
|
||||
except Exception:
|
||||
count = 2
|
||||
if count < 1:
|
||||
return 1
|
||||
if count > 10:
|
||||
return 10
|
||||
return count
|
||||
|
||||
def _set_device_approval_status(
|
||||
self,
|
||||
record_id: str,
|
||||
status: EnrollmentApprovalStatus,
|
||||
*,
|
||||
actor: Optional[str],
|
||||
guid: Optional[str],
|
||||
conflict_resolution: Optional[str],
|
||||
) -> DeviceApprovalActionResult:
|
||||
approval = self._repository.fetch_device_approval(record_id)
|
||||
if approval is None:
|
||||
raise LookupError("not_found")
|
||||
|
||||
if approval.status is not EnrollmentApprovalStatus.PENDING:
|
||||
raise ValueError("approval_not_pending")
|
||||
|
||||
normalized_guid = normalize_guid(guid) or (approval.guid.value if approval.guid else "")
|
||||
resolution_normalized = (conflict_resolution or "").strip().lower() or None
|
||||
|
||||
fingerprint_match = False
|
||||
conflict_guid: Optional[str] = None
|
||||
|
||||
if status is EnrollmentApprovalStatus.APPROVED:
|
||||
pending_records = self._repository.list_device_approvals(status="pending")
|
||||
current_record = next(
|
||||
(record for record in pending_records if record.record_id == approval.record_id),
|
||||
None,
|
||||
)
|
||||
|
||||
conflict = current_record.hostname_conflict if current_record else None
|
||||
if conflict:
|
||||
conflict_guid = normalize_guid(conflict.guid)
|
||||
fingerprint_match = bool(conflict.fingerprint_match)
|
||||
|
||||
if fingerprint_match:
|
||||
normalized_guid = conflict_guid or normalized_guid or ""
|
||||
if resolution_normalized is None:
|
||||
resolution_normalized = "auto_merge_fingerprint"
|
||||
elif resolution_normalized == "overwrite":
|
||||
normalized_guid = conflict_guid or normalized_guid or ""
|
||||
elif resolution_normalized == "coexist":
|
||||
pass
|
||||
else:
|
||||
raise ValueError("conflict_resolution_required")
|
||||
|
||||
if normalized_guid:
|
||||
try:
|
||||
guid_value = DeviceGuid(normalized_guid)
|
||||
except ValueError as exc:
|
||||
raise ValueError("invalid_guid") from exc
|
||||
else:
|
||||
guid_value = None
|
||||
|
||||
actor_identifier = None
|
||||
if actor:
|
||||
actor_identifier = self._users.resolve_identifier(actor)
|
||||
if not actor_identifier:
|
||||
actor_identifier = actor.strip() or None
|
||||
if not actor_identifier:
|
||||
actor_identifier = "system"
|
||||
|
||||
self._repository.update_device_approval_status(
|
||||
approval.record_id,
|
||||
status=status,
|
||||
updated_at=self._clock(),
|
||||
approved_by=actor_identifier,
|
||||
guid=guid_value,
|
||||
)
|
||||
|
||||
if status is EnrollmentApprovalStatus.APPROVED:
|
||||
self._log.info(
|
||||
"device approval %s approved resolution=%s guid=%s",
|
||||
approval.record_id,
|
||||
resolution_normalized or "",
|
||||
guid_value.value if guid_value else normalized_guid or "",
|
||||
)
|
||||
else:
|
||||
self._log.info("device approval %s denied", approval.record_id)
|
||||
|
||||
return DeviceApprovalActionResult(
|
||||
status=status.value,
|
||||
conflict_resolution=resolution_normalized,
|
||||
)
|
||||
|
||||
@@ -1,487 +0,0 @@
|
||||
"""Enrollment workflow orchestration for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable, Optional, Protocol
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
from Data.Engine.builders.device_enrollment import EnrollmentRequestInput
|
||||
from Data.Engine.domain.device_auth import (
|
||||
DeviceFingerprint,
|
||||
DeviceGuid,
|
||||
sanitize_service_context,
|
||||
)
|
||||
from Data.Engine.domain.device_enrollment import (
|
||||
EnrollmentApproval,
|
||||
EnrollmentApprovalStatus,
|
||||
EnrollmentCode,
|
||||
EnrollmentValidationError,
|
||||
)
|
||||
from Data.Engine.services.auth.device_auth_service import DeviceRecord
|
||||
from Data.Engine.services.auth.token_service import JWTIssuer
|
||||
from Data.Engine.services.enrollment.nonce_cache import NonceCache
|
||||
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
|
||||
|
||||
__all__ = [
|
||||
"EnrollmentRequestResult",
|
||||
"EnrollmentService",
|
||||
"EnrollmentStatus",
|
||||
"EnrollmentTokenBundle",
|
||||
"PollingResult",
|
||||
]
|
||||
|
||||
|
||||
class DeviceRepository(Protocol):
|
||||
def fetch_by_guid(self, guid: DeviceGuid) -> Optional[DeviceRecord]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def ensure_device_record(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
hostname: str,
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> DeviceRecord: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def record_device_key(
|
||||
self,
|
||||
*,
|
||||
guid: DeviceGuid,
|
||||
fingerprint: DeviceFingerprint,
|
||||
added_at: datetime,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class EnrollmentRepository(Protocol):
|
||||
def fetch_install_code(self, code: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_install_code_by_id(self, record_id: str) -> Optional[EnrollmentCode]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_install_code_usage(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
use_count_increment: int,
|
||||
last_used_at: datetime,
|
||||
used_by_guid: Optional[DeviceGuid] = None,
|
||||
mark_first_use: bool = False,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def fetch_pending_approval_by_fingerprint(
|
||||
self, fingerprint: DeviceFingerprint
|
||||
) -> Optional[EnrollmentApproval]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_pending_approval(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
hostname: str,
|
||||
guid: Optional[DeviceGuid],
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
updated_at: datetime,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def create_device_approval(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
reference: str,
|
||||
claimed_hostname: str,
|
||||
claimed_fingerprint: DeviceFingerprint,
|
||||
enrollment_code_id: Optional[str],
|
||||
client_nonce_b64: str,
|
||||
server_nonce_b64: str,
|
||||
agent_pubkey_der: bytes,
|
||||
created_at: datetime,
|
||||
status: EnrollmentApprovalStatus = EnrollmentApprovalStatus.PENDING,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> EnrollmentApproval: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def update_device_approval_status(
|
||||
self,
|
||||
record_id: str,
|
||||
*,
|
||||
status: EnrollmentApprovalStatus,
|
||||
updated_at: datetime,
|
||||
approved_by: Optional[str] = None,
|
||||
guid: Optional[DeviceGuid] = None,
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class RefreshTokenRepository(Protocol):
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
record_id: str,
|
||||
guid: DeviceGuid,
|
||||
token_hash: str,
|
||||
created_at: datetime,
|
||||
expires_at: Optional[datetime],
|
||||
) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class ScriptSigner(Protocol):
|
||||
def public_base64_spki(self) -> str: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class EnrollmentStatus(str):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
EXPIRED = "expired"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentTokenBundle:
|
||||
guid: DeviceGuid
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
token_type: str = "Bearer"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EnrollmentRequestResult:
|
||||
status: EnrollmentStatus
|
||||
server_certificate: str
|
||||
signing_key: str
|
||||
approval_reference: Optional[str] = None
|
||||
server_nonce: Optional[str] = None
|
||||
poll_after_ms: Optional[int] = None
|
||||
http_status: int = 200
|
||||
retry_after: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PollingResult:
|
||||
status: EnrollmentStatus
|
||||
http_status: int
|
||||
poll_after_ms: Optional[int] = None
|
||||
reason: Optional[str] = None
|
||||
detail: Optional[str] = None
|
||||
tokens: Optional[EnrollmentTokenBundle] = None
|
||||
server_certificate: Optional[str] = None
|
||||
signing_key: Optional[str] = None
|
||||
|
||||
|
||||
class EnrollmentService:
|
||||
"""Coordinate the Borealis device enrollment handshake."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_repository: DeviceRepository,
|
||||
enrollment_repository: EnrollmentRepository,
|
||||
token_repository: RefreshTokenRepository,
|
||||
jwt_service: JWTIssuer,
|
||||
tls_bundle_loader: Callable[[], str],
|
||||
ip_rate_limiter: SlidingWindowRateLimiter,
|
||||
fingerprint_rate_limiter: SlidingWindowRateLimiter,
|
||||
nonce_cache: NonceCache,
|
||||
script_signer: Optional[ScriptSigner] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._devices = device_repository
|
||||
self._enrollment = enrollment_repository
|
||||
self._tokens = token_repository
|
||||
self._jwt = jwt_service
|
||||
self._load_tls_bundle = tls_bundle_loader
|
||||
self._ip_rate_limiter = ip_rate_limiter
|
||||
self._fp_rate_limiter = fingerprint_rate_limiter
|
||||
self._nonce_cache = nonce_cache
|
||||
self._signer = script_signer
|
||||
self._log = logger or logging.getLogger("borealis.engine.enrollment")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
def request_enrollment(
|
||||
self,
|
||||
payload: EnrollmentRequestInput,
|
||||
*,
|
||||
remote_addr: str,
|
||||
) -> EnrollmentRequestResult:
|
||||
context_hint = sanitize_service_context(payload.service_context)
|
||||
self._log.info(
|
||||
"enrollment-request ip=%s host=%s code_mask=%s", remote_addr, payload.hostname, self._mask_code(payload.enrollment_code)
|
||||
)
|
||||
|
||||
self._enforce_rate_limit(self._ip_rate_limiter, f"ip:{remote_addr}")
|
||||
self._enforce_rate_limit(self._fp_rate_limiter, f"fp:{payload.fingerprint.value}")
|
||||
|
||||
install_code = self._enrollment.fetch_install_code(payload.enrollment_code)
|
||||
reuse_guid = self._determine_reuse_guid(install_code, payload.fingerprint)
|
||||
|
||||
server_nonce_bytes = secrets.token_bytes(32)
|
||||
server_nonce_b64 = base64.b64encode(server_nonce_bytes).decode("ascii")
|
||||
|
||||
now = self._now()
|
||||
approval = self._enrollment.fetch_pending_approval_by_fingerprint(payload.fingerprint)
|
||||
if approval:
|
||||
self._enrollment.update_pending_approval(
|
||||
approval.record_id,
|
||||
hostname=payload.hostname,
|
||||
guid=reuse_guid,
|
||||
enrollment_code_id=install_code.identifier if install_code else None,
|
||||
client_nonce_b64=payload.client_nonce_b64,
|
||||
server_nonce_b64=server_nonce_b64,
|
||||
agent_pubkey_der=payload.agent_public_key_der,
|
||||
updated_at=now,
|
||||
)
|
||||
approval_reference = approval.reference
|
||||
else:
|
||||
record_id = str(uuid.uuid4())
|
||||
approval_reference = str(uuid.uuid4())
|
||||
approval = self._enrollment.create_device_approval(
|
||||
record_id=record_id,
|
||||
reference=approval_reference,
|
||||
claimed_hostname=payload.hostname,
|
||||
claimed_fingerprint=payload.fingerprint,
|
||||
enrollment_code_id=install_code.identifier if install_code else None,
|
||||
client_nonce_b64=payload.client_nonce_b64,
|
||||
server_nonce_b64=server_nonce_b64,
|
||||
agent_pubkey_der=payload.agent_public_key_der,
|
||||
created_at=now,
|
||||
guid=reuse_guid,
|
||||
)
|
||||
|
||||
signing_key = self._signer.public_base64_spki() if self._signer else ""
|
||||
certificate = self._load_tls_bundle()
|
||||
|
||||
return EnrollmentRequestResult(
|
||||
status=EnrollmentStatus.PENDING,
|
||||
approval_reference=approval.reference,
|
||||
server_nonce=server_nonce_b64,
|
||||
poll_after_ms=3000,
|
||||
server_certificate=certificate,
|
||||
signing_key=signing_key,
|
||||
)
|
||||
|
||||
def poll_enrollment(
|
||||
self,
|
||||
*,
|
||||
approval_reference: str,
|
||||
client_nonce_b64: str,
|
||||
proof_signature_b64: str,
|
||||
) -> PollingResult:
|
||||
if not approval_reference:
|
||||
raise EnrollmentValidationError("approval_reference_required")
|
||||
if not client_nonce_b64:
|
||||
raise EnrollmentValidationError("client_nonce_required")
|
||||
if not proof_signature_b64:
|
||||
raise EnrollmentValidationError("proof_sig_required")
|
||||
|
||||
approval = self._enrollment.fetch_device_approval_by_reference(approval_reference)
|
||||
if approval is None:
|
||||
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=404)
|
||||
|
||||
client_nonce = self._decode_base64(client_nonce_b64, "invalid_client_nonce")
|
||||
server_nonce = self._decode_base64(approval.server_nonce_b64, "server_nonce_invalid")
|
||||
proof_sig = self._decode_base64(proof_signature_b64, "invalid_proof_sig")
|
||||
|
||||
if approval.client_nonce_b64 != client_nonce_b64:
|
||||
raise EnrollmentValidationError("nonce_mismatch")
|
||||
|
||||
self._verify_proof_signature(
|
||||
approval=approval,
|
||||
client_nonce=client_nonce,
|
||||
server_nonce=server_nonce,
|
||||
signature=proof_sig,
|
||||
)
|
||||
|
||||
status = approval.status
|
||||
if status is EnrollmentApprovalStatus.PENDING:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.PENDING,
|
||||
http_status=200,
|
||||
poll_after_ms=5000,
|
||||
)
|
||||
if status is EnrollmentApprovalStatus.DENIED:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.DENIED,
|
||||
http_status=200,
|
||||
reason="operator_denied",
|
||||
)
|
||||
if status is EnrollmentApprovalStatus.EXPIRED:
|
||||
return PollingResult(status=EnrollmentStatus.EXPIRED, http_status=200)
|
||||
if status is EnrollmentApprovalStatus.COMPLETED:
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.APPROVED,
|
||||
http_status=200,
|
||||
detail="finalized",
|
||||
)
|
||||
if status is not EnrollmentApprovalStatus.APPROVED:
|
||||
return PollingResult(status=EnrollmentStatus.UNKNOWN, http_status=400)
|
||||
|
||||
nonce_key = f"{approval.reference}:{proof_signature_b64}"
|
||||
if not self._nonce_cache.consume(nonce_key):
|
||||
raise EnrollmentValidationError("proof_replayed", http_status=409)
|
||||
|
||||
token_bundle = self._finalize_approval(approval)
|
||||
signing_key = self._signer.public_base64_spki() if self._signer else ""
|
||||
certificate = self._load_tls_bundle()
|
||||
|
||||
return PollingResult(
|
||||
status=EnrollmentStatus.APPROVED,
|
||||
http_status=200,
|
||||
tokens=token_bundle,
|
||||
server_certificate=certificate,
|
||||
signing_key=signing_key,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
def _enforce_rate_limit(
|
||||
self,
|
||||
limiter: SlidingWindowRateLimiter,
|
||||
key: str,
|
||||
*,
|
||||
limit: int = 60,
|
||||
window_seconds: float = 60.0,
|
||||
) -> None:
|
||||
decision = limiter.check(key, limit, window_seconds)
|
||||
if not decision.allowed:
|
||||
raise EnrollmentValidationError(
|
||||
"rate_limited", http_status=429, retry_after=max(decision.retry_after, 1.0)
|
||||
)
|
||||
|
||||
def _determine_reuse_guid(
|
||||
self,
|
||||
install_code: Optional[EnrollmentCode],
|
||||
fingerprint: DeviceFingerprint,
|
||||
) -> Optional[DeviceGuid]:
|
||||
if install_code is None:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
if install_code.is_expired:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
if not install_code.is_exhausted:
|
||||
return None
|
||||
if not install_code.used_by_guid:
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
|
||||
existing = self._devices.fetch_by_guid(install_code.used_by_guid)
|
||||
if existing and existing.identity.fingerprint.value == fingerprint.value:
|
||||
return install_code.used_by_guid
|
||||
raise EnrollmentValidationError("invalid_enrollment_code")
|
||||
|
||||
def _finalize_approval(self, approval: EnrollmentApproval) -> EnrollmentTokenBundle:
|
||||
now = self._now()
|
||||
effective_guid = approval.guid or DeviceGuid(str(uuid.uuid4()))
|
||||
device_record = self._devices.ensure_device_record(
|
||||
guid=effective_guid,
|
||||
hostname=approval.claimed_hostname,
|
||||
fingerprint=approval.claimed_fingerprint,
|
||||
)
|
||||
self._devices.record_device_key(
|
||||
guid=effective_guid,
|
||||
fingerprint=approval.claimed_fingerprint,
|
||||
added_at=now,
|
||||
)
|
||||
|
||||
if approval.enrollment_code_id:
|
||||
code = self._enrollment.fetch_install_code_by_id(approval.enrollment_code_id)
|
||||
if code is not None:
|
||||
mark_first = code.used_at is None
|
||||
self._enrollment.update_install_code_usage(
|
||||
approval.enrollment_code_id,
|
||||
use_count_increment=1,
|
||||
last_used_at=now,
|
||||
used_by_guid=effective_guid,
|
||||
mark_first_use=mark_first,
|
||||
)
|
||||
|
||||
refresh_token = secrets.token_urlsafe(48)
|
||||
refresh_id = str(uuid.uuid4())
|
||||
expires_at = now + timedelta(days=30)
|
||||
token_hash = hashlib.sha256(refresh_token.encode("utf-8")).hexdigest()
|
||||
self._tokens.create(
|
||||
record_id=refresh_id,
|
||||
guid=effective_guid,
|
||||
token_hash=token_hash,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
access_token = self._jwt.issue_access_token(
|
||||
effective_guid.value,
|
||||
device_record.identity.fingerprint.value,
|
||||
max(device_record.token_version, 1),
|
||||
)
|
||||
|
||||
self._enrollment.update_device_approval_status(
|
||||
approval.record_id,
|
||||
status=EnrollmentApprovalStatus.COMPLETED,
|
||||
updated_at=now,
|
||||
guid=effective_guid,
|
||||
)
|
||||
|
||||
return EnrollmentTokenBundle(
|
||||
guid=effective_guid,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=900,
|
||||
)
|
||||
|
||||
def _verify_proof_signature(
|
||||
self,
|
||||
*,
|
||||
approval: EnrollmentApproval,
|
||||
client_nonce: bytes,
|
||||
server_nonce: bytes,
|
||||
signature: bytes,
|
||||
) -> None:
|
||||
message = server_nonce + approval.reference.encode("utf-8") + client_nonce
|
||||
try:
|
||||
public_key = serialization.load_der_public_key(approval.agent_pubkey_der)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError("agent_pubkey_invalid") from exc
|
||||
|
||||
try:
|
||||
public_key.verify(signature, message)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError("invalid_proof") from exc
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64(value: str, error_code: str) -> bytes:
|
||||
try:
|
||||
return base64.b64decode(value, validate=True)
|
||||
except Exception as exc:
|
||||
raise EnrollmentValidationError(error_code) from exc
|
||||
|
||||
@staticmethod
|
||||
def _mask_code(code: str) -> str:
|
||||
trimmed = (code or "").strip()
|
||||
if len(trimmed) <= 6:
|
||||
return "***"
|
||||
return f"{trimmed[:3]}***{trimmed[-3:]}"
|
||||
|
||||
@staticmethod
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
@@ -1,32 +0,0 @@
|
||||
"""Nonce replay protection for enrollment workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict
|
||||
|
||||
__all__ = ["NonceCache"]
|
||||
|
||||
|
||||
class NonceCache:
|
||||
"""Track recently observed nonces to prevent replay."""
|
||||
|
||||
def __init__(self, ttl_seconds: float = 300.0) -> None:
|
||||
self._ttl = ttl_seconds
|
||||
self._entries: Dict[str, float] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def consume(self, key: str) -> bool:
|
||||
"""Consume *key* if it has not been seen recently."""
|
||||
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
expiry = self._entries.get(key)
|
||||
if expiry and expiry > now:
|
||||
return False
|
||||
self._entries[key] = now + self._ttl
|
||||
stale = [nonce for nonce, ttl in self._entries.items() if ttl <= now]
|
||||
for nonce in stale:
|
||||
self._entries.pop(nonce, None)
|
||||
return True
|
||||
@@ -1,8 +0,0 @@
|
||||
"""GitHub-oriented services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .github_service import GitHubService, GitHubTokenPayload
|
||||
|
||||
__all__ = ["GitHubService", "GitHubTokenPayload"]
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
"""GitHub service layer bridging repositories and integrations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
from Data.Engine.domain.github import GitHubRepoRef, GitHubTokenStatus, RepoHeadSnapshot
|
||||
from Data.Engine.integrations.github import GitHubArtifactProvider
|
||||
from Data.Engine.repositories.sqlite.github_repository import SQLiteGitHubRepository
|
||||
|
||||
__all__ = ["GitHubService", "GitHubTokenPayload"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GitHubTokenPayload:
|
||||
token: Optional[str]
|
||||
status: GitHubTokenStatus
|
||||
checked_at: int
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
payload = self.status.to_dict()
|
||||
payload.update(
|
||||
{
|
||||
"token": self.token or "",
|
||||
"checked_at": self.checked_at,
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
class GitHubService:
|
||||
"""Coordinate GitHub caching, verification, and persistence."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: SQLiteGitHubRepository,
|
||||
provider: GitHubArtifactProvider,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
clock: Optional[Callable[[], float]] = None,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._provider = provider
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.github")
|
||||
self._clock = clock or time.time
|
||||
self._token_cache: Optional[str] = None
|
||||
self._token_loaded_at: float = 0.0
|
||||
|
||||
initial_token = self._repository.load_token()
|
||||
self._apply_token(initial_token)
|
||||
|
||||
def get_repo_head(
|
||||
self,
|
||||
owner_repo: Optional[str],
|
||||
branch: Optional[str],
|
||||
*,
|
||||
ttl_seconds: int,
|
||||
force_refresh: bool = False,
|
||||
) -> RepoHeadSnapshot:
|
||||
repo_str = (owner_repo or self._provider.default_repo).strip()
|
||||
branch_name = (branch or self._provider.default_branch).strip()
|
||||
repo = GitHubRepoRef.parse(repo_str, branch_name)
|
||||
ttl = max(30, min(ttl_seconds, 3600))
|
||||
return self._provider.fetch_repo_head(repo, ttl_seconds=ttl, force_refresh=force_refresh)
|
||||
|
||||
def refresh_default_repo(self, *, force: bool = False) -> RepoHeadSnapshot:
|
||||
return self._provider.refresh_default_repo_head(force=force)
|
||||
|
||||
def get_token_status(self, *, force_refresh: bool = False) -> GitHubTokenPayload:
|
||||
token = self._load_token(force_refresh=force_refresh)
|
||||
status = self._provider.verify_token(token)
|
||||
return GitHubTokenPayload(token=token, status=status, checked_at=int(self._clock()))
|
||||
|
||||
def update_token(self, token: Optional[str]) -> GitHubTokenPayload:
|
||||
normalized = (token or "").strip()
|
||||
self._repository.store_token(normalized)
|
||||
self._apply_token(normalized)
|
||||
status = self._provider.verify_token(normalized)
|
||||
self._provider.start_background_refresh()
|
||||
self._log.info("github-token updated valid=%s", status.valid)
|
||||
return GitHubTokenPayload(token=normalized or None, status=status, checked_at=int(self._clock()))
|
||||
|
||||
def start_background_refresh(self) -> None:
|
||||
self._provider.start_background_refresh()
|
||||
|
||||
@property
|
||||
def default_refresh_interval(self) -> int:
|
||||
return self._provider.refresh_interval
|
||||
|
||||
def _load_token(self, *, force_refresh: bool = False) -> Optional[str]:
|
||||
now = self._clock()
|
||||
if not force_refresh and self._token_cache is not None and (now - self._token_loaded_at) < 15.0:
|
||||
return self._token_cache
|
||||
|
||||
token = self._repository.load_token()
|
||||
self._apply_token(token)
|
||||
return token
|
||||
|
||||
def _apply_token(self, token: Optional[str]) -> None:
|
||||
self._token_cache = (token or "").strip() or None
|
||||
self._token_loaded_at = self._clock()
|
||||
self._provider.set_token(self._token_cache)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Job-related services for the Borealis Engine."""
|
||||
|
||||
from .scheduler_service import SchedulerService
|
||||
|
||||
__all__ = ["SchedulerService"]
|
||||
@@ -1,373 +0,0 @@
|
||||
"""Background scheduler service for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import calendar
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Mapping, Optional
|
||||
|
||||
from Data.Engine.builders.job_fabricator import JobFabricator, JobManifest
|
||||
from Data.Engine.repositories.sqlite.job_repository import (
|
||||
ScheduledJobRecord,
|
||||
ScheduledJobRunRecord,
|
||||
SQLiteJobRepository,
|
||||
)
|
||||
|
||||
__all__ = ["SchedulerService", "SchedulerRuntime"]
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def _floor_minute(ts: int) -> int:
|
||||
ts = int(ts or 0)
|
||||
return ts - (ts % 60)
|
||||
|
||||
|
||||
def _parse_ts(val: Any) -> Optional[int]:
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, (int, float)):
|
||||
return int(val)
|
||||
try:
|
||||
s = str(val).strip().replace("Z", "+00:00")
|
||||
return int(datetime.fromisoformat(s).timestamp())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_expiration(raw: Optional[str]) -> Optional[int]:
|
||||
if not raw or raw == "no_expire":
|
||||
return None
|
||||
try:
|
||||
s = raw.strip().lower()
|
||||
unit = s[-1]
|
||||
value = int(s[:-1])
|
||||
if unit == "m":
|
||||
return value * 60
|
||||
if unit == "h":
|
||||
return value * 3600
|
||||
if unit == "d":
|
||||
return value * 86400
|
||||
return int(s) * 60
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _add_months(dt: datetime, months: int) -> datetime:
|
||||
year = dt.year + (dt.month - 1 + months) // 12
|
||||
month = ((dt.month - 1 + months) % 12) + 1
|
||||
last_day = calendar.monthrange(year, month)[1]
|
||||
day = min(dt.day, last_day)
|
||||
return dt.replace(year=year, month=month, day=day)
|
||||
|
||||
|
||||
def _add_years(dt: datetime, years: int) -> datetime:
|
||||
year = dt.year + years
|
||||
month = dt.month
|
||||
last_day = calendar.monthrange(year, month)[1]
|
||||
day = min(dt.day, last_day)
|
||||
return dt.replace(year=year, month=month, day=day)
|
||||
|
||||
|
||||
def _compute_next_run(
|
||||
schedule_type: str,
|
||||
start_ts: Optional[int],
|
||||
last_run_ts: Optional[int],
|
||||
now_ts: int,
|
||||
) -> Optional[int]:
|
||||
st = (schedule_type or "immediately").strip().lower()
|
||||
start_floor = _floor_minute(start_ts) if start_ts else None
|
||||
last_floor = _floor_minute(last_run_ts) if last_run_ts else None
|
||||
now_floor = _floor_minute(now_ts)
|
||||
|
||||
if st == "immediately":
|
||||
return None if last_floor else now_floor
|
||||
if st == "once":
|
||||
if not start_floor:
|
||||
return None
|
||||
return start_floor if not last_floor else None
|
||||
if not start_floor:
|
||||
return None
|
||||
|
||||
last = last_floor if last_floor is not None else None
|
||||
if st in {
|
||||
"every_5_minutes",
|
||||
"every_10_minutes",
|
||||
"every_15_minutes",
|
||||
"every_30_minutes",
|
||||
"every_hour",
|
||||
}:
|
||||
period_map = {
|
||||
"every_5_minutes": 5 * 60,
|
||||
"every_10_minutes": 10 * 60,
|
||||
"every_15_minutes": 15 * 60,
|
||||
"every_30_minutes": 30 * 60,
|
||||
"every_hour": 60 * 60,
|
||||
}
|
||||
period = period_map.get(st)
|
||||
candidate = (last + period) if last else start_floor
|
||||
while candidate is not None and candidate <= now_floor - 1:
|
||||
candidate += period
|
||||
return candidate
|
||||
if st == "daily":
|
||||
period = 86400
|
||||
candidate = (last + period) if last else start_floor
|
||||
while candidate is not None and candidate <= now_floor - 1:
|
||||
candidate += period
|
||||
return candidate
|
||||
if st == "weekly":
|
||||
period = 7 * 86400
|
||||
candidate = (last + period) if last else start_floor
|
||||
while candidate is not None and candidate <= now_floor - 1:
|
||||
candidate += period
|
||||
return candidate
|
||||
if st == "monthly":
|
||||
base = datetime.utcfromtimestamp(last) if last else datetime.utcfromtimestamp(start_floor)
|
||||
candidate = _add_months(base, 1 if last else 0)
|
||||
while int(candidate.timestamp()) <= now_floor - 1:
|
||||
candidate = _add_months(candidate, 1)
|
||||
return int(candidate.timestamp())
|
||||
if st == "yearly":
|
||||
base = datetime.utcfromtimestamp(last) if last else datetime.utcfromtimestamp(start_floor)
|
||||
candidate = _add_years(base, 1 if last else 0)
|
||||
while int(candidate.timestamp()) <= now_floor - 1:
|
||||
candidate = _add_years(candidate, 1)
|
||||
return int(candidate.timestamp())
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerRuntime:
|
||||
thread: Optional[threading.Thread]
|
||||
stop_event: threading.Event
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""Evaluate and dispatch scheduled jobs using Engine repositories."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
job_repository: SQLiteJobRepository,
|
||||
assemblies_root: Path,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
poll_interval: int = 30,
|
||||
) -> None:
|
||||
self._jobs = job_repository
|
||||
self._fabricator = JobFabricator(assemblies_root=assemblies_root, logger=logger)
|
||||
self._log = logger or logging.getLogger("borealis.engine.scheduler")
|
||||
self._poll_interval = max(5, poll_interval)
|
||||
self._socketio: Optional[Any] = None
|
||||
self._runtime = SchedulerRuntime(thread=None, stop_event=threading.Event())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
def start(self, socketio: Optional[Any] = None) -> None:
|
||||
self._socketio = socketio
|
||||
if self._runtime.thread and self._runtime.thread.is_alive():
|
||||
return
|
||||
self._runtime.stop_event.clear()
|
||||
thread = threading.Thread(target=self._run_loop, name="borealis-engine-scheduler", daemon=True)
|
||||
thread.start()
|
||||
self._runtime.thread = thread
|
||||
self._log.info("scheduler-started")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._runtime.stop_event.set()
|
||||
thread = self._runtime.thread
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=5)
|
||||
self._log.info("scheduler-stopped")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP orchestration helpers
|
||||
# ------------------------------------------------------------------
|
||||
def list_jobs(self) -> list[dict[str, Any]]:
|
||||
return [self._serialize_job(job) for job in self._jobs.list_jobs()]
|
||||
|
||||
def get_job(self, job_id: int) -> Optional[dict[str, Any]]:
|
||||
record = self._jobs.fetch_job(job_id)
|
||||
return self._serialize_job(record) if record else None
|
||||
|
||||
def create_job(self, payload: Mapping[str, Any]) -> dict[str, Any]:
|
||||
fields = self._normalize_payload(payload)
|
||||
record = self._jobs.create_job(**fields)
|
||||
return self._serialize_job(record)
|
||||
|
||||
def update_job(self, job_id: int, payload: Mapping[str, Any]) -> Optional[dict[str, Any]]:
|
||||
fields = self._normalize_payload(payload)
|
||||
record = self._jobs.update_job(job_id, **fields)
|
||||
return self._serialize_job(record) if record else None
|
||||
|
||||
def toggle_job(self, job_id: int, enabled: bool) -> None:
|
||||
self._jobs.set_enabled(job_id, enabled)
|
||||
|
||||
def delete_job(self, job_id: int) -> None:
|
||||
self._jobs.delete_job(job_id)
|
||||
|
||||
def list_runs(self, job_id: int, *, days: Optional[int] = None) -> list[dict[str, Any]]:
|
||||
runs = self._jobs.list_runs(job_id, days=days)
|
||||
return [self._serialize_run(run) for run in runs]
|
||||
|
||||
def purge_runs(self, job_id: int) -> None:
|
||||
self._jobs.purge_runs(job_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scheduling loop
|
||||
# ------------------------------------------------------------------
|
||||
def tick(self, *, now_ts: Optional[int] = None) -> None:
|
||||
self._evaluate_jobs(now_ts=now_ts or _now_ts())
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
stop_event = self._runtime.stop_event
|
||||
while not stop_event.wait(timeout=self._poll_interval):
|
||||
try:
|
||||
self._evaluate_jobs(now_ts=_now_ts())
|
||||
except Exception as exc: # pragma: no cover - safety net
|
||||
self._log.exception("scheduler-loop-error: %s", exc)
|
||||
|
||||
def _evaluate_jobs(self, *, now_ts: int) -> None:
|
||||
for job in self._jobs.list_enabled_jobs():
|
||||
try:
|
||||
self._evaluate_job(job, now_ts=now_ts)
|
||||
except Exception as exc:
|
||||
self._log.exception("job-evaluation-error job_id=%s error=%s", job.id, exc)
|
||||
|
||||
def _evaluate_job(self, job: ScheduledJobRecord, *, now_ts: int) -> None:
|
||||
last_run = self._jobs.fetch_last_run(job.id)
|
||||
last_ts = None
|
||||
if last_run:
|
||||
last_ts = last_run.scheduled_ts or last_run.started_ts or last_run.created_at
|
||||
next_run = _compute_next_run(job.schedule_type, job.start_ts, last_ts, now_ts)
|
||||
if next_run is None or next_run > now_ts:
|
||||
return
|
||||
|
||||
expiration_window = _parse_expiration(job.expiration)
|
||||
if expiration_window and job.start_ts:
|
||||
if job.start_ts + expiration_window <= now_ts:
|
||||
self._log.info(
|
||||
"job-expired",
|
||||
extra={"job_id": job.id, "start_ts": job.start_ts, "expiration": job.expiration},
|
||||
)
|
||||
return
|
||||
|
||||
manifest = self._fabricator.build(job, occurrence_ts=next_run)
|
||||
targets = manifest.targets or ("<unassigned>",)
|
||||
for target in targets:
|
||||
run_id = self._jobs.create_run(job.id, next_run, target_hostname=None if target == "<unassigned>" else target)
|
||||
self._jobs.mark_run_started(run_id, started_ts=now_ts)
|
||||
self._emit_run_event("job_run_started", job, run_id, target, manifest)
|
||||
self._jobs.mark_run_finished(run_id, status="Success", finished_ts=now_ts)
|
||||
self._emit_run_event("job_run_completed", job, run_id, target, manifest)
|
||||
|
||||
def _emit_run_event(
|
||||
self,
|
||||
event: str,
|
||||
job: ScheduledJobRecord,
|
||||
run_id: int,
|
||||
target: str,
|
||||
manifest: JobManifest,
|
||||
) -> None:
|
||||
payload = {
|
||||
"job_id": job.id,
|
||||
"run_id": run_id,
|
||||
"target": target,
|
||||
"schedule_type": job.schedule_type,
|
||||
"occurrence_ts": manifest.occurrence_ts,
|
||||
}
|
||||
if self._socketio is not None:
|
||||
try:
|
||||
self._socketio.emit(event, payload) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
self._log.debug("socketio-emit-failed event=%s payload=%s", event, payload)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _serialize_job(self, job: ScheduledJobRecord) -> dict[str, Any]:
|
||||
return {
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"components": job.components,
|
||||
"targets": job.targets,
|
||||
"schedule": {
|
||||
"type": job.schedule_type,
|
||||
"start": job.start_ts,
|
||||
},
|
||||
"schedule_type": job.schedule_type,
|
||||
"start_ts": job.start_ts,
|
||||
"duration_stop_enabled": job.duration_stop_enabled,
|
||||
"expiration": job.expiration or "no_expire",
|
||||
"execution_context": job.execution_context,
|
||||
"credential_id": job.credential_id,
|
||||
"use_service_account": job.use_service_account,
|
||||
"enabled": job.enabled,
|
||||
"created_at": job.created_at,
|
||||
"updated_at": job.updated_at,
|
||||
}
|
||||
|
||||
def _serialize_run(self, run: ScheduledJobRunRecord) -> dict[str, Any]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"job_id": run.job_id,
|
||||
"scheduled_ts": run.scheduled_ts,
|
||||
"started_ts": run.started_ts,
|
||||
"finished_ts": run.finished_ts,
|
||||
"status": run.status,
|
||||
"error": run.error,
|
||||
"target_hostname": run.target_hostname,
|
||||
"created_at": run.created_at,
|
||||
"updated_at": run.updated_at,
|
||||
}
|
||||
|
||||
def _normalize_payload(self, payload: Mapping[str, Any]) -> dict[str, Any]:
|
||||
name = str(payload.get("name") or "").strip()
|
||||
components = payload.get("components") or []
|
||||
targets = payload.get("targets") or []
|
||||
schedule_block = payload.get("schedule") if isinstance(payload.get("schedule"), Mapping) else {}
|
||||
schedule_type = str(schedule_block.get("type") or payload.get("schedule_type") or "immediately").strip().lower()
|
||||
start_value = schedule_block.get("start") if isinstance(schedule_block, Mapping) else None
|
||||
if start_value is None:
|
||||
start_value = payload.get("start")
|
||||
start_ts = _parse_ts(start_value)
|
||||
duration_block = payload.get("duration") if isinstance(payload.get("duration"), Mapping) else {}
|
||||
duration_stop = bool(duration_block.get("stopAfterEnabled") or payload.get("duration_stop_enabled"))
|
||||
expiration = str(duration_block.get("expiration") or payload.get("expiration") or "no_expire").strip()
|
||||
execution_context = str(payload.get("execution_context") or "system").strip().lower()
|
||||
credential_id = payload.get("credential_id")
|
||||
try:
|
||||
credential_id = int(credential_id) if credential_id is not None else None
|
||||
except Exception:
|
||||
credential_id = None
|
||||
use_service_account_raw = payload.get("use_service_account")
|
||||
use_service_account = bool(use_service_account_raw) if execution_context == "winrm" else False
|
||||
enabled = bool(payload.get("enabled", True))
|
||||
|
||||
if not name:
|
||||
raise ValueError("job name is required")
|
||||
if not isinstance(components, Iterable) or not list(components):
|
||||
raise ValueError("at least one component is required")
|
||||
if not isinstance(targets, Iterable) or not list(targets):
|
||||
raise ValueError("at least one target is required")
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"components": list(components),
|
||||
"targets": list(targets),
|
||||
"schedule_type": schedule_type,
|
||||
"start_ts": start_ts,
|
||||
"duration_stop_enabled": duration_stop,
|
||||
"expiration": expiration,
|
||||
"execution_context": execution_context,
|
||||
"credential_id": credential_id,
|
||||
"use_service_account": use_service_account,
|
||||
"enabled": enabled,
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
"""In-process rate limiting utilities for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Deque, Dict
|
||||
|
||||
__all__ = ["RateLimitDecision", "SlidingWindowRateLimiter"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RateLimitDecision:
|
||||
"""Result of a rate limit check."""
|
||||
|
||||
allowed: bool
|
||||
retry_after: float
|
||||
|
||||
|
||||
class SlidingWindowRateLimiter:
|
||||
"""Tiny in-memory sliding window limiter suitable for single-process use."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buckets: Dict[str, Deque[float]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def check(self, key: str, limit: int, window_seconds: float) -> RateLimitDecision:
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
bucket = self._buckets.get(key)
|
||||
if bucket is None:
|
||||
bucket = deque()
|
||||
self._buckets[key] = bucket
|
||||
|
||||
while bucket and now - bucket[0] > window_seconds:
|
||||
bucket.popleft()
|
||||
|
||||
if len(bucket) >= limit:
|
||||
retry_after = max(0.0, window_seconds - (now - bucket[0]))
|
||||
return RateLimitDecision(False, retry_after)
|
||||
|
||||
bucket.append(now)
|
||||
return RateLimitDecision(True, 0.0)
|
||||
@@ -1,10 +0,0 @@
|
||||
"""Realtime coordination services for the Borealis Engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .agent_registry import AgentRealtimeService, AgentRecord
|
||||
|
||||
__all__ = [
|
||||
"AgentRealtimeService",
|
||||
"AgentRecord",
|
||||
]
|
||||
@@ -1,301 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple
|
||||
|
||||
from Data.Engine.repositories.sqlite import SQLiteDeviceRepository
|
||||
|
||||
__all__ = ["AgentRealtimeService", "AgentRecord"]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentRecord:
|
||||
"""In-memory representation of a connected agent."""
|
||||
|
||||
agent_id: str
|
||||
hostname: str = "unknown"
|
||||
agent_operating_system: str = "-"
|
||||
last_seen: int = 0
|
||||
status: str = "orphaned"
|
||||
service_mode: str = "currentuser"
|
||||
is_script_agent: bool = False
|
||||
collector_active_ts: Optional[float] = None
|
||||
|
||||
|
||||
class AgentRealtimeService:
|
||||
"""Track realtime agent presence and provide persistence hooks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_repository: SQLiteDeviceRepository,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._device_repository = device_repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.realtime.agents")
|
||||
self._agents: Dict[str, AgentRecord] = {}
|
||||
self._configs: Dict[str, Dict[str, Any]] = {}
|
||||
self._screenshots: Dict[str, Dict[str, Any]] = {}
|
||||
self._task_screenshots: Dict[Tuple[str, str], Dict[str, Any]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent presence management
|
||||
# ------------------------------------------------------------------
|
||||
def register_connection(self, agent_id: str, service_mode: Optional[str]) -> AgentRecord:
|
||||
record = self._agents.get(agent_id) or AgentRecord(agent_id=agent_id)
|
||||
mode = self.normalize_service_mode(service_mode, agent_id)
|
||||
now = int(time.time())
|
||||
|
||||
record.service_mode = mode
|
||||
record.is_script_agent = self._is_script_agent(agent_id)
|
||||
record.last_seen = now
|
||||
record.status = "provisioned" if agent_id in self._configs else "orphaned"
|
||||
self._agents[agent_id] = record
|
||||
|
||||
self._persist_activity(
|
||||
hostname=record.hostname,
|
||||
last_seen=record.last_seen,
|
||||
agent_id=agent_id,
|
||||
operating_system=record.agent_operating_system,
|
||||
)
|
||||
return record
|
||||
|
||||
def heartbeat(self, payload: Mapping[str, Any]) -> Optional[AgentRecord]:
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
agent_id = payload.get("agent_id")
|
||||
if not agent_id:
|
||||
return None
|
||||
|
||||
hostname = payload.get("hostname") or ""
|
||||
mode = self.normalize_service_mode(payload.get("service_mode"), agent_id)
|
||||
is_script_agent = self._is_script_agent(agent_id)
|
||||
last_seen = self._coerce_int(payload.get("last_seen"), default=int(time.time()))
|
||||
operating_system = (payload.get("agent_operating_system") or "-").strip() or "-"
|
||||
|
||||
if hostname:
|
||||
self._reconcile_hostname_collisions(
|
||||
hostname=hostname,
|
||||
agent_id=agent_id,
|
||||
incoming_mode=mode,
|
||||
is_script_agent=is_script_agent,
|
||||
last_seen=last_seen,
|
||||
)
|
||||
|
||||
record = self._agents.get(agent_id) or AgentRecord(agent_id=agent_id)
|
||||
if hostname:
|
||||
record.hostname = hostname
|
||||
record.agent_operating_system = operating_system
|
||||
record.last_seen = last_seen
|
||||
record.service_mode = mode
|
||||
record.is_script_agent = is_script_agent
|
||||
record.status = "provisioned" if agent_id in self._configs else record.status or "orphaned"
|
||||
self._agents[agent_id] = record
|
||||
|
||||
self._persist_activity(
|
||||
hostname=record.hostname or hostname,
|
||||
last_seen=record.last_seen,
|
||||
agent_id=agent_id,
|
||||
operating_system=record.agent_operating_system,
|
||||
)
|
||||
return record
|
||||
|
||||
def collector_status(self, payload: Mapping[str, Any]) -> None:
|
||||
if not payload:
|
||||
return
|
||||
|
||||
agent_id = payload.get("agent_id")
|
||||
if not agent_id:
|
||||
return
|
||||
|
||||
hostname = payload.get("hostname") or ""
|
||||
mode = self.normalize_service_mode(payload.get("service_mode"), agent_id)
|
||||
active = bool(payload.get("active"))
|
||||
last_user = (payload.get("last_user") or "").strip()
|
||||
|
||||
record = self._agents.get(agent_id) or AgentRecord(agent_id=agent_id)
|
||||
if hostname:
|
||||
record.hostname = hostname
|
||||
if mode:
|
||||
record.service_mode = mode
|
||||
record.is_script_agent = self._is_script_agent(agent_id) or record.is_script_agent
|
||||
if active:
|
||||
record.collector_active_ts = time.time()
|
||||
self._agents[agent_id] = record
|
||||
|
||||
if (
|
||||
last_user
|
||||
and hostname
|
||||
and self._is_valid_interactive_user(last_user)
|
||||
and not self._is_system_service_agent(agent_id, record.is_script_agent)
|
||||
):
|
||||
self._persist_activity(
|
||||
hostname=hostname,
|
||||
last_seen=int(time.time()),
|
||||
agent_id=agent_id,
|
||||
operating_system=record.agent_operating_system,
|
||||
last_user=last_user,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Configuration management
|
||||
# ------------------------------------------------------------------
|
||||
def set_agent_config(self, agent_id: str, config: Mapping[str, Any]) -> None:
|
||||
self._configs[agent_id] = dict(config)
|
||||
record = self._agents.get(agent_id)
|
||||
if record:
|
||||
record.status = "provisioned"
|
||||
|
||||
def get_agent_config(self, agent_id: str) -> Optional[Dict[str, Any]]:
|
||||
config = self._configs.get(agent_id)
|
||||
if config is None:
|
||||
return None
|
||||
return dict(config)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Screenshot caches
|
||||
# ------------------------------------------------------------------
|
||||
def store_agent_screenshot(self, agent_id: str, image_base64: str) -> None:
|
||||
self._screenshots[agent_id] = {
|
||||
"image_base64": image_base64,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
def store_task_screenshot(self, agent_id: str, node_id: str, image_base64: str) -> None:
|
||||
self._task_screenshots[(agent_id, node_id)] = {
|
||||
"image_base64": image_base64,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def normalize_service_mode(value: Optional[str], agent_id: Optional[str] = None) -> str:
|
||||
text = ""
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
text = value.strip().lower()
|
||||
except Exception:
|
||||
text = ""
|
||||
|
||||
if not text and agent_id:
|
||||
try:
|
||||
lowered = agent_id.lower()
|
||||
if "-svc-" in lowered or lowered.endswith("-svc"):
|
||||
return "system"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if text in {"system", "svc", "service", "system_service"}:
|
||||
return "system"
|
||||
if text in {"interactive", "currentuser", "user", "current_user"}:
|
||||
return "currentuser"
|
||||
return "currentuser"
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, *, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _is_script_agent(agent_id: Optional[str]) -> bool:
|
||||
try:
|
||||
return bool(isinstance(agent_id, str) and agent_id.lower().endswith("-script"))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_interactive_user(candidate: Optional[str]) -> bool:
|
||||
if not candidate:
|
||||
return False
|
||||
try:
|
||||
text = str(candidate).strip()
|
||||
except Exception:
|
||||
return False
|
||||
if not text:
|
||||
return False
|
||||
upper = text.upper()
|
||||
if text.endswith("$"):
|
||||
return False
|
||||
if "NT AUTHORITY\\" in upper or "NT SERVICE\\" in upper:
|
||||
return False
|
||||
if upper.endswith("\\SYSTEM"):
|
||||
return False
|
||||
if upper.endswith("\\LOCAL SERVICE"):
|
||||
return False
|
||||
if upper.endswith("\\NETWORK SERVICE"):
|
||||
return False
|
||||
if upper == "ANONYMOUS LOGON":
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _is_system_service_agent(agent_id: str, is_script_agent: bool) -> bool:
|
||||
try:
|
||||
lowered = agent_id.lower()
|
||||
except Exception:
|
||||
lowered = ""
|
||||
if is_script_agent:
|
||||
return False
|
||||
return "-svc-" in lowered or lowered.endswith("-svc")
|
||||
|
||||
def _reconcile_hostname_collisions(
|
||||
self,
|
||||
*,
|
||||
hostname: str,
|
||||
agent_id: str,
|
||||
incoming_mode: str,
|
||||
is_script_agent: bool,
|
||||
last_seen: int,
|
||||
) -> None:
|
||||
transferred_config = False
|
||||
for existing_id, info in list(self._agents.items()):
|
||||
if existing_id == agent_id:
|
||||
continue
|
||||
if info.hostname != hostname:
|
||||
continue
|
||||
existing_mode = self.normalize_service_mode(info.service_mode, existing_id)
|
||||
if existing_mode != incoming_mode:
|
||||
continue
|
||||
if is_script_agent and not info.is_script_agent:
|
||||
self._persist_activity(
|
||||
hostname=hostname,
|
||||
last_seen=last_seen,
|
||||
agent_id=existing_id,
|
||||
operating_system=info.agent_operating_system,
|
||||
)
|
||||
return
|
||||
if not transferred_config and existing_id in self._configs and agent_id not in self._configs:
|
||||
self._configs[agent_id] = dict(self._configs[existing_id])
|
||||
transferred_config = True
|
||||
self._agents.pop(existing_id, None)
|
||||
if existing_id != agent_id:
|
||||
self._configs.pop(existing_id, None)
|
||||
|
||||
def _persist_activity(
|
||||
self,
|
||||
*,
|
||||
hostname: Optional[str],
|
||||
last_seen: Optional[int],
|
||||
agent_id: Optional[str],
|
||||
operating_system: Optional[str],
|
||||
last_user: Optional[str] = None,
|
||||
) -> None:
|
||||
if not hostname:
|
||||
return
|
||||
try:
|
||||
self._device_repository.update_device_summary(
|
||||
hostname=hostname,
|
||||
last_seen=last_seen,
|
||||
agent_id=agent_id,
|
||||
operating_system=operating_system,
|
||||
last_user=last_user,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._log.debug("failed to persist device activity for %s: %s", hostname, exc)
|
||||
@@ -1,3 +0,0 @@
|
||||
from .site_service import SiteService
|
||||
|
||||
__all__ = ["SiteService"]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Site management service that mirrors the legacy Flask behaviour."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from Data.Engine.domain.sites import SiteDeviceMapping, SiteSummary
|
||||
from Data.Engine.repositories.sqlite.site_repository import SQLiteSiteRepository
|
||||
|
||||
__all__ = ["SiteService"]
|
||||
|
||||
|
||||
class SiteService:
|
||||
def __init__(self, repository: SQLiteSiteRepository, *, logger: Optional[logging.Logger] = None) -> None:
|
||||
self._repo = repository
|
||||
self._log = logger or logging.getLogger("borealis.engine.services.sites")
|
||||
|
||||
def list_sites(self) -> List[SiteSummary]:
|
||||
return self._repo.list_sites()
|
||||
|
||||
def create_site(self, name: str, description: str) -> SiteSummary:
|
||||
normalized_name = (name or "").strip()
|
||||
normalized_description = (description or "").strip()
|
||||
if not normalized_name:
|
||||
raise ValueError("missing_name")
|
||||
try:
|
||||
return self._repo.create_site(normalized_name, normalized_description)
|
||||
except ValueError as exc:
|
||||
if str(exc) == "duplicate":
|
||||
raise ValueError("duplicate") from exc
|
||||
raise
|
||||
|
||||
def delete_sites(self, ids: Iterable[int]) -> int:
|
||||
normalized = []
|
||||
for value in ids:
|
||||
try:
|
||||
normalized.append(int(value))
|
||||
except Exception:
|
||||
continue
|
||||
if not normalized:
|
||||
return 0
|
||||
return self._repo.delete_sites(tuple(normalized))
|
||||
|
||||
def rename_site(self, site_id: int, new_name: str) -> SiteSummary:
|
||||
normalized_name = (new_name or "").strip()
|
||||
if not normalized_name:
|
||||
raise ValueError("missing_name")
|
||||
try:
|
||||
return self._repo.rename_site(int(site_id), normalized_name)
|
||||
except ValueError as exc:
|
||||
if str(exc) == "duplicate":
|
||||
raise ValueError("duplicate") from exc
|
||||
raise
|
||||
|
||||
def map_devices(self, hostnames: Optional[Iterable[str]] = None) -> Dict[str, SiteDeviceMapping]:
|
||||
return self._repo.map_devices(hostnames)
|
||||
|
||||
def assign_devices(self, site_id: int, hostnames: Iterable[str]) -> None:
|
||||
try:
|
||||
numeric_id = int(site_id)
|
||||
except Exception as exc:
|
||||
raise ValueError("invalid_site_id") from exc
|
||||
normalized = [hn for hn in hostnames if isinstance(hn, str) and hn.strip()]
|
||||
if not normalized:
|
||||
raise ValueError("invalid_hostnames")
|
||||
try:
|
||||
self._repo.assign_devices(numeric_id, normalized)
|
||||
except LookupError as exc:
|
||||
if str(exc) == "not_found":
|
||||
raise LookupError("not_found") from exc
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user