mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 19:21:58 -06:00
feat: secure agent auth and heartbeat endpoints
This commit is contained in:
1
Data/Server/Modules/agents/__init__.py
Normal file
1
Data/Server/Modules/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
114
Data/Server/Modules/agents/routes.py
Normal file
114
Data/Server/Modules/agents/routes.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
|
from flask import Blueprint, jsonify, request, g
|
||||||
|
|
||||||
|
from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
|
||||||
|
from Modules.crypto.signing import ScriptSigner
|
||||||
|
|
||||||
|
|
||||||
|
def register(
|
||||||
|
app,
|
||||||
|
*,
|
||||||
|
db_conn_factory: Callable[[], Any],
|
||||||
|
auth_manager: DeviceAuthManager,
|
||||||
|
log: Callable[[str, str], None],
|
||||||
|
script_signer: ScriptSigner,
|
||||||
|
) -> None:
|
||||||
|
blueprint = Blueprint("agents", __name__)
|
||||||
|
|
||||||
|
def _json_or_none(value) -> Optional[str]:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return json.dumps(value)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
|
||||||
|
@require_device_auth(auth_manager)
|
||||||
|
def heartbeat():
|
||||||
|
ctx = getattr(g, "device_auth")
|
||||||
|
payload = request.get_json(force=True, silent=True) or {}
|
||||||
|
|
||||||
|
now_ts = int(time.time())
|
||||||
|
updates: Dict[str, Optional[str]] = {"last_seen": now_ts}
|
||||||
|
|
||||||
|
hostname = payload.get("hostname")
|
||||||
|
if isinstance(hostname, str) and hostname.strip():
|
||||||
|
updates["hostname"] = hostname.strip()
|
||||||
|
|
||||||
|
inventory = payload.get("inventory") if isinstance(payload.get("inventory"), dict) else {}
|
||||||
|
for key in ("memory", "network", "software", "storage", "cpu"):
|
||||||
|
if key in inventory and inventory[key] is not None:
|
||||||
|
encoded = _json_or_none(inventory[key])
|
||||||
|
if encoded is not None:
|
||||||
|
updates[key] = encoded
|
||||||
|
|
||||||
|
metrics = payload.get("metrics") if isinstance(payload.get("metrics"), dict) else {}
|
||||||
|
def _maybe_str(field: str) -> Optional[str]:
|
||||||
|
val = metrics.get(field)
|
||||||
|
if isinstance(val, str):
|
||||||
|
return val.strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "last_user" in metrics and metrics["last_user"]:
|
||||||
|
updates["last_user"] = str(metrics["last_user"])
|
||||||
|
if "operating_system" in metrics and metrics["operating_system"]:
|
||||||
|
updates["operating_system"] = str(metrics["operating_system"])
|
||||||
|
if "uptime" in metrics and metrics["uptime"] is not None:
|
||||||
|
try:
|
||||||
|
updates["uptime"] = int(metrics["uptime"])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
for field in ("external_ip", "internal_ip", "device_type"):
|
||||||
|
if field in payload and payload[field]:
|
||||||
|
updates[field] = str(payload[field])
|
||||||
|
|
||||||
|
conn = db_conn_factory()
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
columns = ", ".join(f"{col} = ?" for col in updates.keys())
|
||||||
|
params = list(updates.values())
|
||||||
|
params.append(ctx.guid)
|
||||||
|
cur.execute(
|
||||||
|
f"UPDATE devices SET {columns} WHERE guid = ?",
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
if cur.rowcount == 0:
|
||||||
|
log("server", f"heartbeat missing device record guid={ctx.guid}")
|
||||||
|
return jsonify({"error": "device_not_registered"}), 404
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return jsonify({"status": "ok", "poll_after_ms": 15000})
|
||||||
|
|
||||||
|
@blueprint.route("/api/agent/script/request", methods=["POST"])
|
||||||
|
@require_device_auth(auth_manager)
|
||||||
|
def script_request():
|
||||||
|
ctx = getattr(g, "device_auth")
|
||||||
|
if ctx.status != "active":
|
||||||
|
return jsonify(
|
||||||
|
{
|
||||||
|
"status": "quarantined",
|
||||||
|
"poll_after_ms": 60000,
|
||||||
|
"sig_alg": "ed25519",
|
||||||
|
"signing_key": script_signer.public_base64_spki(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Placeholder: actual dispatch logic will integrate with job scheduler.
|
||||||
|
return jsonify(
|
||||||
|
{
|
||||||
|
"status": "idle",
|
||||||
|
"poll_after_ms": 30000,
|
||||||
|
"sig_alg": "ed25519",
|
||||||
|
"signing_key": script_signer.public_base64_spki(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
app.register_blueprint(blueprint)
|
||||||
148
Data/Server/Modules/auth/device_auth.py
Normal file
148
Data/Server/Modules/auth/device_auth.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from flask import g, jsonify, request
|
||||||
|
|
||||||
|
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeviceAuthContext:
|
||||||
|
guid: str
|
||||||
|
ssl_key_fingerprint: str
|
||||||
|
token_version: int
|
||||||
|
access_token: str
|
||||||
|
claims: Dict[str, Any]
|
||||||
|
dpop_jkt: Optional[str]
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceAuthError(Exception):
|
||||||
|
status_code = 401
|
||||||
|
error_code = "unauthorized"
|
||||||
|
|
||||||
|
def __init__(self, message: str = "unauthorized", *, status_code: Optional[int] = None):
|
||||||
|
super().__init__(message)
|
||||||
|
if status_code is not None:
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceAuthManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db_conn_factory: Callable[[], Any],
|
||||||
|
jwt_service,
|
||||||
|
dpop_validator: Optional[DPoPValidator],
|
||||||
|
log: Callable[[str, str], None],
|
||||||
|
) -> None:
|
||||||
|
self._db_conn_factory = db_conn_factory
|
||||||
|
self._jwt_service = jwt_service
|
||||||
|
self._dpop_validator = dpop_validator
|
||||||
|
self._log = log
|
||||||
|
|
||||||
|
def authenticate(self) -> DeviceAuthContext:
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
raise DeviceAuthError("missing_authorization")
|
||||||
|
token = auth_header[len("Bearer ") :].strip()
|
||||||
|
if not token:
|
||||||
|
raise DeviceAuthError("missing_authorization")
|
||||||
|
|
||||||
|
try:
|
||||||
|
claims = self._jwt_service.decode(token)
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
raise DeviceAuthError("token_expired")
|
||||||
|
except Exception:
|
||||||
|
raise DeviceAuthError("invalid_token")
|
||||||
|
|
||||||
|
guid = str(claims.get("guid") or "").strip()
|
||||||
|
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
|
||||||
|
token_version = int(claims.get("token_version") or 0)
|
||||||
|
if not guid or not fingerprint or token_version <= 0:
|
||||||
|
raise DeviceAuthError("invalid_claims")
|
||||||
|
|
||||||
|
conn = self._db_conn_factory()
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||||
|
FROM devices
|
||||||
|
WHERE guid = ?
|
||||||
|
""",
|
||||||
|
(guid,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
raise DeviceAuthError("device_not_found", status_code=403)
|
||||||
|
|
||||||
|
db_guid, db_fp, db_token_version, status = row
|
||||||
|
|
||||||
|
if str(db_guid or "").lower() != guid.lower():
|
||||||
|
raise DeviceAuthError("device_guid_mismatch", status_code=403)
|
||||||
|
|
||||||
|
db_fp = (db_fp or "").lower().strip()
|
||||||
|
if db_fp and db_fp != fingerprint:
|
||||||
|
raise DeviceAuthError("fingerprint_mismatch", status_code=403)
|
||||||
|
|
||||||
|
if db_token_version and db_token_version > token_version:
|
||||||
|
raise DeviceAuthError("token_version_revoked", status_code=401)
|
||||||
|
|
||||||
|
status_normalized = (status or "active").strip().lower()
|
||||||
|
allowed_statuses = {"active", "quarantined"}
|
||||||
|
if status_normalized not in allowed_statuses:
|
||||||
|
raise DeviceAuthError("device_revoked", status_code=403)
|
||||||
|
if status_normalized == "quarantined":
|
||||||
|
self._log("server", f"device {guid} is quarantined; limited access for {request.path}")
|
||||||
|
|
||||||
|
dpop_jkt: Optional[str] = None
|
||||||
|
dpop_proof = request.headers.get("DPoP")
|
||||||
|
if dpop_proof:
|
||||||
|
if not self._dpop_validator:
|
||||||
|
raise DeviceAuthError("dpop_not_supported", status_code=400)
|
||||||
|
try:
|
||||||
|
htu = request.url
|
||||||
|
dpop_jkt = self._dpop_validator.verify(request.method, htu, dpop_proof, token)
|
||||||
|
except DPoPReplayError:
|
||||||
|
raise DeviceAuthError("dpop_replayed", status_code=400)
|
||||||
|
except DPoPVerificationError:
|
||||||
|
raise DeviceAuthError("dpop_invalid", status_code=400)
|
||||||
|
|
||||||
|
ctx = DeviceAuthContext(
|
||||||
|
guid=guid,
|
||||||
|
ssl_key_fingerprint=fingerprint,
|
||||||
|
token_version=token_version,
|
||||||
|
access_token=token,
|
||||||
|
claims=claims,
|
||||||
|
dpop_jkt=dpop_jkt,
|
||||||
|
status=status_normalized,
|
||||||
|
)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
def require_device_auth(manager: DeviceAuthManager):
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
ctx = manager.authenticate()
|
||||||
|
except DeviceAuthError as exc:
|
||||||
|
response = jsonify({"error": exc.message})
|
||||||
|
response.status_code = exc.status_code
|
||||||
|
return response
|
||||||
|
|
||||||
|
g.device_auth = ctx
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
1
Data/Server/Modules/tokens/__init__.py
Normal file
1
Data/Server/Modules/tokens/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
125
Data/Server/Modules/tokens/routes.py
Normal file
125
Data/Server/Modules/tokens/routes.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from flask import Blueprint, jsonify, request
|
||||||
|
|
||||||
|
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
||||||
|
|
||||||
|
|
||||||
|
def register(
|
||||||
|
app,
|
||||||
|
*,
|
||||||
|
db_conn_factory: Callable[[], sqlite3.Connection],
|
||||||
|
jwt_service,
|
||||||
|
dpop_validator: DPoPValidator,
|
||||||
|
) -> None:
|
||||||
|
blueprint = Blueprint("tokens", __name__)
|
||||||
|
|
||||||
|
def _hash_token(token: str) -> str:
|
||||||
|
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
def _iso_now() -> str:
|
||||||
|
return datetime.now(tz=timezone.utc).isoformat()
|
||||||
|
|
||||||
|
def _parse_iso(ts: str) -> datetime:
|
||||||
|
return datetime.fromisoformat(ts)
|
||||||
|
|
||||||
|
@blueprint.route("/api/agent/token/refresh", methods=["POST"])
|
||||||
|
def refresh():
|
||||||
|
payload = request.get_json(force=True, silent=True) or {}
|
||||||
|
guid = str(payload.get("guid") or "").strip()
|
||||||
|
refresh_token = str(payload.get("refresh_token") or "").strip()
|
||||||
|
|
||||||
|
if not guid or not refresh_token:
|
||||||
|
return jsonify({"error": "invalid_request"}), 400
|
||||||
|
|
||||||
|
conn = db_conn_factory()
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at
|
||||||
|
FROM refresh_tokens
|
||||||
|
WHERE guid = ?
|
||||||
|
AND token_hash = ?
|
||||||
|
""",
|
||||||
|
(guid, _hash_token(refresh_token)),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if not row:
|
||||||
|
return jsonify({"error": "invalid_refresh_token"}), 401
|
||||||
|
|
||||||
|
record_id, row_guid, _token_hash, stored_jkt, created_at, expires_at, revoked_at = row
|
||||||
|
if row_guid != guid:
|
||||||
|
return jsonify({"error": "invalid_refresh_token"}), 401
|
||||||
|
if revoked_at:
|
||||||
|
return jsonify({"error": "refresh_token_revoked"}), 401
|
||||||
|
if expires_at:
|
||||||
|
try:
|
||||||
|
if _parse_iso(expires_at) <= datetime.now(tz=timezone.utc):
|
||||||
|
return jsonify({"error": "refresh_token_expired"}), 401
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||||
|
FROM devices
|
||||||
|
WHERE guid = ?
|
||||||
|
""",
|
||||||
|
(guid,),
|
||||||
|
)
|
||||||
|
device_row = cur.fetchone()
|
||||||
|
if not device_row:
|
||||||
|
return jsonify({"error": "device_not_found"}), 404
|
||||||
|
|
||||||
|
device_guid, fingerprint, token_version, status = device_row
|
||||||
|
status_norm = (status or "active").strip().lower()
|
||||||
|
if status_norm in {"revoked", "decommissioned"}:
|
||||||
|
return jsonify({"error": "device_revoked"}), 403
|
||||||
|
|
||||||
|
dpop_proof = request.headers.get("DPoP")
|
||||||
|
jkt = stored_jkt or ""
|
||||||
|
if dpop_proof:
|
||||||
|
try:
|
||||||
|
jkt = dpop_validator.verify(request.method, request.url, dpop_proof, access_token=None)
|
||||||
|
except DPoPReplayError:
|
||||||
|
return jsonify({"error": "dpop_replayed"}), 400
|
||||||
|
except DPoPVerificationError:
|
||||||
|
return jsonify({"error": "dpop_invalid"}), 400
|
||||||
|
elif stored_jkt:
|
||||||
|
return jsonify({"error": "dpop_required"}), 400
|
||||||
|
|
||||||
|
new_access_token = jwt_service.issue_access_token(
|
||||||
|
guid,
|
||||||
|
fingerprint or "",
|
||||||
|
token_version or 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
UPDATE refresh_tokens
|
||||||
|
SET last_used_at = ?,
|
||||||
|
dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt)
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(_iso_now(), jkt, record_id),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return jsonify(
|
||||||
|
{
|
||||||
|
"access_token": new_access_token,
|
||||||
|
"expires_in": 900,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
app.register_blueprint(blueprint)
|
||||||
@@ -50,10 +50,14 @@ from datetime import datetime, timezone
|
|||||||
|
|
||||||
from Modules import db_migrations
|
from Modules import db_migrations
|
||||||
from Modules.auth import jwt_service as jwt_service_module
|
from Modules.auth import jwt_service as jwt_service_module
|
||||||
|
from Modules.auth.dpop import DPoPValidator
|
||||||
|
from Modules.auth.device_auth import DeviceAuthManager
|
||||||
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||||
from Modules.crypto import certificates
|
from Modules.agents import routes as agent_routes
|
||||||
|
from Modules.crypto import certificates, signing
|
||||||
from Modules.enrollment import routes as enrollment_routes
|
from Modules.enrollment import routes as enrollment_routes
|
||||||
from Modules.enrollment.nonce_store import NonceCache
|
from Modules.enrollment.nonce_store import NonceCache
|
||||||
|
from Modules.tokens import routes as token_routes
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from cryptography.fernet import Fernet # type: ignore
|
from cryptography.fernet import Fernet # type: ignore
|
||||||
@@ -149,9 +153,12 @@ os.environ.setdefault("BOREALIS_TLS_KEY", TLS_KEY_PATH)
|
|||||||
os.environ.setdefault("BOREALIS_TLS_BUNDLE", TLS_BUNDLE_PATH)
|
os.environ.setdefault("BOREALIS_TLS_BUNDLE", TLS_BUNDLE_PATH)
|
||||||
|
|
||||||
JWT_SERVICE = jwt_service_module.load_service()
|
JWT_SERVICE = jwt_service_module.load_service()
|
||||||
|
SCRIPT_SIGNER = signing.load_signer()
|
||||||
IP_RATE_LIMITER = SlidingWindowRateLimiter()
|
IP_RATE_LIMITER = SlidingWindowRateLimiter()
|
||||||
FP_RATE_LIMITER = SlidingWindowRateLimiter()
|
FP_RATE_LIMITER = SlidingWindowRateLimiter()
|
||||||
ENROLLMENT_NONCE_CACHE = NonceCache()
|
ENROLLMENT_NONCE_CACHE = NonceCache()
|
||||||
|
DPOP_VALIDATOR = DPoPValidator()
|
||||||
|
DEVICE_AUTH_MANAGER: Optional[DeviceAuthManager] = None
|
||||||
|
|
||||||
|
|
||||||
def _set_cached_github_token(token: Optional[str]) -> None:
|
def _set_cached_github_token(token: Optional[str]) -> None:
|
||||||
@@ -1248,6 +1255,14 @@ def _db_conn():
|
|||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
if DEVICE_AUTH_MANAGER is None:
|
||||||
|
DEVICE_AUTH_MANAGER = DeviceAuthManager(
|
||||||
|
db_conn_factory=_db_conn,
|
||||||
|
jwt_service=JWT_SERVICE,
|
||||||
|
dpop_validator=DPOP_VALIDATOR,
|
||||||
|
log=_write_service_log,
|
||||||
|
)
|
||||||
|
|
||||||
def _update_last_login(username: str) -> None:
|
def _update_last_login(username: str) -> None:
|
||||||
if not username:
|
if not username:
|
||||||
return
|
return
|
||||||
@@ -4836,6 +4851,21 @@ enrollment_routes.register(
|
|||||||
nonce_cache=ENROLLMENT_NONCE_CACHE,
|
nonce_cache=ENROLLMENT_NONCE_CACHE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
token_routes.register(
|
||||||
|
app,
|
||||||
|
db_conn_factory=_db_conn,
|
||||||
|
jwt_service=JWT_SERVICE,
|
||||||
|
dpop_validator=DPOP_VALIDATOR,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_routes.register(
|
||||||
|
app,
|
||||||
|
db_conn_factory=_db_conn,
|
||||||
|
auth_manager=DEVICE_AUTH_MANAGER,
|
||||||
|
log=_write_service_log,
|
||||||
|
script_signer=SCRIPT_SIGNER,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ensure_default_admin():
|
def ensure_default_admin():
|
||||||
"""Ensure at least one admin user exists.
|
"""Ensure at least one admin user exists.
|
||||||
|
|||||||
Reference in New Issue
Block a user