mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 15:21:57 -06:00
feat: secure agent auth and heartbeat endpoints
This commit is contained in:
1
Data/Server/Modules/agents/__init__.py
Normal file
1
Data/Server/Modules/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
114
Data/Server/Modules/agents/routes.py
Normal file
114
Data/Server/Modules/agents/routes.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from flask import Blueprint, jsonify, request, g
|
||||
|
||||
from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
|
||||
from Modules.crypto.signing import ScriptSigner
|
||||
|
||||
|
||||
def register(
|
||||
app,
|
||||
*,
|
||||
db_conn_factory: Callable[[], Any],
|
||||
auth_manager: DeviceAuthManager,
|
||||
log: Callable[[str, str], None],
|
||||
script_signer: ScriptSigner,
|
||||
) -> None:
|
||||
blueprint = Blueprint("agents", __name__)
|
||||
|
||||
def _json_or_none(value) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return json.dumps(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@blueprint.route("/api/agent/heartbeat", methods=["POST"])
|
||||
@require_device_auth(auth_manager)
|
||||
def heartbeat():
|
||||
ctx = getattr(g, "device_auth")
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
|
||||
now_ts = int(time.time())
|
||||
updates: Dict[str, Optional[str]] = {"last_seen": now_ts}
|
||||
|
||||
hostname = payload.get("hostname")
|
||||
if isinstance(hostname, str) and hostname.strip():
|
||||
updates["hostname"] = hostname.strip()
|
||||
|
||||
inventory = payload.get("inventory") if isinstance(payload.get("inventory"), dict) else {}
|
||||
for key in ("memory", "network", "software", "storage", "cpu"):
|
||||
if key in inventory and inventory[key] is not None:
|
||||
encoded = _json_or_none(inventory[key])
|
||||
if encoded is not None:
|
||||
updates[key] = encoded
|
||||
|
||||
metrics = payload.get("metrics") if isinstance(payload.get("metrics"), dict) else {}
|
||||
def _maybe_str(field: str) -> Optional[str]:
|
||||
val = metrics.get(field)
|
||||
if isinstance(val, str):
|
||||
return val.strip()
|
||||
return None
|
||||
|
||||
if "last_user" in metrics and metrics["last_user"]:
|
||||
updates["last_user"] = str(metrics["last_user"])
|
||||
if "operating_system" in metrics and metrics["operating_system"]:
|
||||
updates["operating_system"] = str(metrics["operating_system"])
|
||||
if "uptime" in metrics and metrics["uptime"] is not None:
|
||||
try:
|
||||
updates["uptime"] = int(metrics["uptime"])
|
||||
except Exception:
|
||||
pass
|
||||
for field in ("external_ip", "internal_ip", "device_type"):
|
||||
if field in payload and payload[field]:
|
||||
updates[field] = str(payload[field])
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
columns = ", ".join(f"{col} = ?" for col in updates.keys())
|
||||
params = list(updates.values())
|
||||
params.append(ctx.guid)
|
||||
cur.execute(
|
||||
f"UPDATE devices SET {columns} WHERE guid = ?",
|
||||
params,
|
||||
)
|
||||
if cur.rowcount == 0:
|
||||
log("server", f"heartbeat missing device record guid={ctx.guid}")
|
||||
return jsonify({"error": "device_not_registered"}), 404
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return jsonify({"status": "ok", "poll_after_ms": 15000})
|
||||
|
||||
@blueprint.route("/api/agent/script/request", methods=["POST"])
|
||||
@require_device_auth(auth_manager)
|
||||
def script_request():
|
||||
ctx = getattr(g, "device_auth")
|
||||
if ctx.status != "active":
|
||||
return jsonify(
|
||||
{
|
||||
"status": "quarantined",
|
||||
"poll_after_ms": 60000,
|
||||
"sig_alg": "ed25519",
|
||||
"signing_key": script_signer.public_base64_spki(),
|
||||
}
|
||||
)
|
||||
|
||||
# Placeholder: actual dispatch logic will integrate with job scheduler.
|
||||
return jsonify(
|
||||
{
|
||||
"status": "idle",
|
||||
"poll_after_ms": 30000,
|
||||
"sig_alg": "ed25519",
|
||||
"signing_key": script_signer.public_base64_spki(),
|
||||
}
|
||||
)
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
148
Data/Server/Modules/auth/device_auth.py
Normal file
148
Data/Server/Modules/auth/device_auth.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from flask import g, jsonify, request
|
||||
|
||||
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceAuthContext:
|
||||
guid: str
|
||||
ssl_key_fingerprint: str
|
||||
token_version: int
|
||||
access_token: str
|
||||
claims: Dict[str, Any]
|
||||
dpop_jkt: Optional[str]
|
||||
status: str
|
||||
|
||||
|
||||
class DeviceAuthError(Exception):
|
||||
status_code = 401
|
||||
error_code = "unauthorized"
|
||||
|
||||
def __init__(self, message: str = "unauthorized", *, status_code: Optional[int] = None):
|
||||
super().__init__(message)
|
||||
if status_code is not None:
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
|
||||
|
||||
class DeviceAuthManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
db_conn_factory: Callable[[], Any],
|
||||
jwt_service,
|
||||
dpop_validator: Optional[DPoPValidator],
|
||||
log: Callable[[str, str], None],
|
||||
) -> None:
|
||||
self._db_conn_factory = db_conn_factory
|
||||
self._jwt_service = jwt_service
|
||||
self._dpop_validator = dpop_validator
|
||||
self._log = log
|
||||
|
||||
def authenticate(self) -> DeviceAuthContext:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise DeviceAuthError("missing_authorization")
|
||||
token = auth_header[len("Bearer ") :].strip()
|
||||
if not token:
|
||||
raise DeviceAuthError("missing_authorization")
|
||||
|
||||
try:
|
||||
claims = self._jwt_service.decode(token)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise DeviceAuthError("token_expired")
|
||||
except Exception:
|
||||
raise DeviceAuthError("invalid_token")
|
||||
|
||||
guid = str(claims.get("guid") or "").strip()
|
||||
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
|
||||
token_version = int(claims.get("token_version") or 0)
|
||||
if not guid or not fingerprint or token_version <= 0:
|
||||
raise DeviceAuthError("invalid_claims")
|
||||
|
||||
conn = self._db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(guid,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if not row:
|
||||
raise DeviceAuthError("device_not_found", status_code=403)
|
||||
|
||||
db_guid, db_fp, db_token_version, status = row
|
||||
|
||||
if str(db_guid or "").lower() != guid.lower():
|
||||
raise DeviceAuthError("device_guid_mismatch", status_code=403)
|
||||
|
||||
db_fp = (db_fp or "").lower().strip()
|
||||
if db_fp and db_fp != fingerprint:
|
||||
raise DeviceAuthError("fingerprint_mismatch", status_code=403)
|
||||
|
||||
if db_token_version and db_token_version > token_version:
|
||||
raise DeviceAuthError("token_version_revoked", status_code=401)
|
||||
|
||||
status_normalized = (status or "active").strip().lower()
|
||||
allowed_statuses = {"active", "quarantined"}
|
||||
if status_normalized not in allowed_statuses:
|
||||
raise DeviceAuthError("device_revoked", status_code=403)
|
||||
if status_normalized == "quarantined":
|
||||
self._log("server", f"device {guid} is quarantined; limited access for {request.path}")
|
||||
|
||||
dpop_jkt: Optional[str] = None
|
||||
dpop_proof = request.headers.get("DPoP")
|
||||
if dpop_proof:
|
||||
if not self._dpop_validator:
|
||||
raise DeviceAuthError("dpop_not_supported", status_code=400)
|
||||
try:
|
||||
htu = request.url
|
||||
dpop_jkt = self._dpop_validator.verify(request.method, htu, dpop_proof, token)
|
||||
except DPoPReplayError:
|
||||
raise DeviceAuthError("dpop_replayed", status_code=400)
|
||||
except DPoPVerificationError:
|
||||
raise DeviceAuthError("dpop_invalid", status_code=400)
|
||||
|
||||
ctx = DeviceAuthContext(
|
||||
guid=guid,
|
||||
ssl_key_fingerprint=fingerprint,
|
||||
token_version=token_version,
|
||||
access_token=token,
|
||||
claims=claims,
|
||||
dpop_jkt=dpop_jkt,
|
||||
status=status_normalized,
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
def require_device_auth(manager: DeviceAuthManager):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
ctx = manager.authenticate()
|
||||
except DeviceAuthError as exc:
|
||||
response = jsonify({"error": exc.message})
|
||||
response.status_code = exc.status_code
|
||||
return response
|
||||
|
||||
g.device_auth = ctx
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
1
Data/Server/Modules/tokens/__init__.py
Normal file
1
Data/Server/Modules/tokens/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
125
Data/Server/Modules/tokens/routes.py
Normal file
125
Data/Server/Modules/tokens/routes.py
Normal file
@@ -0,0 +1,125 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable
|
||||
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
||||
|
||||
|
||||
def register(
|
||||
app,
|
||||
*,
|
||||
db_conn_factory: Callable[[], sqlite3.Connection],
|
||||
jwt_service,
|
||||
dpop_validator: DPoPValidator,
|
||||
) -> None:
|
||||
blueprint = Blueprint("tokens", __name__)
|
||||
|
||||
def _hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
def _iso_now() -> str:
|
||||
return datetime.now(tz=timezone.utc).isoformat()
|
||||
|
||||
def _parse_iso(ts: str) -> datetime:
|
||||
return datetime.fromisoformat(ts)
|
||||
|
||||
@blueprint.route("/api/agent/token/refresh", methods=["POST"])
|
||||
def refresh():
|
||||
payload = request.get_json(force=True, silent=True) or {}
|
||||
guid = str(payload.get("guid") or "").strip()
|
||||
refresh_token = str(payload.get("refresh_token") or "").strip()
|
||||
|
||||
if not guid or not refresh_token:
|
||||
return jsonify({"error": "invalid_request"}), 400
|
||||
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, guid, token_hash, dpop_jkt, created_at, expires_at, revoked_at
|
||||
FROM refresh_tokens
|
||||
WHERE guid = ?
|
||||
AND token_hash = ?
|
||||
""",
|
||||
(guid, _hash_token(refresh_token)),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return jsonify({"error": "invalid_refresh_token"}), 401
|
||||
|
||||
record_id, row_guid, _token_hash, stored_jkt, created_at, expires_at, revoked_at = row
|
||||
if row_guid != guid:
|
||||
return jsonify({"error": "invalid_refresh_token"}), 401
|
||||
if revoked_at:
|
||||
return jsonify({"error": "refresh_token_revoked"}), 401
|
||||
if expires_at:
|
||||
try:
|
||||
if _parse_iso(expires_at) <= datetime.now(tz=timezone.utc):
|
||||
return jsonify({"error": "refresh_token_expired"}), 401
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||
FROM devices
|
||||
WHERE guid = ?
|
||||
""",
|
||||
(guid,),
|
||||
)
|
||||
device_row = cur.fetchone()
|
||||
if not device_row:
|
||||
return jsonify({"error": "device_not_found"}), 404
|
||||
|
||||
device_guid, fingerprint, token_version, status = device_row
|
||||
status_norm = (status or "active").strip().lower()
|
||||
if status_norm in {"revoked", "decommissioned"}:
|
||||
return jsonify({"error": "device_revoked"}), 403
|
||||
|
||||
dpop_proof = request.headers.get("DPoP")
|
||||
jkt = stored_jkt or ""
|
||||
if dpop_proof:
|
||||
try:
|
||||
jkt = dpop_validator.verify(request.method, request.url, dpop_proof, access_token=None)
|
||||
except DPoPReplayError:
|
||||
return jsonify({"error": "dpop_replayed"}), 400
|
||||
except DPoPVerificationError:
|
||||
return jsonify({"error": "dpop_invalid"}), 400
|
||||
elif stored_jkt:
|
||||
return jsonify({"error": "dpop_required"}), 400
|
||||
|
||||
new_access_token = jwt_service.issue_access_token(
|
||||
guid,
|
||||
fingerprint or "",
|
||||
token_version or 1,
|
||||
)
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET last_used_at = ?,
|
||||
dpop_jkt = COALESCE(NULLIF(?, ''), dpop_jkt)
|
||||
WHERE id = ?
|
||||
""",
|
||||
(_iso_now(), jkt, record_id),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"access_token": new_access_token,
|
||||
"expires_in": 900,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
)
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
@@ -50,10 +50,14 @@ from datetime import datetime, timezone
|
||||
|
||||
from Modules import db_migrations
|
||||
from Modules.auth import jwt_service as jwt_service_module
|
||||
from Modules.auth.dpop import DPoPValidator
|
||||
from Modules.auth.device_auth import DeviceAuthManager
|
||||
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||
from Modules.crypto import certificates
|
||||
from Modules.agents import routes as agent_routes
|
||||
from Modules.crypto import certificates, signing
|
||||
from Modules.enrollment import routes as enrollment_routes
|
||||
from Modules.enrollment.nonce_store import NonceCache
|
||||
from Modules.tokens import routes as token_routes
|
||||
|
||||
try:
|
||||
from cryptography.fernet import Fernet # type: ignore
|
||||
@@ -149,9 +153,12 @@ os.environ.setdefault("BOREALIS_TLS_KEY", TLS_KEY_PATH)
|
||||
os.environ.setdefault("BOREALIS_TLS_BUNDLE", TLS_BUNDLE_PATH)
|
||||
|
||||
JWT_SERVICE = jwt_service_module.load_service()
|
||||
SCRIPT_SIGNER = signing.load_signer()
|
||||
IP_RATE_LIMITER = SlidingWindowRateLimiter()
|
||||
FP_RATE_LIMITER = SlidingWindowRateLimiter()
|
||||
ENROLLMENT_NONCE_CACHE = NonceCache()
|
||||
DPOP_VALIDATOR = DPoPValidator()
|
||||
DEVICE_AUTH_MANAGER: Optional[DeviceAuthManager] = None
|
||||
|
||||
|
||||
def _set_cached_github_token(token: Optional[str]) -> None:
|
||||
@@ -1248,6 +1255,14 @@ def _db_conn():
|
||||
return conn
|
||||
|
||||
|
||||
if DEVICE_AUTH_MANAGER is None:
|
||||
DEVICE_AUTH_MANAGER = DeviceAuthManager(
|
||||
db_conn_factory=_db_conn,
|
||||
jwt_service=JWT_SERVICE,
|
||||
dpop_validator=DPOP_VALIDATOR,
|
||||
log=_write_service_log,
|
||||
)
|
||||
|
||||
def _update_last_login(username: str) -> None:
|
||||
if not username:
|
||||
return
|
||||
@@ -4836,6 +4851,21 @@ enrollment_routes.register(
|
||||
nonce_cache=ENROLLMENT_NONCE_CACHE,
|
||||
)
|
||||
|
||||
token_routes.register(
|
||||
app,
|
||||
db_conn_factory=_db_conn,
|
||||
jwt_service=JWT_SERVICE,
|
||||
dpop_validator=DPOP_VALIDATOR,
|
||||
)
|
||||
|
||||
agent_routes.register(
|
||||
app,
|
||||
db_conn_factory=_db_conn,
|
||||
auth_manager=DEVICE_AUTH_MANAGER,
|
||||
log=_write_service_log,
|
||||
script_signer=SCRIPT_SIGNER,
|
||||
)
|
||||
|
||||
|
||||
def ensure_default_admin():
|
||||
"""Ensure at least one admin user exists.
|
||||
|
||||
Reference in New Issue
Block a user