diff --git a/Data/Engine/Unit_Tests/test_devices_api.py b/Data/Engine/Unit_Tests/test_devices_api.py index c77cb1b0..1270c279 100644 --- a/Data/Engine/Unit_Tests/test_devices_api.py +++ b/Data/Engine/Unit_Tests/test_devices_api.py @@ -10,6 +10,7 @@ from __future__ import annotations from typing import Any import pytest +from Data.Engine.auth import jwt_service as jwt_service_module from Data.Engine.integrations import github as github_integration from Data.Engine.services.API.devices import management as device_management @@ -24,6 +25,37 @@ def _client_with_admin_session(harness: EngineTestHarness): return client +def _device_headers() -> dict: + jwt_service = jwt_service_module.load_service() + token = jwt_service.issue_access_token( + "GUID-TEST-0001", + "ff:ff:ff", + 1, + expires_in=900, + ) + return {"Authorization": f"Bearer {token}"} + + +def _patch_repo_call(monkeypatch: pytest.MonkeyPatch, calls: dict) -> None: + class DummyResponse: + def __init__(self, status_code: int, payload: Any): + self.status_code = status_code + self._payload = payload + + def json(self) -> Any: + return self._payload + + request_exception = getattr(github_integration.requests, "RequestException", RuntimeError) + + def fake_get(url: str, headers: Any, timeout: int) -> DummyResponse: + calls["count"] += 1 + if calls["count"] == 1: + return DummyResponse(200, {"commit": {"sha": "abc123"}}) + raise request_exception("network error") + + monkeypatch.setattr(github_integration.requests, "get", fake_get) + + def test_list_devices(engine_harness: EngineTestHarness) -> None: client = engine_harness.app.test_client() response = client.get("/api/devices") @@ -103,25 +135,9 @@ def test_device_list_views_lifecycle(engine_harness: EngineTestHarness) -> None: def test_repo_current_hash_uses_cache(engine_harness: EngineTestHarness, monkeypatch: pytest.MonkeyPatch) -> None: calls = {"count": 0} - class DummyResponse: - def __init__(self, status_code: int, payload: Any): - self.status_code = status_code - self._payload = payload + _patch_repo_call(monkeypatch, calls) - def json(self) -> Any: - return self._payload - - request_exception = getattr(github_integration.requests, "RequestException", RuntimeError) - - def fake_get(url: str, headers: Any, timeout: int) -> DummyResponse: - calls["count"] += 1 - if calls["count"] == 1: - return DummyResponse(200, {"commit": {"sha": "abc123"}}) - raise request_exception("network error") - - monkeypatch.setattr(github_integration.requests, "get", fake_get) - - client = engine_harness.app.test_client() + client = _client_with_admin_session(engine_harness) first = client.get("/api/repo/current_hash?repo=test/test&branch=main") assert first.status_code == 200 assert first.get_json()["sha"] == "abc123" @@ -133,6 +149,21 @@ def test_repo_current_hash_uses_cache(engine_harness: EngineTestHarness, monkeyp assert calls["count"] == 1 +def test_repo_current_hash_allows_device_token(engine_harness: EngineTestHarness, monkeypatch: pytest.MonkeyPatch) -> None: + calls = {"count": 0} + _patch_repo_call(monkeypatch, calls) + + client = engine_harness.app.test_client() + response = client.get( + "/api/repo/current_hash?repo=test/test&branch=main", + headers=_device_headers(), + ) + assert response.status_code == 200 + payload = response.get_json() + assert payload["sha"] == "abc123" + assert calls["count"] == 1 + + def test_agent_hash_list_permissions(engine_harness: EngineTestHarness) -> None: client = engine_harness.app.test_client() forbidden = client.get("/api/agent/hash_list", environ_base={"REMOTE_ADDR": "192.0.2.10"}) diff --git a/Data/Engine/services/API/devices/management.py b/Data/Engine/services/API/devices/management.py index db83c78a..251587d5 100644 --- a/Data/Engine/services/API/devices/management.py +++ b/Data/Engine/services/API/devices/management.py @@ -20,7 +20,7 @@ # - GET /api/sites/device_map (Token Authenticated) - Provides hostname to site assignment mapping data. # - POST /api/sites/assign (Token Authenticated (Admin)) - Assigns a set of devices to a given site. # - POST /api/sites/rename (Token Authenticated (Admin)) - Renames an existing site record. -# - GET /api/repo/current_hash (Token Authenticated) - Fetches the current agent repository hash (with caching). +# - GET /api/repo/current_hash (Device or Token Authenticated) - Fetches the current agent repository hash (with caching). # - GET/POST /api/agent/hash (Device Authenticated) - Retrieves or updates an agent hash record bound to the authenticated device. # - GET /api/agent/hash_list (Token Authenticated (Admin + Loopback)) - Returns stored agent hash metadata for localhost diagnostics. # ====================================================== @@ -42,7 +42,7 @@ from flask import Blueprint, jsonify, request, session, g from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer from ....auth.guid_utils import normalize_guid -from ....auth.device_auth import require_device_auth +from ....auth.device_auth import DeviceAuthError, require_device_auth if TYPE_CHECKING: # pragma: no cover - typing aide from .. import EngineServiceAdapters @@ -419,6 +419,29 @@ class DeviceManagementService: return {"error": "unauthorized"}, 401 return None + def _require_device_or_login(self) -> Optional[Tuple[Dict[str, Any], int]]: + user = self._current_user() + if user: + return None + + manager = getattr(self.adapters, "device_auth_manager", None) + if manager is None: + return {"error": "unauthorized"}, 401 + + try: + ctx = manager.authenticate() + g.device_auth = ctx + return None + except DeviceAuthError as exc: + payload: Dict[str, Any] = {"error": exc.message} + retry_after = getattr(exc, "retry_after", None) + if retry_after: + payload["retry_after"] = retry_after + return payload, getattr(exc, "status_code", 401) or 401 + except Exception: + self.service_log("server", "/api/repo/current_hash auth failure", level="ERROR") + return {"error": "unauthorized"}, 401 + def _require_admin(self) -> Optional[Tuple[Dict[str, Any], int]]: user = self._current_user() if not user: @@ -1765,7 +1788,7 @@ def register_management(app, adapters: "EngineServiceAdapters") -> None: @blueprint.route("/api/repo/current_hash", methods=["GET"]) def _repo_current_hash(): - requirement = service._require_login() + requirement = service._require_device_or_login() if requirement: payload, status = requirement return jsonify(payload), status