Merge pull request #135 from bunny-lab-io:codex/fix-404-error-on-web-interface

Fix Engine static root fallback for legacy WebUI
This commit is contained in:
2025-10-23 04:56:26 -06:00
committed by GitHub
58 changed files with 8062 additions and 23 deletions

View File

@@ -66,10 +66,15 @@ def bootstrap() -> EngineRuntime:
else:
logger.info("migrations-skipped")
with sqlite_connection.connection_scope(settings.database_path) as conn:
sqlite_migrations.ensure_default_admin(conn)
logger.info("default-admin-ensured")
app = create_app(settings, db_factory=db_factory)
services = build_service_container(settings, db_factory=db_factory, logger=logger.getChild("services"))
app.extensions["engine_services"] = services
register_http_interfaces(app, services)
socketio = create_socket_server(app, settings.socketio)
register_ws_interfaces(socketio, services)
services.scheduler_service.start(socketio)

View File

@@ -8,12 +8,22 @@ from .device_auth import (
RefreshTokenRequest,
RefreshTokenRequestBuilder,
)
from .operator_auth import (
OperatorLoginRequest,
OperatorMFAVerificationRequest,
build_login_request,
build_mfa_request,
)
__all__ = [
"DeviceAuthRequest",
"DeviceAuthRequestBuilder",
"RefreshTokenRequest",
"RefreshTokenRequestBuilder",
"OperatorLoginRequest",
"OperatorMFAVerificationRequest",
"build_login_request",
"build_mfa_request",
]
try: # pragma: no cover - optional dependency shim

View File

@@ -0,0 +1,72 @@
"""Builders for operator authentication payloads."""
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from typing import Mapping
@dataclass(frozen=True, slots=True)
class OperatorLoginRequest:
"""Normalized operator login credentials."""
username: str
password_sha512: str
@dataclass(frozen=True, slots=True)
class OperatorMFAVerificationRequest:
"""Normalized MFA verification payload."""
pending_token: str
code: str
def _sha512_hex(raw: str) -> str:
digest = hashlib.sha512()
digest.update(raw.encode("utf-8"))
return digest.hexdigest()
def build_login_request(payload: Mapping[str, object]) -> OperatorLoginRequest:
"""Validate and normalize the login *payload*."""
username = str(payload.get("username") or "").strip()
password_sha512 = str(payload.get("password_sha512") or "").strip().lower()
password = payload.get("password")
if not username:
raise ValueError("username is required")
if password_sha512:
normalized_hash = password_sha512
else:
if not isinstance(password, str) or not password:
raise ValueError("password is required")
normalized_hash = _sha512_hex(password)
return OperatorLoginRequest(username=username, password_sha512=normalized_hash)
def build_mfa_request(payload: Mapping[str, object]) -> OperatorMFAVerificationRequest:
"""Validate and normalize the MFA verification *payload*."""
pending_token = str(payload.get("pending_token") or "").strip()
raw_code = str(payload.get("code") or "").strip()
digits = "".join(ch for ch in raw_code if ch.isdigit())
if not pending_token:
raise ValueError("pending_token is required")
if len(digits) < 6:
raise ValueError("code must contain 6 digits")
return OperatorMFAVerificationRequest(pending_token=pending_token, code=digits)
__all__ = [
"OperatorLoginRequest",
"OperatorMFAVerificationRequest",
"build_login_request",
"build_mfa_request",
]

View File

@@ -91,7 +91,12 @@ def _resolve_project_root() -> Path:
candidate = os.getenv("BOREALIS_ROOT")
if candidate:
return Path(candidate).expanduser().resolve()
return Path(__file__).resolve().parents[2]
# ``environment.py`` lives under ``Data/Engine/config``. The project
# root is three levels above this module (the repository checkout). The
# previous implementation only walked up two levels which incorrectly
# treated ``Data/`` as the root, breaking all filesystem discovery logic
# that expects peers such as ``Data/Server`` to be available.
return Path(__file__).resolve().parents[3]
def _resolve_database_path(project_root: Path) -> Path:
@@ -114,10 +119,19 @@ def _resolve_static_root(project_root: Path) -> Path:
candidates = (
project_root / "Engine" / "web-interface" / "build",
project_root / "Engine" / "web-interface" / "dist",
project_root / "Engine" / "web-interface",
project_root / "Data" / "Engine" / "WebUI" / "build",
project_root / "Data" / "Engine" / "WebUI",
project_root / "Server" / "web-interface" / "build",
project_root / "Server" / "web-interface",
project_root / "Server" / "WebUI" / "build",
project_root / "Server" / "WebUI",
project_root / "Data" / "Server" / "web-interface" / "build",
project_root / "Data" / "Server" / "web-interface",
project_root / "Data" / "Server" / "WebUI" / "build",
project_root / "Data" / "Server" / "WebUI",
project_root / "Data" / "WebUI" / "build",
project_root / "Data" / "WebUI",
)
for path in candidates:
resolved = path.resolve()

View File

@@ -26,6 +26,11 @@ from .github import ( # noqa: F401
GitHubTokenStatus,
RepoHeadSnapshot,
)
from .operator import ( # noqa: F401
OperatorAccount,
OperatorLoginSuccess,
OperatorMFAChallenge,
)
__all__ = [
"AccessTokenClaims",
@@ -45,5 +50,8 @@ __all__ = [
"GitHubRepoRef",
"GitHubTokenStatus",
"RepoHeadSnapshot",
"OperatorAccount",
"OperatorLoginSuccess",
"OperatorMFAChallenge",
"sanitize_service_context",
]

View File

@@ -18,6 +18,7 @@ __all__ = [
"AccessTokenClaims",
"DeviceAuthContext",
"sanitize_service_context",
"normalize_guid",
]
@@ -73,6 +74,12 @@ class DeviceGuid:
return self.value
def normalize_guid(value: Optional[str]) -> str:
"""Expose GUID normalization for administrative helpers."""
return _normalize_guid(value)
@dataclass(frozen=True, slots=True)
class DeviceFingerprint:
"""Normalized TLS key fingerprint associated with a device."""

View File

@@ -0,0 +1,28 @@
"""Domain objects for saved device list views."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List
__all__ = ["DeviceListView"]
@dataclass(frozen=True, slots=True)
class DeviceListView:
id: int
name: str
columns: List[str]
filters: Dict[str, object]
created_at: int
updated_at: int
def to_dict(self) -> Dict[str, object]:
return {
"id": self.id,
"name": self.name,
"columns": self.columns,
"filters": self.filters,
"created_at": self.created_at,
"updated_at": self.updated_at,
}

View File

@@ -0,0 +1,323 @@
"""Device domain helpers mirroring the legacy server payloads."""
from __future__ import annotations
import json
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Dict, List, Mapping, Optional, Sequence
from Data.Engine.domain.device_auth import normalize_guid
__all__ = [
"DEVICE_TABLE_COLUMNS",
"DEVICE_TABLE",
"DeviceSnapshot",
"assemble_device_snapshot",
"row_to_device_dict",
"serialize_device_json",
"clean_device_str",
"coerce_int",
"ts_to_iso",
"device_column_sql",
"ts_to_human",
]
DEVICE_TABLE = "devices"
DEVICE_JSON_LIST_FIELDS: Mapping[str, List[Any]] = {
"memory": [],
"network": [],
"software": [],
"storage": [],
}
DEVICE_JSON_OBJECT_FIELDS: Mapping[str, Dict[str, Any]] = {
"cpu": {},
}
DEVICE_TABLE_COLUMNS: Sequence[str] = (
"guid",
"hostname",
"description",
"created_at",
"agent_hash",
"memory",
"network",
"software",
"storage",
"cpu",
"device_type",
"domain",
"external_ip",
"internal_ip",
"last_reboot",
"last_seen",
"last_user",
"operating_system",
"uptime",
"agent_id",
"ansible_ee_ver",
"connection_type",
"connection_endpoint",
"ssl_key_fingerprint",
"token_version",
"status",
"key_added_at",
)
@dataclass(frozen=True)
class DeviceSnapshot:
hostname: str
description: str
created_at: int
created_at_iso: str
agent_hash: str
agent_guid: str
guid: str
memory: List[Dict[str, Any]]
network: List[Dict[str, Any]]
software: List[Dict[str, Any]]
storage: List[Dict[str, Any]]
cpu: Dict[str, Any]
device_type: str
domain: str
external_ip: str
internal_ip: str
last_reboot: str
last_seen: int
last_seen_iso: str
last_user: str
operating_system: str
uptime: int
agent_id: str
ansible_ee_ver: str
connection_type: str
connection_endpoint: str
ssl_key_fingerprint: str
token_version: int
status: str
key_added_at: str
details: Dict[str, Any]
summary: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
return {
"hostname": self.hostname,
"description": self.description,
"created_at": self.created_at,
"created_at_iso": self.created_at_iso,
"agent_hash": self.agent_hash,
"agent_guid": self.agent_guid,
"guid": self.guid,
"memory": self.memory,
"network": self.network,
"software": self.software,
"storage": self.storage,
"cpu": self.cpu,
"device_type": self.device_type,
"domain": self.domain,
"external_ip": self.external_ip,
"internal_ip": self.internal_ip,
"last_reboot": self.last_reboot,
"last_seen": self.last_seen,
"last_seen_iso": self.last_seen_iso,
"last_user": self.last_user,
"operating_system": self.operating_system,
"uptime": self.uptime,
"agent_id": self.agent_id,
"ansible_ee_ver": self.ansible_ee_ver,
"connection_type": self.connection_type,
"connection_endpoint": self.connection_endpoint,
"ssl_key_fingerprint": self.ssl_key_fingerprint,
"token_version": self.token_version,
"status": self.status,
"key_added_at": self.key_added_at,
"details": self.details,
"summary": self.summary,
}
def ts_to_iso(ts: Optional[int]) -> str:
if not ts:
return ""
try:
return datetime.fromtimestamp(int(ts), timezone.utc).isoformat()
except Exception:
return ""
def _ts_to_human(ts: Optional[int]) -> str:
if not ts:
return ""
try:
return datetime.utcfromtimestamp(int(ts)).strftime("%Y-%m-%d %H:%M:%S")
except Exception:
return ""
def _parse_device_json(raw: Optional[str], default: Any) -> Any:
if raw is None:
return json.loads(json.dumps(default)) if isinstance(default, (list, dict)) else default
try:
data = json.loads(raw)
except Exception:
data = None
if isinstance(default, list):
if isinstance(data, list):
return data
return []
if isinstance(default, dict):
if isinstance(data, dict):
return data
return {}
return default
def serialize_device_json(value: Any, default: Any) -> str:
candidate = value
if candidate is None:
candidate = default
if not isinstance(candidate, (list, dict)):
candidate = default
try:
return json.dumps(candidate)
except Exception:
try:
return json.dumps(default)
except Exception:
return "{}" if isinstance(default, dict) else "[]"
def clean_device_str(value: Any) -> Optional[str]:
if value is None:
return None
if isinstance(value, (int, float)) and not isinstance(value, bool):
text = str(value)
elif isinstance(value, str):
text = value
else:
try:
text = str(value)
except Exception:
return None
text = text.strip()
return text or None
def coerce_int(value: Any) -> Optional[int]:
if value is None:
return None
try:
if isinstance(value, str) and value.strip() == "":
return None
return int(float(value))
except (ValueError, TypeError):
return None
def row_to_device_dict(row: Sequence[Any], columns: Sequence[str]) -> Dict[str, Any]:
return {columns[idx]: row[idx] for idx in range(min(len(row), len(columns)))}
def assemble_device_snapshot(record: Mapping[str, Any]) -> Dict[str, Any]:
hostname = clean_device_str(record.get("hostname")) or ""
description = clean_device_str(record.get("description")) or ""
agent_hash = clean_device_str(record.get("agent_hash")) or ""
raw_guid = clean_device_str(record.get("guid"))
normalized_guid = normalize_guid(raw_guid)
created_ts = coerce_int(record.get("created_at")) or 0
last_seen_ts = coerce_int(record.get("last_seen")) or 0
uptime_val = coerce_int(record.get("uptime")) or 0
token_version = coerce_int(record.get("token_version")) or 0
parsed_lists = {
key: _parse_device_json(record.get(key), default)
for key, default in DEVICE_JSON_LIST_FIELDS.items()
}
cpu_obj = _parse_device_json(record.get("cpu"), DEVICE_JSON_OBJECT_FIELDS["cpu"])
summary: Dict[str, Any] = {
"hostname": hostname,
"description": description,
"agent_hash": agent_hash,
"agent_guid": normalized_guid or "",
"agent_id": clean_device_str(record.get("agent_id")) or "",
"device_type": clean_device_str(record.get("device_type")) or "",
"domain": clean_device_str(record.get("domain")) or "",
"external_ip": clean_device_str(record.get("external_ip")) or "",
"internal_ip": clean_device_str(record.get("internal_ip")) or "",
"last_reboot": clean_device_str(record.get("last_reboot")) or "",
"last_seen": last_seen_ts,
"last_user": clean_device_str(record.get("last_user")) or "",
"operating_system": clean_device_str(record.get("operating_system")) or "",
"uptime": uptime_val,
"uptime_sec": uptime_val,
"ansible_ee_ver": clean_device_str(record.get("ansible_ee_ver")) or "",
"connection_type": clean_device_str(record.get("connection_type")) or "",
"connection_endpoint": clean_device_str(record.get("connection_endpoint")) or "",
"ssl_key_fingerprint": clean_device_str(record.get("ssl_key_fingerprint")) or "",
"status": clean_device_str(record.get("status")) or "",
"token_version": token_version,
"key_added_at": clean_device_str(record.get("key_added_at")) or "",
"created_at": created_ts,
"created": ts_to_human(created_ts),
}
details = {
"memory": parsed_lists["memory"],
"network": parsed_lists["network"],
"software": parsed_lists["software"],
"storage": parsed_lists["storage"],
"cpu": cpu_obj,
"summary": dict(summary),
}
payload: Dict[str, Any] = {
"hostname": hostname,
"description": description,
"created_at": created_ts,
"created_at_iso": ts_to_iso(created_ts),
"agent_hash": agent_hash,
"agent_guid": summary.get("agent_guid", ""),
"guid": summary.get("agent_guid", ""),
"memory": parsed_lists["memory"],
"network": parsed_lists["network"],
"software": parsed_lists["software"],
"storage": parsed_lists["storage"],
"cpu": cpu_obj,
"device_type": summary.get("device_type", ""),
"domain": summary.get("domain", ""),
"external_ip": summary.get("external_ip", ""),
"internal_ip": summary.get("internal_ip", ""),
"last_reboot": summary.get("last_reboot", ""),
"last_seen": last_seen_ts,
"last_seen_iso": ts_to_iso(last_seen_ts),
"last_user": summary.get("last_user", ""),
"operating_system": summary.get("operating_system", ""),
"uptime": uptime_val,
"agent_id": summary.get("agent_id", ""),
"ansible_ee_ver": summary.get("ansible_ee_ver", ""),
"connection_type": summary.get("connection_type", ""),
"connection_endpoint": summary.get("connection_endpoint", ""),
"ssl_key_fingerprint": summary.get("ssl_key_fingerprint", ""),
"token_version": summary.get("token_version", 0),
"status": summary.get("status", ""),
"key_added_at": summary.get("key_added_at", ""),
"details": details,
"summary": summary,
}
return payload
def device_column_sql(alias: Optional[str] = None) -> str:
if alias:
return ", ".join(f"{alias}.{col}" for col in DEVICE_TABLE_COLUMNS)
return ", ".join(DEVICE_TABLE_COLUMNS)
def ts_to_human(ts: Optional[int]) -> str:
return _ts_to_human(ts)

View File

@@ -0,0 +1,206 @@
"""Administrative enrollment domain models."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Mapping, Optional
from Data.Engine.domain.device_auth import DeviceGuid, normalize_guid
__all__ = [
"EnrollmentCodeRecord",
"DeviceApprovalRecord",
"HostnameConflict",
]
def _parse_iso8601(value: Optional[str]) -> Optional[datetime]:
if not value:
return None
raw = str(value).strip()
if not raw:
return None
try:
dt = datetime.fromisoformat(raw)
except Exception as exc: # pragma: no cover - defensive parsing
raise ValueError(f"invalid ISO8601 timestamp: {raw}") from exc
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)
def _isoformat(value: Optional[datetime]) -> Optional[str]:
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc).isoformat()
@dataclass(frozen=True, slots=True)
class EnrollmentCodeRecord:
"""Installer code metadata exposed to administrative clients."""
record_id: str
code: str
expires_at: datetime
max_uses: int
use_count: int
created_by_user_id: Optional[str]
used_at: Optional[datetime]
used_by_guid: Optional[DeviceGuid]
last_used_at: Optional[datetime]
@classmethod
def from_row(cls, row: Mapping[str, Any]) -> "EnrollmentCodeRecord":
record_id = str(row.get("id") or "").strip()
code = str(row.get("code") or "").strip()
if not record_id or not code:
raise ValueError("invalid enrollment install code record")
used_by = row.get("used_by_guid")
used_by_guid = DeviceGuid(str(used_by)) if used_by else None
return cls(
record_id=record_id,
code=code,
expires_at=_parse_iso8601(row.get("expires_at")) or datetime.now(tz=timezone.utc),
max_uses=int(row.get("max_uses") or 1),
use_count=int(row.get("use_count") or 0),
created_by_user_id=str(row.get("created_by_user_id") or "").strip() or None,
used_at=_parse_iso8601(row.get("used_at")),
used_by_guid=used_by_guid,
last_used_at=_parse_iso8601(row.get("last_used_at")),
)
def status(self, *, now: Optional[datetime] = None) -> str:
reference = now or datetime.now(tz=timezone.utc)
if self.use_count >= self.max_uses:
return "used"
if self.expires_at <= reference:
return "expired"
return "active"
def to_dict(self) -> dict[str, Any]:
return {
"id": self.record_id,
"code": self.code,
"expires_at": _isoformat(self.expires_at),
"max_uses": self.max_uses,
"use_count": self.use_count,
"created_by_user_id": self.created_by_user_id,
"used_at": _isoformat(self.used_at),
"used_by_guid": self.used_by_guid.value if self.used_by_guid else None,
"last_used_at": _isoformat(self.last_used_at),
"status": self.status(),
}
@dataclass(frozen=True, slots=True)
class HostnameConflict:
"""Existing device details colliding with a pending approval."""
guid: Optional[str]
ssl_key_fingerprint: Optional[str]
site_id: Optional[int]
site_name: str
fingerprint_match: bool
requires_prompt: bool
def to_dict(self) -> dict[str, Any]:
return {
"guid": self.guid,
"ssl_key_fingerprint": self.ssl_key_fingerprint,
"site_id": self.site_id,
"site_name": self.site_name,
"fingerprint_match": self.fingerprint_match,
"requires_prompt": self.requires_prompt,
}
@dataclass(frozen=True, slots=True)
class DeviceApprovalRecord:
"""Administrative projection of a device approval entry."""
record_id: str
reference: str
status: str
claimed_hostname: str
claimed_fingerprint: str
created_at: datetime
updated_at: datetime
enrollment_code_id: Optional[str]
guid: Optional[str]
approved_by_user_id: Optional[str]
approved_by_username: Optional[str]
client_nonce: str
server_nonce: str
hostname_conflict: Optional[HostnameConflict]
alternate_hostname: Optional[str]
conflict_requires_prompt: bool
fingerprint_match: bool
@classmethod
def from_row(
cls,
row: Mapping[str, Any],
*,
conflict: Optional[HostnameConflict] = None,
alternate_hostname: Optional[str] = None,
fingerprint_match: bool = False,
requires_prompt: bool = False,
) -> "DeviceApprovalRecord":
record_id = str(row.get("id") or "").strip()
reference = str(row.get("approval_reference") or "").strip()
hostname = str(row.get("hostname_claimed") or "").strip()
fingerprint = str(row.get("ssl_key_fingerprint_claimed") or "").strip().lower()
if not record_id or not reference or not hostname or not fingerprint:
raise ValueError("invalid device approval record")
guid_raw = normalize_guid(row.get("guid")) or None
return cls(
record_id=record_id,
reference=reference,
status=str(row.get("status") or "pending").strip().lower(),
claimed_hostname=hostname,
claimed_fingerprint=fingerprint,
created_at=_parse_iso8601(row.get("created_at")) or datetime.now(tz=timezone.utc),
updated_at=_parse_iso8601(row.get("updated_at")) or datetime.now(tz=timezone.utc),
enrollment_code_id=str(row.get("enrollment_code_id") or "").strip() or None,
guid=guid_raw,
approved_by_user_id=str(row.get("approved_by_user_id") or "").strip() or None,
approved_by_username=str(row.get("approved_by_username") or "").strip() or None,
client_nonce=str(row.get("client_nonce") or "").strip(),
server_nonce=str(row.get("server_nonce") or "").strip(),
hostname_conflict=conflict,
alternate_hostname=alternate_hostname,
conflict_requires_prompt=requires_prompt,
fingerprint_match=fingerprint_match,
)
def to_dict(self) -> dict[str, Any]:
payload: dict[str, Any] = {
"id": self.record_id,
"approval_reference": self.reference,
"status": self.status,
"hostname_claimed": self.claimed_hostname,
"ssl_key_fingerprint_claimed": self.claimed_fingerprint,
"created_at": _isoformat(self.created_at),
"updated_at": _isoformat(self.updated_at),
"enrollment_code_id": self.enrollment_code_id,
"guid": self.guid,
"approved_by_user_id": self.approved_by_user_id,
"approved_by_username": self.approved_by_username,
"client_nonce": self.client_nonce,
"server_nonce": self.server_nonce,
"conflict_requires_prompt": self.conflict_requires_prompt,
"fingerprint_match": self.fingerprint_match,
}
if self.hostname_conflict is not None:
payload["hostname_conflict"] = self.hostname_conflict.to_dict()
if self.alternate_hostname:
payload["alternate_hostname"] = self.alternate_hostname
return payload

View File

@@ -0,0 +1,51 @@
"""Domain models for operator authentication."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal, Optional
@dataclass(frozen=True, slots=True)
class OperatorAccount:
"""Snapshot of an operator account stored in SQLite."""
username: str
display_name: str
password_sha512: str
role: str
last_login: int
created_at: int
updated_at: int
mfa_enabled: bool
mfa_secret: Optional[str]
@dataclass(frozen=True, slots=True)
class OperatorLoginSuccess:
"""Successful login payload for the caller."""
username: str
role: str
token: str
@dataclass(frozen=True, slots=True)
class OperatorMFAChallenge:
"""Details describing an in-progress MFA challenge."""
username: str
role: str
stage: Literal["setup", "verify"]
pending_token: str
expires_at: int
secret: Optional[str] = None
otpauth_url: Optional[str] = None
qr_image: Optional[str] = None
__all__ = [
"OperatorAccount",
"OperatorLoginSuccess",
"OperatorMFAChallenge",
]

View File

@@ -0,0 +1,43 @@
"""Domain models for operator site management."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Optional
__all__ = ["SiteSummary", "SiteDeviceMapping"]
@dataclass(frozen=True, slots=True)
class SiteSummary:
"""Representation of a site record including device counts."""
id: int
name: str
description: str
created_at: int
device_count: int
def to_dict(self) -> Dict[str, object]:
return {
"id": self.id,
"name": self.name,
"description": self.description,
"created_at": self.created_at,
"device_count": self.device_count,
}
@dataclass(frozen=True, slots=True)
class SiteDeviceMapping:
"""Mapping entry describing which site a device belongs to."""
hostname: str
site_id: Optional[int]
site_name: str
def to_dict(self) -> Dict[str, object]:
return {
"site_id": self.site_id,
"site_name": self.site_name,
}

View File

@@ -6,16 +6,40 @@ from flask import Flask
from Data.Engine.services.container import EngineServiceContainer
from . import admin, agents, enrollment, github, health, job_management, tokens
from . import (
admin,
agent,
agents,
auth,
enrollment,
github,
health,
job_management,
tokens,
users,
sites,
devices,
credentials,
assemblies,
server_info,
)
_REGISTRARS = (
health.register,
agent.register,
agents.register,
enrollment.register,
tokens.register,
job_management.register,
github.register,
auth.register,
admin.register,
users.register,
sites.register,
devices.register,
credentials.register,
assemblies.register,
server_info.register,
)

View File

@@ -1,8 +1,8 @@
"""Administrative HTTP interface placeholders for the Engine."""
"""Administrative HTTP endpoints for the Borealis Engine."""
from __future__ import annotations
from flask import Blueprint, Flask
from flask import Blueprint, Flask, current_app, jsonify, request, session
from Data.Engine.services.container import EngineServiceContainer
@@ -11,13 +11,163 @@ blueprint = Blueprint("engine_admin", __name__, url_prefix="/api/admin")
def register(app: Flask, _services: EngineServiceContainer) -> None:
"""Attach administrative routes to *app*.
Concrete endpoints will be migrated in subsequent phases.
"""
"""Attach administrative routes to *app*."""
if "engine_admin" not in app.blueprints:
app.register_blueprint(blueprint)
def _services() -> EngineServiceContainer:
services = current_app.extensions.get("engine_services")
if services is None: # pragma: no cover - defensive
raise RuntimeError("engine services not initialized")
return services
def _admin_service():
return _services().enrollment_admin_service
def _require_admin():
username = session.get("username")
role = (session.get("role") or "").strip().lower()
if not isinstance(username, str) or not username:
return jsonify({"error": "not_authenticated"}), 401
if role != "admin":
return jsonify({"error": "forbidden"}), 403
return None
@blueprint.route("/enrollment-codes", methods=["GET"])
def list_enrollment_codes() -> object:
guard = _require_admin()
if guard:
return guard
status = request.args.get("status")
records = _admin_service().list_install_codes(status=status)
return jsonify({"codes": [record.to_dict() for record in records]})
@blueprint.route("/enrollment-codes", methods=["POST"])
def create_enrollment_code() -> object:
guard = _require_admin()
if guard:
return guard
payload = request.get_json(silent=True) or {}
ttl_value = payload.get("ttl_hours")
if ttl_value is None:
ttl_value = payload.get("ttl") or 1
try:
ttl_hours = int(ttl_value)
except (TypeError, ValueError):
ttl_hours = 1
max_uses_value = payload.get("max_uses")
if max_uses_value is None:
max_uses_value = payload.get("allowed_uses", 2)
try:
max_uses = int(max_uses_value)
except (TypeError, ValueError):
max_uses = 2
creator = session.get("username") if isinstance(session.get("username"), str) else None
try:
record = _admin_service().create_install_code(
ttl_hours=ttl_hours,
max_uses=max_uses,
created_by=creator,
)
except ValueError as exc:
if str(exc) == "invalid_ttl":
return jsonify({"error": "invalid_ttl"}), 400
raise
response = jsonify(record.to_dict())
response.status_code = 201
return response
@blueprint.route("/enrollment-codes/<code_id>", methods=["DELETE"])
def delete_enrollment_code(code_id: str) -> object:
guard = _require_admin()
if guard:
return guard
if not _admin_service().delete_install_code(code_id):
return jsonify({"error": "not_found"}), 404
return jsonify({"status": "deleted"})
@blueprint.route("/device-approvals", methods=["GET"])
def list_device_approvals() -> object:
guard = _require_admin()
if guard:
return guard
status = request.args.get("status")
records = _admin_service().list_device_approvals(status=status)
return jsonify({"approvals": [record.to_dict() for record in records]})
@blueprint.route("/device-approvals/<approval_id>/approve", methods=["POST"])
def approve_device_approval(approval_id: str) -> object:
guard = _require_admin()
if guard:
return guard
payload = request.get_json(silent=True) or {}
guid = payload.get("guid")
resolution_raw = payload.get("conflict_resolution") or payload.get("resolution")
resolution = resolution_raw.strip().lower() if isinstance(resolution_raw, str) else None
actor = session.get("username") if isinstance(session.get("username"), str) else None
try:
result = _admin_service().approve_device_approval(
approval_id,
actor=actor,
guid=guid,
conflict_resolution=resolution,
)
except LookupError:
return jsonify({"error": "not_found"}), 404
except ValueError as exc:
code = str(exc)
if code == "approval_not_pending":
return jsonify({"error": "approval_not_pending"}), 409
if code == "conflict_resolution_required":
return jsonify({"error": "conflict_resolution_required"}), 409
if code == "invalid_guid":
return jsonify({"error": "invalid_guid"}), 400
raise
response = jsonify(result.to_dict())
response.status_code = 200
return response
@blueprint.route("/device-approvals/<approval_id>/deny", methods=["POST"])
def deny_device_approval(approval_id: str) -> object:
guard = _require_admin()
if guard:
return guard
actor = session.get("username") if isinstance(session.get("username"), str) else None
try:
result = _admin_service().deny_device_approval(approval_id, actor=actor)
except LookupError:
return jsonify({"error": "not_found"}), 404
except ValueError as exc:
if str(exc) == "approval_not_pending":
return jsonify({"error": "approval_not_pending"}), 409
raise
return jsonify(result.to_dict())
__all__ = ["register", "blueprint"]

View File

