Implement server-managed agent update handshake

This commit is contained in:
2025-10-05 16:15:48 -06:00
parent 2b0269c6b0
commit 8f29de86af
4 changed files with 498 additions and 396 deletions

View File

@@ -393,6 +393,151 @@ def api_agent_repo_hash():
_write_service_log('server', f'/api/agent/repo_hash error: {exc}')
return jsonify({"error": "internal error"}), 500
@app.route("/api/agent/update_check", methods=["POST"])
def api_agent_update_check():
data = request.get_json(silent=True) or {}
agent_id = (data.get("agent_id") or "").strip()
if not agent_id:
return jsonify({"error": "agent_id required"}), 400
repo_info = _refresh_default_repo_hash()
repo_sha = (repo_info.get("sha") or "").strip()
if not repo_sha:
payload = {
"error": repo_info.get("error") or "repository hash unavailable",
"repo_source": repo_info.get("source"),
}
return jsonify(payload), 503
registry_info = registered_agents.get(agent_id) or {}
hostname = (registry_info.get("hostname") or "").strip() or None
stored_hash: Optional[str] = (registry_info.get("agent_hash") or "").strip() or None
conn = None
try:
conn = _db_conn()
cur = conn.cursor()
rows = _device_rows_for_agent(cur, agent_id)
except Exception:
rows = []
finally:
if conn:
try:
conn.close()
except Exception:
pass
if rows:
hostname = rows[0].get("hostname") or hostname
for row in rows:
if row.get("matched"):
hostname = row.get("hostname") or hostname
candidate = (row.get("agent_hash") or "").strip()
if not candidate:
summary = row.get("details") or {}
try:
candidate = (summary.get("summary") or {}).get("agent_hash") or ""
except Exception:
candidate = ""
candidate = candidate.strip()
stored_hash = candidate or None
break
if stored_hash is None:
first = rows[0]
candidate = (first.get("agent_hash") or "").strip()
if not candidate:
details = first.get("details") or {}
try:
candidate = (details.get("summary") or {}).get("agent_hash") or ""
except Exception:
candidate = ""
candidate = candidate.strip()
stored_hash = candidate or None
update_available = (not stored_hash) or (stored_hash.strip() != repo_sha)
payload = {
"agent_id": agent_id,
"hostname": hostname,
"repo_hash": repo_sha,
"agent_hash": stored_hash,
"update_available": bool(update_available),
"repo_source": repo_info.get("source"),
}
if repo_info.get("cached") is not None:
payload["cached"] = bool(repo_info.get("cached"))
if repo_info.get("age_seconds") is not None:
payload["age_seconds"] = repo_info.get("age_seconds")
if repo_info.get("error"):
payload["repo_error"] = repo_info.get("error")
return jsonify(payload)
@app.route("/api/agent/agent_hash", methods=["POST"])
def api_agent_agent_hash_post():
data = request.get_json(silent=True) or {}
agent_id = (data.get("agent_id") or "").strip()
agent_hash = (data.get("agent_hash") or "").strip()
if not agent_id or not agent_hash:
return jsonify({"error": "agent_id and agent_hash required"}), 400
conn = None
hostname = None
try:
conn = _db_conn()
cur = conn.cursor()
rows = _device_rows_for_agent(cur, agent_id)
target = None
for row in rows:
if row.get("matched"):
target = row
break
if not target:
if conn:
conn.close()
return jsonify({"status": "ignored"}), 200
hostname = target.get("hostname")
details = target.get("details") or {}
summary = details.setdefault("summary", {})
summary["agent_hash"] = agent_hash
cur.execute(
"UPDATE device_details SET agent_hash=?, details=? WHERE hostname=?",
(agent_hash, json.dumps(details), hostname),
)
conn.commit()
except Exception as exc:
if conn:
try:
conn.rollback()
except Exception:
pass
_write_service_log('server', f'/api/agent/agent_hash error: {exc}')
return jsonify({"error": "internal error"}), 500
finally:
if conn:
try:
conn.close()
except Exception:
pass
normalized_hash = agent_hash
if agent_id in registered_agents:
registered_agents[agent_id]["agent_hash"] = normalized_hash
try:
for aid, rec in registered_agents.items():
if rec.get("hostname") and hostname and rec["hostname"] == hostname:
rec["agent_hash"] = normalized_hash
except Exception:
pass
return jsonify({
"status": "ok",
"agent_id": agent_id,
"hostname": hostname,
"agent_hash": agent_hash,
})
# ---------------------------------------------
# Server Time Endpoint
# ---------------------------------------------
@@ -2887,6 +3032,57 @@ def load_agents_from_db():
load_agents_from_db()
def _extract_hostname_from_agent(agent_id: str) -> Optional[str]:
try:
agent_id = (agent_id or "").strip()
if not agent_id:
return None
lower = agent_id.lower()
marker = "-agent"
idx = lower.find(marker)
if idx <= 0:
return None
return agent_id[:idx]
except Exception:
return None
def _device_rows_for_agent(cur, agent_id: str) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = []
normalized_id = (agent_id or "").strip()
if not normalized_id:
return results
base_host = _extract_hostname_from_agent(normalized_id)
if not base_host:
return results
try:
cur.execute(
"SELECT hostname, agent_hash, details FROM device_details WHERE LOWER(hostname) = ?",
(base_host.lower(),),
)
rows = cur.fetchall()
except Exception:
return results
for hostname, agent_hash, details_json in rows or []:
try:
details = json.loads(details_json or "{}")
except Exception:
details = {}
summary = details.get("summary") or {}
summary_agent = (summary.get("agent_id") or "").strip()
matched = summary_agent.lower() == normalized_id.lower() if summary_agent else False
results.append(
{
"hostname": hostname,
"agent_hash": (agent_hash or "").strip(),
"details": details,
"summary_agent_id": summary_agent,
"matched": matched,
}
)
return results
@app.route("/api/agents")
def get_agents():
"""Return agents with collector activity indicator."""
@@ -3085,7 +3281,7 @@ def get_device_details(hostname: str):
conn = _db_conn()
cur = conn.cursor()
cur.execute(
"SELECT details, description, created_at FROM device_details WHERE hostname = ?",
"SELECT details, description, created_at, agent_hash FROM device_details WHERE hostname = ?",
(hostname,),
)
row = cur.fetchone()
@@ -3097,6 +3293,7 @@ def get_device_details(hostname: str):
details = {}
description = row[1] if len(row) > 1 else ""
created_at = int(row[2] or 0) if len(row) > 2 else 0
agent_hash = (row[3] or "").strip() if len(row) > 3 else ""
if description:
details.setdefault("summary", {})["description"] = description
# Ensure created string exists from created_at
@@ -3106,6 +3303,13 @@ def get_device_details(hostname: str):
details.setdefault('summary', {})['created'] = datetime.utcfromtimestamp(created_at).strftime('%Y-%m-%d %H:%M:%S')
except Exception:
pass
if agent_hash:
try:
details.setdefault('summary', {})
if not details['summary'].get('agent_hash'):
details['summary']['agent_hash'] = agent_hash
except Exception:
pass
return jsonify(details)
except Exception:
pass
@@ -3140,28 +3344,12 @@ def set_device_description(hostname: str):
@app.route("/api/agent/hash/<path:agent_id>", methods=["GET"])
def get_agent_hash(agent_id: str):
"""Return the last known github_repo_hash for a specific agent."""
"""Return the last known repository hash for a specific agent."""
agent_id = (agent_id or "").strip()
if not agent_id:
return jsonify({"error": "invalid agent id"}), 400
def _extract_hostname_from_agent(agent: str) -> Optional[str]:
try:
agent = (agent or "").strip()
if not agent:
return None
lower = agent.lower()
marker = "-agent"
if marker not in lower:
return None
idx = lower.index(marker)
if idx <= 0:
return None
return agent[:idx]
except Exception:
return None
# Prefer the in-memory registry (updated on every heartbeat/details post).
info = registered_agents.get(agent_id) or {}
candidate = (info.get("agent_hash") or "").strip()
@@ -3178,87 +3366,66 @@ def get_agent_hash(agent_id: str):
try:
conn = _db_conn()
cur = conn.cursor()
parsed_hostname = _extract_hostname_from_agent(agent_id)
host_candidates = []
if hostname:
host_candidates.append(hostname)
if parsed_hostname and not any(
parsed_hostname.lower() == (h or "").lower() for h in host_candidates
):
host_candidates.append(parsed_hostname)
def _load_row_for_host(host_value: str):
if not host_value:
return None, None
try:
cur.execute(
"SELECT agent_hash, details, hostname FROM device_details WHERE LOWER(hostname) = ?",
(host_value.lower(),),
)
return cur.fetchone()
except Exception:
return None
row = None
matched_hostname = hostname
for host_value in host_candidates:
fetched = _load_row_for_host(host_value)
if fetched:
row = fetched
matched_hostname = fetched[2] or matched_hostname or host_value
break
if not row:
# No hostname available or found; scan for a matching agent_id in the JSON payload.
cur.execute("SELECT hostname, agent_hash, details FROM device_details")
for host, db_hash, details_json in cur.fetchall():
rows = _device_rows_for_agent(cur, agent_id)
if rows:
if not hostname:
hostname = rows[0].get("hostname") or hostname
for row in rows:
if row.get("matched"):
normalized_hash = (row.get("agent_hash") or "").strip()
if not normalized_hash:
details = row.get("details") or {}
try:
normalized_hash = ((details.get("summary") or {}).get("agent_hash") or "").strip()
except Exception:
normalized_hash = ""
if normalized_hash:
effective_hostname = row.get("hostname") or hostname
return jsonify({
"agent_id": agent_id,
"agent_hash": normalized_hash,
"hostname": effective_hostname,
"source": "database",
})
first = rows[0]
fallback_hash = (first.get("agent_hash") or "").strip()
if not fallback_hash:
details = first.get("details") or {}
try:
data = json.loads(details_json or "{}")
fallback_hash = ((details.get("summary") or {}).get("agent_hash") or "").strip()
except Exception:
data = {}
summary = data.get("summary") or {}
summary_agent_id = (summary.get("agent_id") or "").strip()
summary_hostname = (summary.get("hostname") or "").strip()
if summary_agent_id == agent_id:
row = (db_hash, details_json, host)
matched_hostname = host or summary_hostname or matched_hostname
break
if (
not row
and parsed_hostname
and summary_hostname
and summary_hostname.lower() == parsed_hostname.lower()
):
row = (db_hash, details_json, host)
matched_hostname = host or summary_hostname
conn.close()
fallback_hash = ""
if fallback_hash:
effective_hostname = first.get("hostname") or hostname
return jsonify({
"agent_id": agent_id,
"agent_hash": fallback_hash,
"hostname": effective_hostname,
"source": "database",
})
if row:
db_hash = (row[0] or "").strip()
effective_hostname = matched_hostname or hostname or parsed_hostname
if db_hash:
return jsonify({
"agent_id": agent_id,
"agent_hash": db_hash,
"hostname": effective_hostname,
"source": "database",
})
# Hash column may be empty if only stored inside details JSON.
# As a final fallback, scan the table for any matching agent_id in case hostname inference failed.
cur.execute("SELECT hostname, agent_hash, details FROM device_details")
for host, db_hash, details_json in cur.fetchall():
try:
details = json.loads(row[1] if len(row) > 1 else "{}")
data = json.loads(details_json or "{}")
except Exception:
details = {}
summary = details.get("summary") or {}
data = {}
summary = data.get("summary") or {}
summary_agent_id = (summary.get("agent_id") or "").strip()
if summary_agent_id != agent_id:
continue
summary_hash = (summary.get("agent_hash") or "").strip()
if summary_hash:
normalized_hash = (db_hash or "").strip() or summary_hash
if normalized_hash:
effective_hostname = host or summary.get("hostname") or hostname
return jsonify({
"agent_id": agent_id,
"agent_hash": summary_hash,
"agent_hash": normalized_hash,
"hostname": effective_hostname,
"source": "database",
})
conn.close()
return jsonify({"error": "agent hash not found"}), 404
except Exception as e: