mirror of
				https://github.com/bunny-lab-io/Borealis.git
				synced 2025-10-26 17:21:58 -06:00 
			
		
		
		
	feat: secure agent auth and heartbeat endpoints
This commit is contained in:
		
							
								
								
									
										1
									
								
								Data/Server/Modules/agents/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								Data/Server/Modules/agents/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
|  | ||||
							
								
								
									
										114
									
								
								Data/Server/Modules/agents/routes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								Data/Server/Modules/agents/routes.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import json | ||||
| import time | ||||
| from typing import Any, Callable, Dict, Optional | ||||
|  | ||||
| from flask import Blueprint, jsonify, request, g | ||||
|  | ||||
| from Modules.auth.device_auth import DeviceAuthManager, require_device_auth | ||||
| from Modules.crypto.signing import ScriptSigner | ||||
|  | ||||
|  | ||||
| def register( | ||||
|     app, | ||||
|     *, | ||||
|     db_conn_factory: Callable[[], Any], | ||||
|     auth_manager: DeviceAuthManager, | ||||
|     log: Callable[[str, str], None], | ||||
|     script_signer: ScriptSigner, | ||||
| ) -> None: | ||||
|     blueprint = Blueprint("agents", __name__) | ||||
|  | ||||
|     def _json_or_none(value) -> Optional[str]: | ||||
|         if value is None: | ||||
|             return None | ||||
|         try: | ||||
|             return json.dumps(value) | ||||
|         except Exception: | ||||
|             return None | ||||
|  | ||||
|     @blueprint.route("/api/agent/heartbeat", methods=["POST"]) | ||||
|     @require_device_auth(auth_manager) | ||||
|     def heartbeat(): | ||||
|         ctx = getattr(g, "device_auth") | ||||
|         payload = request.get_json(force=True, silent=True) or {} | ||||
|  | ||||
|         now_ts = int(time.time()) | ||||
|         updates: Dict[str, Optional[str]] = {"last_seen": now_ts} | ||||
|  | ||||
|         hostname = payload.get("hostname") | ||||
|         if isinstance(hostname, str) and hostname.strip(): | ||||
|             updates["hostname"] = hostname.strip() | ||||
|  | ||||
|         inventory = payload.get("inventory") if isinstance(payload.get("inventory"), dict) else {} | ||||
|         for key in ("memory", "network", "software", "storage", "cpu"): | ||||
|             if key in inventory and inventory[key] is not None: | ||||
|                 encoded = _json_or_none(inventory[key]) | ||||
|                 if encoded is not None: | ||||
|                     updates[key] = encoded | ||||
|  | ||||
|         metrics = payload.get("metrics") if isinstance(payload.get("metrics"), dict) else {} | ||||
|         def _maybe_str(field: str) -> Optional[str]: | ||||
|             val = metrics.get(field) | ||||
|             if isinstance(val, str): | ||||
|                 return val.strip() | ||||
|             return None | ||||
|  | ||||
|         if "last_user" in metrics and metrics["last_user"]: | ||||
|             updates["last_user"] = str(metrics["last_user"]) | ||||
|         if "operating_system" in metrics and metrics["operating_system"]: | ||||
|             updates["operating_system"] = str(metrics["operating_system"]) | ||||
|         if "uptime" in metrics and metrics["uptime"] is not None: | ||||
|             try: | ||||
|                 updates["uptime"] = int(metrics["uptime"]) | ||||
|             except Exception: | ||||
|                 pass | ||||
|         for field in ("external_ip", "internal_ip", "device_type"): | ||||
|             if field in payload and payload[field]: | ||||
|                 updates[field] = str(payload[field]) | ||||
|  | ||||
|         conn = db_conn_factory() | ||||
|         try: | ||||
|             cur = conn.cursor() | ||||
|             columns = ", ".join(f"{col} = ?" for col in updates.keys()) | ||||
|             params = list(updates.values()) | ||||
|             params.append(ctx.guid) | ||||
|             cur.execute( | ||||
|                 f"UPDATE devices SET {columns} WHERE guid = ?", | ||||
|                 params, | ||||
|             ) | ||||
|             if cur.rowcount == 0: | ||||
|                 log("server", f"heartbeat missing device record guid={ctx.guid}") | ||||
|                 return jsonify({"error": "device_not_registered"}), 404 | ||||
|             conn.commit() | ||||
|         finally: | ||||
|             conn.close() | ||||
|  | ||||
|         return jsonify({"status": "ok", "poll_after_ms": 15000}) | ||||
|  | ||||
|     @blueprint.route("/api/agent/script/request", methods=["POST"]) | ||||
|     @require_device_auth(auth_manager) | ||||
|     def script_request(): | ||||
|         ctx = getattr(g, "device_auth") | ||||
|         if ctx.status != "active": | ||||
|             return jsonify( | ||||
|                 { | ||||
|                     "status": "quarantined", | ||||
|                     "poll_after_ms": 60000, | ||||
|                     "sig_alg": "ed25519", | ||||
|                     "signing_key": script_signer.public_base64_spki(), | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|         # Placeholder: actual dispatch logic will integrate with job scheduler. | ||||
|         return jsonify( | ||||
|             { | ||||
|                 "status": "idle", | ||||
|                 "poll_after_ms": 30000, | ||||
|                 "sig_alg": "ed25519", | ||||
|                 "signing_key": script_signer.public_base64_spki(), | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     app.register_blueprint(blueprint) | ||||
							
								
								
									
										148
									
								
								Data/Server/Modules/auth/device_auth.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								Data/Server/Modules/auth/device_auth.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,148 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import functools | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Callable, Dict, Optional | ||||
|  | ||||
| import jwt | ||||
| from flask import g, jsonify, request | ||||
|  | ||||
| from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class DeviceAuthContext: | ||||
|     guid: str | ||||
|     ssl_key_fingerprint: str | ||||
|     token_version: int | ||||
|     access_token: str | ||||
|     claims: Dict[str, Any] | ||||
|     dpop_jkt: Optional[str] | ||||
|     status: str | ||||
|  | ||||
|  | ||||
| class DeviceAuthError(Exception): | ||||
|     status_code = 401 | ||||
|     error_code = "unauthorized" | ||||
|  | ||||
|     def __init__(self, message: str = "unauthorized", *, status_code: Optional[int] = None): | ||||
|         super().__init__(message) | ||||
|         if status_code is not None: | ||||
|             self.status_code = status_code | ||||
|         self.message = message | ||||
|  | ||||
|  | ||||
| class DeviceAuthManager: | ||||
|     def __init__( | ||||
|         self, | ||||
|         *, | ||||
|         db_conn_factory: Callable[[], Any], | ||||
|         jwt_service, | ||||
|         dpop_validator: Optional[DPoPValidator], | ||||
|         log: Callable[[str, str], None], | ||||
|     ) -> None: | ||||
|         self._db_conn_factory = db_conn_factory | ||||
|         self._jwt_service = jwt_service | ||||
|         self._dpop_validator = dpop_validator | ||||
|         self._log = log | ||||
|  | ||||
|     def authenticate(self) -> DeviceAuthContext: | ||||
|         auth_header = request.headers.get("Authorization", "") | ||||
|         if not auth_header.startswith("Bearer "): | ||||
|             raise DeviceAuthError("missing_authorization") | ||||
|         token = auth_header[len("Bearer ") :].strip() | ||||
|         if not token: | ||||
|             raise DeviceAuthError("missing_authorization") | ||||
|  | ||||
|         try: | ||||
|             claims = self._jwt_service.decode(token) | ||||
|         except jwt.ExpiredSignatureError: | ||||
|             raise DeviceAuthError("token_expired") | ||||
|         except Exception: | ||||
|             raise DeviceAuthError("invalid_token") | ||||
|  | ||||
|         guid = str(claims.get("guid") or "").strip() | ||||
|         fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip() | ||||
|         token_version = int(claims.get("token_version") or 0) | ||||
|         if not guid or not fingerprint or token_version <= 0: | ||||
|             raise DeviceAuthError("invalid_claims") | ||||
|  | ||||
|         conn = self._db_conn_factory() | ||||
|         try: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 """ | ||||
|                 SELECT guid, ssl_key_fingerprint, token_version, status | ||||
|                   FROM devices | ||||
|                  WHERE guid = ? | ||||
|                 """, | ||||
|                 (guid,), | ||||
|             ) | ||||
|             row = cur.fetchone() | ||||
|         finally: | ||||
|             conn.close() | ||||
|  | ||||
|         if not row: | ||||
|             raise DeviceAuthError("device_not_found", status_code=403) | ||||
|  | ||||
|         db_guid, db_fp, db_token_version, status = row | ||||
|  | ||||
|         if str(db_guid or "").lower() != guid.lower(): | ||||
|             raise DeviceAuthError("device_guid_mismatch", status_code=403) | ||||
|  | ||||
|         db_fp = (db_fp or "").lower().strip() | ||||
|         if db_fp and db_fp != fingerprint: | ||||
|             raise DeviceAuthError("fingerprint_mismatch", status_code=403) | ||||
|  | ||||
|         if db_token_version and db_token_version > token_version: | ||||
|             raise DeviceAuthError("token_version_revoked", status_code=401) | ||||
|  | ||||
|         status_normalized = (status or "active").strip().lower() | ||||
|         allowed_statuses = {"active", "quarantined"} | ||||
|         if status_normalized not in allowed_statuses: | ||||
|             raise DeviceAuthError("device_revoked", status_code=403) | ||||
|         if status_normalized == "quarantined": | ||||
|             self._log("server", f"device {guid} is quarantined; limited access for {request.path}") | ||||
|  | ||||
|         dpop_jkt: Optional[str] = None | ||||
|         dpop_proof = request.headers.get("DPoP") | ||||
|         if dpop_proof: | ||||
|             if not self._dpop_validator: | ||||
|                 raise DeviceAuthError("dpop_not_supported", status_code=400) | ||||
|             try: | ||||
|                 htu = request.url | ||||
|                 dpop_jkt = self._dpop_validator.verify(request.method, htu, dpop_proof, token) | ||||
|             except DPoPReplayError: | ||||
|                 raise DeviceAuthError("dpop_replayed", status_code=400) | ||||
|             except DPoPVerificationError: | ||||
|                 raise DeviceAuthError("dpop_invalid", status_code=400) | ||||
|  | ||||
|         ctx = DeviceAuthContext( | ||||
|             guid=guid, | ||||
|             ssl_key_fingerprint=fingerprint, | ||||
|             token_version=token_version, | ||||
|             access_token=token, | ||||
|             claims=claims, | ||||
|             dpop_jkt=dpop_jkt, | ||||
|             status=status_normalized, | ||||
|         ) | ||||
|         return ctx | ||||
|  | ||||
|  | ||||
| def require_device_auth(manager: DeviceAuthManager): | ||||
|     def decorator(func): | ||||
|         @functools.wraps(func) | ||||
|         def wrapper(*args, **kwargs): | ||||
|             try: | ||||
|                 ctx = manager.authenticate() | ||||
|             except DeviceAuthError as exc: | ||||
|                 response = jsonify({"error": exc.message}) | ||||
|                 response.status_code = exc.status_code | ||||
|                 return response | ||||
|  | ||||
|             g.device_auth = ctx | ||||
|             return func(*args, **kwargs) | ||||
|  | ||||
|         return wrapper | ||||
|  | ||||
|     return decorator | ||||
							
								
								
									
										1
									
								
								Data/Server/Modules/tokens/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								Data/Server/Modules/tokens/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
|  | ||||
							
								
								
									
										125
									
								
								Data/Server/Modules/tokens/routes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								Data/Server/Modules/tokens/routes.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,125 @@ | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| import hashlib | ||||
| import sqlite3 | ||||
| from datetime import datetime, timezone | ||||
| from typing import Callable | ||||
|  | ||||
| from flask import Blueprint, jsonify, request | ||||
|  | ||||
| from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError | ||||
|  | ||||
|  | ||||
| def register( | ||||
|     app, | ||||
|     *, | ||||
|     db_conn_factory: Callable[[], sqlite3.Connection], | ||||
|     jwt_service, | ||||
|     dpop_validator: DPoPValidator, | ||||
| ) -> None: | ||||
|     blueprint = Blueprint("tokens", __name__) | ||||
|  | ||||
|     def _hash_token(token: str) -> str: | ||||
|         return hashlib.sha256(token.encode("utf-8")).hexdigest() | ||||
|  | ||||
|     def _iso_now() -> str: | ||||
|         return datetime.now(tz=timezone.utc).isoformat() | ||||
|  | ||||
|     def _parse_iso(ts: str) -> datetime: | ||||
|         return datetime.fromisoformat(ts) | ||||
|  | ||||
|     @blueprint.route("/api/agent/token/refresh", methods=["POST"]) | ||||
|     def refresh(): | ||||
|         payload = request.get_json(force=True, silent=True) or {} | ||||
|         guid = str(payload.get("guid") or "").strip() | ||||
|         refresh_token = str(payload.get("refresh_token") or "").strip() | ||||
|  | ||||
|         if not guid or not refresh_token: | ||||
|             return jsonify({"error": "invalid_request"}), 400 | ||||
|  | ||||
|         conn = db_conn_factory() | ||||
|         try: | ||||
|             cur = conn.cursor() | ||||
|             cur.execute( | ||||
|                 """ | ||||
|                 SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at | ||||
|                   FROM refresh_tokens | ||||
|                  WHERE guid = ? | ||||
|                    AND token_hash = ? | ||||
|                 """, | ||||
|                 (guid, _hash_token(refresh_token)), | ||||
|             ) | ||||
|             row = cur.fetchone() | ||||
|             if not row: | ||||
|                 return jsonify({"error": "invalid_refresh_token"}), 401 | ||||
|  | ||||
|             record_id, row_guid, _token_hash, stored_jkt, created_at, expires_at, revoked_at = row | ||||
|             if row_guid != guid: | ||||
|                 return jsonify({"error": "invalid_refresh_token"}), 401 | ||||
|             if revoked_at: | ||||
|                 return jsonify({"error": "refresh_token_revoked"}), 401 | ||||
|             if expires_at: | ||||
|                 try: | ||||
|                     if _parse_iso(expires_at) <= datetime.now(tz=timezone.utc): | ||||
|                         return jsonify({"error": "refresh_token_expired"}), 401 | ||||
|                 except Exception: | ||||
|                     pass | ||||
|  | ||||
|             cur.execute( | ||||
|                 """ | ||||
|                 SELECT guid, ssl_key_fingerprint, token_version, status | ||||
|                   FROM devices | ||||
|                  WHERE guid = ? | ||||
|                 """, | ||||
|                 (guid,), | ||||
|             ) | ||||
|             device_row = cur.fetchone() | ||||
|             if not device_row: | ||||
|                 return jsonify({"error": "device_not_found"}), 404 | ||||
|  | ||||
|             device_guid, fingerprint, token_version, status = device_row | ||||
|             status_norm = (status or "active").strip().lower() | ||||
|             if status_norm in {"revoked", "decommissioned"}: | ||||
|                 return jsonify({"error": "device_revoked"}), 403 | ||||
|  | ||||
|             dpop_proof = request.headers.get("DPoP") | ||||
|             jkt = stored_jkt or "" | ||||
|             if dpop_proof: | ||||
|                 try: | ||||
|                     jkt = dpop_validator.verify(request.method, request.url, dpop_proof, access_token=None) | ||||
|                 except DPoPReplayError: | ||||
|                     return jsonify({"error": "dpop_replayed"}), 400 | ||||
|                 except DPoPVerificationError: | ||||
|                     return jsonify({"error": "dpop_invalid"}), 400 | ||||
|             elif stored_jkt: | ||||
|                 return jsonify({"error": "dpop_required"}), 400 | ||||
|  | ||||
|             new_access_token = jwt_service.issue_access_token( | ||||
|                 guid, | ||||
|                 fingerprint or "", | ||||
|                 token_version or 1, | ||||
|             ) | ||||
|  | ||||
|             cur.execute( | ||||
|                 """ | ||||
|                 UPDATE refresh_tokens | ||||
|                    SET last_used_at = ?, | ||||
|                        dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt) | ||||
|                  WHERE id = ? | ||||
|                 """, | ||||
|                 (_iso_now(), jkt, record_id), | ||||
|             ) | ||||
|             conn.commit() | ||||
|         finally: | ||||
|             conn.close() | ||||
|  | ||||
|         return jsonify( | ||||
|             { | ||||
|                 "access_token": new_access_token, | ||||
|                 "expires_in": 900, | ||||
|                 "token_type": "Bearer", | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     app.register_blueprint(blueprint) | ||||
| @@ -50,10 +50,14 @@ from datetime import datetime, timezone | ||||
|  | ||||
| from Modules import db_migrations | ||||
| from Modules.auth import jwt_service as jwt_service_module | ||||
| from Modules.auth.dpop import DPoPValidator | ||||
| from Modules.auth.device_auth import DeviceAuthManager | ||||
| from Modules.auth.rate_limit import SlidingWindowRateLimiter | ||||
| from Modules.crypto import certificates | ||||
| from Modules.agents import routes as agent_routes | ||||
| from Modules.crypto import certificates, signing | ||||
| from Modules.enrollment import routes as enrollment_routes | ||||
| from Modules.enrollment.nonce_store import NonceCache | ||||
| from Modules.tokens import routes as token_routes | ||||
|  | ||||
| try: | ||||
|     from cryptography.fernet import Fernet  # type: ignore | ||||
| @@ -149,9 +153,12 @@ os.environ.setdefault("BOREALIS_TLS_KEY", TLS_KEY_PATH) | ||||
| os.environ.setdefault("BOREALIS_TLS_BUNDLE", TLS_BUNDLE_PATH) | ||||
|  | ||||
| JWT_SERVICE = jwt_service_module.load_service() | ||||
| SCRIPT_SIGNER = signing.load_signer() | ||||
| IP_RATE_LIMITER = SlidingWindowRateLimiter() | ||||
| FP_RATE_LIMITER = SlidingWindowRateLimiter() | ||||
| ENROLLMENT_NONCE_CACHE = NonceCache() | ||||
| DPOP_VALIDATOR = DPoPValidator() | ||||
| DEVICE_AUTH_MANAGER: Optional[DeviceAuthManager] = None | ||||
|  | ||||
|  | ||||
| def _set_cached_github_token(token: Optional[str]) -> None: | ||||
| @@ -1248,6 +1255,14 @@ def _db_conn(): | ||||
|     return conn | ||||
|  | ||||
|  | ||||
| if DEVICE_AUTH_MANAGER is None: | ||||
|     DEVICE_AUTH_MANAGER = DeviceAuthManager( | ||||
|         db_conn_factory=_db_conn, | ||||
|         jwt_service=JWT_SERVICE, | ||||
|         dpop_validator=DPOP_VALIDATOR, | ||||
|         log=_write_service_log, | ||||
|     ) | ||||
|  | ||||
| def _update_last_login(username: str) -> None: | ||||
|     if not username: | ||||
|         return | ||||
| @@ -4836,6 +4851,21 @@ enrollment_routes.register( | ||||
|     nonce_cache=ENROLLMENT_NONCE_CACHE, | ||||
| ) | ||||
|  | ||||
| token_routes.register( | ||||
|     app, | ||||
|     db_conn_factory=_db_conn, | ||||
|     jwt_service=JWT_SERVICE, | ||||
|     dpop_validator=DPOP_VALIDATOR, | ||||
| ) | ||||
|  | ||||
| agent_routes.register( | ||||
|     app, | ||||
|     db_conn_factory=_db_conn, | ||||
|     auth_manager=DEVICE_AUTH_MANAGER, | ||||
|     log=_write_service_log, | ||||
|     script_signer=SCRIPT_SIGNER, | ||||
| ) | ||||
|  | ||||
|  | ||||
| def ensure_default_admin(): | ||||
|     """Ensure at least one admin user exists. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user