Implement server-managed agent update handshake

This commit is contained in:
2025-10-05 16:15:48 -06:00
parent 2b0269c6b0
commit 8f29de86af
4 changed files with 498 additions and 396 deletions

View File

@@ -9,7 +9,6 @@ import shutil
import string
import asyncio
from pathlib import Path
from typing import Optional
try:
import psutil # type: ignore
@@ -135,166 +134,6 @@ def _project_root():
return os.getcwd()
_AGENT_HASH_CACHE = {
"path": None,
"mtime": None,
"value": None,
"source": None,
"extra": None,
}
def _iter_hash_roots():
"""Yield candidate folders that may contain github_repo_hash.txt."""
root = _project_root()
if not root:
return
# Breadth-first walk up to a small, bounded set of parents/siblings.
seen = set()
queue = [root]
# Some deployments place the hash file directly under Agent/, while others
# (including the scheduled updater) write to Agent/Borealis/. The previous
# implementation only checked the parent chain which skipped Agent/Borealis,
# so seed the queue with that sibling when available.
borealis = os.path.join(root, "Borealis")
if os.path.isdir(borealis):
queue.append(borealis)
steps = 0
while queue and steps < 12: # hard stop to avoid wandering too far
steps += 1
cur = queue.pop(0)
if not cur or cur in seen:
continue
seen.add(cur)
yield cur
parent = os.path.dirname(cur)
if parent and parent != cur and parent not in seen:
queue.append(parent)
# If we're currently at Agent/ or its parent, also check for an adjacent
# Borealis/ folder in case the hash lives there.
if cur != borealis:
candidate = os.path.join(cur, "Borealis")
if os.path.isdir(candidate) and candidate not in seen:
queue.append(candidate)
def _resolve_git_head_hash(root: str) -> Optional[str]:
git_dir = os.path.join(root, ".git")
head_path = os.path.join(git_dir, "HEAD")
if not os.path.isfile(head_path):
return None
try:
with open(head_path, "r", encoding="utf-8") as fh:
head = fh.read().strip()
except Exception:
return None
if not head:
return None
if head.startswith("ref:"):
ref = head.split(" ", 1)[1].strip() if " " in head else head.split(":", 1)[1].strip()
if not ref:
return None
ref_path = os.path.join(git_dir, *ref.split("/"))
if os.path.isfile(ref_path):
try:
with open(ref_path, "r", encoding="utf-8") as rf:
commit = rf.read().strip()
return commit or None
except Exception:
return None
packed_refs = os.path.join(git_dir, "packed-refs")
if os.path.isfile(packed_refs):
try:
with open(packed_refs, "r", encoding="utf-8") as pf:
for line in pf:
line = line.strip()
if not line or line.startswith("#") or line.startswith("^"):
continue
try:
commit, ref_name = line.split(" ", 1)
except ValueError:
continue
if ref_name.strip() == ref:
commit = commit.strip()
return commit or None
except Exception:
return None
return None
# Detached head contains the commit hash directly
commit = head.splitlines()[0].strip()
return commit or None
def _read_agent_hash():
try:
cache = _AGENT_HASH_CACHE
for root in _iter_hash_roots():
path = os.path.join(root, 'github_repo_hash.txt')
if not os.path.isfile(path):
continue
mtime = os.path.getmtime(path)
if (
cache.get("source") == "file"
and cache.get("path") == path
and cache.get("mtime") == mtime
):
return cache.get("value")
with open(path, 'r', encoding='utf-8') as fh:
value = fh.read().strip()
cache.update(
{
"source": "file",
"path": path,
"mtime": mtime,
"extra": None,
"value": value or None,
}
)
return cache.get("value")
for root in _iter_hash_roots():
git_dir = os.path.join(root, '.git')
head_path = os.path.join(git_dir, 'HEAD')
if not os.path.isfile(head_path):
continue
head_mtime = os.path.getmtime(head_path)
packed_path = os.path.join(git_dir, 'packed-refs')
packed_mtime = os.path.getmtime(packed_path) if os.path.isfile(packed_path) else None
if (
cache.get("source") == "git"
and cache.get("path") == head_path
and cache.get("mtime") == head_mtime
and cache.get("extra") == packed_mtime
):
return cache.get("value")
commit = _resolve_git_head_hash(root)
cache.update(
{
"source": "git",
"path": head_path,
"mtime": head_mtime,
"extra": packed_mtime,
"value": commit or None,
}
)
if commit:
return commit
cache.update({"source": None, "path": None, "mtime": None, "extra": None, "value": None})
return None
except Exception:
try:
_AGENT_HASH_CACHE.update({"value": None})
except Exception:
pass
return None
# Removed Ansible-based audit path; Python collectors provide details directly.
@@ -938,12 +777,6 @@ def _build_details_fallback() -> dict:
'storage': collect_storage(),
'network': network,
}
try:
agent_hash_value = _read_agent_hash()
if agent_hash_value:
details.setdefault('summary', {})['agent_hash'] = agent_hash_value
except Exception:
pass
return details
@@ -995,12 +828,6 @@ class Role:
# Always post the latest available details (possibly cached)
details_to_send = self._last_details or {'summary': collect_summary(self.ctx.config)}
agent_hash_value = _read_agent_hash()
if agent_hash_value:
try:
details_to_send.setdefault('summary', {})['agent_hash'] = agent_hash_value
except Exception:
pass
get_url = (self.ctx.hooks.get('get_server_url') if isinstance(self.ctx.hooks, dict) else None) or (lambda: 'http://localhost:5000')
url = (get_url() or '').rstrip('/') + '/api/agent/details'
payload = {
@@ -1008,8 +835,6 @@ class Role:
'hostname': details_to_send.get('summary', {}).get('hostname', socket.gethostname()),
'details': details_to_send,
}
if agent_hash_value:
payload['agent_hash'] = agent_hash_value
if aiohttp is not None:
async with aiohttp.ClientSession() as session:
await session.post(url, json=payload, timeout=10)

