Improve agent TLS context handling

This commit is contained in:
2025-10-18 13:20:28 -06:00
parent b37beb1a34
commit 1e2f84b886
2 changed files with 120 additions and 16 deletions

View File

@@ -929,25 +929,20 @@ class AgentHttpClient:
pass
context = None
bundle_summary = {"count": None, "fingerprint": None}
if isinstance(verify, str) and os.path.isfile(verify):
try:
# Mirror Requests' certificate handling by starting from a
# default client context (which pre-loads the system
# certificate stores) and then layering the pinned
# certificate bundle on top. This matches the REST client
# behaviour and ensures self-signed leaf certificates work
# the same way for Socket.IO handshakes.
context = ssl.create_default_context()
context.check_hostname = False
context.load_verify_locations(cafile=verify)
bundle_count, bundle_fp = self.key_store.describe_server_certificate()
bundle_summary = {"count": bundle_count, "fingerprint": bundle_fp}
context = self.key_store.build_ssl_context()
if context is not None:
_log_agent(
f"SocketIO TLS alignment created SSLContext from cafile={verify}",
"SocketIO TLS alignment created SSLContext from pinned bundle "
f"count={bundle_count} fp={bundle_fp or '<none>'}",
fname="agent.log",
)
except Exception:
context = None
else:
_log_agent(
f"SocketIO TLS alignment failed to build context from cafile={verify}",
"SocketIO TLS alignment failed to build context from pinned bundle", # noqa: E501
fname="agent.error.log",
)
@@ -960,7 +955,8 @@ class AgentHttpClient:
_set_attr(http_iface, "verify_ssl", True)
_reset_cached_session()
_log_agent(
"SocketIO TLS alignment applied dedicated SSLContext to engine/http",
"SocketIO TLS alignment applied dedicated SSLContext to engine/http "
f"count={bundle_summary['count']} fp={bundle_summary['fingerprint'] or '<none>'}",
fname="agent.log",
)
return

View File

@@ -12,7 +12,9 @@ import platform
import stat
import time
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple
import ssl
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
@@ -371,6 +373,50 @@ class AgentKeyStore:
def server_certificate_path(self) -> str:
return self._server_certificate_path
def describe_server_certificate(self) -> Tuple[int, Optional[str]]:
"""Return (certificate_count, sha256_fingerprint_prefix)."""
try:
if not os.path.isfile(self._server_certificate_path):
return 0, None
with open(self._server_certificate_path, "rb") as fh:
pem_data = fh.read()
except Exception:
return 0, None
if not pem_data:
return 0, None
try:
from cryptography import x509 # type: ignore
except Exception:
return 0, None
certs = []
for chunk in pem_data.split(b"-----END CERTIFICATE-----"):
if b"-----BEGIN CERTIFICATE-----" not in chunk:
continue
block = chunk + b"-----END CERTIFICATE-----\n"
try:
cert = x509.load_pem_x509_certificate(block)
except Exception:
continue
certs.append(cert)
if not certs:
return 0, None
try:
first_cert = certs[0]
fingerprint = hashlib.sha256(
first_cert.public_bytes(serialization.Encoding.DER)
).hexdigest()
except Exception:
fingerprint = None
prefix = fingerprint[:12] if fingerprint else None
return len(certs), prefix
def save_server_certificate(self, pem_text: str) -> None:
if not pem_text:
return
@@ -392,6 +438,68 @@ class AgentKeyStore:
return None
return None
def build_ssl_context(self) -> Optional[ssl.SSLContext]:
if not os.path.isfile(self._server_certificate_path):
return None
try:
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
except Exception:
try:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
except Exception:
return None
try:
context.check_hostname = True
except Exception:
pass
try:
context.verify_mode = ssl.CERT_REQUIRED
except Exception:
pass
if hasattr(context, "minimum_version"):
try:
context.minimum_version = ssl.TLSVersion.TLSv1_2
except Exception:
pass
loaded = False
try:
context.load_verify_locations(cafile=self._server_certificate_path)
loaded = True
except Exception:
pass
if not loaded:
try:
with open(self._server_certificate_path, "r", encoding="utf-8") as fh:
pem_text = fh.read()
if pem_text:
context.load_verify_locations(cadata=pem_text)
loaded = True
except Exception:
loaded = False
if not loaded:
return None
verify_flag = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", None)
if verify_flag is not None:
try:
context.verify_flags |= verify_flag # type: ignore[attr-defined]
except Exception:
pass
try:
context.load_default_certs()
except Exception:
pass
return context
def save_server_signing_key(self, value: str) -> None:
if not value:
return