Additional Auth Changes

This commit is contained in:
2025-10-17 19:11:23 -06:00
parent 174cea5549
commit a1dc656878
6 changed files with 272 additions and 40 deletions

View File

@@ -56,6 +56,7 @@ Agents establish TLS-secured REST calls to the Flask backend on port 5000 and ke
- The canonical device GUID (persisted to `guid.txt` alongside the key material). - The canonical device GUID (persisted to `guid.txt` alongside the key material).
- A short-lived access token (EdDSA/JWT) and a long-lived refresh token (stored encrypted via DPAPI and hashed server-side). - A short-lived access token (EdDSA/JWT) and a long-lived refresh token (stored encrypted via DPAPI and hashed server-side).
- The server TLS certificate and script-signing public key so the agent can pin both for future sessions. - The server TLS certificate and script-signing public key so the agent can pin both for future sessions.
- Scripts delivered over REST are signed with the server's Ed25519 code-signing key. The agent validates the signature before anything is queued for execution.
- Access tokens are automatically refreshed before expiry. Refresh failures trigger a re-enrollment. - Access tokens are automatically refreshed before expiry. Refresh failures trigger a re-enrollment.
- All REST calls (heartbeat, script polling, device details, service check-in) use these tokens; WebSocket connections include the `Authorization` header as well. - All REST calls (heartbeat, script polling, device details, service check-in) use these tokens; WebSocket connections include the `Authorization` header as well.
- Specify the installer code via `--installer-code <code>`, `BOREALIS_INSTALLER_CODE`, or by adding `"installer_code": "<code>"` to `Agent/Borealis/Settings/agent_settings.json`. - Specify the installer code via `--installer-code <code>`, `BOREALIS_INSTALLER_CODE`, or by adding `"installer_code": "<code>"` to `Agent/Borealis/Settings/agent_settings.json`.
@@ -64,7 +65,7 @@ Agents establish TLS-secured REST calls to the Flask backend on port 5000 and ke
The agent runs in the interactive user session. SYSTEM-level script execution is provided by the ScriptExec SYSTEM role using ephemeral scheduled tasks; no separate supervisor or watchdog is required. The agent runs in the interactive user session. SYSTEM-level script execution is provided by the ScriptExec SYSTEM role using ephemeral scheduled tasks; no separate supervisor or watchdog is required.
### Logging & State ### Logging & State
All runtime logs live under `Logs/<ServiceName>` relative to the project root (`Logs/Agent` for the agent family). The project avoids writing to `%ProgramData%`, `%AppData%`, or other system directories so the entire footprint stays under the Borealis folder. Log rotation is not yet implemented; contributions should consider a built-in retention strategy. Configuration and state currently live alongside the agent code. All runtime logs live under `Logs/<ServiceName>` relative to the project root (`Logs/Agent` for the agent family). Logs rotate daily and adopt the `<service>.log.YYYY-MM-DD` suffix on rollover; nothing is deleted automatically. The project avoids writing to `%ProgramData%`, `%AppData%`, or other system directories so the entire footprint stays under the Borealis folder. Configuration and state currently live alongside the agent code.
## Roles & Extensibility ## Roles & Extensibility
- Roles live under `Data/Agent/Roles/` and are autodiscovered at startup; no changes are needed in `agent.py` when adding new roles. - Roles live under `Data/Agent/Roles/` and are autodiscovered at startup; no changes are needed in `agent.py` when adding new roles.
@@ -204,4 +205,3 @@ This section summarizes what is considered usable vs. experimental today.

View File

@@ -20,7 +20,7 @@ import datetime
import shutil import shutil
import string import string
import ssl import ssl
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, List
import requests import requests
try: try:
@@ -32,6 +32,9 @@ import aiohttp
import socketio import socketio
from security import AgentKeyStore from security import AgentKeyStore
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
# Centralized logging helpers (Agent) # Centralized logging helpers (Agent)
def _agent_logs_root() -> str: def _agent_logs_root() -> str:
try: try:
@@ -49,9 +52,8 @@ def _rotate_daily(path: str):
dt = _dt.datetime.fromtimestamp(mtime) dt = _dt.datetime.fromtimestamp(mtime)
today = _dt.datetime.now().date() today = _dt.datetime.now().date()
if dt.date() != today: if dt.date() != today:
base, ext = os.path.splitext(path)
suffix = dt.strftime('%Y-%m-%d') suffix = dt.strftime('%Y-%m-%d')
newp = f"{base}.{suffix}{ext}" newp = f"{path}.{suffix}"
try: try:
os.replace(path, newp) os.replace(path, newp)
except Exception: except Exception:
@@ -241,6 +243,21 @@ def _decode_base64_text(value):
return decoded.decode('utf-8', errors='replace') return decoded.decode('utf-8', errors='replace')
def _decode_base64_bytes(value):
if not isinstance(value, str):
return None
stripped = value.strip()
if not stripped:
return b""
cleaned = ''.join(stripped.split())
if not cleaned:
return b""
try:
return base64.b64decode(cleaned, validate=True)
except Exception:
return None
def _decode_script_payload(content, encoding_hint): def _decode_script_payload(content, encoding_hint):
if isinstance(content, str): if isinstance(content, str):
encoding = str(encoding_hint or '').strip().lower() encoding = str(encoding_hint or '').strip().lower()
@@ -554,6 +571,12 @@ class AgentHttpClient:
if data.get("server_certificate"): if data.get("server_certificate"):
self.key_store.save_server_certificate(data["server_certificate"]) self.key_store.save_server_certificate(data["server_certificate"])
self._configure_verify() self._configure_verify()
signing_key = data.get("signing_key")
if signing_key:
try:
self.store_server_signing_key(signing_key)
except Exception as exc:
_log_agent(f'Unable to persist signing key from enrollment handshake: {exc}', fname='agent.error.log')
if data.get("status") != "pending": if data.get("status") != "pending":
raise RuntimeError(f"Unexpected enrollment status: {data}") raise RuntimeError(f"Unexpected enrollment status: {data}")
approval_reference = data.get("approval_reference") approval_reference = data.get("approval_reference")
@@ -595,6 +618,12 @@ class AgentHttpClient:
if server_cert: if server_cert:
self.key_store.save_server_certificate(server_cert) self.key_store.save_server_certificate(server_cert)
self._configure_verify() self._configure_verify()
signing_key = payload.get("signing_key")
if signing_key:
try:
self.store_server_signing_key(signing_key)
except Exception as exc:
_log_agent(f'Unable to persist signing key from enrollment approval: {exc}', fname='agent.error.log')
guid = payload.get("guid") guid = payload.get("guid")
access_token = payload.get("access_token") access_token = payload.get("access_token")
refresh_token = payload.get("refresh_token") refresh_token = payload.get("refresh_token")
@@ -691,7 +720,8 @@ class AgentHttpClient:
require_auth: bool = True, require_auth: bool = True,
) -> Any: ) -> Any:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self.post_json, path, payload, require_auth) task = partial(self.post_json, path, payload, require_auth=require_auth)
return await loop.run_in_executor(None, task)
def websocket_base_url(self) -> str: def websocket_base_url(self) -> str:
self.refresh_base_url() self.refresh_base_url()
@@ -1324,6 +1354,47 @@ async def send_heartbeat():
await asyncio.sleep(60) await asyncio.sleep(60)
def _verify_and_store_script_signature(
client: AgentHttpClient,
script_bytes: bytes,
signature_b64: str,
signing_key_hint: Optional[str] = None,
) -> bool:
candidates: List[str] = []
if isinstance(signing_key_hint, str) and signing_key_hint.strip():
candidates.append(signing_key_hint.strip())
stored_key = client.load_server_signing_key()
if stored_key:
key_text = stored_key.strip()
if key_text and key_text not in candidates:
candidates.append(key_text)
for key_b64 in candidates:
try:
key_der = base64.b64decode(key_b64, validate=True)
except Exception:
continue
try:
public_key = serialization.load_der_public_key(key_der)
except Exception:
continue
if not isinstance(public_key, ed25519.Ed25519PublicKey):
continue
try:
signature = base64.b64decode(signature_b64, validate=True)
except Exception:
return False
try:
public_key.verify(signature, script_bytes)
if stored_key and stored_key.strip() != key_b64:
client.store_server_signing_key(key_b64)
elif not stored_key:
client.store_server_signing_key(key_b64)
return True
except Exception:
continue
return False
async def poll_script_requests(): async def poll_script_requests():
await asyncio.sleep(20) await asyncio.sleep(20)
client = http_client() client = http_client()
@@ -1334,9 +1405,33 @@ async def poll_script_requests():
response = await client.async_post_json("/api/agent/script/request", payload, require_auth=True) response = await client.async_post_json("/api/agent/script/request", payload, require_auth=True)
if isinstance(response, dict): if isinstance(response, dict):
signing_key = response.get("signing_key") signing_key = response.get("signing_key")
if signing_key: script_b64 = response.get("script")
client.store_server_signing_key(signing_key) signature_b64 = response.get("signature")
# Placeholder: future script execution handling lives here. sig_alg = (response.get("sig_alg") or "").lower()
if script_b64 and signature_b64:
script_bytes = _decode_base64_bytes(script_b64)
if script_bytes is None:
_log_agent('received script payload with invalid base64 encoding', fname='agent.error.log')
elif sig_alg and sig_alg not in ("ed25519", "eddsa"):
_log_agent(f'unsupported script signature algorithm: {sig_alg}', fname='agent.error.log')
else:
existing_key = client.load_server_signing_key()
key_available = bool(
(isinstance(signing_key, str) and signing_key.strip())
or (isinstance(existing_key, str) and existing_key.strip())
)
if not key_available:
_log_agent('no server signing key available to verify script payload', fname='agent.error.log')
elif _verify_and_store_script_signature(client, script_bytes, signature_b64, signing_key):
_log_agent('received signed script payload (verification succeeded); awaiting executor integration')
else:
_log_agent('rejected script payload due to invalid signature', fname='agent.error.log')
elif signing_key:
# No script content, but we may need to persist updated signing key.
try:
client.store_server_signing_key(signing_key)
except Exception as exc:
_log_agent(f'failed to persist server signing key: {exc}', fname='agent.error.log')
except Exception as exc: except Exception as exc:
_log_agent(f'script request poll failed: {exc}', fname='agent.error.log') _log_agent(f'script request poll failed: {exc}', fname='agent.error.log')
await asyncio.sleep(30) await asyncio.sleep(30)

View File

@@ -8,6 +8,7 @@ import jwt
from flask import g, jsonify, request from flask import g, jsonify, request
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
from Modules.auth.rate_limit import SlidingWindowRateLimiter
@dataclass @dataclass
@@ -25,11 +26,18 @@ class DeviceAuthError(Exception):
status_code = 401 status_code = 401
error_code = "unauthorized" error_code = "unauthorized"
def __init__(self, message: str = "unauthorized", *, status_code: Optional[int] = None): def __init__(
self,
message: str = "unauthorized",
*,
status_code: Optional[int] = None,
retry_after: Optional[float] = None,
):
super().__init__(message) super().__init__(message)
if status_code is not None: if status_code is not None:
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.retry_after = retry_after
class DeviceAuthManager: class DeviceAuthManager:
@@ -40,11 +48,13 @@ class DeviceAuthManager:
jwt_service, jwt_service,
dpop_validator: Optional[DPoPValidator], dpop_validator: Optional[DPoPValidator],
log: Callable[[str, str], None], log: Callable[[str, str], None],
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
) -> None: ) -> None:
self._db_conn_factory = db_conn_factory self._db_conn_factory = db_conn_factory
self._jwt_service = jwt_service self._jwt_service = jwt_service
self._dpop_validator = dpop_validator self._dpop_validator = dpop_validator
self._log = log self._log = log
self._rate_limiter = rate_limiter
def authenticate(self) -> DeviceAuthContext: def authenticate(self) -> DeviceAuthContext:
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")
@@ -67,6 +77,15 @@ class DeviceAuthManager:
if not guid or not fingerprint or token_version <= 0: if not guid or not fingerprint or token_version <= 0:
raise DeviceAuthError("invalid_claims") raise DeviceAuthError("invalid_claims")
if self._rate_limiter:
decision = self._rate_limiter.check(f"fp:{fingerprint}", 60, 60.0)
if not decision.allowed:
raise DeviceAuthError(
"rate_limited",
status_code=429,
retry_after=decision.retry_after,
)
conn = self._db_conn_factory() conn = self._db_conn_factory()
try: try:
cur = conn.cursor() cur = conn.cursor()
@@ -138,6 +157,12 @@ def require_device_auth(manager: DeviceAuthManager):
except DeviceAuthError as exc: except DeviceAuthError as exc:
response = jsonify({"error": exc.message}) response = jsonify({"error": exc.message})
response.status_code = exc.status_code response.status_code = exc.status_code
retry_after = getattr(exc, "retry_after", None)
if retry_after:
try:
response.headers["Retry-After"] = str(max(1, int(retry_after)))
except Exception:
response.headers["Retry-After"] = "1"
return response return response
g.device_auth = ctx g.device_auth = ctx

View File

@@ -26,6 +26,7 @@ def register(
ip_rate_limiter: SlidingWindowRateLimiter, ip_rate_limiter: SlidingWindowRateLimiter,
fp_rate_limiter: SlidingWindowRateLimiter, fp_rate_limiter: SlidingWindowRateLimiter,
nonce_cache: NonceCache, nonce_cache: NonceCache,
script_signer,
) -> None: ) -> None:
blueprint = Blueprint("enrollment", __name__) blueprint = Blueprint("enrollment", __name__)
@@ -42,6 +43,14 @@ def register(
addr = request.remote_addr or "unknown" addr = request.remote_addr or "unknown"
return addr.strip() return addr.strip()
def _signing_key_b64() -> str:
if not script_signer:
return ""
try:
return script_signer.public_base64_spki()
except Exception:
return ""
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float): def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
decision = limiter.check(key, limit, window_s) decision = limiter.check(key, limit, window_s)
if not decision.allowed: if not decision.allowed:
@@ -312,6 +321,7 @@ def register(
"server_nonce": server_nonce_b64, "server_nonce": server_nonce_b64,
"poll_after_ms": 3000, "poll_after_ms": 3000,
"server_certificate": _load_tls_bundle(tls_bundle_path), "server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(),
} }
log("server", f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}") log("server", f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}")
return jsonify(response) return jsonify(response)
@@ -466,6 +476,7 @@ def register(
"refresh_token": refresh_info["token"], "refresh_token": refresh_info["token"],
"token_type": "Bearer", "token_type": "Bearer",
"server_certificate": _load_tls_bundle(tls_bundle_path), "server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(),
} }
) )

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timedelta, timezone
from typing import Callable from typing import Callable
import eventlet import eventlet
@@ -25,7 +25,9 @@ def start_prune_job(
def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None]) -> None: def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None]) -> None:
now_iso = datetime.now(tz=timezone.utc).isoformat() now = datetime.now(tz=timezone.utc)
now_iso = now.isoformat()
stale_before = (now - timedelta(hours=24)).isoformat()
conn = db_conn_factory() conn = db_conn_factory()
try: try:
cur = conn.cursor() cur = conn.cursor()
@@ -55,7 +57,7 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
OR created_at < ? OR created_at < ?
) )
""", """,
(now_iso, now_iso, now_iso), (now_iso, now_iso, stale_before),
) )
approvals_marked = cur.rowcount or 0 approvals_marked = cur.rowcount or 0

View File

@@ -27,7 +27,7 @@ from eventlet import tpool
import requests import requests
import re import re
import base64 import base64
from flask import Flask, request, jsonify, Response, send_from_directory, make_response, session from flask import Flask, request, jsonify, Response, send_from_directory, make_response, session, g
from flask_socketio import SocketIO, emit, join_room from flask_socketio import SocketIO, emit, join_room
from flask_cors import CORS from flask_cors import CORS
from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.middleware.proxy_fix import ProxyFix
@@ -51,7 +51,7 @@ from datetime import datetime, timezone
from Modules import db_migrations from Modules import db_migrations
from Modules.auth import jwt_service as jwt_service_module from Modules.auth import jwt_service as jwt_service_module
from Modules.auth.dpop import DPoPValidator from Modules.auth.dpop import DPoPValidator
from Modules.auth.device_auth import DeviceAuthManager from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
from Modules.auth.rate_limit import SlidingWindowRateLimiter from Modules.auth.rate_limit import SlidingWindowRateLimiter
from Modules.agents import routes as agent_routes from Modules.agents import routes as agent_routes
from Modules.crypto import certificates, signing from Modules.crypto import certificates, signing
@@ -95,9 +95,8 @@ def _rotate_daily(path: str):
dt = _dt.datetime.fromtimestamp(mtime) dt = _dt.datetime.fromtimestamp(mtime)
today = _dt.datetime.now().date() today = _dt.datetime.now().date()
if dt.date() != today: if dt.date() != today:
base, ext = os.path.splitext(path)
suffix = dt.strftime('%Y-%m-%d') suffix = dt.strftime('%Y-%m-%d')
newp = f"{base}.{suffix}{ext}" newp = f"{path}.{suffix}"
try: try:
os.replace(path, newp) os.replace(path, newp)
except Exception: except Exception:
@@ -158,6 +157,7 @@ JWT_SERVICE = jwt_service_module.load_service()
SCRIPT_SIGNER = signing.load_signer() SCRIPT_SIGNER = signing.load_signer()
IP_RATE_LIMITER = SlidingWindowRateLimiter() IP_RATE_LIMITER = SlidingWindowRateLimiter()
FP_RATE_LIMITER = SlidingWindowRateLimiter() FP_RATE_LIMITER = SlidingWindowRateLimiter()
AUTH_RATE_LIMITER = SlidingWindowRateLimiter()
ENROLLMENT_NONCE_CACHE = NonceCache() ENROLLMENT_NONCE_CACHE = NonceCache()
DPOP_VALIDATOR = DPoPValidator() DPOP_VALIDATOR = DPoPValidator()
DEVICE_AUTH_MANAGER: Optional[DeviceAuthManager] = None DEVICE_AUTH_MANAGER: Optional[DeviceAuthManager] = None
@@ -1263,6 +1263,7 @@ if DEVICE_AUTH_MANAGER is None:
jwt_service=JWT_SERVICE, jwt_service=JWT_SERVICE,
dpop_validator=DPOP_VALIDATOR, dpop_validator=DPOP_VALIDATOR,
log=_write_service_log, log=_write_service_log,
rate_limiter=AUTH_RATE_LIMITER,
) )
def _update_last_login(username: str) -> None: def _update_last_login(username: str) -> None:
@@ -4851,6 +4852,7 @@ enrollment_routes.register(
ip_rate_limiter=IP_RATE_LIMITER, ip_rate_limiter=IP_RATE_LIMITER,
fp_rate_limiter=FP_RATE_LIMITER, fp_rate_limiter=FP_RATE_LIMITER,
nonce_cache=ENROLLMENT_NONCE_CACHE, nonce_cache=ENROLLMENT_NONCE_CACHE,
script_signer=SCRIPT_SIGNER,
) )
token_routes.register( token_routes.register(
@@ -6254,6 +6256,7 @@ def _deep_merge_preserve(prev: dict, incoming: dict) -> dict:
# Endpoint: /api/agent/details — methods POST. # Endpoint: /api/agent/details — methods POST.
@app.route("/api/agent/details", methods=["POST"]) @app.route("/api/agent/details", methods=["POST"])
@require_device_auth(DEVICE_AUTH_MANAGER)
def save_agent_details(): def save_agent_details():
data = request.get_json(silent=True) or {} data = request.get_json(silent=True) or {}
hostname = data.get("hostname") hostname = data.get("hostname")
@@ -6264,11 +6267,9 @@ def save_agent_details():
agent_hash = agent_hash.strip() or None agent_hash = agent_hash.strip() or None
else: else:
agent_hash = None agent_hash = None
agent_guid = data.get("agent_guid") ctx = getattr(g, "device_auth")
if isinstance(agent_guid, str): auth_guid = _normalize_guid(ctx.guid)
agent_guid = agent_guid.strip() or None fingerprint = (ctx.ssl_key_fingerprint or "").strip()
else:
agent_guid = None
if not hostname and isinstance(details, dict): if not hostname and isinstance(details, dict):
hostname = (details.get("summary") or {}).get("hostname") hostname = (details.get("summary") or {}).get("hostname")
if not hostname or not isinstance(details, dict): if not hostname or not isinstance(details, dict):
@@ -6285,6 +6286,13 @@ def save_agent_details():
created_at = int(snapshot.get("created_at") or 0) created_at = int(snapshot.get("created_at") or 0)
existing_guid = (snapshot.get("agent_guid") or "").strip() or None existing_guid = (snapshot.get("agent_guid") or "").strip() or None
existing_agent_hash = (snapshot.get("agent_hash") or "").strip() or None existing_agent_hash = (snapshot.get("agent_hash") or "").strip() or None
db_fp = (snapshot.get("ssl_key_fingerprint") or "").strip().lower()
if db_fp and fingerprint and db_fp != fingerprint.lower():
return jsonify({"error": "fingerprint_mismatch"}), 403
normalized_existing_guid = _normalize_guid(existing_guid) if existing_guid else None
if normalized_existing_guid and auth_guid and normalized_existing_guid != auth_guid:
return jsonify({"error": "guid_mismatch"}), 403
# Ensure summary exists and attach hostname/agent_id if missing # Ensure summary exists and attach hostname/agent_id if missing
incoming_summary = details.setdefault("summary", {}) incoming_summary = details.setdefault("summary", {})
@@ -6300,10 +6308,12 @@ def save_agent_details():
incoming_summary["agent_hash"] = agent_hash incoming_summary["agent_hash"] = agent_hash
except Exception: except Exception:
pass pass
effective_guid = agent_guid or existing_guid effective_guid = auth_guid or existing_guid
normalized_effective_guid = _normalize_guid(effective_guid) if effective_guid else None normalized_effective_guid = auth_guid or normalized_existing_guid
if normalized_effective_guid: if normalized_effective_guid:
incoming_summary["agent_guid"] = normalized_effective_guid incoming_summary["agent_guid"] = normalized_effective_guid
if fingerprint:
incoming_summary.setdefault("ssl_key_fingerprint", fingerprint)
# Preserve last_seen if incoming omitted it # Preserve last_seen if incoming omitted it
if not incoming_summary.get("last_seen"): if not incoming_summary.get("last_seen"):
@@ -6366,6 +6376,24 @@ def save_agent_details():
agent_hash=agent_hash or existing_agent_hash, agent_hash=agent_hash or existing_agent_hash,
guid=normalized_effective_guid, guid=normalized_effective_guid,
) )
if normalized_effective_guid and fingerprint:
now_iso = datetime.now(timezone.utc).isoformat()
cur.execute(
"""
UPDATE devices
SET ssl_key_fingerprint = ?,
key_added_at = COALESCE(key_added_at, ?)
WHERE guid = ?
""",
(fingerprint, now_iso, normalized_effective_guid),
)
cur.execute(
"""
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
VALUES (?, ?, ?, ?)
""",
(str(uuid.uuid4()), normalized_effective_guid, fingerprint, now_iso),
)
conn.commit() conn.commit()
conn.close() conn.close()
@@ -7149,16 +7177,28 @@ def _service_acct_set(conn, agent_id: str, username: str, plaintext_password: st
# Endpoint: /api/agent/checkin — methods POST. # Endpoint: /api/agent/checkin — methods POST.
@app.route('/api/agent/checkin', methods=['POST']) @app.route('/api/agent/checkin', methods=['POST'])
@require_device_auth(DEVICE_AUTH_MANAGER)
def api_agent_checkin(): def api_agent_checkin():
payload = request.get_json(silent=True) or {} payload = request.get_json(silent=True) or {}
agent_id = (payload.get('agent_id') or '').strip() agent_id = (payload.get('agent_id') or '').strip()
if not agent_id: if not agent_id:
return jsonify({'error': 'agent_id required'}), 400 return jsonify({'error': 'agent_id required'}), 400
ctx = getattr(g, "device_auth")
auth_guid = _normalize_guid(ctx.guid)
fingerprint = (ctx.ssl_key_fingerprint or "").strip()
raw_username = (payload.get('username') or '').strip() raw_username = (payload.get('username') or '').strip()
username = raw_username or DEFAULT_SERVICE_ACCOUNT username = raw_username or DEFAULT_SERVICE_ACCOUNT
if username in LEGACY_SERVICE_ACCOUNTS: if username in LEGACY_SERVICE_ACCOUNTS:
username = DEFAULT_SERVICE_ACCOUNT username = DEFAULT_SERVICE_ACCOUNT
hostname = (payload.get('hostname') or '').strip() hostname = (payload.get('hostname') or '').strip()
reg = registered_agents.get(agent_id) or {}
reg_guid = _normalize_guid(reg.get("agent_guid") or "")
if reg_guid and auth_guid and reg_guid != auth_guid:
return jsonify({'error': 'guid_mismatch'}), 403
conn = None
try: try:
conn = _db_conn() conn = _db_conn()
row = _service_acct_get(conn, agent_id) row = _service_acct_get(conn, agent_id)
@@ -7189,38 +7229,92 @@ def api_agent_checkin():
'password': plain, 'password': plain,
'last_rotated_utc': row[3] or _now_iso_utc(), 'last_rotated_utc': row[3] or _now_iso_utc(),
} }
conn.close()
_ansible_log_server(f"[checkin] return creds agent_id={agent_id} user={out['username']}") now_ts = int(time.time())
try: try:
if hostname: if hostname:
_persist_last_seen(hostname, int(time.time()), agent_id) _persist_last_seen(hostname, now_ts, agent_id)
except Exception: except Exception:
pass pass
agent_guid = _ensure_agent_guid(agent_id, hostname or None)
if agent_guid and agent_id: try:
rec = registered_agents.setdefault(agent_id, {}) cur = conn.cursor()
rec['agent_guid'] = agent_guid if auth_guid:
else: cur.execute(
agent_guid = agent_guid or '' """
return jsonify({ UPDATE devices
'username': out['username'], SET agent_id = COALESCE(?, agent_id),
'password': out['password'], ssl_key_fingerprint = COALESCE(?, ssl_key_fingerprint),
'policy': { 'force_rotation_minutes': 43200 }, last_seen = ?
'agent_guid': agent_guid or None, WHERE guid = ?
}) """,
(agent_id or None, fingerprint or None, now_ts, auth_guid),
)
if cur.rowcount == 0 and hostname:
cur.execute(
"""
UPDATE devices
SET guid = ?,
agent_id = COALESCE(?, agent_id),
ssl_key_fingerprint = COALESCE(?, ssl_key_fingerprint),
last_seen = ?
WHERE hostname = ?
""",
(auth_guid, agent_id or None, fingerprint or None, now_ts, hostname),
)
if fingerprint:
cur.execute(
"""
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
VALUES (?, ?, ?, ?)
""",
(str(uuid.uuid4()), auth_guid, fingerprint, datetime.now(timezone.utc).isoformat()),
)
conn.commit()
except Exception as exc:
_write_service_log("server", f"device update during checkin failed: {exc}")
registered = registered_agents.setdefault(agent_id, {})
if auth_guid:
registered["agent_guid"] = auth_guid
_ansible_log_server(f"[checkin] return creds agent_id={agent_id} user={out['username']}")
return jsonify(
{
'username': out['username'],
'password': out['password'],
'policy': {'force_rotation_minutes': 43200},
'agent_guid': auth_guid or None,
}
)
except Exception as e: except Exception as e:
_ansible_log_server(f"[checkin] error agent_id={agent_id} err={e}") _ansible_log_server(f"[checkin] error agent_id={agent_id} err={e}")
return jsonify({'error': str(e)}), 500 return jsonify({'error': str(e)}), 500
finally:
if conn:
try:
conn.close()
except Exception:
pass
# Endpoint: /api/agent/service-account/rotate — methods POST. # Endpoint: /api/agent/service-account/rotate — methods POST.
@app.route('/api/agent/service-account/rotate', methods=['POST']) @app.route('/api/agent/service-account/rotate', methods=['POST'])
@require_device_auth(DEVICE_AUTH_MANAGER)
def api_agent_service_account_rotate(): def api_agent_service_account_rotate():
payload = request.get_json(silent=True) or {} payload = request.get_json(silent=True) or {}
agent_id = (payload.get('agent_id') or '').strip() agent_id = (payload.get('agent_id') or '').strip()
if not agent_id: if not agent_id:
return jsonify({'error': 'agent_id required'}), 400 return jsonify({'error': 'agent_id required'}), 400
ctx = getattr(g, "device_auth")
auth_guid = _normalize_guid(ctx.guid)
reg = registered_agents.get(agent_id) or {}
reg_guid = _normalize_guid(reg.get("agent_guid") or "")
if reg_guid and auth_guid and reg_guid != auth_guid:
return jsonify({'error': 'guid_mismatch'}), 403
requested_username = (payload.get('username') or '').strip() requested_username = (payload.get('username') or '').strip()
try: try:
conn = _db_conn() conn = _db_conn()
@@ -7234,7 +7328,12 @@ def api_agent_service_account_rotate():
_ansible_log_server(f"[rotate] upgrading legacy service user for agent_id={agent_id}") _ansible_log_server(f"[rotate] upgrading legacy service user for agent_id={agent_id}")
pw_new = _gen_strong_password() pw_new = _gen_strong_password()
out = _service_acct_set(conn, agent_id, user_eff, pw_new) out = _service_acct_set(conn, agent_id, user_eff, pw_new)
conn.close() try:
registered = registered_agents.setdefault(agent_id, {})
if auth_guid:
registered["agent_guid"] = auth_guid
finally:
conn.close()
_ansible_log_server(f"[rotate] rotated agent_id={agent_id} user={out['username']} at={out['last_rotated_utc']}") _ansible_log_server(f"[rotate] rotated agent_id={agent_id} user={out['username']} at={out['last_rotated_utc']}")
return jsonify({ return jsonify({
'username': out['username'], 'username': out['username'],