Improve agent TLS context handling

This commit is contained in:
2025-10-18 13:20:28 -06:00
parent b37beb1a34
commit 1e2f84b886
2 changed files with 120 additions and 16 deletions

View File

@@ -929,25 +929,20 @@ class AgentHttpClient:
pass pass
context = None context = None
bundle_summary = {"count": None, "fingerprint": None}
if isinstance(verify, str) and os.path.isfile(verify): if isinstance(verify, str) and os.path.isfile(verify):
try: bundle_count, bundle_fp = self.key_store.describe_server_certificate()
# Mirror Requests' certificate handling by starting from a bundle_summary = {"count": bundle_count, "fingerprint": bundle_fp}
# default client context (which pre-loads the system context = self.key_store.build_ssl_context()
# certificate stores) and then layering the pinned if context is not None:
# certificate bundle on top. This matches the REST client
# behaviour and ensures self-signed leaf certificates work
# the same way for Socket.IO handshakes.
context = ssl.create_default_context()
context.check_hostname = False
context.load_verify_locations(cafile=verify)
_log_agent( _log_agent(
f"SocketIO TLS alignment created SSLContext from cafile={verify}", "SocketIO TLS alignment created SSLContext from pinned bundle "
f"count={bundle_count} fp={bundle_fp or '<none>'}",
fname="agent.log", fname="agent.log",
) )
except Exception: else:
context = None
_log_agent( _log_agent(
f"SocketIO TLS alignment failed to build context from cafile={verify}", "SocketIO TLS alignment failed to build context from pinned bundle", # noqa: E501
fname="agent.error.log", fname="agent.error.log",
) )
@@ -960,7 +955,8 @@ class AgentHttpClient:
_set_attr(http_iface, "verify_ssl", True) _set_attr(http_iface, "verify_ssl", True)
_reset_cached_session() _reset_cached_session()
_log_agent( _log_agent(
"SocketIO TLS alignment applied dedicated SSLContext to engine/http", "SocketIO TLS alignment applied dedicated SSLContext to engine/http "
f"count={bundle_summary['count']} fp={bundle_summary['fingerprint'] or '<none>'}",
fname="agent.log", fname="agent.log",
) )
return return

View File

@@ -12,7 +12,9 @@ import platform
import stat import stat
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, Tuple
import ssl
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519 from cryptography.hazmat.primitives.asymmetric import ed25519
@@ -371,6 +373,50 @@ class AgentKeyStore:
def server_certificate_path(self) -> str: def server_certificate_path(self) -> str:
return self._server_certificate_path return self._server_certificate_path
def describe_server_certificate(self) -> Tuple[int, Optional[str]]:
"""Return (certificate_count, sha256_fingerprint_prefix)."""
try:
if not os.path.isfile(self._server_certificate_path):
return 0, None
with open(self._server_certificate_path, "rb") as fh:
pem_data = fh.read()
except Exception:
return 0, None
if not pem_data:
return 0, None
try:
from cryptography import x509 # type: ignore
except Exception:
return 0, None
certs = []
for chunk in pem_data.split(b"-----END CERTIFICATE-----"):
if b"-----BEGIN CERTIFICATE-----" not in chunk:
continue
block = chunk + b"-----END CERTIFICATE-----\n"
try:
cert = x509.load_pem_x509_certificate(block)
except Exception:
continue
certs.append(cert)
if not certs:
return 0, None
try:
first_cert = certs[0]
fingerprint = hashlib.sha256(
first_cert.public_bytes(serialization.Encoding.DER)
).hexdigest()
except Exception:
fingerprint = None
prefix = fingerprint[:12] if fingerprint else None
return len(certs), prefix
def save_server_certificate(self, pem_text: str) -> None: def save_server_certificate(self, pem_text: str) -> None:
if not pem_text: if not pem_text:
return return
@@ -392,6 +438,68 @@ class AgentKeyStore:
return None return None
return None return None
def build_ssl_context(self) -> Optional[ssl.SSLContext]:
if not os.path.isfile(self._server_certificate_path):
return None
try:
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
except Exception:
try:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
except Exception:
return None
try:
context.check_hostname = True
except Exception:
pass
try:
context.verify_mode = ssl.CERT_REQUIRED
except Exception:
pass
if hasattr(context, "minimum_version"):
try:
context.minimum_version = ssl.TLSVersion.TLSv1_2
except Exception:
pass
loaded = False
try:
context.load_verify_locations(cafile=self._server_certificate_path)
loaded = True
except Exception:
pass
if not loaded:
try:
with open(self._server_certificate_path, "r", encoding="utf-8") as fh:
pem_text = fh.read()
if pem_text:
context.load_verify_locations(cadata=pem_text)
loaded = True
except Exception:
loaded = False
if not loaded:
return None
verify_flag = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", None)
if verify_flag is not None:
try:
context.verify_flags |= verify_flag # type: ignore[attr-defined]
except Exception:
pass
try:
context.load_default_certs()
except Exception:
pass
return context
def save_server_signing_key(self, value: str) -> None: def save_server_signing_key(self, value: str) -> None:
if not value: if not value:
return return