@@ -0,0 +1,148 @@
"""Agent REST endpoints for device communication."""
from __future__ import annotations
import math
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, cast
from flask import Blueprint, Flask, current_app, g, jsonify, request
from Data.Engine.builders.device_auth import DeviceAuthRequestBuilder
from Data.Engine.domain.device_auth import DeviceAuthContext, DeviceAuthFailure
from Data.Engine.services.container import EngineServiceContainer
from Data.Engine.services.devices.device_inventory_service import (
DeviceDetailsError,
DeviceHeartbeatError,
)
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
blueprint = Blueprint("engine_agent", __name__)
F = TypeVar("F", bound=Callable[..., Any])
def _services() -> EngineServiceContainer:
return cast(EngineServiceContainer, current_app.extensions["engine_services"])
def require_device_auth(func: F) -> F:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any):
services = _services()
builder = (
DeviceAuthRequestBuilder()
.with_authorization(request.headers.get("Authorization"))
.with_http_method(request.method)
.with_htu(request.url)
.with_service_context(request.headers.get(AGENT_CONTEXT_HEADER))
.with_dpop_proof(request.headers.get("DPoP"))
)
try:
auth_request = builder.build()
context = services.device_auth.authenticate(auth_request, path=request.path)
except DeviceAuthFailure as exc:
payload = exc.to_dict()
response = jsonify(payload)
if exc.retry_after is not None:
response.headers["Retry-After"] = str(int(math.ceil(exc.retry_after)))
return response, exc.http_status
g.device_auth = context
try:
return func(*args, **kwargs)
finally:
g.pop("device_auth", None)
return cast(F, wrapper)
def register(app: Flask, _services: EngineServiceContainer) -> None:
if "engine_agent" not in app.blueprints:
app.register_blueprint(blueprint)
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
@require_device_auth
def heartbeat() -> Any:
services = _services()
payload = request.get_json(force=True, silent=True) or {}
context = cast(DeviceAuthContext, g.device_auth)
try:
services.device_inventory.record_heartbeat(context=context, payload=payload)
except DeviceHeartbeatError as exc:
error_payload = {"error": exc.code}
if exc.code == "device_not_registered":
return jsonify(error_payload), 404
if exc.code == "storage_conflict":
return jsonify(error_payload), 409
current_app.logger.exception(
"device-heartbeat-error guid=%s code=%s", context.identity.guid.value, exc.code
)
return jsonify(error_payload), 500
return jsonify({"status": "ok", "poll_after_ms": 15000})
@blueprint.route("/api/agent/script/request", methods=["POST"])
@require_device_auth
def script_request() -> Any:
services = _services()
context = cast(DeviceAuthContext, g.device_auth)
signing_key: Optional[str] = None
signer = services.script_signer
if signer is not None:
try:
signing_key = signer.public_base64_spki()
except Exception as exc: # pragma: no cover - defensive logging
current_app.logger.warning("script-signer-unavailable: %s", exc)
status = "quarantined" if context.is_quarantined else "idle"
poll_after = 60000 if context.is_quarantined else 30000
response = {
"status": status,
"poll_after_ms": poll_after,
"sig_alg": "ed25519",
}
if signing_key:
response["signing_key"] = signing_key
return jsonify(response)
@blueprint.route("/api/agent/details", methods=["POST"])
@require_device_auth
def save_details() -> Any:
services = _services()
payload = request.get_json(force=True, silent=True) or {}
context = cast(DeviceAuthContext, g.device_auth)
try:
services.device_inventory.save_agent_details(context=context, payload=payload)
except DeviceDetailsError as exc:
error_payload = {"error": exc.code}
if exc.code == "invalid_payload":
return jsonify(error_payload), 400
if exc.code in {"fingerprint_mismatch", "guid_mismatch"}:
return jsonify(error_payload), 403
if exc.code == "device_not_registered":
return jsonify(error_payload), 404
current_app.logger.exception(
"device-details-error guid=%s code=%s", context.identity.guid.value, exc.code
)
return jsonify(error_payload), 500
return jsonify({"status": "ok"})
__all__ = [
"register",
"blueprint",
"heartbeat",
"script_request",
"save_details",
"require_device_auth",
]

View File

@@ -0,0 +1,182 @@
"""HTTP endpoints for assembly management."""
from __future__ import annotations
from flask import Blueprint, Flask, current_app, jsonify, request
from Data.Engine.services.container import EngineServiceContainer
blueprint = Blueprint("engine_assemblies", __name__)
def register(app: Flask, _services: EngineServiceContainer) -> None:
if "engine_assemblies" not in app.blueprints:
app.register_blueprint(blueprint)
def _services() -> EngineServiceContainer:
services = current_app.extensions.get("engine_services")
if services is None: # pragma: no cover - defensive
raise RuntimeError("engine services not initialized")
return services
def _assembly_service():
return _services().assembly_service
def _value_error_response(exc: ValueError):
code = str(exc)
if code == "invalid_island":
return jsonify({"error": "invalid island"}), 400
if code == "path_required":
return jsonify({"error": "path required"}), 400
if code == "invalid_kind":
return jsonify({"error": "invalid kind"}), 400
if code == "invalid_destination":
return jsonify({"error": "invalid destination"}), 400
if code == "invalid_path":
return jsonify({"error": "invalid path"}), 400
if code == "cannot_delete_root":
return jsonify({"error": "cannot delete root"}), 400
return jsonify({"error": code or "invalid request"}), 400
def _not_found_response(exc: FileNotFoundError):
code = str(exc)
if code == "file_not_found":
return jsonify({"error": "file not found"}), 404
if code == "folder_not_found":
return jsonify({"error": "folder not found"}), 404
return jsonify({"error": "not found"}), 404
@blueprint.route("/api/assembly/list", methods=["GET"])
def list_assemblies() -> object:
island = (request.args.get("island") or "").strip()
try:
listing = _assembly_service().list_items(island)
except ValueError as exc:
return _value_error_response(exc)
return jsonify(listing.to_dict())
@blueprint.route("/api/assembly/load", methods=["GET"])
def load_assembly() -> object:
island = (request.args.get("island") or "").strip()
rel_path = (request.args.get("path") or "").strip()
try:
result = _assembly_service().load_item(island, rel_path)
except ValueError as exc:
return _value_error_response(exc)
except FileNotFoundError as exc:
return _not_found_response(exc)
return jsonify(result.to_dict())
@blueprint.route("/api/assembly/create", methods=["POST"])
def create_assembly() -> object:
payload = request.get_json(silent=True) or {}
island = (payload.get("island") or "").strip()
kind = (payload.get("kind") or "").strip().lower()
rel_path = (payload.get("path") or "").strip()
content = payload.get("content")
item_type = payload.get("type")
try:
result = _assembly_service().create_item(
island,
kind=kind,
rel_path=rel_path,
content=content,
item_type=item_type if isinstance(item_type, str) else None,
)
except ValueError as exc:
return _value_error_response(exc)
return jsonify(result.to_dict())
@blueprint.route("/api/assembly/edit", methods=["POST"])
def edit_assembly() -> object:
payload = request.get_json(silent=True) or {}
island = (payload.get("island") or "").strip()
rel_path = (payload.get("path") or "").strip()
content = payload.get("content")
item_type = payload.get("type")
try:
result = _assembly_service().edit_item(
island,
rel_path=rel_path,
content=content,
item_type=item_type if isinstance(item_type, str) else None,
)
except ValueError as exc:
return _value_error_response(exc)
except FileNotFoundError as exc:
return _not_found_response(exc)
return jsonify(result.to_dict())
@blueprint.route("/api/assembly/rename", methods=["POST"])
def rename_assembly() -> object:
payload = request.get_json(silent=True) or {}
island = (payload.get("island") or "").strip()
kind = (payload.get("kind") or "").strip().lower()
rel_path = (payload.get("path") or "").strip()
new_name = (payload.get("new_name") or "").strip()
item_type = payload.get("type")
try:
result = _assembly_service().rename_item(
island,
kind=kind,
rel_path=rel_path,
new_name=new_name,
item_type=item_type if isinstance(item_type, str) else None,
)
except ValueError as exc:
return _value_error_response(exc)
except FileNotFoundError as exc:
return _not_found_response(exc)
return jsonify(result.to_dict())
@blueprint.route("/api/assembly/move", methods=["POST"])
def move_assembly() -> object:
payload = request.get_json(silent=True) or {}
island = (payload.get("island") or "").strip()
rel_path = (payload.get("path") or "").strip()
new_path = (payload.get("new_path") or "").strip()
kind = (payload.get("kind") or "").strip().lower()
try:
result = _assembly_service().move_item(
island,
rel_path=rel_path,
new_path=new_path,
kind=kind,
)
except ValueError as exc:
return _value_error_response(exc)
except FileNotFoundError as exc:
return _not_found_response(exc)
return jsonify(result.to_dict())
@blueprint.route("/api/assembly/delete", methods=["POST"])
def delete_assembly() -> object:
payload = request.get_json(silent=True) or {}
island = (payload.get("island") or "").strip()
rel_path = (payload.get("path") or "").strip()
kind = (payload.get("kind") or "").strip().lower()
try:
result = _assembly_service().delete_item(
island,
rel_path=rel_path,
kind=kind,
)
except ValueError as exc:
return _value_error_response(exc)
except FileNotFoundError as exc:
return _not_found_response(exc)
return jsonify(result.to_dict())
__all__ = ["register", "blueprint"]

View File

@@ -0,0 +1,195 @@
"""Operator authentication HTTP endpoints."""
from __future__ import annotations
from typing import Any, Dict
from flask import Blueprint, Flask, current_app, jsonify, request, session
from Data.Engine.builders import build_login_request, build_mfa_request
from Data.Engine.domain import OperatorLoginSuccess, OperatorMFAChallenge
from Data.Engine.services.auth import (
InvalidCredentialsError,
InvalidMFACodeError,
MFAUnavailableError,
MFASessionError,
OperatorAuthService,
)
from Data.Engine.services.container import EngineServiceContainer
def _service(container: EngineServiceContainer) -> OperatorAuthService:
return container.operator_auth_service
def register(app: Flask, services: EngineServiceContainer) -> None:
bp = Blueprint("auth", __name__)
@bp.route("/api/auth/login", methods=["POST"])
def login() -> Any:
payload = request.get_json(silent=True) or {}
try:
login_request = build_login_request(payload)
except ValueError as exc:
return jsonify({"error": str(exc)}), 400
service = _service(services)
try:
result = service.authenticate(login_request)
except InvalidCredentialsError:
return jsonify({"error": "invalid username or password"}), 401
except MFAUnavailableError as exc:
current_app.logger.error("mfa unavailable: %s", exc)
return jsonify({"error": str(exc)}), 500
session.pop("username", None)
session.pop("role", None)
if isinstance(result, OperatorLoginSuccess):
session.pop("mfa_pending", None)
session["username"] = result.username
session["role"] = result.role or "User"
response = jsonify(
{"status": "ok", "username": result.username, "role": result.role, "token": result.token}
)
_set_auth_cookie(response, result.token)
return response
challenge = result
session["mfa_pending"] = {
"username": challenge.username,
"role": challenge.role,
"stage": challenge.stage,
"token": challenge.pending_token,
"expires": challenge.expires_at,
"secret": challenge.secret,
}
session.modified = True
payload: Dict[str, Any] = {
"status": "mfa_required",
"stage": challenge.stage,
"pending_token": challenge.pending_token,
"username": challenge.username,
"role": challenge.role,
}
if challenge.stage == "setup":
if challenge.secret:
payload["secret"] = challenge.secret
if challenge.otpauth_url:
payload["otpauth_url"] = challenge.otpauth_url
if challenge.qr_image:
payload["qr_image"] = challenge.qr_image
return jsonify(payload)
@bp.route("/api/auth/logout", methods=["POST"])
def logout() -> Any:
session.clear()
response = jsonify({"status": "ok"})
_set_auth_cookie(response, "", expires=0)
return response
@bp.route("/api/auth/me", methods=["GET"])
def me() -> Any:
service = _service(services)
account = None
username = session.get("username")
if isinstance(username, str) and username:
account = service.fetch_account(username)
if account is None:
token = request.cookies.get("borealis_auth", "")
if not token:
auth_header = request.headers.get("Authorization", "")
if auth_header.lower().startswith("bearer "):
token = auth_header.split(None, 1)[1]
account = service.resolve_token(token)
if account is not None:
session["username"] = account.username
session["role"] = account.role or "User"
if account is None:
return jsonify({"error": "not_authenticated"}), 401
payload = {
"username": account.username,
"display_name": account.display_name or account.username,
"role": account.role,
}
return jsonify(payload)
@bp.route("/api/auth/mfa/verify", methods=["POST"])
def verify_mfa() -> Any:
pending = session.get("mfa_pending")
if not isinstance(pending, dict):
return jsonify({"error": "mfa_pending"}), 401
try:
request_payload = build_mfa_request(request.get_json(silent=True) or {})
except ValueError as exc:
return jsonify({"error": str(exc)}), 400
challenge = OperatorMFAChallenge(
username=str(pending.get("username") or ""),
role=str(pending.get("role") or "User"),
stage=str(pending.get("stage") or "verify"),
pending_token=str(pending.get("token") or ""),
expires_at=int(pending.get("expires") or 0),
secret=str(pending.get("secret") or "") or None,
)
service = _service(services)
try:
result = service.verify_mfa(challenge, request_payload)
except MFASessionError as exc:
error_key = str(exc)
status = 401 if error_key != "mfa_not_configured" else 403
if error_key not in {"expired", "invalid_session", "mfa_not_configured"}:
error_key = "invalid_session"
session.pop("mfa_pending", None)
return jsonify({"error": error_key}), status
except InvalidMFACodeError as exc:
return jsonify({"error": str(exc) or "invalid_code"}), 401
except MFAUnavailableError as exc:
current_app.logger.error("mfa unavailable: %s", exc)
return jsonify({"error": str(exc)}), 500
except InvalidCredentialsError:
session.pop("mfa_pending", None)
return jsonify({"error": "invalid username or password"}), 401
session.pop("mfa_pending", None)
session["username"] = result.username
session["role"] = result.role or "User"
payload = {
"status": "ok",
"username": result.username,
"role": result.role,
"token": result.token,
}
response = jsonify(payload)
_set_auth_cookie(response, result.token)
return response
app.register_blueprint(bp)
def _set_auth_cookie(response, value: str, *, expires: int | None = None) -> None:
same_site = current_app.config.get("SESSION_COOKIE_SAMESITE", "Lax")
secure = bool(current_app.config.get("SESSION_COOKIE_SECURE", False))
domain = current_app.config.get("SESSION_COOKIE_DOMAIN", None)
response.set_cookie(
"borealis_auth",
value,
httponly=False,
samesite=same_site,
secure=secure,
domain=domain,
path="/",
expires=expires,
)
__all__ = ["register"]

View File

@@ -0,0 +1,70 @@
from __future__ import annotations
from flask import Blueprint, Flask, current_app, jsonify, request, session
from Data.Engine.services.container import EngineServiceContainer
blueprint = Blueprint("engine_credentials", __name__)
def register(app: Flask, _services: EngineServiceContainer) -> None:
if "engine_credentials" not in app.blueprints:
app.register_blueprint(blueprint)
def _services() -> EngineServiceContainer:
services = current_app.extensions.get("engine_services")
if services is None: # pragma: no cover - defensive
raise RuntimeError("engine services not initialized")
return services
def _credentials_service():
return _services().credential_service
def _require_admin():
username = session.get("username")
role = (session.get("role") or "").strip().lower()
if not isinstance(username, str) or not username:
return jsonify({"error": "not_authenticated"}), 401
if role != "admin":
return jsonify({"error": "forbidden"}), 403
return None
@blueprint.route("/api/credentials", methods=["GET"])
def list_credentials() -> object:
guard = _require_admin()
if guard:
return guard
site_id_param = request.args.get("site_id")
connection_type = (request.args.get("connection_type") or "").strip() or None
try:
site_id = int(site_id_param) if site_id_param not in (None, "") else None
except (TypeError, ValueError):
site_id = None
records = _credentials_service().list_credentials(
site_id=site_id,
connection_type=connection_type,
)
return jsonify({"credentials": records})
@blueprint.route("/api/credentials", methods=["POST"])
def create_credential() -> object: # pragma: no cover - placeholder
return jsonify({"error": "not implemented"}), 501
@blueprint.route("/api/credentials/<int:credential_id>", methods=["GET", "PUT", "DELETE"])
def credential_detail(credential_id: int) -> object: # pragma: no cover - placeholder
if request.method == "GET":
return jsonify({"error": "not implemented"}), 501
if request.method == "DELETE":
return jsonify({"error": "not implemented"}), 501
return jsonify({"error": "not implemented"}), 501
__all__ = ["register", "blueprint"]

View File

@@ -0,0 +1,319 @@
from __future__ import annotations
from ipaddress import ip_address
from flask import Blueprint, Flask, current_app, jsonify, request, session
from Data.Engine.services.container import EngineServiceContainer
from Data.Engine.services.devices import DeviceDescriptionError, RemoteDeviceError
blueprint = Blueprint("engine_devices", __name__)
def register(app: Flask, _services: EngineServiceContainer) -> None:
if "engine_devices" not in app.blueprints:
app.register_blueprint(blueprint)
def _services() -> EngineServiceContainer:
services = current_app.extensions.get("engine_services")
if services is None: # pragma: no cover - defensive
raise RuntimeError("engine services not initialized")
return services
def _inventory():
return _services().device_inventory
def _views():
return _services().device_view_service
def _require_admin():
username = session.get("username")
role = (session.get("role") or "").strip().lower()
if not isinstance(username, str) or not username:
return jsonify({"error": "not_authenticated"}), 401
if role != "admin":
return jsonify({"error": "forbidden"}), 403
return None
def _is_internal_request(req: request) -> bool:
remote = (req.remote_addr or "").strip()
if not remote:
return False
try:
return ip_address(remote).is_loopback
except ValueError:
return remote in {"localhost"}
@blueprint.route("/api/devices", methods=["GET"])
def list_devices() -> object:
devices = _inventory().list_devices()
return jsonify({"devices": devices})
@blueprint.route("/api/devices/<guid>", methods=["GET"])
def get_device_by_guid(guid: str) -> object:
device = _inventory().get_device_by_guid(guid)
if not device:
return jsonify({"error": "not found"}), 404
return jsonify(device)
@blueprint.route("/api/device/description/<hostname>", methods=["POST"])
def set_device_description(hostname: str) -> object:
payload = request.get_json(silent=True) or {}
description = payload.get("description")
try:
_inventory().update_device_description(hostname, description)
except DeviceDescriptionError as exc:
if exc.code == "invalid_hostname":
return jsonify({"error": "invalid hostname"}), 400
if exc.code == "not_found":
return jsonify({"error": "not found"}), 404
current_app.logger.exception(
"device-description-error host=%s code=%s", hostname, exc.code
)
return jsonify({"error": "internal error"}), 500
return jsonify({"status": "ok"})
@blueprint.route("/api/agent_devices", methods=["GET"])
def list_agent_devices() -> object:
guard = _require_admin()
if guard:
return guard
devices = _inventory().list_agent_devices()
return jsonify({"devices": devices})
@blueprint.route("/api/ssh_devices", methods=["GET", "POST"])
def ssh_devices() -> object:
return _remote_devices_endpoint("ssh")
@blueprint.route("/api/winrm_devices", methods=["GET", "POST"])
def winrm_devices() -> object:
return _remote_devices_endpoint("winrm")
@blueprint.route("/api/ssh_devices/<hostname>", methods=["PUT", "DELETE"])
def ssh_device_detail(hostname: str) -> object:
return _remote_device_detail("ssh", hostname)
@blueprint.route("/api/winrm_devices/<hostname>", methods=["PUT", "DELETE"])
def winrm_device_detail(hostname: str) -> object:
return _remote_device_detail("winrm", hostname)
@blueprint.route("/api/agent/hash_list", methods=["GET"])
def agent_hash_list() -> object:
if not _is_internal_request(request):
remote_addr = (request.remote_addr or "unknown").strip() or "unknown"
current_app.logger.warning(
"/api/agent/hash_list denied non-local request from %s", remote_addr
)
return jsonify({"error": "forbidden"}), 403
try:
records = _inventory().collect_agent_hash_records()
except Exception as exc: # pragma: no cover - defensive logging
current_app.logger.exception("/api/agent/hash_list error: %s", exc)
return jsonify({"error": "internal error"}), 500
return jsonify({"agents": records})
@blueprint.route("/api/device_list_views", methods=["GET"])
def list_device_list_views() -> object:
views = _views().list_views()
return jsonify({"views": [view.to_dict() for view in views]})
@blueprint.route("/api/device_list_views/<int:view_id>", methods=["GET"])
def get_device_list_view(view_id: int) -> object:
view = _views().get_view(view_id)
if not view:
return jsonify({"error": "not found"}), 404
return jsonify(view.to_dict())
@blueprint.route("/api/device_list_views", methods=["POST"])
def create_device_list_view() -> object:
payload = request.get_json(silent=True) or {}
name = (payload.get("name") or "").strip()
columns = payload.get("columns") or []
filters = payload.get("filters") or {}
if not name:
return jsonify({"error": "name is required"}), 400
if name.lower() == "default view":
return jsonify({"error": "reserved name"}), 400
if not isinstance(columns, list) or not all(isinstance(x, str) for x in columns):
return jsonify({"error": "columns must be a list of strings"}), 400
if not isinstance(filters, dict):
return jsonify({"error": "filters must be an object"}), 400
try:
view = _views().create_view(name, columns, filters)
except ValueError as exc:
if str(exc) == "duplicate":
return jsonify({"error": "name already exists"}), 409
raise
response = jsonify(view.to_dict())
response.status_code = 201
return response
@blueprint.route("/api/device_list_views/<int:view_id>", methods=["PUT"])
def update_device_list_view(view_id: int) -> object:
payload = request.get_json(silent=True) or {}
updates: dict = {}
if "name" in payload:
name_val = payload.get("name")
if name_val is None:
return jsonify({"error": "name cannot be empty"}), 400
normalized = (str(name_val) or "").strip()
if not normalized:
return jsonify({"error": "name cannot be empty"}), 400
if normalized.lower() == "default view":
return jsonify({"error": "reserved name"}), 400
updates["name"] = normalized
if "columns" in payload:
columns_val = payload.get("columns")
if not isinstance(columns_val, list) or not all(isinstance(x, str) for x in columns_val):
return jsonify({"error": "columns must be a list of strings"}), 400
updates["columns"] = columns_val
if "filters" in payload:
filters_val = payload.get("filters")
if filters_val is not None and not isinstance(filters_val, dict):
return jsonify({"error": "filters must be an object"}), 400
if filters_val is not None:
updates["filters"] = filters_val
if not updates:
return jsonify({"error": "no fields to update"}), 400
try:
view = _views().update_view(
view_id,
name=updates.get("name"),
columns=updates.get("columns"),
filters=updates.get("filters"),
)
except ValueError as exc:
code = str(exc)
if code == "duplicate":
return jsonify({"error": "name already exists"}), 409
if code == "missing_name":
return jsonify({"error": "name cannot be empty"}), 400
if code == "reserved":
return jsonify({"error": "reserved name"}), 400
return jsonify({"error": "invalid payload"}), 400
except LookupError:
return jsonify({"error": "not found"}), 404
return jsonify(view.to_dict())
@blueprint.route("/api/device_list_views/<int:view_id>", methods=["DELETE"])
def delete_device_list_view(view_id: int) -> object:
if not _views().delete_view(view_id):
return jsonify({"error": "not found"}), 404
return jsonify({"status": "ok"})
def _remote_devices_endpoint(connection_type: str) -> object:
guard = _require_admin()
if guard:
return guard
if request.method == "GET":
devices = _inventory().list_remote_devices(connection_type)
return jsonify({"devices": devices})
payload = request.get_json(silent=True) or {}
hostname = (payload.get("hostname") or "").strip()
address = (
payload.get("address")
or payload.get("connection_endpoint")
or payload.get("endpoint")
or payload.get("host")
)
description = payload.get("description")
os_hint = payload.get("operating_system") or payload.get("os")
if not hostname:
return jsonify({"error": "hostname is required"}), 400
if not (address or "").strip():
return jsonify({"error": "address is required"}), 400
try:
device = _inventory().upsert_remote_device(
connection_type,
hostname,
address,
description,
os_hint,
ensure_existing_type=None,
)
except RemoteDeviceError as exc:
status = 409 if exc.code in {"conflict", "address_required"} else 500
if exc.code == "conflict":
return jsonify({"error": str(exc)}), 409
if exc.code == "address_required":
return jsonify({"error": "address is required"}), 400
return jsonify({"error": str(exc)}), status
return jsonify({"device": device}), 201
def _remote_device_detail(connection_type: str, hostname: str) -> object:
guard = _require_admin()
if guard:
return guard
normalized_host = (hostname or "").strip()
if not normalized_host:
return jsonify({"error": "invalid hostname"}), 400
if request.method == "DELETE":
try:
_inventory().delete_remote_device(connection_type, normalized_host)
except RemoteDeviceError as exc:
if exc.code == "not_found":
return jsonify({"error": "device not found"}), 404
if exc.code == "invalid_hostname":
return jsonify({"error": "invalid hostname"}), 400
return jsonify({"error": str(exc)}), 500
return jsonify({"status": "ok"})
payload = request.get_json(silent=True) or {}
address = (
payload.get("address")
or payload.get("connection_endpoint")
or payload.get("endpoint")
)
description = payload.get("description")
os_hint = payload.get("operating_system") or payload.get("os")
if address is None and description is None and os_hint is None:
return jsonify({"error": "no fields to update"}), 400
try:
device = _inventory().upsert_remote_device(
connection_type,
normalized_host,
address if address is not None else "",
description,
os_hint,
ensure_existing_type=connection_type,
)
except RemoteDeviceError as exc:
if exc.code == "not_found":
return jsonify({"error": "device not found"}), 404
if exc.code == "address_required":
return jsonify({"error": "address is required"}), 400
return jsonify({"error": str(exc)}), 500
return jsonify({"device": device})
__all__ = ["register", "blueprint"]

View File

@@ -0,0 +1,53 @@
"""Server metadata endpoints."""
from __future__ import annotations
from datetime import datetime, timezone
from flask import Blueprint, Flask, jsonify
from Data.Engine.services.container import EngineServiceContainer
blueprint = Blueprint("engine_server_info", __name__)
def register(app: Flask, _services: EngineServiceContainer) -> None:
if "engine_server_info" not in app.blueprints:
app.register_blueprint(blueprint)
@blueprint.route("/api/server/time", methods=["GET"])
def server_time() -> object:
now_local = datetime.now().astimezone()
now_utc = datetime.now(timezone.utc)
tzinfo = now_local.tzinfo
offset = tzinfo.utcoffset(now_local) if tzinfo else None
def _ordinal(n: int) -> str:
if 11 <= (n % 100) <= 13:
suffix = "th"
else:
suffix = {1: "st", 2: "nd", 3: "rd"}.get(n % 10, "th")
return f"{n}{suffix}"
month = now_local.strftime("%B")
day_disp = _ordinal(now_local.day)
year = now_local.strftime("%Y")
hour24 = now_local.hour
hour12 = hour24 % 12 or 12
minute = now_local.minute
ampm = "AM" if hour24 < 12 else "PM"
display = f"{month} {day_disp} {year} @ {hour12}:{minute:02d}{ampm}"
payload = {
"epoch": int(now_local.timestamp()),
"iso": now_local.isoformat(),
"utc_iso": now_utc.isoformat().replace("+00:00", "Z"),
"timezone": str(tzinfo) if tzinfo else "",
"offset_seconds": int(offset.total_seconds()) if offset else 0,
"display": display,
}
return jsonify(payload)
__all__ = ["register", "blueprint"]

View File

@@ -0,0 +1,112 @@
from __future__ import annotations
from flask import Blueprint, Flask, current_app, jsonify, request
from Data.Engine.services.container import EngineServiceContainer
blueprint = Blueprint("engine_sites", __name__)
def register(app: Flask, _services: EngineServiceContainer) -> None:
if "engine_sites" not in app.blueprints:
app.register_blueprint(blueprint)
def _services() -> EngineServiceContainer:
services = current_app.extensions.get("engine_services")
if services is None: # pragma: no cover - defensive
raise RuntimeError("engine services not initialized")
return services
def _site_service():
return _services().site_service
@blueprint.route("/api/sites", methods=["GET"])
def list_sites() -> object:
records = _site_service().list_sites()
return jsonify({"sites": [record.to_dict() for record in records]})
@blueprint.route("/api/sites", methods=["POST"])
def create_site() -> object:
payload = request.get_json(silent=True) or {}
name = payload.get("name")
description = payload.get("description")
try:
record = _site_service().create_site(name or "", description or "")
except ValueError as exc:
if str(exc) == "missing_name":
return jsonify({"error": "name is required"}), 400
if str(exc) == "duplicate":
return jsonify({"error": "name already exists"}), 409
raise
response = jsonify(record.to_dict())
response.status_code = 201
return response
@blueprint.route("/api/sites/delete", methods=["POST"])
def delete_sites() -> object:
payload = request.get_json(silent=True) or {}
ids = payload.get("ids") or []
if not isinstance(ids, list):
return jsonify({"error": "ids must be a list"}), 400
deleted = _site_service().delete_sites(ids)
return jsonify({"status": "ok", "deleted": deleted})
@blueprint.route("/api/sites/device_map", methods=["GET"])
def sites_device_map() -> object:
host_param = (request.args.get("hostnames") or "").strip()
filter_set = []
if host_param:
for part in host_param.split(","):
normalized = part.strip()
if normalized:
filter_set.append(normalized)
mapping = _site_service().map_devices(filter_set or None)
return jsonify({"mapping": {hostname: entry.to_dict() for hostname, entry in mapping.items()}})
@blueprint.route("/api/sites/assign", methods=["POST"])
def assign_devices_to_site() -> object:
payload = request.get_json(silent=True) or {}
site_id = payload.get("site_id")
hostnames = payload.get("hostnames") or []
if not isinstance(hostnames, list):
return jsonify({"error": "hostnames must be a list of strings"}), 400
try:
_site_service().assign_devices(site_id, hostnames)
except ValueError as exc:
message = str(exc)
if message == "invalid_site_id":
return jsonify({"error": "invalid site_id"}), 400
if message == "invalid_hostnames":
return jsonify({"error": "hostnames must be a list of strings"}), 400
raise
except LookupError:
return jsonify({"error": "site not found"}), 404
return jsonify({"status": "ok"})
@blueprint.route("/api/sites/rename", methods=["POST"])
def rename_site() -> object:
payload = request.get_json(silent=True) or {}
site_id = payload.get("id")
new_name = payload.get("new_name") or ""
try:
record = _site_service().rename_site(site_id, new_name)
except ValueError as exc:
if str(exc) == "missing_name":
return jsonify({"error": "new_name is required"}), 400
if str(exc) == "duplicate":
return jsonify({"error": "name already exists"}), 409
raise
except LookupError:
return jsonify({"error": "site not found"}), 404
return jsonify(record.to_dict())
__all__ = ["register", "blueprint"]

View File

@@ -0,0 +1,185 @@
"""HTTP endpoints for operator account management."""
from __future__ import annotations
from flask import Blueprint, Flask, jsonify, request, session
from Data.Engine.services.auth import (
AccountNotFoundError,
CannotModifySelfError,
InvalidPasswordHashError,
InvalidRoleError,
LastAdminError,
LastUserError,
OperatorAccountService,
UsernameAlreadyExistsError,
)
from Data.Engine.services.container import EngineServiceContainer
blueprint = Blueprint("engine_users", __name__)
def register(app: Flask, services: EngineServiceContainer) -> None:
blueprint.services = services # type: ignore[attr-defined]
app.register_blueprint(blueprint)
def _services() -> EngineServiceContainer:
svc = getattr(blueprint, "services", None)
if svc is None: # pragma: no cover - defensive
raise RuntimeError("user blueprint not initialized")
return svc
def _accounts() -> OperatorAccountService:
return _services().operator_account_service
def _require_admin():
username = session.get("username")
role = (session.get("role") or "").strip().lower()
if not isinstance(username, str) or not username:
return jsonify({"error": "not_authenticated"}), 401
if role != "admin":
return jsonify({"error": "forbidden"}), 403
return None
def _format_user(record) -> dict[str, object]:
return {
"username": record.username,
"display_name": record.display_name,
"role": record.role,
"last_login": record.last_login,
"created_at": record.created_at,
"updated_at": record.updated_at,
"mfa_enabled": 1 if record.mfa_enabled else 0,
}
@blueprint.route("/api/users", methods=["GET"])
def list_users() -> object:
guard = _require_admin()
if guard:
return guard
records = _accounts().list_accounts()
return jsonify({"users": [_format_user(record) for record in records]})
@blueprint.route("/api/users", methods=["POST"])
def create_user() -> object:
guard = _require_admin()
if guard:
return guard
payload = request.get_json(silent=True) or {}
username = str(payload.get("username") or "").strip()
password_sha512 = str(payload.get("password_sha512") or "").strip()
role = str(payload.get("role") or "User")
display_name = str(payload.get("display_name") or username)
try:
_accounts().create_account(
username=username,
password_sha512=password_sha512,
role=role,
display_name=display_name,
)
except UsernameAlreadyExistsError as exc:
return jsonify({"error": str(exc)}), 409
except (InvalidPasswordHashError, InvalidRoleError) as exc:
return jsonify({"error": str(exc)}), 400
return jsonify({"status": "ok"})
@blueprint.route("/api/users/<username>", methods=["DELETE"])
def delete_user(username: str) -> object:
guard = _require_admin()
if guard:
return guard
actor = session.get("username") if isinstance(session.get("username"), str) else None
try:
_accounts().delete_account(username, actor=actor)
except CannotModifySelfError as exc:
return jsonify({"error": str(exc)}), 400
except LastUserError as exc:
return jsonify({"error": str(exc)}), 400
except LastAdminError as exc:
return jsonify({"error": str(exc)}), 400
except AccountNotFoundError as exc:
return jsonify({"error": str(exc)}), 404
return jsonify({"status": "ok"})
@blueprint.route("/api/users/<username>/reset_password", methods=["POST"])
def reset_password(username: str) -> object:
guard = _require_admin()
if guard:
return guard
payload = request.get_json(silent=True) or {}
password_sha512 = str(payload.get("password_sha512") or "").strip()
try:
_accounts().reset_password(username, password_sha512)
except InvalidPasswordHashError as exc:
return jsonify({"error": str(exc)}), 400
except AccountNotFoundError as exc:
return jsonify({"error": str(exc)}), 404
return jsonify({"status": "ok"})
@blueprint.route("/api/users/<username>/role", methods=["POST"])
def change_role(username: str) -> object:
guard = _require_admin()
if guard:
return guard
payload = request.get_json(silent=True) or {}
role = str(payload.get("role") or "").strip()
actor = session.get("username") if isinstance(session.get("username"), str) else None
try:
record = _accounts().change_role(username, role, actor=actor)
except InvalidRoleError as exc:
return jsonify({"error": str(exc)}), 400
except LastAdminError as exc:
return jsonify({"error": str(exc)}), 400
except AccountNotFoundError as exc:
return jsonify({"error": str(exc)}), 404
if actor and actor.strip().lower() == username.strip().lower():
session["role"] = record.role
return jsonify({"status": "ok"})
@blueprint.route("/api/users/<username>/mfa", methods=["POST"])
def update_mfa(username: str) -> object:
guard = _require_admin()
if guard:
return guard
payload = request.get_json(silent=True) or {}
enabled = bool(payload.get("enabled", False))
reset_secret = bool(payload.get("reset_secret", False))
try:
_accounts().update_mfa(username, enabled=enabled, reset_secret=reset_secret)
except AccountNotFoundError as exc:
return jsonify({"error": str(exc)}), 404
actor = session.get("username") if isinstance(session.get("username"), str) else None
if actor and actor.strip().lower() == username.strip().lower() and not enabled:
session.pop("mfa_pending", None)
return jsonify({"status": "ok"})
__all__ = ["register", "blueprint"]

View File

@@ -9,7 +9,7 @@ from .connection import (
connection_factory,
connection_scope,
)
from .migrations import apply_all
from .migrations import apply_all, ensure_default_admin
__all__ = [
"SQLiteConnectionFactory",
@@ -18,14 +18,20 @@ __all__ = [
"connection_factory",
"connection_scope",
"apply_all",
"ensure_default_admin",
]
try: # pragma: no cover - optional dependency shim
from .device_repository import SQLiteDeviceRepository
from .enrollment_repository import SQLiteEnrollmentRepository
from .device_inventory_repository import SQLiteDeviceInventoryRepository
from .device_view_repository import SQLiteDeviceViewRepository
from .credential_repository import SQLiteCredentialRepository
from .github_repository import SQLiteGitHubRepository
from .job_repository import SQLiteJobRepository
from .site_repository import SQLiteSiteRepository
from .token_repository import SQLiteRefreshTokenRepository
from .user_repository import SQLiteUserRepository
except ModuleNotFoundError as exc: # pragma: no cover - triggered when auth deps missing
def _missing_repo(*_args: object, **_kwargs: object) -> None:
raise ModuleNotFoundError(
@@ -34,8 +40,12 @@ except ModuleNotFoundError as exc: # pragma: no cover - triggered when auth dep
SQLiteDeviceRepository = _missing_repo # type: ignore[assignment]
SQLiteEnrollmentRepository = _missing_repo # type: ignore[assignment]
SQLiteDeviceInventoryRepository = _missing_repo # type: ignore[assignment]
SQLiteDeviceViewRepository = _missing_repo # type: ignore[assignment]
SQLiteCredentialRepository = _missing_repo # type: ignore[assignment]
SQLiteGitHubRepository = _missing_repo # type: ignore[assignment]
SQLiteJobRepository = _missing_repo # type: ignore[assignment]
SQLiteSiteRepository = _missing_repo # type: ignore[assignment]
SQLiteRefreshTokenRepository = _missing_repo # type: ignore[assignment]
else:
__all__ += [
@@ -43,5 +53,10 @@ else:
"SQLiteRefreshTokenRepository",
"SQLiteJobRepository",
"SQLiteEnrollmentRepository",
"SQLiteDeviceInventoryRepository",
"SQLiteDeviceViewRepository",
"SQLiteCredentialRepository",
"SQLiteGitHubRepository",
"SQLiteUserRepository",
"SQLiteSiteRepository",
]

View File

@@ -0,0 +1,103 @@
"""SQLite access for operator credential metadata."""
from __future__ import annotations
import json
import logging
import sqlite3
from contextlib import closing
from typing import Dict, List, Optional
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
__all__ = ["SQLiteCredentialRepository"]
class SQLiteCredentialRepository:
def __init__(
self,
connection_factory: SQLiteConnectionFactory,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._connections = connection_factory
self._log = logger or logging.getLogger("borealis.engine.repositories.credentials")
def list_credentials(
self,
*,
site_id: Optional[int] = None,
connection_type: Optional[str] = None,
) -> List[Dict[str, object]]:
sql = """
SELECT c.id,
c.name,
c.description,
c.credential_type,
c.connection_type,
c.username,
c.site_id,
s.name AS site_name,
c.become_method,
c.become_username,
c.metadata_json,
c.created_at,
c.updated_at,
c.password_encrypted,
c.private_key_encrypted,
c.private_key_passphrase_encrypted,
c.become_password_encrypted
FROM credentials c
LEFT JOIN sites s ON s.id = c.site_id
"""
clauses: List[str] = []
params: List[object] = []
if site_id is not None:
clauses.append("c.site_id = ?")
params.append(site_id)
if connection_type:
clauses.append("LOWER(c.connection_type) = LOWER(?)")
params.append(connection_type)
if clauses:
sql += " WHERE " + " AND ".join(clauses)
sql += " ORDER BY LOWER(c.name) ASC"
with closing(self._connections()) as conn:
conn.row_factory = sqlite3.Row # type: ignore[attr-defined]
cur = conn.cursor()
cur.execute(sql, params)
rows = cur.fetchall()
results: List[Dict[str, object]] = []
for row in rows:
metadata_json = row["metadata_json"] if "metadata_json" in row.keys() else None
metadata = {}
if metadata_json:
try:
candidate = json.loads(metadata_json)
if isinstance(candidate, dict):
metadata = candidate
except Exception:
metadata = {}
results.append(
{
"id": row["id"],
"name": row["name"],
"description": row["description"] or "",
"credential_type": row["credential_type"] or "machine",
"connection_type": row["connection_type"] or "ssh",
"site_id": row["site_id"],
"site_name": row["site_name"],
"username": row["username"] or "",
"become_method": row["become_method"] or "",
"become_username": row["become_username"] or "",
"metadata": metadata,
"created_at": int(row["created_at"] or 0),
"updated_at": int(row["updated_at"] or 0),
"has_password": bool(row["password_encrypted"]),
"has_private_key": bool(row["private_key_encrypted"]),
"has_private_key_passphrase": bool(row["private_key_passphrase_encrypted"]),
"has_become_password": bool(row["become_password_encrypted"]),
}
)
return results

View File

@@ -0,0 +1,338 @@
"""Device inventory operations backed by SQLite."""
from __future__ import annotations
import logging
import sqlite3
import time
import uuid
from contextlib import closing
from typing import Any, Dict, List, Optional, Tuple
from Data.Engine.domain.devices import (
DEVICE_TABLE,
DEVICE_TABLE_COLUMNS,
assemble_device_snapshot,
clean_device_str,
coerce_int,
device_column_sql,
row_to_device_dict,
serialize_device_json,
)
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
__all__ = ["SQLiteDeviceInventoryRepository"]
class SQLiteDeviceInventoryRepository:
def __init__(
self,
connection_factory: SQLiteConnectionFactory,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._connections = connection_factory
self._log = logger or logging.getLogger("borealis.engine.repositories.device_inventory")
def fetch_devices(
self,
*,
connection_type: Optional[str] = None,
hostname: Optional[str] = None,
only_agents: bool = False,
) -> List[Dict[str, Any]]:
sql = f"""
SELECT {device_column_sql('d')}, s.id, s.name, s.description
FROM {DEVICE_TABLE} d
LEFT JOIN device_sites ds ON ds.device_hostname = d.hostname
LEFT JOIN sites s ON s.id = ds.site_id
"""
clauses: List[str] = []
params: List[Any] = []
if connection_type:
clauses.append("LOWER(d.connection_type) = LOWER(?)")
params.append(connection_type)
if hostname:
clauses.append("LOWER(d.hostname) = LOWER(?)")
params.append(hostname.lower())
if only_agents:
clauses.append("(d.connection_type IS NULL OR TRIM(d.connection_type) = '')")
if clauses:
sql += " WHERE " + " AND ".join(clauses)
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(sql, params)
rows = cur.fetchall()
now = time.time()
devices: List[Dict[str, Any]] = []
for row in rows:
core = row[: len(DEVICE_TABLE_COLUMNS)]
site_id, site_name, site_description = row[len(DEVICE_TABLE_COLUMNS) :]
record = row_to_device_dict(core, DEVICE_TABLE_COLUMNS)
snapshot = assemble_device_snapshot(record)
summary = snapshot.get("summary", {})
last_seen = snapshot.get("last_seen") or 0
status = "Offline"
try:
if last_seen and (now - float(last_seen)) <= 300:
status = "Online"
except Exception:
pass
devices.append(
{
**snapshot,
"site_id": site_id,
"site_name": site_name or "",
"site_description": site_description or "",
"status": status,
}
)
return devices
def load_snapshot(self, *, hostname: Optional[str] = None, guid: Optional[str] = None) -> Optional[Dict[str, Any]]:
if not hostname and not guid:
return None
sql = None
params: Tuple[Any, ...]
if hostname:
sql = f"SELECT {device_column_sql()} FROM {DEVICE_TABLE} WHERE hostname = ?"
params = (hostname,)
else:
sql = f"SELECT {device_column_sql()} FROM {DEVICE_TABLE} WHERE LOWER(guid) = LOWER(?)"
params = (guid,)
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(sql, params)
row = cur.fetchone()
if not row:
return None
record = row_to_device_dict(row, DEVICE_TABLE_COLUMNS)
return assemble_device_snapshot(record)
def upsert_device(
self,
hostname: str,
description: Optional[str],
merged_details: Dict[str, Any],
created_at: Optional[int],
*,
agent_hash: Optional[str] = None,
guid: Optional[str] = None,
) -> None:
if not hostname:
return
column_values = self._extract_device_columns(merged_details or {})
normalized_description = description if description is not None else ""
try:
normalized_description = str(normalized_description)
except Exception:
normalized_description = ""
normalized_hash = clean_device_str(agent_hash) or None
normalized_guid = clean_device_str(guid) or None
created_ts = coerce_int(created_at) or int(time.time())
sql = f"""
INSERT INTO {DEVICE_TABLE}(
hostname,
description,
created_at,
agent_hash,
guid,
memory,
network,
software,
storage,
cpu,
device_type,
domain,
external_ip,
internal_ip,
last_reboot,
last_seen,
last_user,
operating_system,
uptime,
agent_id,
ansible_ee_ver,
connection_type,
connection_endpoint,
ssl_key_fingerprint,
token_version,
status,
key_added_at
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
ON CONFLICT(hostname) DO UPDATE SET
description=excluded.description,
created_at=COALESCE({DEVICE_TABLE}.created_at, excluded.created_at),
agent_hash=COALESCE(NULLIF(excluded.agent_hash, ''), {DEVICE_TABLE}.agent_hash),
guid=COALESCE(NULLIF(excluded.guid, ''), {DEVICE_TABLE}.guid),
memory=excluded.memory,
network=excluded.network,
software=excluded.software,
storage=excluded.storage,
cpu=excluded.cpu,
device_type=COALESCE(NULLIF(excluded.device_type, ''), {DEVICE_TABLE}.device_type),
domain=COALESCE(NULLIF(excluded.domain, ''), {DEVICE_TABLE}.domain),
external_ip=COALESCE(NULLIF(excluded.external_ip, ''), {DEVICE_TABLE}.external_ip),
internal_ip=COALESCE(NULLIF(excluded.internal_ip, ''), {DEVICE_TABLE}.internal_ip),
last_reboot=COALESCE(NULLIF(excluded.last_reboot, ''), {DEVICE_TABLE}.last_reboot),
last_seen=COALESCE(NULLIF(excluded.last_seen, 0), {DEVICE_TABLE}.last_seen),
last_user=COALESCE(NULLIF(excluded.last_user, ''), {DEVICE_TABLE}.last_user),
operating_system=COALESCE(NULLIF(excluded.operating_system, ''), {DEVICE_TABLE}.operating_system),
uptime=COALESCE(NULLIF(excluded.uptime, 0), {DEVICE_TABLE}.uptime),
agent_id=COALESCE(NULLIF(excluded.agent_id, ''), {DEVICE_TABLE}.agent_id),
ansible_ee_ver=COALESCE(NULLIF(excluded.ansible_ee_ver, ''), {DEVICE_TABLE}.ansible_ee_ver),
connection_type=COALESCE(NULLIF(excluded.connection_type, ''), {DEVICE_TABLE}.connection_type),
connection_endpoint=COALESCE(NULLIF(excluded.connection_endpoint, ''), {DEVICE_TABLE}.connection_endpoint),
ssl_key_fingerprint=COALESCE(NULLIF(excluded.ssl_key_fingerprint, ''), {DEVICE_TABLE}.ssl_key_fingerprint),
token_version=COALESCE(NULLIF(excluded.token_version, 0), {DEVICE_TABLE}.token_version),
status=COALESCE(NULLIF(excluded.status, ''), {DEVICE_TABLE}.status),
key_added_at=COALESCE(NULLIF(excluded.key_added_at, ''), {DEVICE_TABLE}.key_added_at)
"""
params: List[Any] = [
hostname,
normalized_description,
created_ts,
normalized_hash,
normalized_guid,
column_values.get("memory"),
column_values.get("network"),
column_values.get("software"),
column_values.get("storage"),
column_values.get("cpu"),
column_values.get("device_type"),
column_values.get("domain"),
column_values.get("external_ip"),
column_values.get("internal_ip"),
column_values.get("last_reboot"),
column_values.get("last_seen"),
column_values.get("last_user"),
column_values.get("operating_system"),
column_values.get("uptime"),
column_values.get("agent_id"),
column_values.get("ansible_ee_ver"),
column_values.get("connection_type"),
column_values.get("connection_endpoint"),
column_values.get("ssl_key_fingerprint"),
column_values.get("token_version"),
column_values.get("status"),
column_values.get("key_added_at"),
]
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(sql, params)
conn.commit()
def delete_device_by_hostname(self, hostname: str) -> None:
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute("DELETE FROM device_sites WHERE device_hostname = ?", (hostname,))
cur.execute(f"DELETE FROM {DEVICE_TABLE} WHERE hostname = ?", (hostname,))
conn.commit()
def record_device_fingerprint(self, guid: Optional[str], fingerprint: Optional[str], added_at: str) -> None:
normalized_guid = clean_device_str(guid)
normalized_fp = clean_device_str(fingerprint)
if not normalized_guid or not normalized_fp:
return
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
VALUES (?, ?, ?, ?)
""",
(str(uuid.uuid4()), normalized_guid, normalized_fp.lower(), added_at),
)
cur.execute(
"""
UPDATE device_keys
SET retired_at = ?
WHERE guid = ?
AND ssl_key_fingerprint != ?
AND retired_at IS NULL
""",
(added_at, normalized_guid, normalized_fp.lower()),
)
cur.execute(
"""
UPDATE devices
SET ssl_key_fingerprint = COALESCE(LOWER(?), ssl_key_fingerprint),
key_added_at = COALESCE(key_added_at, ?)
WHERE LOWER(guid) = LOWER(?)
""",
(normalized_fp, added_at, normalized_guid),
)
conn.commit()
def _extract_device_columns(self, details: Dict[str, Any]) -> Dict[str, Any]:
summary = details.get("summary") or {}
payload: Dict[str, Any] = {}
for field in ("memory", "network", "software", "storage"):
payload[field] = serialize_device_json(details.get(field), [])
payload["cpu"] = serialize_device_json(summary.get("cpu") or details.get("cpu"), {})
payload["device_type"] = clean_device_str(
summary.get("device_type")
or summary.get("type")
or summary.get("device_class")
)
payload["domain"] = clean_device_str(
summary.get("domain") or summary.get("domain_name")
)
payload["external_ip"] = clean_device_str(
summary.get("external_ip") or summary.get("public_ip")
)
payload["internal_ip"] = clean_device_str(
summary.get("internal_ip") or summary.get("private_ip")
)
payload["last_reboot"] = clean_device_str(
summary.get("last_reboot") or summary.get("last_boot")
)
payload["last_seen"] = coerce_int(
summary.get("last_seen") or summary.get("last_seen_epoch")
)
payload["last_user"] = clean_device_str(
summary.get("last_user")
or summary.get("last_user_name")
or summary.get("logged_in_user")
or summary.get("username")
or summary.get("user")
)
payload["operating_system"] = clean_device_str(
summary.get("operating_system")
or summary.get("agent_operating_system")
or summary.get("os")
)
uptime_value = (
summary.get("uptime_sec")
or summary.get("uptime_seconds")
or summary.get("uptime")
)
payload["uptime"] = coerce_int(uptime_value)
payload["agent_id"] = clean_device_str(summary.get("agent_id"))
payload["ansible_ee_ver"] = clean_device_str(summary.get("ansible_ee_ver"))
payload["connection_type"] = clean_device_str(
summary.get("connection_type") or summary.get("remote_type")
)
payload["connection_endpoint"] = clean_device_str(
summary.get("connection_endpoint")
or summary.get("endpoint")
or summary.get("connection_address")
or summary.get("address")
or summary.get("external_ip")
or summary.get("internal_ip")
)
payload["ssl_key_fingerprint"] = clean_device_str(
summary.get("ssl_key_fingerprint")
)
payload["token_version"] = coerce_int(summary.get("token_version")) or 0
payload["status"] = clean_device_str(summary.get("status"))
payload["key_added_at"] = clean_device_str(summary.get("key_added_at"))
return payload

View File

@@ -0,0 +1,143 @@
"""SQLite persistence for device list views."""
from __future__ import annotations
import json
import logging
import sqlite3
import time
from contextlib import closing
from typing import Dict, Iterable, List, Optional
from Data.Engine.domain.device_views import DeviceListView
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
__all__ = ["SQLiteDeviceViewRepository"]
class SQLiteDeviceViewRepository:
def __init__(
self,
connection_factory: SQLiteConnectionFactory,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._connections = connection_factory
self._log = logger or logging.getLogger("borealis.engine.repositories.device_views")
def list_views(self) -> List[DeviceListView]:
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"SELECT id, name, columns_json, filters_json, created_at, updated_at\n"
" FROM device_list_views ORDER BY name COLLATE NOCASE ASC"
)
rows = cur.fetchall()
return [self._row_to_view(row) for row in rows]
def get_view(self, view_id: int) -> Optional[DeviceListView]:
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"SELECT id, name, columns_json, filters_json, created_at, updated_at\n"
" FROM device_list_views WHERE id = ?",
(view_id,),
)
row = cur.fetchone()
return self._row_to_view(row) if row else None
def create_view(self, name: str, columns: List[str], filters: Dict[str, object]) -> DeviceListView:
now = int(time.time())
with closing(self._connections()) as conn:
cur = conn.cursor()
try:
cur.execute(
"INSERT INTO device_list_views(name, columns_json, filters_json, created_at, updated_at)\n"
"VALUES (?, ?, ?, ?, ?)",
(name, json.dumps(columns), json.dumps(filters), now, now),
)
except sqlite3.IntegrityError as exc:
raise ValueError("duplicate") from exc
view_id = cur.lastrowid
conn.commit()
cur.execute(
"SELECT id, name, columns_json, filters_json, created_at, updated_at FROM device_list_views WHERE id = ?",
(view_id,),
)
row = cur.fetchone()
if not row:
raise RuntimeError("view missing after insert")
return self._row_to_view(row)
def update_view(
self,
view_id: int,
*,
name: Optional[str] = None,
columns: Optional[List[str]] = None,
filters: Optional[Dict[str, object]] = None,
) -> DeviceListView:
fields: List[str] = []
params: List[object] = []
if name is not None:
fields.append("name = ?")
params.append(name)
if columns is not None:
fields.append("columns_json = ?")
params.append(json.dumps(columns))
if filters is not None:
fields.append("filters_json = ?")
params.append(json.dumps(filters))
fields.append("updated_at = ?")
params.append(int(time.time()))
params.append(view_id)
with closing(self._connections()) as conn:
cur = conn.cursor()
try:
cur.execute(
f"UPDATE device_list_views SET {', '.join(fields)} WHERE id = ?",
params,
)
except sqlite3.IntegrityError as exc:
raise ValueError("duplicate") from exc
if cur.rowcount == 0:
raise LookupError("not_found")
conn.commit()
cur.execute(
"SELECT id, name, columns_json, filters_json, created_at, updated_at FROM device_list_views WHERE id = ?",
(view_id,),
)
row = cur.fetchone()
if not row:
raise LookupError("not_found")
return self._row_to_view(row)
def delete_view(self, view_id: int) -> bool:
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute("DELETE FROM device_list_views WHERE id = ?", (view_id,))
deleted = cur.rowcount
conn.commit()
return bool(deleted)
def _row_to_view(self, row: Optional[Iterable[object]]) -> DeviceListView:
if row is None:
raise ValueError("row required")
view_id, name, columns_json, filters_json, created_at, updated_at = row
try:
columns = json.loads(columns_json or "[]")
except Exception:
columns = []
try:
filters = json.loads(filters_json or "{}")
except Exception:
filters = {}
return DeviceListView(
id=int(view_id),
name=str(name or ""),
columns=list(columns) if isinstance(columns, list) else [],
filters=dict(filters) if isinstance(filters, dict) else {},
created_at=int(created_at or 0),
updated_at=int(updated_at or 0),
)

View File

@@ -5,14 +5,19 @@ from __future__ import annotations
import logging
from contextlib import closing
from datetime import datetime, timezone
from typing import Optional
from typing import Any, List, Optional, Tuple
from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid
from Data.Engine.domain.device_auth import DeviceFingerprint, DeviceGuid, normalize_guid
from Data.Engine.domain.device_enrollment import (
EnrollmentApproval,
EnrollmentApprovalStatus,
EnrollmentCode,
)
from Data.Engine.domain.enrollment_admin import (
DeviceApprovalRecord,
EnrollmentCodeRecord,
HostnameConflict,
)
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
__all__ = ["SQLiteEnrollmentRepository"]
@@ -122,6 +127,158 @@ class SQLiteEnrollmentRepository:
self._log.warning("invalid enrollment code record for id=%s: %s", record_value, exc)
return None
def list_install_codes(
self,
*,
status: Optional[str] = None,
now: Optional[datetime] = None,
) -> List[EnrollmentCodeRecord]:
reference = now or datetime.now(tz=timezone.utc)
status_filter = (status or "").strip().lower()
params: List[str] = []
sql = """
SELECT id,
code,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
"""
if status_filter in {"active", "expired", "used"}:
sql += " WHERE "
if status_filter == "active":
sql += "use_count < max_uses AND expires_at > ?"
params.append(self._isoformat(reference))
elif status_filter == "expired":
sql += "use_count < max_uses AND expires_at <= ?"
params.append(self._isoformat(reference))
else: # used
sql += "use_count >= max_uses"
sql += " ORDER BY expires_at ASC"
rows: List[EnrollmentCodeRecord] = []
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(sql, params)
for raw in cur.fetchall():
record = {
"id": raw[0],
"code": raw[1],
"expires_at": raw[2],
"created_by_user_id": raw[3],
"used_at": raw[4],
"used_by_guid": raw[5],
"max_uses": raw[6],
"use_count": raw[7],
"last_used_at": raw[8],
}
try:
rows.append(EnrollmentCodeRecord.from_row(record))
except Exception as exc: # pragma: no cover - defensive logging
self._log.warning("invalid enrollment install code row id=%s: %s", record.get("id"), exc)
return rows
def get_install_code_record(self, record_id: str) -> Optional[EnrollmentCodeRecord]:
identifier = (record_id or "").strip()
if not identifier:
return None
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT id,
code,
expires_at,
created_by_user_id,
used_at,
used_by_guid,
max_uses,
use_count,
last_used_at
FROM enrollment_install_codes
WHERE id = ?
""",
(identifier,),
)
row = cur.fetchone()
if not row:
return None
payload = {
"id": row[0],
"code": row[1],
"expires_at": row[2],
"created_by_user_id": row[3],
"used_at": row[4],
"used_by_guid": row[5],
"max_uses": row[6],
"use_count": row[7],
"last_used_at": row[8],
}
try:
return EnrollmentCodeRecord.from_row(payload)
except Exception as exc: # pragma: no cover - defensive logging
self._log.warning("invalid enrollment install code record id=%s: %s", identifier, exc)
return None
def insert_install_code(
self,
*,
record_id: str,
code: str,
expires_at: datetime,
created_by: Optional[str],
max_uses: int,
) -> EnrollmentCodeRecord:
expires_iso = self._isoformat(expires_at)
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
INSERT INTO enrollment_install_codes (
id,
code,
expires_at,
created_by_user_id,
max_uses,
use_count
) VALUES (?, ?, ?, ?, ?, 0)
""",
(record_id, code, expires_iso, created_by, max_uses),
)
conn.commit()
record = self.get_install_code_record(record_id)
if record is None:
raise RuntimeError("failed to load install code after insert")
return record
def delete_install_code_if_unused(self, record_id: str) -> bool:
identifier = (record_id or "").strip()
if not identifier:
return False
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"DELETE FROM enrollment_install_codes WHERE id = ? AND use_count = 0",
(identifier,),
)
deleted = cur.rowcount > 0
conn.commit()
return deleted
def update_install_code_usage(
self,
record_id: str,
@@ -165,6 +322,100 @@ class SQLiteEnrollmentRepository:
# ------------------------------------------------------------------
# Device approvals
# ------------------------------------------------------------------
def list_device_approvals(
self,
*,
status: Optional[str] = None,
) -> List[DeviceApprovalRecord]:
status_filter = (status or "").strip().lower()
params: List[str] = []
sql = """
SELECT
da.id,
da.approval_reference,
da.guid,
da.hostname_claimed,
da.ssl_key_fingerprint_claimed,
da.enrollment_code_id,
da.status,
da.client_nonce,
da.server_nonce,
da.created_at,
da.updated_at,
da.approved_by_user_id,
u.username AS approved_by_username
FROM device_approvals AS da
LEFT JOIN users AS u
ON (
CAST(da.approved_by_user_id AS TEXT) = CAST(u.id AS TEXT)
OR LOWER(da.approved_by_user_id) = LOWER(u.username)
)
"""
if status_filter and status_filter not in {"all", "*"}:
sql += " WHERE LOWER(da.status) = ?"
params.append(status_filter)
sql += " ORDER BY da.created_at ASC"
approvals: List[DeviceApprovalRecord] = []
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(sql, params)
rows = cur.fetchall()
for raw in rows:
record = {
"id": raw[0],
"approval_reference": raw[1],
"guid": raw[2],
"hostname_claimed": raw[3],
"ssl_key_fingerprint_claimed": raw[4],
"enrollment_code_id": raw[5],
"status": raw[6],
"client_nonce": raw[7],
"server_nonce": raw[8],
"created_at": raw[9],
"updated_at": raw[10],
"approved_by_user_id": raw[11],
"approved_by_username": raw[12],
}
conflict, fingerprint_match, requires_prompt = self._compute_hostname_conflict(
conn,
record.get("hostname_claimed"),
record.get("guid"),
record.get("ssl_key_fingerprint_claimed") or "",
)
alternate = None
if conflict and requires_prompt:
alternate = self._suggest_alternate_hostname(
conn,
record.get("hostname_claimed"),
record.get("guid"),
)
try:
approvals.append(
DeviceApprovalRecord.from_row(
record,
conflict=conflict,
alternate_hostname=alternate,
fingerprint_match=fingerprint_match,
requires_prompt=requires_prompt,
)
)
except Exception as exc: # pragma: no cover - defensive logging
self._log.warning(
"invalid device approval record id=%s: %s",
record.get("id"),
exc,
)
return approvals
def fetch_device_approval_by_reference(self, reference: str) -> Optional[EnrollmentApproval]:
"""Load a device approval using its operator-visible reference."""
@@ -376,6 +627,98 @@ class SQLiteEnrollmentRepository:
)
return None
def _compute_hostname_conflict(
self,
conn,
hostname: Optional[str],
pending_guid: Optional[str],
claimed_fp: str,
) -> Tuple[Optional[HostnameConflict], bool, bool]:
normalized_host = (hostname or "").strip()
if not normalized_host:
return None, False, False
try:
cur = conn.cursor()
cur.execute(
"""
SELECT d.guid,
d.ssl_key_fingerprint,
ds.site_id,
s.name
FROM devices AS d
LEFT JOIN device_sites AS ds ON ds.device_hostname = d.hostname
LEFT JOIN sites AS s ON s.id = ds.site_id
WHERE d.hostname = ?
""",
(normalized_host,),
)
row = cur.fetchone()
except Exception as exc: # pragma: no cover - defensive logging
self._log.warning("failed to inspect hostname conflict for %s: %s", normalized_host, exc)
return None, False, False
if not row:
return None, False, False
existing_guid = normalize_guid(row[0])
pending_norm = normalize_guid(pending_guid)
if existing_guid and pending_norm and existing_guid == pending_norm:
return None, False, False
stored_fp = (row[1] or "").strip().lower()
claimed_fp_normalized = (claimed_fp or "").strip().lower()
fingerprint_match = bool(stored_fp and claimed_fp_normalized and stored_fp == claimed_fp_normalized)
site_id = None
if row[2] is not None:
try:
site_id = int(row[2])
except (TypeError, ValueError): # pragma: no cover - defensive
site_id = None
site_name = str(row[3] or "").strip()
requires_prompt = not fingerprint_match
conflict = HostnameConflict(
guid=existing_guid or None,
ssl_key_fingerprint=stored_fp or None,
site_id=site_id,
site_name=site_name,
fingerprint_match=fingerprint_match,
requires_prompt=requires_prompt,
)
return conflict, fingerprint_match, requires_prompt
def _suggest_alternate_hostname(
self,
conn,
hostname: Optional[str],
pending_guid: Optional[str],
) -> Optional[str]:
base = (hostname or "").strip()
if not base:
return None
base = base[:253]
candidate = base
pending_norm = normalize_guid(pending_guid)
suffix = 1
cur = conn.cursor()
while True:
cur.execute("SELECT guid FROM devices WHERE hostname = ?", (candidate,))
row = cur.fetchone()
if not row:
return candidate
existing_guid = normalize_guid(row[0])
if pending_norm and existing_guid == pending_norm:
return candidate
candidate = f"{base}-{suffix}"
suffix += 1
if suffix > 50:
return pending_norm or candidate
@staticmethod
def _isoformat(value: datetime) -> str:
if value.tzinfo is None:

View File

@@ -15,6 +15,10 @@ from typing import List, Optional, Sequence, Tuple
DEVICE_TABLE = "devices"
_DEFAULT_ADMIN_USERNAME = "admin"
_DEFAULT_ADMIN_PASSWORD_SHA512 = (
"e6c83b282aeb2e022844595721cc00bbda47cb24537c1779f9bb84f04039e1676e6ba8573e588da1052510e3aa0a32a9e55879ae22b0c2d62136fc0a3e85f8bb"
)
def apply_all(conn: sqlite3.Connection) -> None:
@@ -27,9 +31,14 @@ def apply_all(conn: sqlite3.Connection) -> None:
_ensure_refresh_token_table(conn)
_ensure_install_code_table(conn)
_ensure_device_approval_table(conn)
_ensure_device_list_views_table(conn)
_ensure_sites_tables(conn)
_ensure_credentials_table(conn)
_ensure_github_token_table(conn)
_ensure_scheduled_jobs_table(conn)
_ensure_scheduled_job_run_tables(conn)
_ensure_users_table(conn)
_ensure_default_admin(conn)
conn.commit()
@@ -227,6 +236,73 @@ def _ensure_device_approval_table(conn: sqlite3.Connection) -> None:
)
def _ensure_device_list_views_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS device_list_views (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
columns_json TEXT NOT NULL,
filters_json TEXT,
created_at INTEGER,
updated_at INTEGER
)
"""
)
def _ensure_sites_tables(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS sites (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
description TEXT,
created_at INTEGER
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS device_sites (
device_hostname TEXT UNIQUE NOT NULL,
site_id INTEGER NOT NULL,
assigned_at INTEGER,
FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE CASCADE
)
"""
)
def _ensure_credentials_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS credentials (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
description TEXT,
site_id INTEGER,
credential_type TEXT NOT NULL DEFAULT 'machine',
connection_type TEXT NOT NULL DEFAULT 'ssh',
username TEXT,
password_encrypted BLOB,
private_key_encrypted BLOB,
private_key_passphrase_encrypted BLOB,
become_method TEXT,
become_username TEXT,
become_password_encrypted BLOB,
metadata_json TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
FOREIGN KEY(site_id) REFERENCES sites(id) ON DELETE SET NULL
)
"""
)
def _ensure_github_token_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
@@ -504,4 +580,86 @@ def _normalized_guid(value: Optional[str]) -> str:
return ""
return str(value).strip()
__all__ = ["apply_all"]
def _ensure_users_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
display_name TEXT,
password_sha512 TEXT NOT NULL,
role TEXT NOT NULL DEFAULT 'Admin',
last_login INTEGER,
created_at INTEGER,
updated_at INTEGER,
mfa_enabled INTEGER NOT NULL DEFAULT 0,
mfa_secret TEXT
)
"""
)
try:
cur.execute("PRAGMA table_info(users)")
columns = [row[1] for row in cur.fetchall()]
if "mfa_enabled" not in columns:
cur.execute("ALTER TABLE users ADD COLUMN mfa_enabled INTEGER NOT NULL DEFAULT 0")
if "mfa_secret" not in columns:
cur.execute("ALTER TABLE users ADD COLUMN mfa_secret TEXT")
except sqlite3.Error:
# Aligning the schema is best-effort; older deployments may lack ALTER
# TABLE privileges but can continue using existing columns.
pass
def _ensure_default_admin(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute("SELECT COUNT(*) FROM users WHERE LOWER(role)='admin'")
row = cur.fetchone()
if row and (row[0] or 0):
return
now = int(datetime.now(timezone.utc).timestamp())
cur.execute(
"SELECT COUNT(*) FROM users WHERE LOWER(username)=LOWER(?)",
(_DEFAULT_ADMIN_USERNAME,),
)
existing = cur.fetchone()
if not existing or not (existing[0] or 0):
cur.execute(
"""
INSERT INTO users (
username, display_name, password_sha512, role,
last_login, created_at, updated_at, mfa_enabled, mfa_secret
) VALUES (?, ?, ?, 'Admin', 0, ?, ?, 0, NULL)
""",
(
_DEFAULT_ADMIN_USERNAME,
"Administrator",
_DEFAULT_ADMIN_PASSWORD_SHA512,
now,
now,
),
)
else:
cur.execute(
"""
UPDATE users
SET role='Admin',
updated_at=?
WHERE LOWER(username)=LOWER(?)
AND LOWER(role)!='admin'
""",
(now, _DEFAULT_ADMIN_USERNAME),
)
def ensure_default_admin(conn: sqlite3.Connection) -> None:
"""Guarantee that at least one admin account exists."""
_ensure_users_table(conn)
_ensure_default_admin(conn)
conn.commit()
__all__ = ["apply_all", "ensure_default_admin"]

View File

@@ -0,0 +1,189 @@
"""SQLite persistence for site management."""
from __future__ import annotations
import logging
import sqlite3
import time
from contextlib import closing
from typing import Dict, Iterable, List, Optional, Sequence
from Data.Engine.domain.sites import SiteDeviceMapping, SiteSummary
from Data.Engine.repositories.sqlite.connection import SQLiteConnectionFactory
__all__ = ["SQLiteSiteRepository"]
class SQLiteSiteRepository:
"""Repository exposing site CRUD and device assignment helpers."""
def __init__(
self,
connection_factory: SQLiteConnectionFactory,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._connections = connection_factory
self._log = logger or logging.getLogger("borealis.engine.repositories.sites")
def list_sites(self) -> List[SiteSummary]:
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT s.id, s.name, s.description, s.created_at,
COALESCE(ds.cnt, 0) AS device_count
FROM sites s
LEFT JOIN (
SELECT site_id, COUNT(*) AS cnt
FROM device_sites
GROUP BY site_id
) ds
ON ds.site_id = s.id
ORDER BY LOWER(s.name) ASC
"""
)
rows = cur.fetchall()
return [self._row_to_site(row) for row in rows]
def create_site(self, name: str, description: str) -> SiteSummary:
now = int(time.time())
with closing(self._connections()) as conn:
cur = conn.cursor()
try:
cur.execute(
"INSERT INTO sites(name, description, created_at) VALUES (?, ?, ?)",
(name, description, now),
)
except sqlite3.IntegrityError as exc:
raise ValueError("duplicate") from exc
site_id = cur.lastrowid
conn.commit()
cur.execute(
"SELECT id, name, description, created_at, 0 FROM sites WHERE id = ?",
(site_id,),
)
row = cur.fetchone()
if not row:
raise RuntimeError("site not found after insert")
return self._row_to_site(row)
def delete_sites(self, ids: Sequence[int]) -> int:
if not ids:
return 0
with closing(self._connections()) as conn:
cur = conn.cursor()
placeholders = ",".join("?" for _ in ids)
try:
cur.execute(
f"DELETE FROM device_sites WHERE site_id IN ({placeholders})",
tuple(ids),
)
cur.execute(
f"DELETE FROM sites WHERE id IN ({placeholders})",
tuple(ids),
)
except sqlite3.DatabaseError as exc:
conn.rollback()
raise
deleted = cur.rowcount
conn.commit()
return deleted
def rename_site(self, site_id: int, new_name: str) -> SiteSummary:
with closing(self._connections()) as conn:
cur = conn.cursor()
try:
cur.execute("UPDATE sites SET name = ? WHERE id = ?", (new_name, site_id))
except sqlite3.IntegrityError as exc:
raise ValueError("duplicate") from exc
if cur.rowcount == 0:
raise LookupError("not_found")
conn.commit()
cur.execute(
"""
SELECT s.id, s.name, s.description, s.created_at,
COALESCE(ds.cnt, 0) AS device_count
FROM sites s
LEFT JOIN (
SELECT site_id, COUNT(*) AS cnt
FROM device_sites
GROUP BY site_id
) ds
ON ds.site_id = s.id
WHERE s.id = ?
""",
(site_id,),
)
row = cur.fetchone()
if not row:
raise LookupError("not_found")
return self._row_to_site(row)
def map_devices(self, hostnames: Optional[Iterable[str]] = None) -> Dict[str, SiteDeviceMapping]:
with closing(self._connections()) as conn:
cur = conn.cursor()
if hostnames:
normalized = [hn.strip() for hn in hostnames if hn and hn.strip()]
if not normalized:
return {}
placeholders = ",".join("?" for _ in normalized)
cur.execute(
f"""
SELECT ds.device_hostname, s.id, s.name
FROM device_sites ds
INNER JOIN sites s ON s.id = ds.site_id
WHERE ds.device_hostname IN ({placeholders})
""",
tuple(normalized),
)
else:
cur.execute(
"""
SELECT ds.device_hostname, s.id, s.name
FROM device_sites ds
INNER JOIN sites s ON s.id = ds.site_id
"""
)
rows = cur.fetchall()
mapping: Dict[str, SiteDeviceMapping] = {}
for hostname, site_id, site_name in rows:
mapping[str(hostname)] = SiteDeviceMapping(
hostname=str(hostname),
site_id=int(site_id) if site_id is not None else None,
site_name=str(site_name or ""),
)
return mapping
def assign_devices(self, site_id: int, hostnames: Sequence[str]) -> None:
now = int(time.time())
normalized = [hn.strip() for hn in hostnames if isinstance(hn, str) and hn.strip()]
if not normalized:
return
with closing(self._connections()) as conn:
cur = conn.cursor()
cur.execute("SELECT 1 FROM sites WHERE id = ?", (site_id,))
if not cur.fetchone():
raise LookupError("not_found")
for hostname in normalized:
cur.execute(
"""
INSERT INTO device_sites(device_hostname, site_id, assigned_at)
VALUES (?, ?, ?)
ON CONFLICT(device_hostname)
DO UPDATE SET site_id = excluded.site_id,
assigned_at = excluded.assigned_at
""",
(hostname, site_id, now),
)
conn.commit()
def _row_to_site(self, row: Sequence[object]) -> SiteSummary:
return SiteSummary(
id=int(row[0]),
name=str(row[1] or ""),
description=str(row[2] or ""),
created_at=int(row[3] or 0),
device_count=int(row[4] or 0),
)

View File

@@ -0,0 +1,340 @@
"""SQLite repository for operator accounts."""
from __future__ import annotations
import logging
import sqlite3
from dataclasses import dataclass
from typing import Iterable, Optional
from Data.Engine.domain import OperatorAccount
from .connection import SQLiteConnectionFactory
@dataclass(frozen=True, slots=True)
class _UserRow:
id: str
username: str
display_name: str
password_sha512: str
role: str
last_login: int
created_at: int
updated_at: int
mfa_enabled: int
mfa_secret: str
class SQLiteUserRepository:
"""Expose CRUD helpers for operator accounts stored in SQLite."""
def __init__(
self,
connection_factory: SQLiteConnectionFactory,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._connection_factory = connection_factory
self._log = logger or logging.getLogger("borealis.engine.repositories.users")
def fetch_by_username(self, username: str) -> Optional[OperatorAccount]:
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT
id,
username,
display_name,
COALESCE(password_sha512, '') as password_sha512,
COALESCE(role, 'User') as role,
COALESCE(last_login, 0) as last_login,
COALESCE(created_at, 0) as created_at,
COALESCE(updated_at, 0) as updated_at,
COALESCE(mfa_enabled, 0) as mfa_enabled,
COALESCE(mfa_secret, '') as mfa_secret
FROM users
WHERE LOWER(username) = LOWER(?)
""",
(username,),
)
row = cur.fetchone()
if not row:
return None
record = _UserRow(*row)
return _row_to_account(record)
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.error("failed to load user %s: %s", username, exc)
return None
finally:
conn.close()
def resolve_identifier(self, username: str) -> Optional[str]:
normalized = (username or "").strip()
if not normalized:
return None
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute(
"SELECT id FROM users WHERE LOWER(username) = LOWER(?)",
(normalized,),
)
row = cur.fetchone()
if not row:
return None
return str(row[0]) if row[0] is not None else None
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.error("failed to resolve identifier for %s: %s", username, exc)
return None
finally:
conn.close()
def username_for_identifier(self, identifier: str) -> Optional[str]:
token = (identifier or "").strip()
if not token:
return None
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT username
FROM users
WHERE CAST(id AS TEXT) = ?
OR LOWER(username) = LOWER(?)
LIMIT 1
""",
(token, token),
)
row = cur.fetchone()
if not row:
return None
username = str(row[0] or "").strip()
return username or None
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.error("failed to resolve username for %s: %s", identifier, exc)
return None
finally:
conn.close()
def list_accounts(self) -> list[OperatorAccount]:
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT
id,
username,
display_name,
COALESCE(password_sha512, '') as password_sha512,
COALESCE(role, 'User') as role,
COALESCE(last_login, 0) as last_login,
COALESCE(created_at, 0) as created_at,
COALESCE(updated_at, 0) as updated_at,
COALESCE(mfa_enabled, 0) as mfa_enabled,
COALESCE(mfa_secret, '') as mfa_secret
FROM users
ORDER BY LOWER(username) ASC
"""
)
rows = [_UserRow(*row) for row in cur.fetchall()]
return [_row_to_account(row) for row in rows]
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.error("failed to enumerate users: %s", exc)
return []
finally:
conn.close()
def create_account(
self,
*,
username: str,
display_name: str,
password_sha512: str,
role: str,
timestamp: int,
) -> None:
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute(
"""
INSERT INTO users (
username,
display_name,
password_sha512,
role,
created_at,
updated_at
) VALUES (?, ?, ?, ?, ?, ?)
""",
(username, display_name, password_sha512, role, timestamp, timestamp),
)
conn.commit()
finally:
conn.close()
def delete_account(self, username: str) -> bool:
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute("DELETE FROM users WHERE LOWER(username) = LOWER(?)", (username,))
deleted = cur.rowcount > 0
conn.commit()
return deleted
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.error("failed to delete user %s: %s", username, exc)
return False
finally:
conn.close()
def update_password(self, username: str, password_sha512: str, *, timestamp: int) -> bool:
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute(
"""
UPDATE users
SET password_sha512 = ?,
updated_at = ?
WHERE LOWER(username) = LOWER(?)
""",
(password_sha512, timestamp, username),
)
conn.commit()
return cur.rowcount > 0
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.error("failed to update password for %s: %s", username, exc)
return False
finally:
conn.close()
def update_role(self, username: str, role: str, *, timestamp: int) -> bool:
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute(
"""
UPDATE users
SET role = ?,
updated_at = ?
WHERE LOWER(username) = LOWER(?)
""",
(role, timestamp, username),
)
conn.commit()
return cur.rowcount > 0
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.error("failed to update role for %s: %s", username, exc)
return False
finally:
conn.close()
def update_mfa(
self,
username: str,
*,
enabled: bool,
reset_secret: bool,
timestamp: int,
) -> bool:
conn = self._connection_factory()
try:
cur = conn.cursor()
secret_clause = "mfa_secret = NULL" if reset_secret else None
assignments: list[str] = ["mfa_enabled = ?", "updated_at = ?"]
params: list[object] = [1 if enabled else 0, timestamp]
if secret_clause is not None:
assignments.append(secret_clause)
query = "UPDATE users SET " + ", ".join(assignments) + " WHERE LOWER(username) = LOWER(?)"
params.append(username)
cur.execute(query, tuple(params))
conn.commit()
return cur.rowcount > 0
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.error("failed to update MFA for %s: %s", username, exc)
return False
finally:
conn.close()
def count_accounts(self) -> int:
return self._scalar("SELECT COUNT(*) FROM users", ())
def count_admins(self) -> int:
return self._scalar("SELECT COUNT(*) FROM users WHERE LOWER(role) = 'admin'", ())
def _scalar(self, query: str, params: Iterable[object]) -> int:
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute(query, tuple(params))
row = cur.fetchone()
if not row:
return 0
return int(row[0] or 0)
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.error("scalar query failed: %s", exc)
return 0
finally:
conn.close()
def update_last_login(self, username: str, timestamp: int) -> None:
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute(
"""
UPDATE users
SET last_login = ?,
updated_at = ?
WHERE LOWER(username) = LOWER(?)
""",
(timestamp, timestamp, username),
)
conn.commit()
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.warning("failed to update last_login for %s: %s", username, exc)
finally:
conn.close()
def store_mfa_secret(self, username: str, secret: str, *, timestamp: int) -> None:
conn = self._connection_factory()
try:
cur = conn.cursor()
cur.execute(
"""
UPDATE users
SET mfa_secret = ?,
updated_at = ?
WHERE LOWER(username) = LOWER(?)
""",
(secret, timestamp, username),
)
conn.commit()
except sqlite3.Error as exc: # pragma: no cover - defensive
self._log.warning("failed to persist MFA secret for %s: %s", username, exc)
finally:
conn.close()
__all__ = ["SQLiteUserRepository"]
def _row_to_account(record: _UserRow) -> OperatorAccount:
return OperatorAccount(
username=record.username,
display_name=record.display_name or record.username,
password_sha512=(record.password_sha512 or "").lower(),
role=record.role or "User",
last_login=int(record.last_login or 0),
created_at=int(record.created_at or 0),
updated_at=int(record.updated_at or 0),
mfa_enabled=bool(record.mfa_enabled),
mfa_secret=(record.mfa_secret or "") or None,
)

View File

@@ -9,3 +9,5 @@ requests
# Auth & security
PyJWT[crypto]
cryptography
pyotp
qrcode

View File

@@ -23,6 +23,15 @@ __all__ = [
"SchedulerService",
"GitHubService",
"GitHubTokenPayload",
"EnrollmentAdminService",
"SiteService",
"DeviceInventoryService",
"DeviceViewService",
"CredentialService",
"AssemblyService",
"AssemblyListing",
"AssemblyLoadResult",
"AssemblyMutationResult",
]
_LAZY_TARGETS: Dict[str, Tuple[str, str]] = {
@@ -43,6 +52,39 @@ _LAZY_TARGETS: Dict[str, Tuple[str, str]] = {
"SchedulerService": ("Data.Engine.services.jobs.scheduler_service", "SchedulerService"),
"GitHubService": ("Data.Engine.services.github.github_service", "GitHubService"),
"GitHubTokenPayload": ("Data.Engine.services.github.github_service", "GitHubTokenPayload"),
"EnrollmentAdminService": (
"Data.Engine.services.enrollment.admin_service",
"EnrollmentAdminService",
),
"SiteService": ("Data.Engine.services.sites.site_service", "SiteService"),
"DeviceInventoryService": (
"Data.Engine.services.devices.device_inventory_service",
"DeviceInventoryService",
),
"DeviceViewService": (
"Data.Engine.services.devices.device_view_service",
"DeviceViewService",
),
"CredentialService": (
"Data.Engine.services.credentials.credential_service",
"CredentialService",
),
"AssemblyService": (
"Data.Engine.services.assemblies.assembly_service",
"AssemblyService",
),
"AssemblyListing": (
"Data.Engine.services.assemblies.assembly_service",
"AssemblyListing",
),
"AssemblyLoadResult": (
"Data.Engine.services.assemblies.assembly_service",
"AssemblyLoadResult",
),
"AssemblyMutationResult": (
"Data.Engine.services.assemblies.assembly_service",
"AssemblyMutationResult",
),
}

View File

@@ -0,0 +1,10 @@
"""Assembly management services."""
from .assembly_service import AssemblyService, AssemblyMutationResult, AssemblyLoadResult, AssemblyListing
__all__ = [
"AssemblyService",
"AssemblyMutationResult",
"AssemblyLoadResult",
"AssemblyListing",
]

View File

@@ -0,0 +1,715 @@
"""Filesystem-backed assembly management service."""
from __future__ import annotations
import base64
import json
import logging
import os
import re
import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
__all__ = [
"AssemblyService",
"AssemblyListing",
"AssemblyLoadResult",
"AssemblyMutationResult",
]
@dataclass(frozen=True, slots=True)
class AssemblyListing:
"""Listing payload for an assembly island."""
root: Path
items: List[Dict[str, Any]]
folders: List[str]
def to_dict(self) -> dict[str, Any]:
return {
"root": str(self.root),
"items": self.items,
"folders": self.folders,
}
@dataclass(frozen=True, slots=True)
class AssemblyLoadResult:
"""Container describing a loaded assembly artifact."""
payload: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
return dict(self.payload)
@dataclass(frozen=True, slots=True)
class AssemblyMutationResult:
"""Mutation acknowledgement for create/edit/rename operations."""
status: str = "ok"
rel_path: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
payload: Dict[str, Any] = {"status": self.status}
if self.rel_path:
payload["rel_path"] = self.rel_path
return payload
class AssemblyService:
"""Provide CRUD helpers for workflow/script/ansible assemblies."""
_ISLAND_DIR_MAP = {
"workflows": "Workflows",
"workflow": "Workflows",
"scripts": "Scripts",
"script": "Scripts",
"ansible": "Ansible_Playbooks",
"ansible_playbooks": "Ansible_Playbooks",
"ansible-playbooks": "Ansible_Playbooks",
"playbooks": "Ansible_Playbooks",
}
_SCRIPT_EXTENSIONS = (".json", ".ps1", ".bat", ".sh")
_ANSIBLE_EXTENSIONS = (".json", ".yml")
def __init__(self, *, root: Path, logger: Optional[logging.Logger] = None) -> None:
self._root = root.resolve()
self._log = logger or logging.getLogger("borealis.engine.services.assemblies")
try:
self._root.mkdir(parents=True, exist_ok=True)
except Exception as exc: # pragma: no cover - defensive logging
self._log.warning("failed to ensure assemblies root %s: %s", self._root, exc)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def list_items(self, island: str) -> AssemblyListing:
root = self._resolve_island_root(island)
root.mkdir(parents=True, exist_ok=True)
items: List[Dict[str, Any]] = []
folders: List[str] = []
isl = (island or "").strip().lower()
if isl in {"workflows", "workflow"}:
for dirpath, dirnames, filenames in os.walk(root):
rel_root = os.path.relpath(dirpath, root)
if rel_root != ".":
folders.append(rel_root.replace(os.sep, "/"))
for fname in filenames:
if not fname.lower().endswith(".json"):
continue
abs_path = Path(dirpath) / fname
rel_path = abs_path.relative_to(root).as_posix()
try:
mtime = abs_path.stat().st_mtime
except OSError:
mtime = 0.0
obj = self._safe_read_json(abs_path)
tab = self._extract_tab_name(obj)
items.append(
{
"file_name": fname,
"rel_path": rel_path,
"type": "workflow",
"tab_name": tab,
"last_edited": time.strftime(
"%Y-%m-%dT%H:%M:%S", time.localtime(mtime)
),
"last_edited_epoch": mtime,
}
)
elif isl in {"scripts", "script"}:
for dirpath, dirnames, filenames in os.walk(root):
rel_root = os.path.relpath(dirpath, root)
if rel_root != ".":
folders.append(rel_root.replace(os.sep, "/"))
for fname in filenames:
if not fname.lower().endswith(self._SCRIPT_EXTENSIONS):
continue
abs_path = Path(dirpath) / fname
rel_path = abs_path.relative_to(root).as_posix()
try:
mtime = abs_path.stat().st_mtime
except OSError:
mtime = 0.0
script_type = self._detect_script_type(abs_path)
doc = self._load_assembly_document(abs_path, "scripts", script_type)
items.append(
{
"file_name": fname,
"rel_path": rel_path,
"type": doc.get("type", script_type),
"name": doc.get("name"),
"category": doc.get("category"),
"description": doc.get("description"),
"last_edited": time.strftime(
"%Y-%m-%dT%H:%M:%S", time.localtime(mtime)
),
"last_edited_epoch": mtime,
}
)
elif isl in {
"ansible",
"ansible_playbooks",
"ansible-playbooks",
"playbooks",
}:
for dirpath, dirnames, filenames in os.walk(root):
rel_root = os.path.relpath(dirpath, root)
if rel_root != ".":
folders.append(rel_root.replace(os.sep, "/"))
for fname in filenames:
if not fname.lower().endswith(self._ANSIBLE_EXTENSIONS):
continue
abs_path = Path(dirpath) / fname
rel_path = abs_path.relative_to(root).as_posix()
try:
mtime = abs_path.stat().st_mtime
except OSError:
mtime = 0.0
script_type = self._detect_script_type(abs_path)
doc = self._load_assembly_document(abs_path, "ansible", script_type)
items.append(
{
"file_name": fname,
"rel_path": rel_path,
"type": doc.get("type", "ansible"),
"name": doc.get("name"),
"category": doc.get("category"),
"description": doc.get("description"),
"last_edited": time.strftime(
"%Y-%m-%dT%H:%M:%S", time.localtime(mtime)
),
"last_edited_epoch": mtime,
}
)
else:
raise ValueError("invalid_island")
items.sort(key=lambda entry: entry.get("last_edited_epoch", 0.0), reverse=True)
return AssemblyListing(root=root, items=items, folders=folders)
def load_item(self, island: str, rel_path: str) -> AssemblyLoadResult:
root, abs_path, _ = self._resolve_assembly_path(island, rel_path)
if not abs_path.is_file():
raise FileNotFoundError("file_not_found")
isl = (island or "").strip().lower()
if isl in {"workflows", "workflow"}:
payload = self._safe_read_json(abs_path)
return AssemblyLoadResult(payload=payload)
doc = self._load_assembly_document(abs_path, island)
rel = abs_path.relative_to(root).as_posix()
payload = {
"file_name": abs_path.name,
"rel_path": rel,
"type": doc.get("type"),
"assembly": doc,
"content": doc.get("script"),
}
return AssemblyLoadResult(payload=payload)
def create_item(
self,
island: str,
*,
kind: str,
rel_path: str,
content: Any,
item_type: Optional[str] = None,
) -> AssemblyMutationResult:
root, abs_path, rel_norm = self._resolve_assembly_path(island, rel_path)
if not rel_norm:
raise ValueError("path_required")
normalized_kind = (kind or "").strip().lower()
if normalized_kind == "folder":
abs_path.mkdir(parents=True, exist_ok=True)
return AssemblyMutationResult()
if normalized_kind != "file":
raise ValueError("invalid_kind")
target_path = abs_path
if not target_path.suffix:
target_path = target_path.with_suffix(
self._default_ext_for_island(island, item_type or "")
)
target_path.parent.mkdir(parents=True, exist_ok=True)
isl = (island or "").strip().lower()
if isl in {"workflows", "workflow"}:
payload = self._ensure_workflow_document(content)
base_name = target_path.stem
payload.setdefault("tab_name", base_name)
self._write_json(target_path, payload)
else:
document = self._normalize_assembly_document(
content,
self._default_type_for_island(island, item_type or ""),
target_path.stem,
)
self._write_json(target_path, self._prepare_assembly_storage(document))
rel_new = target_path.relative_to(root).as_posix()
return AssemblyMutationResult(rel_path=rel_new)
def edit_item(
self,
island: str,
*,
rel_path: str,
content: Any,
item_type: Optional[str] = None,
) -> AssemblyMutationResult:
root, abs_path, _ = self._resolve_assembly_path(island, rel_path)
if not abs_path.exists():
raise FileNotFoundError("file_not_found")
target_path = abs_path
if not target_path.suffix:
target_path = target_path.with_suffix(
self._default_ext_for_island(island, item_type or "")
)
isl = (island or "").strip().lower()
if isl in {"workflows", "workflow"}:
payload = self._ensure_workflow_document(content)
self._write_json(target_path, payload)
else:
document = self._normalize_assembly_document(
content,
self._default_type_for_island(island, item_type or ""),
target_path.stem,
)
self._write_json(target_path, self._prepare_assembly_storage(document))
if target_path != abs_path and abs_path.exists():
try:
abs_path.unlink()
except OSError: # pragma: no cover - best effort cleanup
pass
rel_new = target_path.relative_to(root).as_posix()
return AssemblyMutationResult(rel_path=rel_new)
def rename_item(
self,
island: str,
*,
kind: str,
rel_path: str,
new_name: str,
item_type: Optional[str] = None,
) -> AssemblyMutationResult:
root, old_path, _ = self._resolve_assembly_path(island, rel_path)
normalized_kind = (kind or "").strip().lower()
if normalized_kind not in {"file", "folder"}:
raise ValueError("invalid_kind")
if normalized_kind == "folder":
if not old_path.is_dir():
raise FileNotFoundError("folder_not_found")
destination = old_path.parent / new_name
else:
if not old_path.is_file():
raise FileNotFoundError("file_not_found")
candidate = Path(new_name)
if not candidate.suffix:
candidate = candidate.with_suffix(
self._default_ext_for_island(island, item_type or "")
)
destination = old_path.parent / candidate.name
destination = destination.resolve()
if not str(destination).startswith(str(root)):
raise ValueError("invalid_destination")
old_path.rename(destination)
isl = (island or "").strip().lower()
if normalized_kind == "file" and isl in {"workflows", "workflow"}:
try:
obj = self._safe_read_json(destination)
base_name = destination.stem
for key in ["tabName", "tab_name", "name", "title"]:
if key in obj:
obj[key] = base_name
obj.setdefault("tab_name", base_name)
self._write_json(destination, obj)
except Exception: # pragma: no cover - best effort update
self._log.debug("failed to normalize workflow metadata for %s", destination)
rel_new = destination.relative_to(root).as_posix()
return AssemblyMutationResult(rel_path=rel_new)
def move_item(
self,
island: str,
*,
rel_path: str,
new_path: str,
kind: Optional[str] = None,
) -> AssemblyMutationResult:
root, old_path, _ = self._resolve_assembly_path(island, rel_path)
_, dest_path, _ = self._resolve_assembly_path(island, new_path)
normalized_kind = (kind or "").strip().lower()
if normalized_kind == "folder":
if not old_path.is_dir():
raise FileNotFoundError("folder_not_found")
else:
if not old_path.exists():
raise FileNotFoundError("file_not_found")
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(old_path), str(dest_path))
return AssemblyMutationResult()
def delete_item(
self,
island: str,
*,
rel_path: str,
kind: str,
) -> AssemblyMutationResult:
_, abs_path, rel_norm = self._resolve_assembly_path(island, rel_path)
if not rel_norm:
raise ValueError("cannot_delete_root")
normalized_kind = (kind or "").strip().lower()
if normalized_kind == "folder":
if not abs_path.is_dir():
raise FileNotFoundError("folder_not_found")
shutil.rmtree(abs_path)
elif normalized_kind == "file":
if not abs_path.is_file():
raise FileNotFoundError("file_not_found")
abs_path.unlink()
else:
raise ValueError("invalid_kind")
return AssemblyMutationResult()
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _resolve_island_root(self, island: str) -> Path:
key = (island or "").strip().lower()
subdir = self._ISLAND_DIR_MAP.get(key)
if not subdir:
raise ValueError("invalid_island")
root = (self._root / subdir).resolve()
root.mkdir(parents=True, exist_ok=True)
return root
def _resolve_assembly_path(self, island: str, rel_path: str) -> Tuple[Path, Path, str]:
root = self._resolve_island_root(island)
rel_norm = self._normalize_relpath(rel_path)
abs_path = (root / rel_norm).resolve()
if not str(abs_path).startswith(str(root)):
raise ValueError("invalid_path")
return root, abs_path, rel_norm
@staticmethod
def _normalize_relpath(value: str) -> str:
return (value or "").replace("\\", "/").strip("/")
@staticmethod
def _default_ext_for_island(island: str, item_type: str) -> str:
isl = (island or "").strip().lower()
if isl in {"workflows", "workflow"}:
return ".json"
if isl in {"ansible", "ansible_playbooks", "ansible-playbooks", "playbooks"}:
return ".json"
if isl in {"scripts", "script"}:
return ".json"
typ = (item_type or "").strip().lower()
if typ in {"bash", "batch", "powershell"}:
return ".json"
return ".json"
@staticmethod
def _default_type_for_island(island: str, item_type: str) -> str:
isl = (island or "").strip().lower()
if isl in {"ansible", "ansible_playbooks", "ansible-playbooks", "playbooks"}:
return "ansible"
typ = (item_type or "").strip().lower()
if typ in {"powershell", "batch", "bash", "ansible"}:
return typ
return "powershell"
@staticmethod
def _empty_assembly_document(default_type: str) -> Dict[str, Any]:
return {
"version": 1,
"name": "",
"description": "",
"category": "application" if default_type.lower() == "ansible" else "script",
"type": default_type or "powershell",
"script": "",
"timeout_seconds": 3600,
"sites": {"mode": "all", "values": []},
"variables": [],
"files": [],
}
@staticmethod
def _decode_base64_text(value: Any) -> Optional[str]:
if not isinstance(value, str):
return None
stripped = value.strip()
if not stripped:
return ""
try:
cleaned = re.sub(r"\s+", "", stripped)
except Exception:
cleaned = stripped
try:
decoded = base64.b64decode(cleaned, validate=True)
except Exception:
return None
try:
return decoded.decode("utf-8")
except Exception:
return decoded.decode("utf-8", errors="replace")
def _decode_script_content(self, value: Any, encoding_hint: str = "") -> str:
encoding = (encoding_hint or "").strip().lower()
if isinstance(value, str):
if encoding in {"base64", "b64", "base-64"}:
decoded = self._decode_base64_text(value)
if decoded is not None:
return decoded.replace("\r\n", "\n")
decoded = self._decode_base64_text(value)
if decoded is not None:
return decoded.replace("\r\n", "\n")
return value.replace("\r\n", "\n")
return ""
@staticmethod
def _encode_script_content(script_text: Any) -> str:
if not isinstance(script_text, str):
if script_text is None:
script_text = ""
else:
script_text = str(script_text)
normalized = script_text.replace("\r\n", "\n")
if not normalized:
return ""
encoded = base64.b64encode(normalized.encode("utf-8"))
return encoded.decode("ascii")
def _prepare_assembly_storage(self, document: Dict[str, Any]) -> Dict[str, Any]:
stored: Dict[str, Any] = {}
for key, value in (document or {}).items():
if key == "script":
stored[key] = self._encode_script_content(value)
else:
stored[key] = value
stored["script_encoding"] = "base64"
return stored
def _normalize_assembly_document(
self,
obj: Any,
default_type: str,
base_name: str,
) -> Dict[str, Any]:
doc = self._empty_assembly_document(default_type)
if not isinstance(obj, dict):
obj = {}
base = (base_name or "assembly").strip()
doc["name"] = str(obj.get("name") or obj.get("display_name") or base)
doc["description"] = str(obj.get("description") or "")
category = str(obj.get("category") or doc["category"]).strip().lower()
if category in {"script", "application"}:
doc["category"] = category
typ = str(obj.get("type") or obj.get("script_type") or default_type or "powershell").strip().lower()
if typ in {"powershell", "batch", "bash", "ansible"}:
doc["type"] = typ
script_val = obj.get("script")
content_val = obj.get("content")
script_lines = obj.get("script_lines")
if isinstance(script_lines, list):
try:
doc["script"] = "\n".join(str(line) for line in script_lines)
except Exception:
doc["script"] = ""
elif isinstance(script_val, str):
doc["script"] = script_val
elif isinstance(content_val, str):
doc["script"] = content_val
encoding_hint = str(
obj.get("script_encoding") or obj.get("scriptEncoding") or ""
).strip().lower()
doc["script"] = self._decode_script_content(doc.get("script"), encoding_hint)
if encoding_hint in {"base64", "b64", "base-64"}:
doc["script_encoding"] = "base64"
else:
probe_source = ""
if isinstance(script_val, str) and script_val:
probe_source = script_val
elif isinstance(content_val, str) and content_val:
probe_source = content_val
decoded_probe = self._decode_base64_text(probe_source) if probe_source else None
if decoded_probe is not None:
doc["script_encoding"] = "base64"
doc["script"] = decoded_probe.replace("\r\n", "\n")
else:
doc["script_encoding"] = "plain"
timeout_val = obj.get("timeout_seconds", obj.get("timeout"))
if timeout_val is not None:
try:
doc["timeout_seconds"] = max(0, int(timeout_val))
except Exception:
pass
sites = obj.get("sites") if isinstance(obj.get("sites"), dict) else {}
values = sites.get("values") if isinstance(sites.get("values"), list) else []
mode = str(sites.get("mode") or ("specific" if values else "all")).strip().lower()
if mode not in {"all", "specific"}:
mode = "all"
doc["sites"] = {
"mode": mode,
"values": [
str(v).strip()
for v in values
if isinstance(v, (str, int, float)) and str(v).strip()
],
}
vars_in = obj.get("variables") if isinstance(obj.get("variables"), list) else []
doc_vars: List[Dict[str, Any]] = []
for entry in vars_in:
if not isinstance(entry, dict):
continue
name = str(entry.get("name") or entry.get("key") or "").strip()
if not name:
continue
vtype = str(entry.get("type") or "string").strip().lower()
if vtype not in {"string", "number", "boolean", "credential"}:
vtype = "string"
default_val = entry.get("default", entry.get("default_value"))
doc_vars.append(
{
"name": name,
"label": str(entry.get("label") or ""),
"type": vtype,
"default": default_val,
"required": bool(entry.get("required")),
"description": str(entry.get("description") or ""),
}
)
doc["variables"] = doc_vars
files_in = obj.get("files") if isinstance(obj.get("files"), list) else []
doc_files: List[Dict[str, Any]] = []
for record in files_in:
if not isinstance(record, dict):
continue
fname = record.get("file_name") or record.get("name")
data = record.get("data")
if not fname or not isinstance(data, str):
continue
size_val = record.get("size")
try:
size_int = int(size_val)
except Exception:
size_int = 0
doc_files.append(
{
"file_name": str(fname),
"size": size_int,
"mime_type": str(record.get("mime_type") or record.get("mimeType") or ""),
"data": data,
}
)
doc["files"] = doc_files
try:
doc["version"] = int(obj.get("version") or doc["version"])
except Exception:
pass
return doc
def _load_assembly_document(
self,
abs_path: Path,
island: str,
type_hint: str = "",
) -> Dict[str, Any]:
base_name = abs_path.stem
default_type = self._default_type_for_island(island, type_hint)
if abs_path.suffix.lower() == ".json":
data = self._safe_read_json(abs_path)
return self._normalize_assembly_document(data, default_type, base_name)
try:
content = abs_path.read_text(encoding="utf-8", errors="replace")
except Exception:
content = ""
document = self._empty_assembly_document(default_type)
document["name"] = base_name
document["script"] = (content or "").replace("\r\n", "\n")
if default_type == "ansible":
document["category"] = "application"
return document
@staticmethod
def _safe_read_json(path: Path) -> Dict[str, Any]:
try:
return json.loads(path.read_text(encoding="utf-8"))
except Exception:
return {}
@staticmethod
def _extract_tab_name(obj: Dict[str, Any]) -> str:
if not isinstance(obj, dict):
return ""
for key in ["tabName", "tab_name", "name", "title"]:
value = obj.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return ""
def _detect_script_type(self, path: Path) -> str:
lower = path.name.lower()
if lower.endswith(".json") and path.is_file():
obj = self._safe_read_json(path)
if isinstance(obj, dict):
typ = str(
obj.get("type") or obj.get("script_type") or ""
).strip().lower()
if typ in {"powershell", "batch", "bash", "ansible"}:
return typ
return "powershell"
if lower.endswith(".yml"):
return "ansible"
if lower.endswith(".ps1"):
return "powershell"
if lower.endswith(".bat"):
return "batch"
if lower.endswith(".sh"):
return "bash"
return "unknown"
@staticmethod
def _ensure_workflow_document(content: Any) -> Dict[str, Any]:
payload = content
if isinstance(payload, str):
try:
payload = json.loads(payload)
except Exception:
payload = {}
if not isinstance(payload, dict):
payload = {}
return payload
@staticmethod
def _write_json(path: Path, payload: Dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")

View File

@@ -11,6 +11,26 @@ from .token_service import (
TokenRefreshErrorCode,
TokenService,
)
from .operator_account_service import (
AccountNotFoundError,
CannotModifySelfError,
InvalidPasswordHashError,
InvalidRoleError,
LastAdminError,
LastUserError,
OperatorAccountError,
OperatorAccountRecord,
OperatorAccountService,
UsernameAlreadyExistsError,
)
from .operator_auth_service import (
InvalidCredentialsError,
InvalidMFACodeError,
MFAUnavailableError,
MFASessionError,
OperatorAuthError,
OperatorAuthService,
)
__all__ = [
"DeviceAuthService",
@@ -24,4 +44,20 @@ __all__ = [
"TokenRefreshError",
"TokenRefreshErrorCode",
"TokenService",
"OperatorAccountService",
"OperatorAccountError",
"OperatorAccountRecord",
"UsernameAlreadyExistsError",
"AccountNotFoundError",
"LastAdminError",
"LastUserError",
"CannotModifySelfError",
"InvalidRoleError",
"InvalidPasswordHashError",
"OperatorAuthService",
"OperatorAuthError",
"InvalidCredentialsError",
"InvalidMFACodeError",
"MFAUnavailableError",
"MFASessionError",
]

View File

@@ -0,0 +1,211 @@
"""Operator account management service."""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
from typing import Optional
from Data.Engine.domain import OperatorAccount
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
class OperatorAccountError(Exception):
"""Base class for operator account management failures."""
class UsernameAlreadyExistsError(OperatorAccountError):
"""Raised when attempting to create an operator with a duplicate username."""
class AccountNotFoundError(OperatorAccountError):
"""Raised when the requested operator account cannot be located."""
class LastAdminError(OperatorAccountError):
"""Raised when attempting to demote or delete the last remaining admin."""
class LastUserError(OperatorAccountError):
"""Raised when attempting to delete the final operator account."""
class CannotModifySelfError(OperatorAccountError):
"""Raised when the caller attempts to delete themselves."""
class InvalidRoleError(OperatorAccountError):
"""Raised when a role value is invalid."""
class InvalidPasswordHashError(OperatorAccountError):
"""Raised when a password hash is malformed."""
@dataclass(frozen=True, slots=True)
class OperatorAccountRecord:
username: str
display_name: str
role: str
last_login: int
created_at: int
updated_at: int
mfa_enabled: bool
class OperatorAccountService:
"""High-level operations for managing operator accounts."""
def __init__(
self,
repository: SQLiteUserRepository,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._repository = repository
self._log = logger or logging.getLogger("borealis.engine.services.operator_accounts")
def list_accounts(self) -> list[OperatorAccountRecord]:
return [_to_record(account) for account in self._repository.list_accounts()]
def create_account(
self,
*,
username: str,
password_sha512: str,
role: str,
display_name: Optional[str] = None,
) -> OperatorAccountRecord:
normalized_role = self._normalize_role(role)
username = (username or "").strip()
password_sha512 = (password_sha512 or "").strip().lower()
display_name = (display_name or username or "").strip()
if not username or not password_sha512:
raise InvalidPasswordHashError("username and password are required")
if len(password_sha512) != 128:
raise InvalidPasswordHashError("password hash must be 128 hex characters")
now = int(time.time())
try:
self._repository.create_account(
username=username,
display_name=display_name or username,
password_sha512=password_sha512,
role=normalized_role,
timestamp=now,
)
except Exception as exc: # pragma: no cover - sqlite integrity errors are deterministic
import sqlite3
if isinstance(exc, sqlite3.IntegrityError):
raise UsernameAlreadyExistsError("username already exists") from exc
raise
account = self._repository.fetch_by_username(username)
if not account: # pragma: no cover - sanity guard
raise AccountNotFoundError("account creation failed")
return _to_record(account)
def delete_account(self, username: str, *, actor: Optional[str] = None) -> None:
username = (username or "").strip()
if not username:
raise AccountNotFoundError("invalid username")
if actor and actor.strip().lower() == username.lower():
raise CannotModifySelfError("cannot delete yourself")
total_accounts = self._repository.count_accounts()
if total_accounts <= 1:
raise LastUserError("cannot delete the last user")
target = self._repository.fetch_by_username(username)
if not target:
raise AccountNotFoundError("user not found")
if target.role.lower() == "admin" and self._repository.count_admins() <= 1:
raise LastAdminError("cannot delete the last admin")
if not self._repository.delete_account(username):
raise AccountNotFoundError("user not found")
def reset_password(self, username: str, password_sha512: str) -> None:
username = (username or "").strip()
password_sha512 = (password_sha512 or "").strip().lower()
if len(password_sha512) != 128:
raise InvalidPasswordHashError("invalid password hash")
now = int(time.time())
if not self._repository.update_password(username, password_sha512, timestamp=now):
raise AccountNotFoundError("user not found")
def change_role(self, username: str, role: str, *, actor: Optional[str] = None) -> OperatorAccountRecord:
username = (username or "").strip()
normalized_role = self._normalize_role(role)
account = self._repository.fetch_by_username(username)
if not account:
raise AccountNotFoundError("user not found")
if account.role.lower() == "admin" and normalized_role.lower() != "admin":
if self._repository.count_admins() <= 1:
raise LastAdminError("cannot demote the last admin")
now = int(time.time())
if not self._repository.update_role(username, normalized_role, timestamp=now):
raise AccountNotFoundError("user not found")
updated = self._repository.fetch_by_username(username)
if not updated: # pragma: no cover - guard
raise AccountNotFoundError("user not found")
record = _to_record(updated)
if actor and actor.strip().lower() == username.lower():
self._log.info("actor-role-updated", extra={"username": username, "role": record.role})
return record
def update_mfa(self, username: str, *, enabled: bool, reset_secret: bool) -> None:
username = (username or "").strip()
if not username:
raise AccountNotFoundError("invalid username")
now = int(time.time())
if not self._repository.update_mfa(username, enabled=enabled, reset_secret=reset_secret, timestamp=now):
raise AccountNotFoundError("user not found")
def fetch_account(self, username: str) -> Optional[OperatorAccountRecord]:
account = self._repository.fetch_by_username(username)
return _to_record(account) if account else None
def _normalize_role(self, role: str) -> str:
normalized = (role or "").strip().title() or "User"
if normalized not in {"User", "Admin"}:
raise InvalidRoleError("invalid role")
return normalized
def _to_record(account: OperatorAccount) -> OperatorAccountRecord:
return OperatorAccountRecord(
username=account.username,
display_name=account.display_name or account.username,
role=account.role or "User",
last_login=int(account.last_login or 0),
created_at=int(account.created_at or 0),
updated_at=int(account.updated_at or 0),
mfa_enabled=bool(account.mfa_enabled),
)
__all__ = [
"OperatorAccountService",
"OperatorAccountError",
"UsernameAlreadyExistsError",
"AccountNotFoundError",
"LastAdminError",
"LastUserError",
"CannotModifySelfError",
"InvalidRoleError",
"InvalidPasswordHashError",
"OperatorAccountRecord",
]

View File

@@ -0,0 +1,236 @@
"""Operator authentication service."""
from __future__ import annotations
import base64
import io
import logging
import os
import time
import uuid
from typing import Optional
try: # pragma: no cover - optional dependencies mirror legacy server behaviour
import pyotp # type: ignore
except Exception: # pragma: no cover - gracefully degrade when unavailable
pyotp = None # type: ignore
try: # pragma: no cover - optional dependency
import qrcode # type: ignore
except Exception: # pragma: no cover - gracefully degrade when unavailable
qrcode = None # type: ignore
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from Data.Engine.builders.operator_auth import (
OperatorLoginRequest,
OperatorMFAVerificationRequest,
)
from Data.Engine.domain import (
OperatorAccount,
OperatorLoginSuccess,
OperatorMFAChallenge,
)
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
class OperatorAuthError(Exception):
"""Base class for operator authentication errors."""
class InvalidCredentialsError(OperatorAuthError):
"""Raised when username/password verification fails."""
class MFAUnavailableError(OperatorAuthError):
"""Raised when MFA functionality is requested but dependencies are missing."""
class InvalidMFACodeError(OperatorAuthError):
"""Raised when the submitted MFA code is invalid."""
class MFASessionError(OperatorAuthError):
"""Raised when the MFA session state cannot be validated."""
class OperatorAuthService:
"""Authenticate operator accounts and manage MFA challenges."""
def __init__(
self,
repository: SQLiteUserRepository,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._repository = repository
self._log = logger or logging.getLogger("borealis.engine.services.operator_auth")
def authenticate(
self, request: OperatorLoginRequest
) -> OperatorLoginSuccess | OperatorMFAChallenge:
account = self._repository.fetch_by_username(request.username)
if not account:
raise InvalidCredentialsError("invalid username or password")
if not self._password_matches(account, request.password_sha512):
raise InvalidCredentialsError("invalid username or password")
if not account.mfa_enabled:
return self._finalize_login(account)
stage = "verify" if account.mfa_secret else "setup"
return self._build_mfa_challenge(account, stage)
def verify_mfa(
self,
challenge: OperatorMFAChallenge,
request: OperatorMFAVerificationRequest,
) -> OperatorLoginSuccess:
now = int(time.time())
if challenge.pending_token != request.pending_token:
raise MFASessionError("invalid_session")
if challenge.expires_at < now:
raise MFASessionError("expired")
if challenge.stage == "setup":
secret = (challenge.secret or "").strip()
if not secret:
raise MFASessionError("mfa_not_configured")
totp = self._totp_for_secret(secret)
if not totp.verify(request.code, valid_window=1):
raise InvalidMFACodeError("invalid_code")
self._repository.store_mfa_secret(challenge.username, secret, timestamp=now)
else:
account = self._repository.fetch_by_username(challenge.username)
if not account or not account.mfa_secret:
raise MFASessionError("mfa_not_configured")
totp = self._totp_for_secret(account.mfa_secret)
if not totp.verify(request.code, valid_window=1):
raise InvalidMFACodeError("invalid_code")
account = self._repository.fetch_by_username(challenge.username)
if not account:
raise InvalidCredentialsError("invalid username or password")
return self._finalize_login(account)
def issue_token(self, username: str, role: str) -> str:
serializer = self._token_serializer()
payload = {"u": username, "r": role or "User", "ts": int(time.time())}
return serializer.dumps(payload)
def resolve_token(self, token: str, *, max_age: int = 30 * 24 * 3600) -> Optional[OperatorAccount]:
"""Return the account associated with *token* if it is valid."""
token = (token or "").strip()
if not token:
return None
serializer = self._token_serializer()
try:
payload = serializer.loads(token, max_age=max_age)
except (BadSignature, SignatureExpired):
return None
username = str(payload.get("u") or "").strip()
if not username:
return None
return self._repository.fetch_by_username(username)
def fetch_account(self, username: str) -> Optional[OperatorAccount]:
"""Return the operator account for *username* if it exists."""
username = (username or "").strip()
if not username:
return None
return self._repository.fetch_by_username(username)
def _finalize_login(self, account: OperatorAccount) -> OperatorLoginSuccess:
now = int(time.time())
self._repository.update_last_login(account.username, now)
token = self.issue_token(account.username, account.role)
return OperatorLoginSuccess(username=account.username, role=account.role, token=token)
def _password_matches(self, account: OperatorAccount, provided_hash: str) -> bool:
expected = (account.password_sha512 or "").strip().lower()
candidate = (provided_hash or "").strip().lower()
return bool(expected and candidate and expected == candidate)
def _build_mfa_challenge(
self,
account: OperatorAccount,
stage: str,
) -> OperatorMFAChallenge:
now = int(time.time())
pending_token = uuid.uuid4().hex
secret = None
otpauth_url = None
qr_image = None
if stage == "setup":
secret = self._generate_totp_secret()
otpauth_url = self._totp_provisioning_uri(secret, account.username)
qr_image = self._totp_qr_data_uri(otpauth_url) if otpauth_url else None
return OperatorMFAChallenge(
username=account.username,
role=account.role,
stage="verify" if stage == "verify" else "setup",
pending_token=pending_token,
expires_at=now + 300,
secret=secret,
otpauth_url=otpauth_url,
qr_image=qr_image,
)
def _token_serializer(self) -> URLSafeTimedSerializer:
secret = os.getenv("BOREALIS_FLASK_SECRET_KEY") or "change-me"
return URLSafeTimedSerializer(secret, salt="borealis-auth")
def _generate_totp_secret(self) -> str:
if not pyotp:
raise MFAUnavailableError("pyotp is not installed; MFA unavailable")
return pyotp.random_base32() # type: ignore[no-any-return]
def _totp_for_secret(self, secret: str):
if not pyotp:
raise MFAUnavailableError("pyotp is not installed; MFA unavailable")
normalized = secret.replace(" ", "").strip().upper()
if not normalized:
raise MFASessionError("mfa_not_configured")
return pyotp.TOTP(normalized, digits=6, interval=30)
def _totp_provisioning_uri(self, secret: str, username: str) -> Optional[str]:
try:
totp = self._totp_for_secret(secret)
except OperatorAuthError:
return None
issuer = os.getenv("BOREALIS_MFA_ISSUER", "Borealis")
try:
return totp.provisioning_uri(name=username, issuer_name=issuer)
except Exception: # pragma: no cover - defensive
return None
def _totp_qr_data_uri(self, payload: str) -> Optional[str]:
if not payload or qrcode is None:
return None
try:
img = qrcode.make(payload, box_size=6, border=4)
buf = io.BytesIO()
img.save(buf, format="PNG")
encoded = base64.b64encode(buf.getvalue()).decode("ascii")
return f"data:image/png;base64,{encoded}"
except Exception: # pragma: no cover - defensive
self._log.warning("failed to generate MFA QR code", exc_info=True)
return None
__all__ = [
"OperatorAuthService",
"OperatorAuthError",
"InvalidCredentialsError",
"MFAUnavailableError",
"InvalidMFACodeError",
"MFASessionError",
]

View File

@@ -13,25 +13,38 @@ from Data.Engine.integrations.github import GitHubArtifactProvider
from Data.Engine.repositories.sqlite import (
SQLiteConnectionFactory,
SQLiteDeviceRepository,
SQLiteDeviceInventoryRepository,
SQLiteDeviceViewRepository,
SQLiteCredentialRepository,
SQLiteEnrollmentRepository,
SQLiteGitHubRepository,
SQLiteJobRepository,
SQLiteRefreshTokenRepository,
SQLiteSiteRepository,
SQLiteUserRepository,
)
from Data.Engine.services.auth import (
DeviceAuthService,
DPoPValidator,
OperatorAccountService,
OperatorAuthService,
JWTService,
TokenService,
load_jwt_service,
)
from Data.Engine.services.crypto.signing import ScriptSigner, load_signer
from Data.Engine.services.enrollment import EnrollmentService
from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService
from Data.Engine.services.enrollment.nonce_cache import NonceCache
from Data.Engine.services.devices import DeviceInventoryService
from Data.Engine.services.devices import DeviceViewService
from Data.Engine.services.credentials import CredentialService
from Data.Engine.services.github import GitHubService
from Data.Engine.services.jobs import SchedulerService
from Data.Engine.services.rate_limit import SlidingWindowRateLimiter
from Data.Engine.services.realtime import AgentRealtimeService
from Data.Engine.services.sites import SiteService
from Data.Engine.services.assemblies import AssemblyService
__all__ = ["EngineServiceContainer", "build_service_container"]
@@ -39,13 +52,22 @@ __all__ = ["EngineServiceContainer", "build_service_container"]
@dataclass(frozen=True, slots=True)
class EngineServiceContainer:
device_auth: DeviceAuthService
device_inventory: DeviceInventoryService
device_view_service: DeviceViewService
credential_service: CredentialService
token_service: TokenService
enrollment_service: EnrollmentService
enrollment_admin_service: EnrollmentAdminService
site_service: SiteService
jwt_service: JWTService
dpop_validator: DPoPValidator
agent_realtime: AgentRealtimeService
scheduler_service: SchedulerService
github_service: GitHubService
operator_auth_service: OperatorAuthService
operator_account_service: OperatorAccountService
assembly_service: AssemblyService
script_signer: Optional[ScriptSigner]
def build_service_container(
@@ -57,10 +79,21 @@ def build_service_container(
log = logger or logging.getLogger("borealis.engine.services")
device_repo = SQLiteDeviceRepository(db_factory, logger=log.getChild("devices"))
device_inventory_repo = SQLiteDeviceInventoryRepository(
db_factory, logger=log.getChild("devices.inventory")
)
device_view_repo = SQLiteDeviceViewRepository(
db_factory, logger=log.getChild("devices.views")
)
credential_repo = SQLiteCredentialRepository(
db_factory, logger=log.getChild("credentials.repo")
)
token_repo = SQLiteRefreshTokenRepository(db_factory, logger=log.getChild("tokens"))
enrollment_repo = SQLiteEnrollmentRepository(db_factory, logger=log.getChild("enrollment"))
job_repo = SQLiteJobRepository(db_factory, logger=log.getChild("jobs"))
github_repo = SQLiteGitHubRepository(db_factory, logger=log.getChild("github_repo"))
site_repo = SQLiteSiteRepository(db_factory, logger=log.getChild("sites.repo"))
user_repo = SQLiteUserRepository(db_factory, logger=log.getChild("users"))
jwt_service = load_jwt_service()
dpop_validator = DPoPValidator()
@@ -74,6 +107,8 @@ def build_service_container(
logger=log.getChild("token_service"),
)
script_signer = _load_script_signer(log)
enrollment_service = EnrollmentService(
device_repository=device_repo,
enrollment_repository=enrollment_repo,
@@ -83,10 +118,16 @@ def build_service_container(
ip_rate_limiter=SlidingWindowRateLimiter(),
fingerprint_rate_limiter=SlidingWindowRateLimiter(),
nonce_cache=NonceCache(),
script_signer=_load_script_signer(log),
script_signer=script_signer,
logger=log.getChild("enrollment"),
)
enrollment_admin_service = EnrollmentAdminService(
repository=enrollment_repo,
user_repository=user_repo,
logger=log.getChild("enrollment_admin"),
)
device_auth = DeviceAuthService(
device_repository=device_repo,
jwt_service=jwt_service,
@@ -106,6 +147,36 @@ def build_service_container(
logger=log.getChild("scheduler"),
)
operator_auth_service = OperatorAuthService(
repository=user_repo,
logger=log.getChild("operator_auth"),
)
operator_account_service = OperatorAccountService(
repository=user_repo,
logger=log.getChild("operator_accounts"),
)
device_inventory = DeviceInventoryService(
repository=device_inventory_repo,
logger=log.getChild("device_inventory"),
)
device_view_service = DeviceViewService(
repository=device_view_repo,
logger=log.getChild("device_views"),
)
credential_service = CredentialService(
repository=credential_repo,
logger=log.getChild("credentials"),
)
site_service = SiteService(
repository=site_repo,
logger=log.getChild("sites"),
)
assembly_service = AssemblyService(
root=settings.project_root / "Assemblies",
logger=log.getChild("assemblies"),
)
github_provider = GitHubArtifactProvider(
cache_file=settings.github.cache_file,
default_repo=settings.github.default_repo,
@@ -124,11 +195,20 @@ def build_service_container(
device_auth=device_auth,
token_service=token_service,
enrollment_service=enrollment_service,
enrollment_admin_service=enrollment_admin_service,
jwt_service=jwt_service,
dpop_validator=dpop_validator,
agent_realtime=agent_realtime,
scheduler_service=scheduler_service,
github_service=github_service,
operator_auth_service=operator_auth_service,
operator_account_service=operator_account_service,
device_inventory=device_inventory,
device_view_service=device_view_service,
credential_service=credential_service,
site_service=site_service,
assembly_service=assembly_service,
script_signer=script_signer,
)

View File

@@ -0,0 +1,3 @@
from .credential_service import CredentialService
__all__ = ["CredentialService"]

View File

@@ -0,0 +1,29 @@
"""Expose read access to stored credentials."""
from __future__ import annotations
import logging
from typing import List, Optional
from Data.Engine.repositories.sqlite.credential_repository import SQLiteCredentialRepository
__all__ = ["CredentialService"]
class CredentialService:
def __init__(
self,
repository: SQLiteCredentialRepository,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._repo = repository
self._log = logger or logging.getLogger("borealis.engine.services.credentials")
def list_credentials(
self,
*,
site_id: Optional[int] = None,
connection_type: Optional[str] = None,
) -> List[dict]:
return self._repo.list_credentials(site_id=site_id, connection_type=connection_type)

View File

@@ -0,0 +1,15 @@
from .device_inventory_service import (
DeviceDescriptionError,
DeviceDetailsError,
DeviceInventoryService,
RemoteDeviceError,
)
from .device_view_service import DeviceViewService
__all__ = [
"DeviceInventoryService",
"RemoteDeviceError",
"DeviceViewService",
"DeviceDetailsError",
"DeviceDescriptionError",
]

View File

@@ -0,0 +1,511 @@
"""Mirrors the legacy device inventory HTTP behaviour."""
from __future__ import annotations
import json
import logging
import sqlite3
import time
from datetime import datetime, timezone
from collections.abc import Mapping
from typing import Any, Dict, List, Optional
from Data.Engine.repositories.sqlite.device_inventory_repository import (
SQLiteDeviceInventoryRepository,
)
from Data.Engine.domain.device_auth import DeviceAuthContext, normalize_guid
from Data.Engine.domain.devices import clean_device_str, coerce_int, ts_to_human
__all__ = [
"DeviceInventoryService",
"RemoteDeviceError",
"DeviceHeartbeatError",
"DeviceDetailsError",
"DeviceDescriptionError",
]
class RemoteDeviceError(Exception):
def __init__(self, code: str, message: Optional[str] = None) -> None:
super().__init__(message or code)
self.code = code
class DeviceHeartbeatError(Exception):
def __init__(self, code: str, message: Optional[str] = None) -> None:
super().__init__(message or code)
self.code = code
class DeviceDetailsError(Exception):
def __init__(self, code: str, message: Optional[str] = None) -> None:
super().__init__(message or code)
self.code = code
class DeviceDescriptionError(Exception):
def __init__(self, code: str, message: Optional[str] = None) -> None:
super().__init__(message or code)
self.code = code
class DeviceInventoryService:
def __init__(
self,
repository: SQLiteDeviceInventoryRepository,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._repo = repository
self._log = logger or logging.getLogger("borealis.engine.services.devices")
def list_devices(self) -> List[Dict[str, object]]:
return self._repo.fetch_devices()
def list_agent_devices(self) -> List[Dict[str, object]]:
return self._repo.fetch_devices(only_agents=True)
def list_remote_devices(self, connection_type: str) -> List[Dict[str, object]]:
return self._repo.fetch_devices(connection_type=connection_type)
def get_device_by_guid(self, guid: str) -> Optional[Dict[str, object]]:
snapshot = self._repo.load_snapshot(guid=guid)
if not snapshot:
return None
devices = self._repo.fetch_devices(hostname=snapshot.get("hostname"))
return devices[0] if devices else None
def collect_agent_hash_records(self) -> List[Dict[str, object]]:
records: List[Dict[str, object]] = []
key_to_index: Dict[str, int] = {}
for device in self._repo.fetch_devices():
summary = device.get("summary", {}) if isinstance(device, dict) else {}
agent_id = (summary.get("agent_id") or "").strip()
agent_guid = (summary.get("agent_guid") or "").strip()
hostname = (summary.get("hostname") or device.get("hostname") or "").strip()
agent_hash = (summary.get("agent_hash") or device.get("agent_hash") or "").strip()
keys: List[str] = []
if agent_id:
keys.append(f"id:{agent_id.lower()}")
if agent_guid:
keys.append(f"guid:{agent_guid.lower()}")
if hostname:
keys.append(f"host:{hostname.lower()}")
payload = {
"agent_id": agent_id or None,
"agent_guid": agent_guid or None,
"hostname": hostname or None,
"agent_hash": agent_hash or None,
"source": "database",
}
if not keys:
records.append(payload)
continue
existing_index = None
for key in keys:
if key in key_to_index:
existing_index = key_to_index[key]
break
if existing_index is None:
existing_index = len(records)
records.append(payload)
for key in keys:
key_to_index[key] = existing_index
continue
merged = records[existing_index]
for key in ("agent_id", "agent_guid", "hostname", "agent_hash"):
if not merged.get(key) and payload.get(key):
merged[key] = payload[key]
return records
def upsert_remote_device(
self,
connection_type: str,
hostname: str,
address: Optional[str],
description: Optional[str],
os_hint: Optional[str],
*,
ensure_existing_type: Optional[str],
) -> Dict[str, object]:
normalized_type = (connection_type or "").strip().lower()
if not normalized_type:
raise RemoteDeviceError("invalid_type", "connection type required")
normalized_host = (hostname or "").strip()
if not normalized_host:
raise RemoteDeviceError("invalid_hostname", "hostname is required")
existing = self._repo.load_snapshot(hostname=normalized_host)
existing_type = (existing or {}).get("summary", {}).get("connection_type") or ""
existing_type = existing_type.strip().lower()
if ensure_existing_type and existing_type != ensure_existing_type.lower():
raise RemoteDeviceError("not_found", "device not found")
if ensure_existing_type is None and existing_type and existing_type != normalized_type:
raise RemoteDeviceError("conflict", "device already exists with different connection type")
created_ts = None
if existing:
created_ts = existing.get("summary", {}).get("created_at")
endpoint = (address or "").strip() or (existing or {}).get("summary", {}).get("connection_endpoint") or ""
if not endpoint:
raise RemoteDeviceError("address_required", "address is required")
description_val = description if description is not None else (existing or {}).get("summary", {}).get("description")
os_value = os_hint or (existing or {}).get("summary", {}).get("operating_system")
os_value = (os_value or "").strip()
device_type_label = "SSH Remote" if normalized_type == "ssh" else "WinRM Remote"
summary_payload = {
"connection_type": normalized_type,
"connection_endpoint": endpoint,
"internal_ip": endpoint,
"external_ip": endpoint,
"device_type": device_type_label,
"operating_system": os_value or "",
"last_seen": 0,
"description": (description_val or ""),
}
try:
self._repo.upsert_device(
normalized_host,
description_val,
{"summary": summary_payload},
created_ts,
)
except sqlite3.DatabaseError as exc: # type: ignore[name-defined]
raise RemoteDeviceError("storage_error", str(exc)) from exc
except Exception as exc: # pragma: no cover - defensive
raise RemoteDeviceError("storage_error", str(exc)) from exc
devices = self._repo.fetch_devices(hostname=normalized_host)
if not devices:
raise RemoteDeviceError("reload_failed", "failed to load device after upsert")
return devices[0]
def delete_remote_device(self, connection_type: str, hostname: str) -> None:
normalized_host = (hostname or "").strip()
if not normalized_host:
raise RemoteDeviceError("invalid_hostname", "invalid hostname")
existing = self._repo.load_snapshot(hostname=normalized_host)
if not existing:
raise RemoteDeviceError("not_found", "device not found")
existing_type = (existing.get("summary", {}) or {}).get("connection_type") or ""
if (existing_type or "").strip().lower() != (connection_type or "").strip().lower():
raise RemoteDeviceError("not_found", "device not found")
self._repo.delete_device_by_hostname(normalized_host)
# ------------------------------------------------------------------
# Agent heartbeats
# ------------------------------------------------------------------
def record_heartbeat(
self,
*,
context: DeviceAuthContext,
payload: Mapping[str, Any],
) -> None:
guid = context.identity.guid.value
snapshot = self._repo.load_snapshot(guid=guid)
if not snapshot:
raise DeviceHeartbeatError("device_not_registered", "device not registered")
summary = dict(snapshot.get("summary") or {})
details = dict(snapshot.get("details") or {})
now_ts = int(time.time())
summary["last_seen"] = now_ts
summary["agent_guid"] = guid
existing_hostname = clean_device_str(summary.get("hostname")) or clean_device_str(
snapshot.get("hostname")
)
incoming_hostname = clean_device_str(payload.get("hostname"))
raw_metrics = payload.get("metrics")
metrics = raw_metrics if isinstance(raw_metrics, Mapping) else {}
metrics_hostname = clean_device_str(metrics.get("hostname")) if metrics else None
hostname = incoming_hostname or metrics_hostname or existing_hostname
if not hostname:
hostname = f"RECOVERED-{guid[:12]}"
summary["hostname"] = hostname
if metrics:
last_user = metrics.get("last_user") or metrics.get("username") or metrics.get("user")
if last_user:
cleaned_user = clean_device_str(last_user)
if cleaned_user:
summary["last_user"] = cleaned_user
operating_system = metrics.get("operating_system")
if operating_system:
cleaned_os = clean_device_str(operating_system)
if cleaned_os:
summary["operating_system"] = cleaned_os
uptime = metrics.get("uptime")
if uptime is not None:
coerced = coerce_int(uptime)
if coerced is not None:
summary["uptime"] = coerced
agent_id = metrics.get("agent_id")
if agent_id:
cleaned_agent = clean_device_str(agent_id)
if cleaned_agent:
summary["agent_id"] = cleaned_agent
for field in ("external_ip", "internal_ip", "device_type"):
value = payload.get(field)
cleaned = clean_device_str(value)
if cleaned:
summary[field] = cleaned
summary.setdefault("description", summary.get("description") or "")
created_at = coerce_int(summary.get("created_at"))
if created_at is None:
created_at = coerce_int(snapshot.get("created_at"))
if created_at is None:
created_at = now_ts
summary["created_at"] = created_at
raw_inventory = payload.get("inventory")
inventory = raw_inventory if isinstance(raw_inventory, Mapping) else {}
memory = inventory.get("memory") if isinstance(inventory.get("memory"), list) else details.get("memory")
network = inventory.get("network") if isinstance(inventory.get("network"), list) else details.get("network")
software = (
inventory.get("software") if isinstance(inventory.get("software"), list) else details.get("software")
)
storage = inventory.get("storage") if isinstance(inventory.get("storage"), list) else details.get("storage")
cpu = inventory.get("cpu") if isinstance(inventory.get("cpu"), Mapping) else details.get("cpu")
merged_details: Dict[str, Any] = {
"summary": summary,
"memory": memory,
"network": network,
"software": software,
"storage": storage,
"cpu": cpu,
}
try:
self._repo.upsert_device(
summary["hostname"],
summary.get("description"),
merged_details,
summary.get("created_at"),
agent_hash=clean_device_str(summary.get("agent_hash")),
guid=guid,
)
except sqlite3.IntegrityError as exc:
self._log.warning(
"device-heartbeat-conflict guid=%s hostname=%s error=%s",
guid,
summary["hostname"],
exc,
)
raise DeviceHeartbeatError("storage_conflict", str(exc)) from exc
except Exception as exc: # pragma: no cover - defensive
self._log.exception(
"device-heartbeat-failure guid=%s hostname=%s",
guid,
summary["hostname"],
exc_info=exc,
)
raise DeviceHeartbeatError("storage_error", "failed to persist heartbeat") from exc
# ------------------------------------------------------------------
# Agent details
# ------------------------------------------------------------------
@staticmethod
def _is_empty(value: Any) -> bool:
return value in (None, "", [], {})
@classmethod
def _deep_merge_preserve(cls, prev: Dict[str, Any], incoming: Dict[str, Any]) -> Dict[str, Any]:
merged: Dict[str, Any] = dict(prev or {})
for key, value in (incoming or {}).items():
if isinstance(value, Mapping):
existing = merged.get(key)
if not isinstance(existing, Mapping):
existing = {}
merged[key] = cls._deep_merge_preserve(dict(existing), dict(value))
elif isinstance(value, list):
if value:
merged[key] = value
else:
if cls._is_empty(value):
continue
merged[key] = value
return merged
def save_agent_details(
self,
*,
context: DeviceAuthContext,
payload: Mapping[str, Any],
) -> None:
hostname = clean_device_str(payload.get("hostname"))
details_raw = payload.get("details")
agent_id = clean_device_str(payload.get("agent_id"))
agent_hash = clean_device_str(payload.get("agent_hash"))
if not isinstance(details_raw, Mapping):
raise DeviceDetailsError("invalid_payload", "details object required")
details_dict: Dict[str, Any]
try:
details_dict = json.loads(json.dumps(details_raw))
except Exception:
details_dict = dict(details_raw)
incoming_summary = dict(details_dict.get("summary") or {})
if not hostname:
hostname = clean_device_str(incoming_summary.get("hostname"))
if not hostname:
raise DeviceDetailsError("invalid_payload", "hostname required")
snapshot = self._repo.load_snapshot(hostname=hostname)
if not snapshot:
snapshot = {}
previous_details = snapshot.get("details")
if isinstance(previous_details, Mapping):
try:
prev_details = json.loads(json.dumps(previous_details))
except Exception:
prev_details = dict(previous_details)
else:
prev_details = {}
prev_summary = dict(prev_details.get("summary") or {})
existing_guid = clean_device_str(snapshot.get("guid") or snapshot.get("summary", {}).get("agent_guid"))
normalized_existing_guid = normalize_guid(existing_guid)
auth_guid = context.identity.guid.value
if normalized_existing_guid and normalized_existing_guid != auth_guid:
raise DeviceDetailsError("guid_mismatch", "device guid mismatch")
fingerprint = context.identity.fingerprint.value.lower()
stored_fp = clean_device_str(snapshot.get("summary", {}).get("ssl_key_fingerprint"))
if stored_fp and stored_fp.lower() != fingerprint:
raise DeviceDetailsError("fingerprint_mismatch", "device fingerprint mismatch")
incoming_summary.setdefault("hostname", hostname)
if agent_id and not incoming_summary.get("agent_id"):
incoming_summary["agent_id"] = agent_id
if agent_hash:
incoming_summary["agent_hash"] = agent_hash
incoming_summary["agent_guid"] = auth_guid
if fingerprint:
incoming_summary["ssl_key_fingerprint"] = fingerprint
if not incoming_summary.get("last_seen") and prev_summary.get("last_seen"):
incoming_summary["last_seen"] = prev_summary.get("last_seen")
details_dict["summary"] = incoming_summary
merged_details = self._deep_merge_preserve(prev_details, details_dict)
merged_summary = merged_details.setdefault("summary", {})
if not merged_summary.get("last_user") and prev_summary.get("last_user"):
merged_summary["last_user"] = prev_summary.get("last_user")
created_at = coerce_int(merged_summary.get("created_at"))
if created_at is None:
created_at = coerce_int(snapshot.get("created_at"))
if created_at is None:
created_at = int(time.time())
merged_summary["created_at"] = created_at
if not merged_summary.get("created"):
merged_summary["created"] = ts_to_human(created_at)
if fingerprint:
merged_summary["ssl_key_fingerprint"] = fingerprint
if not merged_summary.get("key_added_at"):
merged_summary["key_added_at"] = datetime.now(timezone.utc).isoformat()
if merged_summary.get("token_version") is None:
merged_summary["token_version"] = 1
if not merged_summary.get("status") and snapshot.get("summary", {}).get("status"):
merged_summary["status"] = snapshot.get("summary", {}).get("status")
uptime_val = merged_summary.get("uptime")
if merged_summary.get("uptime_sec") is None and uptime_val is not None:
coerced = coerce_int(uptime_val)
if coerced is not None:
merged_summary["uptime_sec"] = coerced
merged_summary.setdefault("uptime_seconds", coerced)
if merged_summary.get("uptime_seconds") is None and merged_summary.get("uptime_sec") is not None:
merged_summary["uptime_seconds"] = merged_summary.get("uptime_sec")
description = clean_device_str(merged_summary.get("description"))
existing_description = snapshot.get("description") if snapshot else ""
description_to_store = description if description is not None else (existing_description or "")
existing_hash = clean_device_str(snapshot.get("agent_hash") or snapshot.get("summary", {}).get("agent_hash"))
effective_hash = agent_hash or existing_hash
try:
self._repo.upsert_device(
hostname,
description_to_store,
merged_details,
created_at,
agent_hash=effective_hash,
guid=auth_guid,
)
except sqlite3.DatabaseError as exc:
raise DeviceDetailsError("storage_error", str(exc)) from exc
added_at = merged_summary.get("key_added_at") or datetime.now(timezone.utc).isoformat()
self._repo.record_device_fingerprint(auth_guid, fingerprint, added_at)
# ------------------------------------------------------------------
# Description management
# ------------------------------------------------------------------
def update_device_description(self, hostname: str, description: Optional[str]) -> None:
normalized_host = clean_device_str(hostname)
if not normalized_host:
raise DeviceDescriptionError("invalid_hostname", "invalid hostname")
snapshot = self._repo.load_snapshot(hostname=normalized_host)
if not snapshot:
raise DeviceDescriptionError("not_found", "device not found")
details = snapshot.get("details")
if isinstance(details, Mapping):
try:
existing = json.loads(json.dumps(details))
except Exception:
existing = dict(details)
else:
existing = {}
summary = dict(existing.get("summary") or {})
summary["description"] = description or ""
existing["summary"] = summary
created_at = coerce_int(summary.get("created_at"))
if created_at is None:
created_at = coerce_int(snapshot.get("created_at"))
if created_at is None:
created_at = int(time.time())
agent_hash = clean_device_str(summary.get("agent_hash") or snapshot.get("agent_hash"))
guid = clean_device_str(summary.get("agent_guid") or snapshot.get("guid"))
try:
self._repo.upsert_device(
normalized_host,
description or (snapshot.get("description") or ""),
existing,
created_at,
agent_hash=agent_hash,
guid=guid,
)
except sqlite3.DatabaseError as exc:
raise DeviceDescriptionError("storage_error", str(exc)) from exc

View File

@@ -0,0 +1,73 @@
"""Service exposing CRUD for saved device list views."""
from __future__ import annotations
import logging
from typing import List, Optional
from Data.Engine.domain.device_views import DeviceListView
from Data.Engine.repositories.sqlite.device_view_repository import SQLiteDeviceViewRepository
__all__ = ["DeviceViewService"]
class DeviceViewService:
def __init__(
self,
repository: SQLiteDeviceViewRepository,
*,
logger: Optional[logging.Logger] = None,
) -> None:
self._repo = repository
self._log = logger or logging.getLogger("borealis.engine.services.device_views")
def list_views(self) -> List[DeviceListView]:
return self._repo.list_views()
def get_view(self, view_id: int) -> Optional[DeviceListView]:
return self._repo.get_view(view_id)
def create_view(self, name: str, columns: List[str], filters: dict) -> DeviceListView:
normalized_name = (name or "").strip()
if not normalized_name:
raise ValueError("missing_name")
if normalized_name.lower() == "default view":
raise ValueError("reserved")
return self._repo.create_view(normalized_name, list(columns), dict(filters))
def update_view(
self,
view_id: int,
*,
name: Optional[str] = None,
columns: Optional[List[str]] = None,
filters: Optional[dict] = None,
) -> DeviceListView:
updates: dict = {}
if name is not None:
normalized = (name or "").strip()
if not normalized:
raise ValueError("missing_name")
if normalized.lower() == "default view":
raise ValueError("reserved")
updates["name"] = normalized
if columns is not None:
if not isinstance(columns, list) or not all(isinstance(col, str) for col in columns):
raise ValueError("invalid_columns")
updates["columns"] = list(columns)
if filters is not None:
if not isinstance(filters, dict):
raise ValueError("invalid_filters")
updates["filters"] = dict(filters)
if not updates:
raise ValueError("no_fields")
return self._repo.update_view(
view_id,
name=updates.get("name"),
columns=updates.get("columns"),
filters=updates.get("filters"),
)
def delete_view(self, view_id: int) -> bool:
return self._repo.delete_view(view_id)

View File

@@ -2,20 +2,54 @@
from __future__ import annotations
from .enrollment_service import (
EnrollmentRequestResult,
EnrollmentService,
EnrollmentStatus,
EnrollmentTokenBundle,
PollingResult,
)
from Data.Engine.domain.device_enrollment import EnrollmentValidationError
from importlib import import_module
from typing import Any
__all__ = [
"EnrollmentRequestResult",
"EnrollmentService",
"EnrollmentRequestResult",
"EnrollmentStatus",
"EnrollmentTokenBundle",
"EnrollmentValidationError",
"PollingResult",
"EnrollmentValidationError",
"EnrollmentAdminService",
]
_LAZY: dict[str, tuple[str, str]] = {
"EnrollmentService": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentService"),
"EnrollmentRequestResult": (
"Data.Engine.services.enrollment.enrollment_service",
"EnrollmentRequestResult",
),
"EnrollmentStatus": ("Data.Engine.services.enrollment.enrollment_service", "EnrollmentStatus"),
"EnrollmentTokenBundle": (
"Data.Engine.services.enrollment.enrollment_service",
"EnrollmentTokenBundle",
),
"PollingResult": ("Data.Engine.services.enrollment.enrollment_service", "PollingResult"),
"EnrollmentValidationError": (
"Data.Engine.domain.device_enrollment",
"EnrollmentValidationError",
),
"EnrollmentAdminService": (
"Data.Engine.services.enrollment.admin_service",
"EnrollmentAdminService",
),
}
def __getattr__(name: str) -> Any:
try:
module_name, attribute = _LAZY[name]
except KeyError as exc: # pragma: no cover - defensive
raise AttributeError(name) from exc
module = import_module(module_name)
value = getattr(module, attribute)
globals()[name] = value
return value
def __dir__() -> list[str]: # pragma: no cover - interactive helper
return sorted(set(__all__))

View File

@@ -0,0 +1,245 @@
"""Administrative helpers for enrollment workflows."""
from __future__ import annotations
import logging
import secrets
import uuid
from datetime import datetime, timedelta, timezone
from typing import Callable, List, Optional
from dataclasses import dataclass
from Data.Engine.domain.device_auth import DeviceGuid, normalize_guid
from Data.Engine.domain.device_enrollment import EnrollmentApprovalStatus
from Data.Engine.domain.enrollment_admin import DeviceApprovalRecord, EnrollmentCodeRecord
from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
__all__ = ["EnrollmentAdminService", "DeviceApprovalActionResult"]
@dataclass(frozen=True, slots=True)
class DeviceApprovalActionResult:
"""Outcome metadata returned after mutating an approval."""
status: str
conflict_resolution: Optional[str] = None
def to_dict(self) -> dict[str, str]:
payload = {"status": self.status}
if self.conflict_resolution:
payload["conflict_resolution"] = self.conflict_resolution
return payload
class EnrollmentAdminService:
"""Expose administrative enrollment operations."""
_VALID_TTL_HOURS = {1, 3, 6, 12, 24}
def __init__(
self,
*,
repository: SQLiteEnrollmentRepository,
user_repository: SQLiteUserRepository,
logger: Optional[logging.Logger] = None,
clock: Optional[Callable[[], datetime]] = None,
) -> None:
self._repository = repository
self._users = user_repository
self._log = logger or logging.getLogger("borealis.engine.services.enrollment_admin")
self._clock = clock or (lambda: datetime.now(tz=timezone.utc))
# ------------------------------------------------------------------
# Enrollment install codes
# ------------------------------------------------------------------
def list_install_codes(self, *, status: Optional[str] = None) -> List[EnrollmentCodeRecord]:
return self._repository.list_install_codes(status=status, now=self._clock())
def create_install_code(
self,
*,
ttl_hours: int,
max_uses: int,
created_by: Optional[str],
) -> EnrollmentCodeRecord:
if ttl_hours not in self._VALID_TTL_HOURS:
raise ValueError("invalid_ttl")
normalized_max = self._normalize_max_uses(max_uses)
now = self._clock()
expires_at = now + timedelta(hours=ttl_hours)
record_id = str(uuid.uuid4())
code = self._generate_install_code()
created_by_identifier = None
if created_by:
created_by_identifier = self._users.resolve_identifier(created_by)
if not created_by_identifier:
created_by_identifier = created_by.strip() or None
record = self._repository.insert_install_code(
record_id=record_id,
code=code,
expires_at=expires_at,
created_by=created_by_identifier,
max_uses=normalized_max,
)
self._log.info(
"install code created id=%s ttl=%sh max_uses=%s",
record.record_id,
ttl_hours,
normalized_max,
)
return record
def delete_install_code(self, record_id: str) -> bool:
deleted = self._repository.delete_install_code_if_unused(record_id)
if deleted:
self._log.info("install code deleted id=%s", record_id)
return deleted
# ------------------------------------------------------------------
# Device approvals
# ------------------------------------------------------------------
def list_device_approvals(self, *, status: Optional[str] = None) -> List[DeviceApprovalRecord]:
return self._repository.list_device_approvals(status=status)
def approve_device_approval(
self,
record_id: str,
*,
actor: Optional[str],
guid: Optional[str] = None,
conflict_resolution: Optional[str] = None,
) -> DeviceApprovalActionResult:
return self._set_device_approval_status(
record_id,
EnrollmentApprovalStatus.APPROVED,
actor=actor,
guid=guid,
conflict_resolution=conflict_resolution,
)
def deny_device_approval(
self,
record_id: str,
*,
actor: Optional[str],
) -> DeviceApprovalActionResult:
return self._set_device_approval_status(
record_id,
EnrollmentApprovalStatus.DENIED,
actor=actor,
guid=None,
conflict_resolution=None,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _generate_install_code() -> str:
raw = secrets.token_hex(16).upper()
return "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))
@staticmethod
def _normalize_max_uses(value: int) -> int:
try:
count = int(value)
except Exception:
count = 2
if count < 1:
return 1
if count > 10:
return 10
return count
def _set_device_approval_status(
self,
record_id: str,
status: EnrollmentApprovalStatus,
*,
actor: Optional[str],
guid: Optional[str],
conflict_resolution: Optional[str],
) -> DeviceApprovalActionResult:
approval = self._repository.fetch_device_approval(record_id)
if approval is None:
raise LookupError("not_found")
if approval.status is not EnrollmentApprovalStatus.PENDING:
raise ValueError("approval_not_pending")
normalized_guid = normalize_guid(guid) or (approval.guid.value if approval.guid else "")
resolution_normalized = (conflict_resolution or "").strip().lower() or None
fingerprint_match = False
conflict_guid: Optional[str] = None
if status is EnrollmentApprovalStatus.APPROVED:
pending_records = self._repository.list_device_approvals(status="pending")
current_record = next(
(record for record in pending_records if record.record_id == approval.record_id),
None,
)
conflict = current_record.hostname_conflict if current_record else None
if conflict:
conflict_guid = normalize_guid(conflict.guid)
fingerprint_match = bool(conflict.fingerprint_match)
if fingerprint_match:
normalized_guid = conflict_guid or normalized_guid or ""
if resolution_normalized is None:
resolution_normalized = "auto_merge_fingerprint"
elif resolution_normalized == "overwrite":
normalized_guid = conflict_guid or normalized_guid or ""
elif resolution_normalized == "coexist":
pass
else:
raise ValueError("conflict_resolution_required")
if normalized_guid:
try:
guid_value = DeviceGuid(normalized_guid)
except ValueError as exc:
raise ValueError("invalid_guid") from exc
else:
guid_value = None
actor_identifier = None
if actor:
actor_identifier = self._users.resolve_identifier(actor)
if not actor_identifier:
actor_identifier = actor.strip() or None
if not actor_identifier:
actor_identifier = "system"
self._repository.update_device_approval_status(
approval.record_id,
status=status,
updated_at=self._clock(),
approved_by=actor_identifier,
guid=guid_value,
)
if status is EnrollmentApprovalStatus.APPROVED:
self._log.info(
"device approval %s approved resolution=%s guid=%s",
approval.record_id,
resolution_normalized or "",
guid_value.value if guid_value else normalized_guid or "",
)
else:
self._log.info("device approval %s denied", approval.record_id)
return DeviceApprovalActionResult(
status=status.value,
conflict_resolution=resolution_normalized,
)

View File

@@ -0,0 +1,3 @@
from .site_service import SiteService
__all__ = ["SiteService"]

View File

@@ -0,0 +1,73 @@
"""Site management service that mirrors the legacy Flask behaviour."""
from __future__ import annotations
import logging
from typing import Dict, Iterable, List, Optional
from Data.Engine.domain.sites import SiteDeviceMapping, SiteSummary
from Data.Engine.repositories.sqlite.site_repository import SQLiteSiteRepository
__all__ = ["SiteService"]
class SiteService:
def __init__(self, repository: SQLiteSiteRepository, *, logger: Optional[logging.Logger] = None) -> None:
self._repo = repository
self._log = logger or logging.getLogger("borealis.engine.services.sites")
def list_sites(self) -> List[SiteSummary]:
return self._repo.list_sites()
def create_site(self, name: str, description: str) -> SiteSummary:
normalized_name = (name or "").strip()
normalized_description = (description or "").strip()
if not normalized_name:
raise ValueError("missing_name")
try:
return self._repo.create_site(normalized_name, normalized_description)
except ValueError as exc:
if str(exc) == "duplicate":
raise ValueError("duplicate") from exc
raise
def delete_sites(self, ids: Iterable[int]) -> int:
normalized = []
for value in ids:
try:
normalized.append(int(value))
except Exception:
continue
if not normalized:
return 0
return self._repo.delete_sites(tuple(normalized))
def rename_site(self, site_id: int, new_name: str) -> SiteSummary:
normalized_name = (new_name or "").strip()
if not normalized_name:
raise ValueError("missing_name")
try:
return self._repo.rename_site(int(site_id), normalized_name)
except ValueError as exc:
if str(exc) == "duplicate":
raise ValueError("duplicate") from exc
raise
def map_devices(self, hostnames: Optional[Iterable[str]] = None) -> Dict[str, SiteDeviceMapping]:
return self._repo.map_devices(hostnames)
def assign_devices(self, site_id: int, hostnames: Iterable[str]) -> None:
try:
numeric_id = int(site_id)
except Exception as exc:
raise ValueError("invalid_site_id") from exc
normalized = [hn for hn in hostnames if isinstance(hn, str) and hn.strip()]
if not normalized:
raise ValueError("invalid_hostnames")
try:
self._repo.assign_devices(numeric_id, normalized)
except LookupError as exc:
if str(exc) == "not_found":
raise LookupError("not_found") from exc
raise

View File

@@ -2,6 +2,8 @@
from __future__ import annotations
from pathlib import Path
from Data.Engine.config.environment import load_environment
@@ -42,3 +44,48 @@ def test_static_root_env_override(tmp_path, monkeypatch):
monkeypatch.delenv("BOREALIS_STATIC_ROOT", raising=False)
monkeypatch.delenv("BOREALIS_ROOT", raising=False)
def test_static_root_falls_back_to_legacy_source(tmp_path, monkeypatch):
"""Legacy WebUI source should be served when no build assets exist."""
legacy_source = tmp_path / "Data" / "Server" / "WebUI"
legacy_source.mkdir(parents=True)
(legacy_source / "index.html").write_text("<html></html>", encoding="utf-8")
monkeypatch.setenv("BOREALIS_ROOT", str(tmp_path))
monkeypatch.delenv("BOREALIS_STATIC_ROOT", raising=False)
settings = load_environment()
assert settings.flask.static_root == legacy_source.resolve()
monkeypatch.delenv("BOREALIS_ROOT", raising=False)
def test_static_root_considers_runtime_copy(tmp_path, monkeypatch):
"""Runtime Server/WebUI copies should be considered when Data assets are missing."""
runtime_source = tmp_path / "Server" / "WebUI"
runtime_source.mkdir(parents=True)
(runtime_source / "index.html").write_text("runtime", encoding="utf-8")
monkeypatch.setenv("BOREALIS_ROOT", str(tmp_path))
monkeypatch.delenv("BOREALIS_STATIC_ROOT", raising=False)
settings = load_environment()
assert settings.flask.static_root == runtime_source.resolve()
monkeypatch.delenv("BOREALIS_ROOT", raising=False)
def test_resolve_project_root_defaults_to_repository(monkeypatch):
"""The project root should resolve to the repository checkout."""
monkeypatch.delenv("BOREALIS_ROOT", raising=False)
from Data.Engine.config import environment as env_module
expected = Path(env_module.__file__).resolve().parents[3]
assert env_module._resolve_project_root() == expected

View File

@@ -0,0 +1,122 @@
import base64
import sqlite3
from datetime import datetime, timezone
import pytest
from Data.Engine.repositories.sqlite import connection as sqlite_connection
from Data.Engine.repositories.sqlite import migrations as sqlite_migrations
from Data.Engine.repositories.sqlite.enrollment_repository import SQLiteEnrollmentRepository
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
from Data.Engine.services.enrollment.admin_service import EnrollmentAdminService
def _build_service(tmp_path):
db_path = tmp_path / "admin.db"
conn = sqlite3.connect(db_path)
sqlite_migrations.apply_all(conn)
conn.close()
factory = sqlite_connection.connection_factory(db_path)
enrollment_repo = SQLiteEnrollmentRepository(factory)
user_repo = SQLiteUserRepository(factory)
fixed_now = datetime(2024, 1, 1, tzinfo=timezone.utc)
service = EnrollmentAdminService(
repository=enrollment_repo,
user_repository=user_repo,
clock=lambda: fixed_now,
)
return service, factory, fixed_now
def test_create_and_list_install_codes(tmp_path):
service, factory, fixed_now = _build_service(tmp_path)
record = service.create_install_code(ttl_hours=3, max_uses=5, created_by="admin")
assert record.code
assert record.max_uses == 5
assert record.status(now=fixed_now) == "active"
records = service.list_install_codes()
assert any(r.record_id == record.record_id for r in records)
# Invalid TTL should raise
with pytest.raises(ValueError):
service.create_install_code(ttl_hours=2, max_uses=1, created_by=None)
# Deleting should succeed and remove the record
assert service.delete_install_code(record.record_id) is True
remaining = service.list_install_codes()
assert all(r.record_id != record.record_id for r in remaining)
def test_list_device_approvals_includes_conflict(tmp_path):
service, factory, fixed_now = _build_service(tmp_path)
conn = factory()
cur = conn.cursor()
cur.execute(
"INSERT INTO sites (name, description, created_at) VALUES (?, ?, ?)",
("HQ", "Primary site", int(fixed_now.timestamp())),
)
site_id = cur.lastrowid
cur.execute(
"""
INSERT INTO devices (guid, hostname, created_at, last_seen, ssl_key_fingerprint, status)
VALUES (?, ?, ?, ?, ?, 'active')
""",
("11111111-1111-1111-1111-111111111111", "agent-one", int(fixed_now.timestamp()), int(fixed_now.timestamp()), "abc123",),
)
cur.execute(
"INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)",
("agent-one", site_id, int(fixed_now.timestamp())),
)
now_iso = fixed_now.isoformat()
cur.execute(
"""
INSERT INTO device_approvals (
id,
approval_reference,
guid,
hostname_claimed,
ssl_key_fingerprint_claimed,
enrollment_code_id,
status,
client_nonce,
server_nonce,
created_at,
updated_at,
approved_by_user_id,
agent_pubkey_der
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
"approval-1",
"REF123",
None,
"agent-one",
"abc123",
"code-1",
"pending",
base64.b64encode(b"client").decode(),
base64.b64encode(b"server").decode(),
now_iso,
now_iso,
None,
b"pubkey",
),
)
conn.commit()
conn.close()
approvals = service.list_device_approvals()
assert len(approvals) == 1
record = approvals[0]
assert record.hostname_conflict is not None
assert record.hostname_conflict.fingerprint_match is True
assert record.conflict_requires_prompt is False

View File

@@ -0,0 +1,353 @@
import base64
import sqlite3
from datetime import datetime, timezone
from .test_http_auth import _login, prepared_app
def test_enrollment_codes_require_authentication(prepared_app):
client = prepared_app.test_client()
resp = client.get("/api/admin/enrollment-codes")
assert resp.status_code == 401
def test_enrollment_code_workflow(prepared_app):
client = prepared_app.test_client()
_login(client)
payload = {"ttl_hours": 3, "max_uses": 4}
resp = client.post("/api/admin/enrollment-codes", json=payload)
assert resp.status_code == 201
created = resp.get_json()
assert created["max_uses"] == 4
assert created["status"] == "active"
resp = client.get("/api/admin/enrollment-codes")
assert resp.status_code == 200
codes = resp.get_json().get("codes", [])
assert any(code["id"] == created["id"] for code in codes)
resp = client.delete(f"/api/admin/enrollment-codes/{created['id']}")
assert resp.status_code == 200
def test_device_approvals_listing(prepared_app, engine_settings):
client = prepared_app.test_client()
_login(client)
conn = sqlite3.connect(engine_settings.database.path)
cur = conn.cursor()
now = datetime.now(tz=timezone.utc)
cur.execute(
"INSERT INTO sites (name, description, created_at) VALUES (?, ?, ?)",
("HQ", "Primary", int(now.timestamp())),
)
site_id = cur.lastrowid
cur.execute(
"""
INSERT INTO devices (guid, hostname, created_at, last_seen, ssl_key_fingerprint, status)
VALUES (?, ?, ?, ?, ?, 'active')
""",
(
"22222222-2222-2222-2222-222222222222",
"approval-host",
int(now.timestamp()),
int(now.timestamp()),
"deadbeef",
),
)
cur.execute(
"INSERT INTO device_sites (device_hostname, site_id, assigned_at) VALUES (?, ?, ?)",
("approval-host", site_id, int(now.timestamp())),
)
now_iso = now.isoformat()
cur.execute(
"""
INSERT INTO device_approvals (
id,
approval_reference,
guid,
hostname_claimed,
ssl_key_fingerprint_claimed,
enrollment_code_id,
status,
client_nonce,
server_nonce,
created_at,
updated_at,
approved_by_user_id,
agent_pubkey_der
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
"approval-http",
"REFHTTP",
None,
"approval-host",
"deadbeef",
"code-http",
"pending",
base64.b64encode(b"client").decode(),
base64.b64encode(b"server").decode(),
now_iso,
now_iso,
None,
b"pub",
),
)
conn.commit()
conn.close()
resp = client.get("/api/admin/device-approvals")
assert resp.status_code == 200
body = resp.get_json()
approvals = body.get("approvals", [])
assert any(a["id"] == "approval-http" for a in approvals)
record = next(a for a in approvals if a["id"] == "approval-http")
assert record.get("hostname_conflict", {}).get("fingerprint_match") is True
def test_device_approval_requires_resolution(prepared_app, engine_settings):
client = prepared_app.test_client()
_login(client)
now = datetime.now(tz=timezone.utc)
conn = sqlite3.connect(engine_settings.database.path)
cur = conn.cursor()
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
status
) VALUES (?, ?, ?, ?, ?, 'active')
""",
(
"33333333-3333-3333-3333-333333333333",
"conflict-host",
int(now.timestamp()),
int(now.timestamp()),
"existingfp",
),
)
now_iso = now.isoformat()
cur.execute(
"""
INSERT INTO device_approvals (
id,
approval_reference,
guid,
hostname_claimed,
ssl_key_fingerprint_claimed,
enrollment_code_id,
status,
client_nonce,
server_nonce,
created_at,
updated_at,
approved_by_user_id,
agent_pubkey_der
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
"approval-conflict",
"REF-CONFLICT",
None,
"conflict-host",
"newfinger",
"code-conflict",
"pending",
base64.b64encode(b"client").decode(),
base64.b64encode(b"server").decode(),
now_iso,
now_iso,
None,
b"pub",
),
)
conn.commit()
conn.close()
resp = client.post("/api/admin/device-approvals/approval-conflict/approve", json={})
assert resp.status_code == 409
assert resp.get_json().get("error") == "conflict_resolution_required"
resp = client.post(
"/api/admin/device-approvals/approval-conflict/approve",
json={"conflict_resolution": "overwrite"},
)
assert resp.status_code == 200
body = resp.get_json()
assert body == {"status": "approved", "conflict_resolution": "overwrite"}
conn = sqlite3.connect(engine_settings.database.path)
cur = conn.cursor()
cur.execute(
"SELECT status, guid, approved_by_user_id FROM device_approvals WHERE id = ?",
("approval-conflict",),
)
row = cur.fetchone()
conn.close()
assert row[0] == "approved"
assert row[1] == "33333333-3333-3333-3333-333333333333"
assert row[2]
resp = client.post(
"/api/admin/device-approvals/approval-conflict/approve",
json={"conflict_resolution": "overwrite"},
)
assert resp.status_code == 409
assert resp.get_json().get("error") == "approval_not_pending"
def test_device_approval_auto_merge(prepared_app, engine_settings):
client = prepared_app.test_client()
_login(client)
now = datetime.now(tz=timezone.utc)
conn = sqlite3.connect(engine_settings.database.path)
cur = conn.cursor()
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
status
) VALUES (?, ?, ?, ?, ?, 'active')
""",
(
"44444444-4444-4444-4444-444444444444",
"merge-host",
int(now.timestamp()),
int(now.timestamp()),
"deadbeef",
),
)
now_iso = now.isoformat()
cur.execute(
"""
INSERT INTO device_approvals (
id,
approval_reference,
guid,
hostname_claimed,
ssl_key_fingerprint_claimed,
enrollment_code_id,
status,
client_nonce,
server_nonce,
created_at,
updated_at,
approved_by_user_id,
agent_pubkey_der
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
"approval-merge",
"REF-MERGE",
None,
"merge-host",
"deadbeef",
"code-merge",
"pending",
base64.b64encode(b"client").decode(),
base64.b64encode(b"server").decode(),
now_iso,
now_iso,
None,
b"pub",
),
)
conn.commit()
conn.close()
resp = client.post("/api/admin/device-approvals/approval-merge/approve", json={})
assert resp.status_code == 200
body = resp.get_json()
assert body.get("status") == "approved"
assert body.get("conflict_resolution") == "auto_merge_fingerprint"
conn = sqlite3.connect(engine_settings.database.path)
cur = conn.cursor()
cur.execute(
"SELECT guid, status FROM device_approvals WHERE id = ?",
("approval-merge",),
)
row = cur.fetchone()
conn.close()
assert row[1] == "approved"
assert row[0] == "44444444-4444-4444-4444-444444444444"
def test_device_approval_deny(prepared_app, engine_settings):
client = prepared_app.test_client()
_login(client)
now = datetime.now(tz=timezone.utc)
conn = sqlite3.connect(engine_settings.database.path)
cur = conn.cursor()
now_iso = now.isoformat()
cur.execute(
"""
INSERT INTO device_approvals (
id,
approval_reference,
guid,
hostname_claimed,
ssl_key_fingerprint_claimed,
enrollment_code_id,
status,
client_nonce,
server_nonce,
created_at,
updated_at,
approved_by_user_id,
agent_pubkey_der
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
"approval-deny",
"REF-DENY",
None,
"deny-host",
"cafebabe",
"code-deny",
"pending",
base64.b64encode(b"client").decode(),
base64.b64encode(b"server").decode(),
now_iso,
now_iso,
None,
b"pub",
),
)
conn.commit()
conn.close()
resp = client.post("/api/admin/device-approvals/approval-deny/deny", json={})
assert resp.status_code == 200
assert resp.get_json() == {"status": "denied"}
conn = sqlite3.connect(engine_settings.database.path)
cur = conn.cursor()
cur.execute(
"SELECT status FROM device_approvals WHERE id = ?",
("approval-deny",),
)
row = cur.fetchone()
conn.close()
assert row[0] == "denied"

View File

@@ -0,0 +1,385 @@
import pytest
pytest.importorskip("jwt")
import json
import sqlite3
import time
from datetime import datetime, timezone
from pathlib import Path
from Data.Engine.config.environment import (
DatabaseSettings,
EngineSettings,
FlaskSettings,
GitHubSettings,
ServerSettings,
SocketIOSettings,
)
from Data.Engine.domain.device_auth import (
AccessTokenClaims,
DeviceAuthContext,
DeviceFingerprint,
DeviceGuid,
DeviceIdentity,
DeviceStatus,
)
from Data.Engine.interfaces.http import register_http_interfaces
from Data.Engine.repositories.sqlite import connection as sqlite_connection
from Data.Engine.repositories.sqlite import migrations as sqlite_migrations
from Data.Engine.server import create_app
from Data.Engine.services.container import build_service_container
@pytest.fixture()
def engine_settings(tmp_path: Path) -> EngineSettings:
project_root = tmp_path
static_root = project_root / "static"
static_root.mkdir()
(static_root / "index.html").write_text("<html></html>", encoding="utf-8")
database_path = project_root / "database.db"
return EngineSettings(
project_root=project_root,
debug=False,
database=DatabaseSettings(path=database_path, apply_migrations=False),
flask=FlaskSettings(
secret_key="test-key",
static_root=static_root,
cors_allowed_origins=("https://localhost",),
),
socketio=SocketIOSettings(cors_allowed_origins=("https://localhost",)),
server=ServerSettings(host="127.0.0.1", port=5000),
github=GitHubSettings(
default_repo="owner/repo",
default_branch="main",
refresh_interval_seconds=60,
cache_root=project_root / "cache",
),
)
@pytest.fixture()
def prepared_app(engine_settings: EngineSettings):
settings = engine_settings
settings.github.cache_root.mkdir(exist_ok=True, parents=True)
db_factory = sqlite_connection.connection_factory(settings.database.path)
with sqlite_connection.connection_scope(settings.database.path) as conn:
sqlite_migrations.apply_all(conn)
app = create_app(settings, db_factory=db_factory)
services = build_service_container(settings, db_factory=db_factory)
app.extensions["engine_services"] = services
register_http_interfaces(app, services)
app.config.update(TESTING=True)
return app
def _insert_device(app, guid: str, fingerprint: str, hostname: str) -> None:
db_path = Path(app.config["ENGINE_DATABASE_PATH"])
now = int(time.time())
with sqlite3.connect(db_path) as conn:
conn.execute(
"""
INSERT INTO devices (
guid,
hostname,
created_at,
last_seen,
ssl_key_fingerprint,
token_version,
status,
key_added_at
) VALUES (?, ?, ?, ?, ?, ?, 'active', ?)
""",
(
guid,
hostname,
now,
now,
fingerprint.lower(),
1,
datetime.now(timezone.utc).isoformat(),
),
)
conn.commit()
def _build_context(guid: str, fingerprint: str, *, status: DeviceStatus = DeviceStatus.ACTIVE) -> DeviceAuthContext:
now = int(time.time())
claims = AccessTokenClaims(
subject="device",
guid=DeviceGuid(guid),
fingerprint=DeviceFingerprint(fingerprint),
token_version=1,
issued_at=now,
not_before=now,
expires_at=now + 600,
raw={"sub": "device"},
)
identity = DeviceIdentity(DeviceGuid(guid), DeviceFingerprint(fingerprint))
return DeviceAuthContext(
identity=identity,
access_token="token",
claims=claims,
status=status,
service_context="SYSTEM",
)
def test_heartbeat_updates_device(prepared_app, monkeypatch):
client = prepared_app.test_client()
guid = "DE305D54-75B4-431B-ADB2-EB6B9E546014"
fingerprint = "aa:bb:cc"
hostname = "device-heartbeat"
_insert_device(prepared_app, guid, fingerprint, hostname)
services = prepared_app.extensions["engine_services"]
context = _build_context(guid, fingerprint)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context)
payload = {
"hostname": hostname,
"inventory": {"memory": [{"total": "16GB"}], "cpu": {"cores": 8}},
"metrics": {"operating_system": "Windows", "last_user": "Admin", "uptime": 120},
"external_ip": "1.2.3.4",
}
start = int(time.time())
resp = client.post(
"/api/agent/heartbeat",
json=payload,
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 200
body = resp.get_json()
assert body == {"status": "ok", "poll_after_ms": 15000}
db_path = Path(prepared_app.config["ENGINE_DATABASE_PATH"])
with sqlite3.connect(db_path) as conn:
row = conn.execute(
"SELECT last_seen, external_ip, memory, cpu FROM devices WHERE guid = ?",
(guid,),
).fetchone()
assert row is not None
last_seen, external_ip, memory_json, cpu_json = row
assert last_seen >= start
assert external_ip == "1.2.3.4"
assert json.loads(memory_json)[0]["total"] == "16GB"
assert json.loads(cpu_json)["cores"] == 8
def test_heartbeat_returns_404_when_device_missing(prepared_app, monkeypatch):
client = prepared_app.test_client()
guid = "9E295C27-8339-40C8-AD1A-6ED95C164A4A"
fingerprint = "11:22:33"
services = prepared_app.extensions["engine_services"]
context = _build_context(guid, fingerprint)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context)
resp = client.post(
"/api/agent/heartbeat",
json={"hostname": "missing-device"},
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 404
assert resp.get_json() == {"error": "device_not_registered"}
def test_script_request_reports_status_and_signing_key(prepared_app, monkeypatch):
client = prepared_app.test_client()
guid = "2F8D76C0-38D4-4700-B247-3E90C03A67D7"
fingerprint = "44:55:66"
hostname = "device-script"
_insert_device(prepared_app, guid, fingerprint, hostname)
services = prepared_app.extensions["engine_services"]
context = _build_context(guid, fingerprint)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context)
class DummySigner:
def public_base64_spki(self) -> str:
return "PUBKEY"
object.__setattr__(services, "script_signer", DummySigner())
resp = client.post(
"/api/agent/script/request",
json={"guid": guid},
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 200
body = resp.get_json()
assert body == {
"status": "idle",
"poll_after_ms": 30000,
"sig_alg": "ed25519",
"signing_key": "PUBKEY",
}
quarantined_context = _build_context(guid, fingerprint, status=DeviceStatus.QUARANTINED)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: quarantined_context)
resp = client.post(
"/api/agent/script/request",
json={},
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 200
assert resp.get_json()["status"] == "quarantined"
assert resp.get_json()["poll_after_ms"] == 60000
def test_agent_details_persists_inventory(prepared_app, monkeypatch):
client = prepared_app.test_client()
guid = "5C9D76E4-4C5A-4A5D-9B5D-1C2E3F4A5B6C"
fingerprint = "aa:bb:cc:dd"
hostname = "device-details"
_insert_device(prepared_app, guid, fingerprint, hostname)
services = prepared_app.extensions["engine_services"]
context = _build_context(guid, fingerprint)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context)
payload = {
"hostname": hostname,
"agent_id": "AGENT-01",
"agent_hash": "hash-value",
"details": {
"summary": {
"hostname": hostname,
"device_type": "Laptop",
"last_user": "BUNNY-LAB\\nicole.rappe",
"operating_system": "Windows 11",
"description": "Primary workstation",
"last_reboot": "2025-10-01 10:00:00",
"uptime": 3600,
},
"memory": [{"slot": "DIMM0", "capacity": 17179869184}],
"storage": [{"model": "NVMe", "size": 512}],
"network": [{"adapter": "Ethernet", "ips": ["192.168.1.50"]}],
"software": [{"name": "Borealis Agent", "version": "2.0"}],
"cpu": {"name": "Intel Core i7", "logical_cores": 8, "base_clock_ghz": 3.4},
},
}
resp = client.post(
"/api/agent/details",
json=payload,
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 200
assert resp.get_json() == {"status": "ok"}
db_path = Path(prepared_app.config["ENGINE_DATABASE_PATH"])
with sqlite3.connect(db_path) as conn:
row = conn.execute(
"""
SELECT device_type, last_user, memory, storage, network, description
FROM devices
WHERE guid = ?
""",
(guid,),
).fetchone()
assert row is not None
device_type, last_user, memory_json, storage_json, network_json, description = row
assert device_type == "Laptop"
assert last_user == "BUNNY-LAB\\nicole.rappe"
assert description == "Primary workstation"
assert json.loads(memory_json)[0]["capacity"] == 17179869184
assert json.loads(storage_json)[0]["model"] == "NVMe"
assert json.loads(network_json)[0]["ips"][0] == "192.168.1.50"
resp = client.get("/api/devices")
assert resp.status_code == 200
listing = resp.get_json()
device = next((dev for dev in listing.get("devices", []) if dev["hostname"] == hostname), None)
assert device is not None
summary = device["summary"]
details = device["details"]
assert summary["device_type"] == "Laptop"
assert summary["last_user"] == "BUNNY-LAB\\nicole.rappe"
assert summary["created"]
assert summary.get("uptime_sec") == 3600
assert details["summary"]["device_type"] == "Laptop"
assert details["summary"]["last_reboot"] == "2025-10-01 10:00:00"
assert details["summary"]["created"] == summary["created"]
assert details["software"][0]["name"] == "Borealis Agent"
assert device["storage"][0]["model"] == "NVMe"
assert device["memory"][0]["capacity"] == 17179869184
assert device["cpu"]["name"] == "Intel Core i7"
def test_heartbeat_preserves_last_user_from_details(prepared_app, monkeypatch):
client = prepared_app.test_client()
guid = "7E8F90A1-B2C3-4D5E-8F90-A1B2C3D4E5F6"
fingerprint = "11:22:33:44"
hostname = "device-preserve"
_insert_device(prepared_app, guid, fingerprint, hostname)
services = prepared_app.extensions["engine_services"]
context = _build_context(guid, fingerprint)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context)
client.post(
"/api/agent/details",
json={
"hostname": hostname,
"details": {
"summary": {"hostname": hostname, "last_user": "BUNNY-LAB\\nicole.rappe"}
},
},
headers={"Authorization": "Bearer token"},
)
client.post(
"/api/agent/heartbeat",
json={"hostname": hostname, "metrics": {"uptime": 120}},
headers={"Authorization": "Bearer token"},
)
db_path = Path(prepared_app.config["ENGINE_DATABASE_PATH"])
with sqlite3.connect(db_path) as conn:
row = conn.execute(
"SELECT last_user FROM devices WHERE guid = ?",
(guid,),
).fetchone()
assert row is not None
assert row[0] == "BUNNY-LAB\\nicole.rappe"
def test_heartbeat_uses_username_when_last_user_missing(prepared_app, monkeypatch):
client = prepared_app.test_client()
guid = "802A4E5F-1B2C-4D5E-8F90-A1B2C3D4E5F7"
fingerprint = "55:66:77:88"
hostname = "device-username"
_insert_device(prepared_app, guid, fingerprint, hostname)
services = prepared_app.extensions["engine_services"]
context = _build_context(guid, fingerprint)
monkeypatch.setattr(services.device_auth, "authenticate", lambda request, path: context)
resp = client.post(
"/api/agent/heartbeat",
json={"hostname": hostname, "metrics": {"username": "BUNNY-LAB\\alice.smith"}},
headers={"Authorization": "Bearer token"},
)
assert resp.status_code == 200
db_path = Path(prepared_app.config["ENGINE_DATABASE_PATH"])
with sqlite3.connect(db_path) as conn:
row = conn.execute(
"SELECT last_user FROM devices WHERE guid = ?",
(guid,),
).fetchone()
assert row is not None
assert row[0] == "BUNNY-LAB\\alice.smith"

View File

@@ -0,0 +1,86 @@
import pytest
pytest.importorskip("flask")
from .test_http_auth import _login, prepared_app
def test_assembly_crud_flow(prepared_app, engine_settings):
client = prepared_app.test_client()
_login(client)
resp = client.post(
"/api/assembly/create",
json={"island": "scripts", "kind": "folder", "path": "Utilities"},
)
assert resp.status_code == 200
resp = client.post(
"/api/assembly/create",
json={
"island": "scripts",
"kind": "file",
"path": "Utilities/sample",
"content": {"name": "Sample", "script": "Write-Output 'Hello'", "type": "powershell"},
},
)
assert resp.status_code == 200
body = resp.get_json()
rel_path = body.get("rel_path")
assert rel_path and rel_path.endswith(".json")
resp = client.get("/api/assembly/list?island=scripts")
assert resp.status_code == 200
listing = resp.get_json()
assert any(item["rel_path"] == rel_path for item in listing.get("items", []))
resp = client.get(f"/api/assembly/load?island=scripts&path={rel_path}")
assert resp.status_code == 200
loaded = resp.get_json()
assert loaded.get("assembly", {}).get("name") == "Sample"
resp = client.post(
"/api/assembly/rename",
json={
"island": "scripts",
"kind": "file",
"path": rel_path,
"new_name": "renamed",
},
)
assert resp.status_code == 200
renamed_rel = resp.get_json().get("rel_path")
assert renamed_rel and renamed_rel.endswith(".json")
resp = client.post(
"/api/assembly/move",
json={
"island": "scripts",
"path": renamed_rel,
"new_path": "Utilities/Nested/renamed.json",
"kind": "file",
},
)
assert resp.status_code == 200
resp = client.post(
"/api/assembly/delete",
json={
"island": "scripts",
"path": "Utilities/Nested/renamed.json",
"kind": "file",
},
)
assert resp.status_code == 200
resp = client.get("/api/assembly/list?island=scripts")
remaining = resp.get_json().get("items", [])
assert all(item["rel_path"] != "Utilities/Nested/renamed.json" for item in remaining)
def test_server_time_endpoint(prepared_app):
client = prepared_app.test_client()
resp = client.get("/api/server/time")
assert resp.status_code == 200
body = resp.get_json()
assert set(["epoch", "iso", "utc_iso", "timezone", "offset_seconds", "display"]).issubset(body)

View File

@@ -0,0 +1,121 @@
import hashlib
from pathlib import Path
import pytest
pytest.importorskip("flask")
pytest.importorskip("jwt")
from Data.Engine.config.environment import (
DatabaseSettings,
EngineSettings,
FlaskSettings,
GitHubSettings,
ServerSettings,
SocketIOSettings,
)
from Data.Engine.interfaces.http import register_http_interfaces
from Data.Engine.repositories.sqlite import connection as sqlite_connection
from Data.Engine.repositories.sqlite import migrations as sqlite_migrations
from Data.Engine.server import create_app
from Data.Engine.services.container import build_service_container
@pytest.fixture()
def engine_settings(tmp_path: Path) -> EngineSettings:
project_root = tmp_path
static_root = project_root / "static"
static_root.mkdir()
(static_root / "index.html").write_text("<html></html>", encoding="utf-8")
database_path = project_root / "database.db"
return EngineSettings(
project_root=project_root,
debug=False,
database=DatabaseSettings(path=database_path, apply_migrations=False),
flask=FlaskSettings(
secret_key="test-key",
static_root=static_root,
cors_allowed_origins=("https://localhost",),
),
socketio=SocketIOSettings(cors_allowed_origins=("https://localhost",)),
server=ServerSettings(host="127.0.0.1", port=5000),
github=GitHubSettings(
default_repo="owner/repo",
default_branch="main",
refresh_interval_seconds=60,
cache_root=project_root / "cache",
),
)
@pytest.fixture()
def prepared_app(engine_settings: EngineSettings):
settings = engine_settings
settings.github.cache_root.mkdir(exist_ok=True, parents=True)
db_factory = sqlite_connection.connection_factory(settings.database.path)
with sqlite_connection.connection_scope(settings.database.path) as conn:
sqlite_migrations.apply_all(conn)
app = create_app(settings, db_factory=db_factory)
services = build_service_container(settings, db_factory=db_factory)
app.extensions["engine_services"] = services
register_http_interfaces(app, services)
app.config.update(TESTING=True)
return app
def _login(client) -> dict:
payload = {
"username": "admin",
"password_sha512": hashlib.sha512("Password".encode()).hexdigest(),
}
resp = client.post("/api/auth/login", json=payload)
assert resp.status_code == 200
data = resp.get_json()
assert isinstance(data, dict)
return data
def test_auth_me_returns_session_user(prepared_app):
client = prepared_app.test_client()
_login(client)
resp = client.get("/api/auth/me")
assert resp.status_code == 200
body = resp.get_json()
assert body == {
"username": "admin",
"display_name": "admin",
"role": "Admin",
}
def test_auth_me_uses_token_when_session_missing(prepared_app):
client = prepared_app.test_client()
login_data = _login(client)
token = login_data.get("token")
assert token
# New client without session
other_client = prepared_app.test_client()
other_client.set_cookie(server_name="localhost", key="borealis_auth", value=token)
resp = other_client.get("/api/auth/me")
assert resp.status_code == 200
body = resp.get_json()
assert body == {
"username": "admin",
"display_name": "admin",
"role": "Admin",
}
def test_auth_me_requires_authentication(prepared_app):
client = prepared_app.test_client()
resp = client.get("/api/auth/me")
assert resp.status_code == 401
body = resp.get_json()
assert body == {"error": "not_authenticated"}

View File

@@ -0,0 +1,151 @@
from datetime import datetime, timezone
import sqlite3
import time
import pytest
pytest.importorskip("flask")
from .test_http_auth import _login, prepared_app, engine_settings
def _ensure_admin_session(client):
_login(client)
def test_sites_crud_flow(prepared_app):
client = prepared_app.test_client()
_ensure_admin_session(client)
resp = client.get("/api/sites")
assert resp.status_code == 200
assert resp.get_json() == {"sites": []}
create = client.post("/api/sites", json={"name": "HQ", "description": "Primary"})
assert create.status_code == 201
created = create.get_json()
assert created["name"] == "HQ"
listing = client.get("/api/sites")
sites = listing.get_json()["sites"]
assert len(sites) == 1
resp = client.post("/api/sites/assign", json={"site_id": created["id"], "hostnames": ["device-1"]})
assert resp.status_code == 200
mapping = client.get("/api/sites/device_map?hostnames=device-1")
data = mapping.get_json()["mapping"]
assert data["device-1"]["site_id"] == created["id"]
rename = client.post("/api/sites/rename", json={"id": created["id"], "new_name": "Main"})
assert rename.status_code == 200
assert rename.get_json()["name"] == "Main"
delete = client.post("/api/sites/delete", json={"ids": [created["id"]]})
assert delete.status_code == 200
assert delete.get_json()["deleted"] == 1
def test_devices_listing(prepared_app, engine_settings):
client = prepared_app.test_client()
_ensure_admin_session(client)
now = datetime.now(tz=timezone.utc)
conn = sqlite3.connect(engine_settings.database.path)
cur = conn.cursor()
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
description,
created_at,
agent_hash,
last_seen,
connection_type,
connection_endpoint
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
"11111111-1111-1111-1111-111111111111",
"test-device",
"Test Device",
int(now.timestamp()),
"hashvalue",
int(now.timestamp()),
"",
"",
),
)
conn.commit()
conn.close()
resp = client.get("/api/devices")
assert resp.status_code == 200
devices = resp.get_json()["devices"]
assert any(device["hostname"] == "test-device" for device in devices)
def test_agent_hash_list_requires_local_request(prepared_app):
client = prepared_app.test_client()
_ensure_admin_session(client)
resp = client.get("/api/agent/hash_list", environ_overrides={"REMOTE_ADDR": "203.0.113.5"})
assert resp.status_code == 403
resp = client.get("/api/agent/hash_list", environ_overrides={"REMOTE_ADDR": "127.0.0.1"})
assert resp.status_code == 200
assert resp.get_json() == {"agents": []}
def test_credentials_list_requires_admin(prepared_app):
client = prepared_app.test_client()
resp = client.get("/api/credentials")
assert resp.status_code == 401
_ensure_admin_session(client)
resp = client.get("/api/credentials")
assert resp.status_code == 200
assert resp.get_json() == {"credentials": []}
def test_device_description_update(prepared_app, engine_settings):
client = prepared_app.test_client()
hostname = "device-desc"
guid = "A3D3F1E5-9B8C-4C6F-80F1-4D5E6F7A8B9C"
now = int(time.time())
conn = sqlite3.connect(engine_settings.database.path)
cur = conn.cursor()
cur.execute(
"""
INSERT INTO devices (
guid,
hostname,
description,
created_at,
last_seen
) VALUES (?, ?, '', ?, ?)
""",
(guid, hostname, now, now),
)
conn.commit()
conn.close()
resp = client.post(
f"/api/device/description/{hostname}",
json={"description": "Primary workstation"},
)
assert resp.status_code == 200
assert resp.get_json() == {"status": "ok"}
conn = sqlite3.connect(engine_settings.database.path)
row = conn.execute(
"SELECT description FROM devices WHERE hostname = ?",
(hostname,),
).fetchone()
conn.close()
assert row is not None
assert row[0] == "Primary workstation"

View File

@@ -0,0 +1,120 @@
"""HTTP integration tests for operator account endpoints."""
from __future__ import annotations
import hashlib
from .test_http_auth import _login, prepared_app
def test_list_users_requires_authentication(prepared_app):
client = prepared_app.test_client()
resp = client.get("/api/users")
assert resp.status_code == 401
def test_list_users_returns_accounts(prepared_app):
client = prepared_app.test_client()
_login(client)
resp = client.get("/api/users")
assert resp.status_code == 200
payload = resp.get_json()
assert isinstance(payload, dict)
assert "users" in payload
assert any(user["username"] == "admin" for user in payload["users"])
def test_create_user_validates_payload(prepared_app):
client = prepared_app.test_client()
_login(client)
resp = client.post("/api/users", json={"username": "bob"})
assert resp.status_code == 400
payload = {
"username": "bob",
"password_sha512": hashlib.sha512(b"pw").hexdigest(),
"role": "User",
}
resp = client.post("/api/users", json=payload)
assert resp.status_code == 200
# Duplicate username should conflict
resp = client.post("/api/users", json=payload)
assert resp.status_code == 409
def test_delete_user_handles_edge_cases(prepared_app):
client = prepared_app.test_client()
_login(client)
# cannot delete the only user
resp = client.delete("/api/users/admin")
assert resp.status_code == 400
# create another user then delete them successfully
payload = {
"username": "alice",
"password_sha512": hashlib.sha512(b"pw").hexdigest(),
"role": "User",
}
client.post("/api/users", json=payload)
resp = client.delete("/api/users/alice")
assert resp.status_code == 200
def test_delete_user_prevents_self_deletion(prepared_app):
client = prepared_app.test_client()
_login(client)
payload = {
"username": "charlie",
"password_sha512": hashlib.sha512(b"pw").hexdigest(),
"role": "User",
}
client.post("/api/users", json=payload)
resp = client.delete("/api/users/admin")
assert resp.status_code == 400
def test_change_role_updates_session(prepared_app):
client = prepared_app.test_client()
_login(client)
payload = {
"username": "backup",
"password_sha512": hashlib.sha512(b"pw").hexdigest(),
"role": "Admin",
}
client.post("/api/users", json=payload)
resp = client.post("/api/users/backup/role", json={"role": "User"})
assert resp.status_code == 200
resp = client.post("/api/users/admin/role", json={"role": "User"})
assert resp.status_code == 400
def test_reset_password_requires_valid_hash(prepared_app):
client = prepared_app.test_client()
_login(client)
resp = client.post("/api/users/admin/reset_password", json={"password_sha512": "abc"})
assert resp.status_code == 400
resp = client.post(
"/api/users/admin/reset_password",
json={"password_sha512": hashlib.sha512(b"new").hexdigest()},
)
assert resp.status_code == 200
def test_update_mfa_returns_not_found_for_unknown_user(prepared_app):
client = prepared_app.test_client()
_login(client)
resp = client.post("/api/users/missing/mfa", json={"enabled": True})
assert resp.status_code == 404

View File

@@ -0,0 +1,191 @@
"""Tests for the operator account management service."""
from __future__ import annotations
import hashlib
import sqlite3
from pathlib import Path
from typing import Callable
import pytest
pytest.importorskip("jwt")
from Data.Engine.repositories.sqlite.connection import connection_factory
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
from Data.Engine.services.auth.operator_account_service import (
AccountNotFoundError,
CannotModifySelfError,
InvalidPasswordHashError,
InvalidRoleError,
LastAdminError,
LastUserError,
OperatorAccountService,
UsernameAlreadyExistsError,
)
def _prepare_db(path: Path) -> Callable[[], sqlite3.Connection]:
conn = sqlite3.connect(path)
conn.execute(
"""
CREATE TABLE users (
id TEXT PRIMARY KEY,
username TEXT UNIQUE,
display_name TEXT,
password_sha512 TEXT,
role TEXT,
last_login INTEGER,
created_at INTEGER,
updated_at INTEGER,
mfa_enabled INTEGER,
mfa_secret TEXT
)
"""
)
conn.commit()
conn.close()
return connection_factory(path)
def _insert_user(
factory: Callable[[], sqlite3.Connection],
*,
user_id: str,
username: str,
password_hash: str,
role: str = "Admin",
mfa_enabled: int = 0,
mfa_secret: str = "",
) -> None:
conn = factory()
conn.execute(
"""
INSERT INTO users (
id, username, display_name, password_sha512, role,
last_login, created_at, updated_at, mfa_enabled, mfa_secret
) VALUES (?, ?, ?, ?, ?, 0, 0, 0, ?, ?)
""",
(user_id, username, username, password_hash, role, mfa_enabled, mfa_secret),
)
conn.commit()
conn.close()
def _service(factory: Callable[[], sqlite3.Connection]) -> OperatorAccountService:
repo = SQLiteUserRepository(factory)
return OperatorAccountService(repo)
def test_list_accounts_returns_users(tmp_path):
db = tmp_path / "users.db"
factory = _prepare_db(db)
password_hash = hashlib.sha512(b"password").hexdigest()
_insert_user(factory, user_id="1", username="admin", password_hash=password_hash)
service = _service(factory)
records = service.list_accounts()
assert len(records) == 1
assert records[0].username == "admin"
assert records[0].role == "Admin"
def test_create_account_enforces_uniqueness(tmp_path):
db = tmp_path / "users.db"
factory = _prepare_db(db)
service = _service(factory)
password_hash = hashlib.sha512(b"pw").hexdigest()
service.create_account(username="admin", password_sha512=password_hash, role="Admin")
with pytest.raises(UsernameAlreadyExistsError):
service.create_account(username="admin", password_sha512=password_hash, role="Admin")
def test_create_account_validates_password_hash(tmp_path):
db = tmp_path / "users.db"
factory = _prepare_db(db)
service = _service(factory)
with pytest.raises(InvalidPasswordHashError):
service.create_account(username="user", password_sha512="abc", role="User")
def test_delete_account_protects_last_user(tmp_path):
db = tmp_path / "users.db"
factory = _prepare_db(db)
password_hash = hashlib.sha512(b"pw").hexdigest()
_insert_user(factory, user_id="1", username="admin", password_hash=password_hash)
service = _service(factory)
with pytest.raises(LastUserError):
service.delete_account("admin")
def test_delete_account_prevents_self_deletion(tmp_path):
db = tmp_path / "users.db"
factory = _prepare_db(db)
password_hash = hashlib.sha512(b"pw").hexdigest()
_insert_user(factory, user_id="1", username="admin", password_hash=password_hash)
_insert_user(factory, user_id="2", username="user", password_hash=password_hash, role="User")
service = _service(factory)
with pytest.raises(CannotModifySelfError):
service.delete_account("admin", actor="admin")
def test_delete_account_prevents_last_admin_removal(tmp_path):
db = tmp_path / "users.db"
factory = _prepare_db(db)
password_hash = hashlib.sha512(b"pw").hexdigest()
_insert_user(factory, user_id="1", username="admin", password_hash=password_hash)
_insert_user(factory, user_id="2", username="user", password_hash=password_hash, role="User")
service = _service(factory)
with pytest.raises(LastAdminError):
service.delete_account("admin")
def test_change_role_demotes_only_when_valid(tmp_path):
db = tmp_path / "users.db"
factory = _prepare_db(db)
password_hash = hashlib.sha512(b"pw").hexdigest()
_insert_user(factory, user_id="1", username="admin", password_hash=password_hash)
_insert_user(factory, user_id="2", username="backup", password_hash=password_hash)
service = _service(factory)
service.change_role("backup", "User")
with pytest.raises(LastAdminError):
service.change_role("admin", "User")
with pytest.raises(InvalidRoleError):
service.change_role("admin", "invalid")
def test_reset_password_validates_hash(tmp_path):
db = tmp_path / "users.db"
factory = _prepare_db(db)
password_hash = hashlib.sha512(b"pw").hexdigest()
_insert_user(factory, user_id="1", username="admin", password_hash=password_hash)
service = _service(factory)
with pytest.raises(InvalidPasswordHashError):
service.reset_password("admin", "abc")
new_hash = hashlib.sha512(b"new").hexdigest()
service.reset_password("admin", new_hash)
def test_update_mfa_raises_for_unknown_user(tmp_path):
db = tmp_path / "users.db"
factory = _prepare_db(db)
service = _service(factory)
with pytest.raises(AccountNotFoundError):
service.update_mfa("missing", enabled=True, reset_secret=False)

View File

@@ -0,0 +1,63 @@
"""Tests for operator authentication builders."""
from __future__ import annotations
import pytest
from Data.Engine.builders import (
OperatorLoginRequest,
OperatorMFAVerificationRequest,
build_login_request,
build_mfa_request,
)
def test_build_login_request_uses_explicit_hash():
payload = {"username": "Admin", "password_sha512": "abc123"}
result = build_login_request(payload)
assert isinstance(result, OperatorLoginRequest)
assert result.username == "Admin"
assert result.password_sha512 == "abc123"
def test_build_login_request_hashes_plain_password():
payload = {"username": "user", "password": "secret"}
result = build_login_request(payload)
assert isinstance(result, OperatorLoginRequest)
assert result.username == "user"
assert result.password_sha512
assert result.password_sha512 != "secret"
@pytest.mark.parametrize(
"payload",
[
{"password": "secret"},
{"username": ""},
{"username": "user"},
],
)
def test_build_login_request_validation(payload):
with pytest.raises(ValueError):
build_login_request(payload)
def test_build_mfa_request_normalizes_code():
payload = {"pending_token": "token", "code": "12 34-56"}
result = build_mfa_request(payload)
assert isinstance(result, OperatorMFAVerificationRequest)
assert result.pending_token == "token"
assert result.code == "123456"
def test_build_mfa_request_requires_token_and_code():
with pytest.raises(ValueError):
build_mfa_request({"code": "123"})
with pytest.raises(ValueError):
build_mfa_request({"pending_token": "token", "code": "12"})

View File

@@ -0,0 +1,197 @@
"""Tests for the operator authentication service."""
from __future__ import annotations
import hashlib
import sqlite3
from pathlib import Path
from typing import Callable
import pytest
pyotp = pytest.importorskip("pyotp")
from Data.Engine.builders import (
OperatorLoginRequest,
OperatorMFAVerificationRequest,
)
from Data.Engine.repositories.sqlite.connection import connection_factory
from Data.Engine.repositories.sqlite.user_repository import SQLiteUserRepository
from Data.Engine.services.auth.operator_auth_service import (
InvalidCredentialsError,
InvalidMFACodeError,
OperatorAuthService,
)
def _prepare_db(path: Path) -> Callable[[], sqlite3.Connection]:
conn = sqlite3.connect(path)
conn.execute(
"""
CREATE TABLE users (
id TEXT PRIMARY KEY,
username TEXT,
display_name TEXT,
password_sha512 TEXT,
role TEXT,
last_login INTEGER,
created_at INTEGER,
updated_at INTEGER,
mfa_enabled INTEGER,
mfa_secret TEXT
)
"""
)
conn.commit()
conn.close()
return connection_factory(path)
def _insert_user(
factory: Callable[[], sqlite3.Connection],
*,
user_id: str,
username: str,
password_hash: str,
role: str = "Admin",
mfa_enabled: int = 0,
mfa_secret: str = "",
) -> None:
conn = factory()
conn.execute(
"""
INSERT INTO users (
id, username, display_name, password_sha512, role,
last_login, created_at, updated_at, mfa_enabled, mfa_secret
) VALUES (?, ?, ?, ?, ?, 0, 0, 0, ?, ?)
""",
(user_id, username, username, password_hash, role, mfa_enabled, mfa_secret),
)
conn.commit()
conn.close()
def test_authenticate_success_updates_last_login(tmp_path):
db_path = tmp_path / "auth.db"
factory = _prepare_db(db_path)
password_hash = hashlib.sha512(b"password").hexdigest()
_insert_user(factory, user_id="1", username="admin", password_hash=password_hash)
repo = SQLiteUserRepository(factory)
service = OperatorAuthService(repo)
request = OperatorLoginRequest(username="admin", password_sha512=password_hash)
result = service.authenticate(request)
assert result.username == "admin"
conn = factory()
row = conn.execute("SELECT last_login FROM users WHERE username=?", ("admin",)).fetchone()
conn.close()
assert row[0] > 0
def test_authenticate_invalid_credentials(tmp_path):
db_path = tmp_path / "auth.db"
factory = _prepare_db(db_path)
repo = SQLiteUserRepository(factory)
service = OperatorAuthService(repo)
request = OperatorLoginRequest(username="missing", password_sha512="abc")
with pytest.raises(InvalidCredentialsError):
service.authenticate(request)
def test_mfa_verify_flow(tmp_path):
db_path = tmp_path / "auth.db"
factory = _prepare_db(db_path)
secret = pyotp.random_base32()
password_hash = hashlib.sha512(b"password").hexdigest()
_insert_user(
factory,
user_id="1",
username="admin",
password_hash=password_hash,
mfa_enabled=1,
mfa_secret=secret,
)
repo = SQLiteUserRepository(factory)
service = OperatorAuthService(repo)
login_request = OperatorLoginRequest(username="admin", password_sha512=password_hash)
challenge = service.authenticate(login_request)
assert challenge.stage == "verify"
totp = pyotp.TOTP(secret)
verify_request = OperatorMFAVerificationRequest(
pending_token=challenge.pending_token,
code=totp.now(),
)
result = service.verify_mfa(challenge, verify_request)
assert result.username == "admin"
def test_mfa_setup_flow_persists_secret(tmp_path):
db_path = tmp_path / "auth.db"
factory = _prepare_db(db_path)
password_hash = hashlib.sha512(b"password").hexdigest()
_insert_user(
factory,
user_id="1",
username="admin",
password_hash=password_hash,
mfa_enabled=1,
mfa_secret="",
)
repo = SQLiteUserRepository(factory)
service = OperatorAuthService(repo)
challenge = service.authenticate(OperatorLoginRequest(username="admin", password_sha512=password_hash))
assert challenge.stage == "setup"
assert challenge.secret
totp = pyotp.TOTP(challenge.secret)
verify_request = OperatorMFAVerificationRequest(
pending_token=challenge.pending_token,
code=totp.now(),
)
result = service.verify_mfa(challenge, verify_request)
assert result.username == "admin"
conn = factory()
stored_secret = conn.execute(
"SELECT mfa_secret FROM users WHERE username=?", ("admin",)
).fetchone()[0]
conn.close()
assert stored_secret
def test_mfa_invalid_code_raises(tmp_path):
db_path = tmp_path / "auth.db"
factory = _prepare_db(db_path)
secret = pyotp.random_base32()
password_hash = hashlib.sha512(b"password").hexdigest()
_insert_user(
factory,
user_id="1",
username="admin",
password_hash=password_hash,
mfa_enabled=1,
mfa_secret=secret,
)
repo = SQLiteUserRepository(factory)
service = OperatorAuthService(repo)
challenge = service.authenticate(OperatorLoginRequest(username="admin", password_sha512=password_hash))
verify_request = OperatorMFAVerificationRequest(
pending_token=challenge.pending_token,
code="000000",
)
with pytest.raises(InvalidMFACodeError):
service.verify_mfa(challenge, verify_request)

View File

@@ -1,3 +1,4 @@
import hashlib
import sqlite3
import unittest
@@ -24,6 +25,56 @@ class MigrationTests(unittest.TestCase):
self.assertIn("scheduled_jobs", tables)
self.assertIn("scheduled_job_runs", tables)
self.assertIn("github_token", tables)
self.assertIn("users", tables)
cursor.execute(
"SELECT username, role, password_sha512 FROM users WHERE LOWER(username)=LOWER(?)",
("admin",),
)
row = cursor.fetchone()
self.assertIsNotNone(row)
if row:
self.assertEqual(row[0], "admin")
self.assertEqual(row[1].lower(), "admin")
self.assertEqual(row[2], hashlib.sha512(b"Password").hexdigest())
finally:
conn.close()
def test_ensure_default_admin_promotes_existing_user(self) -> None:
conn = sqlite3.connect(":memory:")
try:
conn.execute(
"""
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
display_name TEXT,
password_sha512 TEXT,
role TEXT,
last_login INTEGER,
created_at INTEGER,
updated_at INTEGER,
mfa_enabled INTEGER DEFAULT 0,
mfa_secret TEXT
)
"""
)
conn.execute(
"INSERT INTO users (username, display_name, password_sha512, role) VALUES (?, ?, ?, ?)",
("admin", "Custom", "hash", "user"),
)
conn.commit()
migrations.ensure_default_admin(conn)
cursor = conn.cursor()
cursor.execute(
"SELECT role, password_sha512 FROM users WHERE LOWER(username)=LOWER(?)",
("admin",),
)
role, password_hash = cursor.fetchone()
self.assertEqual(role.lower(), "admin")
self.assertEqual(password_hash, "hash")
finally:
conn.close()