View File

@@ -393,6 +393,151 @@ def api_agent_repo_hash():
_write_service_log('server', f'/api/agent/repo_hash error: {exc}')
return jsonify({"error": "internal error"}), 500
@app.route("/api/agent/update_check", methods=["POST"])
def api_agent_update_check():
data = request.get_json(silent=True) or {}
agent_id = (data.get("agent_id") or "").strip()
if not agent_id:
return jsonify({"error": "agent_id required"}), 400
repo_info = _refresh_default_repo_hash()
repo_sha = (repo_info.get("sha") or "").strip()
if not repo_sha:
payload = {
"error": repo_info.get("error") or "repository hash unavailable",
"repo_source": repo_info.get("source"),
}
return jsonify(payload), 503
registry_info = registered_agents.get(agent_id) or {}
hostname = (registry_info.get("hostname") or "").strip() or None
stored_hash: Optional[str] = (registry_info.get("agent_hash") or "").strip() or None
conn = None
try:
conn = _db_conn()
cur = conn.cursor()
rows = _device_rows_for_agent(cur, agent_id)
except Exception:
rows = []
finally:
if conn:
try:
conn.close()
except Exception:
pass
if rows:
hostname = rows[0].get("hostname") or hostname
for row in rows:
if row.get("matched"):
hostname = row.get("hostname") or hostname
candidate = (row.get("agent_hash") or "").strip()
if not candidate:
summary = row.get("details") or {}
try:
candidate = (summary.get("summary") or {}).get("agent_hash") or ""
except Exception:
candidate = ""
candidate = candidate.strip()
stored_hash = candidate or None
break
if stored_hash is None:
first = rows[0]
candidate = (first.get("agent_hash") or "").strip()
if not candidate:
details = first.get("details") or {}
try:
candidate = (details.get("summary") or {}).get("agent_hash") or ""
except Exception:
candidate = ""
candidate = candidate.strip()
stored_hash = candidate or None
update_available = (not stored_hash) or (stored_hash.strip() != repo_sha)
payload = {
"agent_id": agent_id,
"hostname": hostname,
"repo_hash": repo_sha,
"agent_hash": stored_hash,
"update_available": bool(update_available),
"repo_source": repo_info.get("source"),
}
if repo_info.get("cached") is not None:
payload["cached"] = bool(repo_info.get("cached"))
if repo_info.get("age_seconds") is not None:
payload["age_seconds"] = repo_info.get("age_seconds")
if repo_info.get("error"):
payload["repo_error"] = repo_info.get("error")
return jsonify(payload)
@app.route("/api/agent/agent_hash", methods=["POST"])
def api_agent_agent_hash_post():
data = request.get_json(silent=True) or {}
agent_id = (data.get("agent_id") or "").strip()
agent_hash = (data.get("agent_hash") or "").strip()
if not agent_id or not agent_hash:
return jsonify({"error": "agent_id and agent_hash required"}), 400
conn = None
hostname = None
try:
conn = _db_conn()
cur = conn.cursor()
rows = _device_rows_for_agent(cur, agent_id)
target = None
for row in rows:
if row.get("matched"):
target = row
break
if not target:
if conn:
conn.close()
return jsonify({"status": "ignored"}), 200
hostname = target.get("hostname")
details = target.get("details") or {}
summary = details.setdefault("summary", {})
summary["agent_hash"] = agent_hash
cur.execute(
"UPDATE device_details SET agent_hash=?, details=? WHERE hostname=?",
(agent_hash, json.dumps(details), hostname),
)
conn.commit()
except Exception as exc:
if conn:
try:
conn.rollback()
except Exception:
pass
_write_service_log('server', f'/api/agent/agent_hash error: {exc}')
return jsonify({"error": "internal error"}), 500
finally:
if conn:
try:
conn.close()
except Exception:
pass
normalized_hash = agent_hash
if agent_id in registered_agents:
registered_agents[agent_id]["agent_hash"] = normalized_hash
try:
for aid, rec in registered_agents.items():
if rec.get("hostname") and hostname and rec["hostname"] == hostname:
rec["agent_hash"] = normalized_hash
except Exception:
pass
return jsonify({
"status": "ok",
"agent_id": agent_id,
"hostname": hostname,
"agent_hash": agent_hash,
})
# ---------------------------------------------
# Server Time Endpoint
# ---------------------------------------------
@@ -2887,6 +3032,57 @@ def load_agents_from_db():
load_agents_from_db()
def _extract_hostname_from_agent(agent_id: str) -> Optional[str]:
try:
agent_id = (agent_id or "").strip()
if not agent_id:
return None
lower = agent_id.lower()
marker = "-agent"
idx = lower.find(marker)
if idx <= 0:
return None
return agent_id[:idx]
except Exception:
return None
def _device_rows_for_agent(cur, agent_id: str) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = []
normalized_id = (agent_id or "").strip()
if not normalized_id:
return results
base_host = _extract_hostname_from_agent(normalized_id)
if not base_host:
return results
try:
cur.execute(
"SELECT hostname, agent_hash, details FROM device_details WHERE LOWER(hostname) = ?",
(base_host.lower(),),
)
rows = cur.fetchall()
except Exception:
return results
for hostname, agent_hash, details_json in rows or []:
try:
details = json.loads(details_json or "{}")
except Exception:
details = {}
summary = details.get("summary") or {}
summary_agent = (summary.get("agent_id") or "").strip()
matched = summary_agent.lower() == normalized_id.lower() if summary_agent else False
results.append(
{
"hostname": hostname,
"agent_hash": (agent_hash or "").strip(),
"details": details,
"summary_agent_id": summary_agent,
"matched": matched,
}
)
return results
@app.route("/api/agents")
def get_agents():
"""Return agents with collector activity indicator."""
@@ -3085,7 +3281,7 @@ def get_device_details(hostname: str):
conn = _db_conn()
cur = conn.cursor()
cur.execute(
"SELECT details, description, created_at FROM device_details WHERE hostname = ?",
"SELECT details, description, created_at, agent_hash FROM device_details WHERE hostname = ?",
(hostname,),
)
row = cur.fetchone()
@@ -3097,6 +3293,7 @@ def get_device_details(hostname: str):
details = {}
description = row[1] if len(row) > 1 else ""
created_at = int(row[2] or 0) if len(row) > 2 else 0
agent_hash = (row[3] or "").strip() if len(row) > 3 else ""
if description:
details.setdefault("summary", {})["description"] = description
# Ensure created string exists from created_at
@@ -3106,6 +3303,13 @@ def get_device_details(hostname: str):
details.setdefault('summary', {})['created'] = datetime.utcfromtimestamp(created_at).strftime('%Y-%m-%d %H:%M:%S')
except Exception:
pass
if agent_hash:
try:
details.setdefault('summary', {})
if not details['summary'].get('agent_hash'):
details['summary']['agent_hash'] = agent_hash
except Exception:
pass
return jsonify(details)
except Exception:
pass
@@ -3140,28 +3344,12 @@ def set_device_description(hostname: str):
@app.route("/api/agent/hash/<path:agent_id>", methods=["GET"])
def get_agent_hash(agent_id: str):
"""Return the last known github_repo_hash for a specific agent."""
"""Return the last known repository hash for a specific agent."""
agent_id = (agent_id or "").strip()
if not agent_id:
return jsonify({"error": "invalid agent id"}), 400
def _extract_hostname_from_agent(agent: str) -> Optional[str]:
try:
agent = (agent or "").strip()
if not agent:
return None
lower = agent.lower()
marker = "-agent"
if marker not in lower:
return None
idx = lower.index(marker)
if idx <= 0:
return None
return agent[:idx]
except Exception:
return None
# Prefer the in-memory registry (updated on every heartbeat/details post).
info = registered_agents.get(agent_id) or {}
candidate = (info.get("agent_hash") or "").strip()
@@ -3178,87 +3366,66 @@ def get_agent_hash(agent_id: str):
try:
conn = _db_conn()
cur = conn.cursor()
parsed_hostname = _extract_hostname_from_agent(agent_id)
host_candidates = []
if hostname:
host_candidates.append(hostname)
if parsed_hostname and not any(
parsed_hostname.lower() == (h or "").lower() for h in host_candidates
):
host_candidates.append(parsed_hostname)
def _load_row_for_host(host_value: str):
if not host_value:
return None, None
try:
cur.execute(
"SELECT agent_hash, details, hostname FROM device_details WHERE LOWER(hostname) = ?",
(host_value.lower(),),
)
return cur.fetchone()
except Exception:
return None
row = None
matched_hostname = hostname
for host_value in host_candidates:
fetched = _load_row_for_host(host_value)
if fetched:
row = fetched
matched_hostname = fetched[2] or matched_hostname or host_value
break
if not row:
# No hostname available or found; scan for a matching agent_id in the JSON payload.
cur.execute("SELECT hostname, agent_hash, details FROM device_details")
for host, db_hash, details_json in cur.fetchall():
rows = _device_rows_for_agent(cur, agent_id)
if rows:
if not hostname:
hostname = rows[0].get("hostname") or hostname
for row in rows:
if row.get("matched"):
normalized_hash = (row.get("agent_hash") or "").strip()
if not normalized_hash:
details = row.get("details") or {}
try:
normalized_hash = ((details.get("summary") or {}).get("agent_hash") or "").strip()
except Exception:
normalized_hash = ""
if normalized_hash:
effective_hostname = row.get("hostname") or hostname
return jsonify({
"agent_id": agent_id,
"agent_hash": normalized_hash,
"hostname": effective_hostname,
"source": "database",
})
first = rows[0]
fallback_hash = (first.get("agent_hash") or "").strip()
if not fallback_hash:
details = first.get("details") or {}
try:
data = json.loads(details_json or "{}")
fallback_hash = ((details.get("summary") or {}).get("agent_hash") or "").strip()
except Exception:
data = {}
summary = data.get("summary") or {}
summary_agent_id = (summary.get("agent_id") or "").strip()
summary_hostname = (summary.get("hostname") or "").strip()
if summary_agent_id == agent_id:
row = (db_hash, details_json, host)
matched_hostname = host or summary_hostname or matched_hostname
break
if (
not row
and parsed_hostname
and summary_hostname
and summary_hostname.lower() == parsed_hostname.lower()
):
row = (db_hash, details_json, host)
matched_hostname = host or summary_hostname
conn.close()
fallback_hash = ""
if fallback_hash:
effective_hostname = first.get("hostname") or hostname
return jsonify({
"agent_id": agent_id,
"agent_hash": fallback_hash,
"hostname": effective_hostname,
"source": "database",
})
if row:
db_hash = (row[0] or "").strip()
effective_hostname = matched_hostname or hostname or parsed_hostname
if db_hash:
return jsonify({
"agent_id": agent_id,
"agent_hash": db_hash,
"hostname": effective_hostname,
"source": "database",
})
# Hash column may be empty if only stored inside details JSON.
# As a final fallback, scan the table for any matching agent_id in case hostname inference failed.
cur.execute("SELECT hostname, agent_hash, details FROM device_details")
for host, db_hash, details_json in cur.fetchall():
try:
details = json.loads(row[1] if len(row) > 1 else "{}")
data = json.loads(details_json or "{}")
except Exception:
details = {}
summary = details.get("summary") or {}
data = {}
summary = data.get("summary") or {}
summary_agent_id = (summary.get("agent_id") or "").strip()
if summary_agent_id != agent_id:
continue
summary_hash = (summary.get("agent_hash") or "").strip()
if summary_hash:
normalized_hash = (db_hash or "").strip() or summary_hash
if normalized_hash:
effective_hostname = host or summary.get("hostname") or hostname
return jsonify({
"agent_id": agent_id,
"agent_hash": summary_hash,
"agent_hash": normalized_hash,
"hostname": effective_hostname,
"source": "database",
})
conn.close()
return jsonify({"error": "agent hash not found"}), 404
except Exception as e: