mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-17 17:55:48 -07:00
Additional Changes
This commit is contained in:
@@ -818,6 +818,7 @@ class AgentHttpClient:
|
|||||||
self.access_expires_at: Optional[int] = None
|
self.access_expires_at: Optional[int] = None
|
||||||
self._auth_lock = threading.RLock()
|
self._auth_lock = threading.RLock()
|
||||||
self._active_installer_code: Optional[str] = None
|
self._active_installer_code: Optional[str] = None
|
||||||
|
self._cached_ssl_context: Optional[ssl.SSLContext] = None
|
||||||
self.refresh_base_url()
|
self.refresh_base_url()
|
||||||
self._configure_verify()
|
self._configure_verify()
|
||||||
self._reload_tokens_from_disk()
|
self._reload_tokens_from_disk()
|
||||||
@@ -852,11 +853,17 @@ class AgentHttpClient:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _reload_tokens_from_disk(self) -> None:
|
def _reload_tokens_from_disk(self) -> None:
|
||||||
guid = self.key_store.load_guid()
|
raw_guid = self.key_store.load_guid()
|
||||||
|
normalized_guid = _normalize_agent_guid(raw_guid) if raw_guid else ''
|
||||||
access_token = self.key_store.load_access_token()
|
access_token = self.key_store.load_access_token()
|
||||||
refresh_token = self.key_store.load_refresh_token()
|
refresh_token = self.key_store.load_refresh_token()
|
||||||
access_expiry = self.key_store.get_access_expiry()
|
access_expiry = self.key_store.get_access_expiry()
|
||||||
self.guid = guid if guid else None
|
if normalized_guid and normalized_guid != (raw_guid or ""):
|
||||||
|
try:
|
||||||
|
self.key_store.save_guid(normalized_guid)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.guid = normalized_guid or None
|
||||||
self.access_token = access_token if access_token else None
|
self.access_token = access_token if access_token else None
|
||||||
self.refresh_token = refresh_token if refresh_token else None
|
self.refresh_token = refresh_token if refresh_token else None
|
||||||
self.access_expires_at = access_expiry if access_expiry else None
|
self.access_expires_at = access_expiry if access_expiry else None
|
||||||
@@ -930,6 +937,7 @@ class AgentHttpClient:
|
|||||||
|
|
||||||
context = None
|
context = None
|
||||||
bundle_summary = {"count": None, "fingerprint": None, "layered_default": None}
|
bundle_summary = {"count": None, "fingerprint": None, "layered_default": None}
|
||||||
|
context = None
|
||||||
if isinstance(verify, str) and os.path.isfile(verify):
|
if isinstance(verify, str) and os.path.isfile(verify):
|
||||||
bundle_count, bundle_fp, layered_default = self.key_store.summarize_server_certificate()
|
bundle_count, bundle_fp, layered_default = self.key_store.summarize_server_certificate()
|
||||||
bundle_summary = {
|
bundle_summary = {
|
||||||
@@ -939,6 +947,7 @@ class AgentHttpClient:
|
|||||||
}
|
}
|
||||||
context = self.key_store.build_ssl_context()
|
context = self.key_store.build_ssl_context()
|
||||||
if context is not None:
|
if context is not None:
|
||||||
|
self._cached_ssl_context = context
|
||||||
if bundle_summary["layered_default"] is None:
|
if bundle_summary["layered_default"] is None:
|
||||||
bundle_summary["layered_default"] = getattr(
|
bundle_summary["layered_default"] = getattr(
|
||||||
context, "_borealis_layered_default", None
|
context, "_borealis_layered_default", None
|
||||||
@@ -975,6 +984,7 @@ class AgentHttpClient:
|
|||||||
# Fall back to boolean verification flags when we either do not
|
# Fall back to boolean verification flags when we either do not
|
||||||
# have a pinned certificate bundle or failed to build a dedicated
|
# have a pinned certificate bundle or failed to build a dedicated
|
||||||
# context for it.
|
# context for it.
|
||||||
|
self._cached_ssl_context = None
|
||||||
verify_flag = False if verify is False else True
|
verify_flag = False if verify is False else True
|
||||||
_set_attr(engine, "ssl_context", None)
|
_set_attr(engine, "ssl_context", None)
|
||||||
_set_attr(engine, "ssl_verify", verify_flag)
|
_set_attr(engine, "ssl_verify", verify_flag)
|
||||||
@@ -994,6 +1004,34 @@ class AgentHttpClient:
|
|||||||
)
|
)
|
||||||
_log_exception_trace("configure_socketio")
|
_log_exception_trace("configure_socketio")
|
||||||
|
|
||||||
|
def socketio_ssl_params(self) -> Dict[str, Any]:
|
||||||
|
verify = getattr(self.session, "verify", True)
|
||||||
|
if isinstance(verify, str) and os.path.isfile(verify):
|
||||||
|
context = self._cached_ssl_context
|
||||||
|
if context is None:
|
||||||
|
context = self.key_store.build_ssl_context()
|
||||||
|
if context is not None:
|
||||||
|
self._cached_ssl_context = context
|
||||||
|
if context is not None:
|
||||||
|
return {"ssl": context}
|
||||||
|
try:
|
||||||
|
fallback = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
|
||||||
|
fallback.load_verify_locations(cafile=verify)
|
||||||
|
self._cached_ssl_context = fallback
|
||||||
|
return {"ssl": fallback}
|
||||||
|
except Exception as exc:
|
||||||
|
self._cached_ssl_context = None
|
||||||
|
_log_agent(
|
||||||
|
f"SocketIO TLS fallback context build failed: {exc}; disabling verification",
|
||||||
|
fname="agent.error.log",
|
||||||
|
)
|
||||||
|
return {"ssl": False}
|
||||||
|
if verify is False:
|
||||||
|
self._cached_ssl_context = None
|
||||||
|
return {"ssl": False}
|
||||||
|
self._cached_ssl_context = None
|
||||||
|
return {}
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Enrollment & token management
|
# Enrollment & token management
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -1221,7 +1259,7 @@ class AgentHttpClient:
|
|||||||
self.store_server_signing_key(signing_key)
|
self.store_server_signing_key(signing_key)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
_log_agent(f'Unable to persist signing key from enrollment approval: {exc}', fname='agent.error.log')
|
_log_agent(f'Unable to persist signing key from enrollment approval: {exc}', fname='agent.error.log')
|
||||||
guid = payload.get("guid")
|
guid = _normalize_agent_guid(payload.get("guid"))
|
||||||
access_token = payload.get("access_token")
|
access_token = payload.get("access_token")
|
||||||
refresh_token = payload.get("refresh_token")
|
refresh_token = payload.get("refresh_token")
|
||||||
expires_in = int(payload.get("expires_in") or 900)
|
expires_in = int(payload.get("expires_in") or 900)
|
||||||
@@ -1233,7 +1271,7 @@ class AgentHttpClient:
|
|||||||
f"expires_in={expires_in}",
|
f"expires_in={expires_in}",
|
||||||
fname="agent.log",
|
fname="agent.log",
|
||||||
)
|
)
|
||||||
self.guid = str(guid).strip()
|
self.guid = guid
|
||||||
self.access_token = access_token.strip()
|
self.access_token = access_token.strip()
|
||||||
self.refresh_token = refresh_token.strip()
|
self.refresh_token = refresh_token.strip()
|
||||||
expiry = int(time.time()) + max(expires_in - 5, 0)
|
expiry = int(time.time()) + max(expires_in - 5, 0)
|
||||||
@@ -2781,8 +2819,16 @@ async def connect_loop():
|
|||||||
headers = client.auth_headers()
|
headers = client.auth_headers()
|
||||||
header_summary = _summarize_headers(headers)
|
header_summary = _summarize_headers(headers)
|
||||||
verify_value = getattr(client.session, 'verify', None)
|
verify_value = getattr(client.session, 'verify', None)
|
||||||
|
ssl_kwargs = client.socketio_ssl_params()
|
||||||
|
ssl_summary: Dict[str, Any] = {}
|
||||||
|
for key, value in ssl_kwargs.items():
|
||||||
|
if isinstance(value, ssl.SSLContext):
|
||||||
|
ssl_summary[key] = "SSLContext"
|
||||||
|
else:
|
||||||
|
ssl_summary[key] = value
|
||||||
_log_agent(
|
_log_agent(
|
||||||
f"connect_loop attempt={attempt} dialing websocket url={url} transports=['websocket'] verify={verify_value!r} headers={header_summary}",
|
f"connect_loop attempt={attempt} dialing websocket url={url} transports=['websocket'] "
|
||||||
|
f"verify={verify_value!r} headers={header_summary} ssl={ssl_summary or '{}'}",
|
||||||
fname='agent.log',
|
fname='agent.log',
|
||||||
)
|
)
|
||||||
print(f"[INFO] Connecting Agent to {url}...")
|
print(f"[INFO] Connecting Agent to {url}...")
|
||||||
@@ -2790,6 +2836,7 @@ async def connect_loop():
|
|||||||
url,
|
url,
|
||||||
transports=['websocket'],
|
transports=['websocket'],
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
**ssl_kwargs,
|
||||||
)
|
)
|
||||||
_log_agent(
|
_log_agent(
|
||||||
f'connect_loop attempt={attempt} sio.connect completed successfully',
|
f'connect_loop attempt={attempt} sio.connect completed successfully',
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from flask import Blueprint, jsonify, request, g
|
|||||||
|
|
||||||
from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
|
from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
|
||||||
from Modules.crypto.signing import ScriptSigner
|
from Modules.crypto.signing import ScriptSigner
|
||||||
|
from Modules.guid_utils import normalize_guid
|
||||||
|
|
||||||
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||||
|
|
||||||
@@ -102,13 +103,36 @@ def register(
|
|||||||
if not updates:
|
if not updates:
|
||||||
return 0
|
return 0
|
||||||
columns = ", ".join(f"{col} = ?" for col in updates.keys())
|
columns = ", ".join(f"{col} = ?" for col in updates.keys())
|
||||||
params = list(updates.values())
|
values = list(updates.values())
|
||||||
params.append(ctx.guid)
|
normalized_guid = normalize_guid(ctx.guid)
|
||||||
|
selected_guid: Optional[str] = None
|
||||||
|
if normalized_guid:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT guid FROM devices WHERE UPPER(guid) = ?",
|
||||||
|
(normalized_guid,),
|
||||||
|
)
|
||||||
|
rows = cur.fetchall()
|
||||||
|
for (stored_guid,) in rows or []:
|
||||||
|
if stored_guid == ctx.guid:
|
||||||
|
selected_guid = stored_guid
|
||||||
|
break
|
||||||
|
if not selected_guid and rows:
|
||||||
|
selected_guid = rows[0][0]
|
||||||
|
target_guid = selected_guid or ctx.guid
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"UPDATE devices SET {columns} WHERE guid = ?",
|
f"UPDATE devices SET {columns} WHERE guid = ?",
|
||||||
params,
|
values + [target_guid],
|
||||||
)
|
)
|
||||||
return cur.rowcount
|
updated = cur.rowcount
|
||||||
|
if updated > 0 and normalized_guid and target_guid != normalized_guid:
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"UPDATE devices SET guid = ? WHERE guid = ?",
|
||||||
|
(normalized_guid, target_guid),
|
||||||
|
)
|
||||||
|
except sqlite3.IntegrityError:
|
||||||
|
pass
|
||||||
|
return updated
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rowcount = _apply_updates()
|
rowcount = _apply_updates()
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from flask import g, jsonify, request
|
|||||||
|
|
||||||
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
||||||
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||||
|
from Modules.guid_utils import normalize_guid
|
||||||
|
|
||||||
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
AGENT_CONTEXT_HEADER = "X-Borealis-Agent-Context"
|
||||||
|
|
||||||
@@ -87,7 +88,8 @@ class DeviceAuthManager:
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise DeviceAuthError("invalid_token")
|
raise DeviceAuthError("invalid_token")
|
||||||
|
|
||||||
guid = str(claims.get("guid") or "").strip()
|
raw_guid = str(claims.get("guid") or "").strip()
|
||||||
|
guid = normalize_guid(raw_guid)
|
||||||
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
|
fingerprint = str(claims.get("ssl_key_fingerprint") or "").lower().strip()
|
||||||
token_version = int(claims.get("token_version") or 0)
|
token_version = int(claims.get("token_version") or 0)
|
||||||
if not guid or not fingerprint or token_version <= 0:
|
if not guid or not fingerprint or token_version <= 0:
|
||||||
@@ -110,11 +112,19 @@ class DeviceAuthManager:
|
|||||||
"""
|
"""
|
||||||
SELECT guid, ssl_key_fingerprint, token_version, status
|
SELECT guid, ssl_key_fingerprint, token_version, status
|
||||||
FROM devices
|
FROM devices
|
||||||
WHERE guid = ?
|
WHERE UPPER(guid) = ?
|
||||||
""",
|
""",
|
||||||
(guid,),
|
(guid,),
|
||||||
)
|
)
|
||||||
row = cur.fetchone()
|
rows = cur.fetchall()
|
||||||
|
row = None
|
||||||
|
for candidate in rows or []:
|
||||||
|
candidate_guid = normalize_guid(candidate[0])
|
||||||
|
if candidate_guid == guid:
|
||||||
|
row = candidate
|
||||||
|
break
|
||||||
|
if row is None and rows:
|
||||||
|
row = rows[0]
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
row = self._recover_device_record(
|
row = self._recover_device_record(
|
||||||
@@ -125,8 +135,9 @@ class DeviceAuthManager:
|
|||||||
raise DeviceAuthError("device_not_found", status_code=403)
|
raise DeviceAuthError("device_not_found", status_code=403)
|
||||||
|
|
||||||
db_guid, db_fp, db_token_version, status = row
|
db_guid, db_fp, db_token_version, status = row
|
||||||
|
db_guid_normalized = normalize_guid(db_guid)
|
||||||
|
|
||||||
if str(db_guid or "").lower() != guid.lower():
|
if not db_guid_normalized or db_guid_normalized != guid:
|
||||||
raise DeviceAuthError("device_guid_mismatch", status_code=403)
|
raise DeviceAuthError("device_guid_mismatch", status_code=403)
|
||||||
|
|
||||||
db_fp = (db_fp or "").lower().strip()
|
db_fp = (db_fp or "").lower().strip()
|
||||||
@@ -182,7 +193,7 @@ class DeviceAuthManager:
|
|||||||
) -> Optional[tuple]:
|
) -> Optional[tuple]:
|
||||||
"""Attempt to recreate a missing device row for an authenticated token."""
|
"""Attempt to recreate a missing device row for an authenticated token."""
|
||||||
|
|
||||||
guid = (guid or "").strip()
|
guid = normalize_guid(guid)
|
||||||
fingerprint = (fingerprint or "").strip()
|
fingerprint = (fingerprint or "").strip()
|
||||||
if not guid or not fingerprint:
|
if not guid or not fingerprint:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from flask import Blueprint, jsonify, request
|
|||||||
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||||
from Modules.crypto import keys as crypto_keys
|
from Modules.crypto import keys as crypto_keys
|
||||||
from Modules.enrollment.nonce_store import NonceCache
|
from Modules.enrollment.nonce_store import NonceCache
|
||||||
|
from Modules.guid_utils import normalize_guid
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
|
||||||
|
|
||||||
@@ -141,11 +142,11 @@ def register(
|
|||||||
if use_count < max_uses:
|
if use_count < max_uses:
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
guid = str(record.get("used_by_guid") or "").strip()
|
guid = normalize_guid(record.get("used_by_guid"))
|
||||||
if not guid:
|
if not guid:
|
||||||
return False, None
|
return False, None
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT ssl_key_fingerprint FROM devices WHERE guid = ?",
|
"SELECT ssl_key_fingerprint FROM devices WHERE UPPER(guid) = ?",
|
||||||
(guid,),
|
(guid,),
|
||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
@@ -159,31 +160,36 @@ def register(
|
|||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
|
def _normalize_host(hostname: str, guid: str, cur: sqlite3.Cursor) -> str:
|
||||||
base = (hostname or "").strip() or guid
|
guid_norm = normalize_guid(guid)
|
||||||
|
base = (hostname or "").strip() or guid_norm
|
||||||
base = base[:253]
|
base = base[:253]
|
||||||
candidate = base
|
candidate = base
|
||||||
suffix = 1
|
suffix = 1
|
||||||
while True:
|
while True:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT guid FROM devices WHERE hostname = ? AND guid != ?",
|
"SELECT guid FROM devices WHERE hostname = ?",
|
||||||
(candidate, guid),
|
(candidate,),
|
||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
return candidate
|
return candidate
|
||||||
|
existing_guid = normalize_guid(row[0])
|
||||||
|
if existing_guid == guid_norm:
|
||||||
|
return candidate
|
||||||
candidate = f"{base}-{suffix}"
|
candidate = f"{base}-{suffix}"
|
||||||
suffix += 1
|
suffix += 1
|
||||||
if suffix > 50:
|
if suffix > 50:
|
||||||
return f"{guid}"
|
return guid_norm
|
||||||
|
|
||||||
def _store_device_key(cur: sqlite3.Cursor, guid: str, fingerprint: str) -> None:
|
def _store_device_key(cur: sqlite3.Cursor, guid: str, fingerprint: str) -> None:
|
||||||
|
guid_norm = normalize_guid(guid)
|
||||||
added_at = _iso(_now())
|
added_at = _iso(_now())
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
|
INSERT OR IGNORE INTO device_keys (id, guid, ssl_key_fingerprint, added_at)
|
||||||
VALUES (?, ?, ?, ?)
|
VALUES (?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
(str(uuid.uuid4()), guid, fingerprint, added_at),
|
(str(uuid.uuid4()), guid_norm, fingerprint, added_at),
|
||||||
)
|
)
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
@@ -193,17 +199,18 @@ def register(
|
|||||||
AND ssl_key_fingerprint != ?
|
AND ssl_key_fingerprint != ?
|
||||||
AND retired_at IS NULL
|
AND retired_at IS NULL
|
||||||
""",
|
""",
|
||||||
(_iso(_now()), guid, fingerprint),
|
(_iso(_now()), guid_norm, fingerprint),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
|
def _ensure_device_record(cur: sqlite3.Cursor, guid: str, hostname: str, fingerprint: str) -> Dict[str, Any]:
|
||||||
|
guid_norm = normalize_guid(guid)
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
|
SELECT guid, hostname, token_version, status, ssl_key_fingerprint, key_added_at
|
||||||
FROM devices
|
FROM devices
|
||||||
WHERE guid = ?
|
WHERE UPPER(guid) = ?
|
||||||
""",
|
""",
|
||||||
(guid,),
|
(guid_norm,),
|
||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if row:
|
if row:
|
||||||
@@ -216,12 +223,13 @@ def register(
|
|||||||
"key_added_at",
|
"key_added_at",
|
||||||
]
|
]
|
||||||
record = dict(zip(keys, row))
|
record = dict(zip(keys, row))
|
||||||
|
record["guid"] = normalize_guid(record.get("guid"))
|
||||||
stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower()
|
stored_fp = (record.get("ssl_key_fingerprint") or "").strip().lower()
|
||||||
new_fp = (fingerprint or "").strip().lower()
|
new_fp = (fingerprint or "").strip().lower()
|
||||||
if not stored_fp and new_fp:
|
if not stored_fp and new_fp:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
"UPDATE devices SET ssl_key_fingerprint = ?, key_added_at = ? WHERE guid = ?",
|
||||||
(fingerprint, _iso(_now()), guid),
|
(fingerprint, _iso(_now()), record["guid"]),
|
||||||
)
|
)
|
||||||
record["ssl_key_fingerprint"] = fingerprint
|
record["ssl_key_fingerprint"] = fingerprint
|
||||||
elif new_fp and stored_fp != new_fp:
|
elif new_fp and stored_fp != new_fp:
|
||||||
@@ -240,7 +248,7 @@ def register(
|
|||||||
status = 'active'
|
status = 'active'
|
||||||
WHERE guid = ?
|
WHERE guid = ?
|
||||||
""",
|
""",
|
||||||
(fingerprint, now_iso, new_version, guid),
|
(fingerprint, now_iso, new_version, record["guid"]),
|
||||||
)
|
)
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
@@ -249,7 +257,7 @@ def register(
|
|||||||
WHERE guid = ?
|
WHERE guid = ?
|
||||||
AND revoked_at IS NULL
|
AND revoked_at IS NULL
|
||||||
""",
|
""",
|
||||||
(now_iso, guid),
|
(now_iso, record["guid"]),
|
||||||
)
|
)
|
||||||
record["ssl_key_fingerprint"] = fingerprint
|
record["ssl_key_fingerprint"] = fingerprint
|
||||||
record["token_version"] = new_version
|
record["token_version"] = new_version
|
||||||
@@ -257,7 +265,7 @@ def register(
|
|||||||
record["key_added_at"] = now_iso
|
record["key_added_at"] = now_iso
|
||||||
return record
|
return record
|
||||||
|
|
||||||
resolved_hostname = _normalize_host(hostname, guid, cur)
|
resolved_hostname = _normalize_host(hostname, guid_norm, cur)
|
||||||
created_at = int(time.time())
|
created_at = int(time.time())
|
||||||
key_added_at = _iso(_now())
|
key_added_at = _iso(_now())
|
||||||
cur.execute(
|
cur.execute(
|
||||||
@@ -269,7 +277,7 @@ def register(
|
|||||||
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
|
VALUES (?, ?, ?, ?, ?, 1, 'active', ?)
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
guid,
|
guid_norm,
|
||||||
resolved_hostname,
|
resolved_hostname,
|
||||||
created_at,
|
created_at,
|
||||||
created_at,
|
created_at,
|
||||||
@@ -278,7 +286,7 @@ def register(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"guid": guid,
|
"guid": guid_norm,
|
||||||
"hostname": resolved_hostname,
|
"hostname": resolved_hostname,
|
||||||
"token_version": 1,
|
"token_version": 1,
|
||||||
"status": "active",
|
"status": "active",
|
||||||
@@ -620,7 +628,7 @@ def register(
|
|||||||
return jsonify({"error": "proof_replayed"}), 409
|
return jsonify({"error": "proof_replayed"}), 409
|
||||||
|
|
||||||
# Finalize enrollment
|
# Finalize enrollment
|
||||||
effective_guid = guid or str(uuid.uuid4())
|
effective_guid = normalize_guid(guid) if guid else normalize_guid(str(uuid.uuid4()))
|
||||||
now_iso = _iso(_now())
|
now_iso = _iso(_now())
|
||||||
|
|
||||||
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
|
device_record = _ensure_device_record(cur, effective_guid, hostname_claimed, fingerprint)
|
||||||
|
|||||||
26
Data/Server/Modules/guid_utils.py
Normal file
26
Data/Server/Modules/guid_utils.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import string
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_guid(value: Optional[str]) -> str:
|
||||||
|
"""
|
||||||
|
Canonicalize GUID strings so the server treats different casings/formats uniformly.
|
||||||
|
"""
|
||||||
|
candidate = (value or "").strip()
|
||||||
|
if not candidate:
|
||||||
|
return ""
|
||||||
|
candidate = candidate.strip("{}")
|
||||||
|
try:
|
||||||
|
return str(uuid.UUID(candidate)).upper()
|
||||||
|
except Exception:
|
||||||
|
cleaned = "".join(ch for ch in candidate if ch in string.hexdigits or ch == "-")
|
||||||
|
cleaned = cleaned.strip("-")
|
||||||
|
if cleaned:
|
||||||
|
try:
|
||||||
|
return str(uuid.UUID(cleaned)).upper()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return candidate.upper()
|
||||||
@@ -102,6 +102,7 @@ from Modules.auth.device_auth import DeviceAuthManager, require_device_auth
|
|||||||
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||||
from Modules.agents import routes as agent_routes
|
from Modules.agents import routes as agent_routes
|
||||||
from Modules.crypto import certificates, signing
|
from Modules.crypto import certificates, signing
|
||||||
|
from Modules.guid_utils import normalize_guid
|
||||||
from Modules.enrollment import routes as enrollment_routes
|
from Modules.enrollment import routes as enrollment_routes
|
||||||
from Modules.enrollment.nonce_store import NonceCache
|
from Modules.enrollment.nonce_store import NonceCache
|
||||||
from Modules.tokens import routes as token_routes
|
from Modules.tokens import routes as token_routes
|
||||||
@@ -6063,22 +6064,7 @@ def _persist_last_seen(hostname: str, last_seen: int, agent_id: str = None):
|
|||||||
|
|
||||||
|
|
||||||
def _normalize_guid(value: Optional[str]) -> str:
|
def _normalize_guid(value: Optional[str]) -> str:
|
||||||
candidate = (value or "").strip()
|
return normalize_guid(value)
|
||||||
if not candidate:
|
|
||||||
return ""
|
|
||||||
candidate = candidate.replace("{", "").replace("}", "")
|
|
||||||
try:
|
|
||||||
upper = candidate.upper()
|
|
||||||
if upper.count("-") == 4 and len(upper) == 36:
|
|
||||||
return upper
|
|
||||||
if len(candidate) == 32 and all(c in "0123456789abcdefABCDEF" for c in candidate):
|
|
||||||
grouped = "-".join(
|
|
||||||
[candidate[0:8], candidate[8:12], candidate[12:16], candidate[16:20], candidate[20:32]]
|
|
||||||
)
|
|
||||||
return grouped.upper()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return candidate.upper()
|
|
||||||
|
|
||||||
|
|
||||||
def load_agents_from_db():
|
def load_agents_from_db():
|
||||||
|
|||||||
Reference in New Issue
Block a user