Refine pinned certificate handling for Socket.IO

This commit is contained in:
2025-10-18 16:54:38 -06:00
parent 1e2f84b886
commit 393f56a398
2 changed files with 127 additions and 55 deletions

View File

@@ -929,15 +929,24 @@ class AgentHttpClient:
pass pass
context = None context = None
bundle_summary = {"count": None, "fingerprint": None} bundle_summary = {"count": None, "fingerprint": None, "layered_default": None}
if isinstance(verify, str) and os.path.isfile(verify): if isinstance(verify, str) and os.path.isfile(verify):
bundle_count, bundle_fp = self.key_store.describe_server_certificate() bundle_count, bundle_fp, layered_default = self.key_store.summarize_server_certificate()
bundle_summary = {"count": bundle_count, "fingerprint": bundle_fp} bundle_summary = {
"count": bundle_count,
"fingerprint": bundle_fp,
"layered_default": layered_default,
}
context = self.key_store.build_ssl_context() context = self.key_store.build_ssl_context()
if context is not None: if context is not None:
if bundle_summary["layered_default"] is None:
bundle_summary["layered_default"] = getattr(
context, "_borealis_layered_default", None
)
_log_agent( _log_agent(
"SocketIO TLS alignment created SSLContext from pinned bundle " "SocketIO TLS alignment created SSLContext from pinned bundle "
f"count={bundle_count} fp={bundle_fp or '<none>'}", f"count={bundle_count} fp={bundle_fp or '<none>'} "
f"layered_default={bundle_summary['layered_default']}",
fname="agent.log", fname="agent.log",
) )
else: else:
@@ -956,7 +965,9 @@ class AgentHttpClient:
_reset_cached_session() _reset_cached_session()
_log_agent( _log_agent(
"SocketIO TLS alignment applied dedicated SSLContext to engine/http " "SocketIO TLS alignment applied dedicated SSLContext to engine/http "
f"count={bundle_summary['count']} fp={bundle_summary['fingerprint'] or '<none>'}", f"count={bundle_summary['count']} "
f"fp={bundle_summary['fingerprint'] or '<none>'} "
f"layered_default={bundle_summary['layered_default']}",
fname="agent.log", fname="agent.log",
) )
return return

View File

@@ -12,13 +12,18 @@ import platform
import stat import stat
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import List, Optional, Tuple
import ssl import ssl
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519 from cryptography.hazmat.primitives.asymmetric import ed25519
try:
from cryptography import x509 # type: ignore
except Exception: # pragma: no cover - optional dependency guard
x509 = None # type: ignore
IS_WINDOWS = platform.system().lower().startswith("win") IS_WINDOWS = platform.system().lower().startswith("win")
try: try:
@@ -376,46 +381,37 @@ class AgentKeyStore:
def describe_server_certificate(self) -> Tuple[int, Optional[str]]: def describe_server_certificate(self) -> Tuple[int, Optional[str]]:
"""Return (certificate_count, sha256_fingerprint_prefix).""" """Return (certificate_count, sha256_fingerprint_prefix)."""
try: count, fingerprint, _ = self.summarize_server_certificate()
if not os.path.isfile(self._server_certificate_path): return count, fingerprint
return 0, None
with open(self._server_certificate_path, "rb") as fh:
pem_data = fh.read()
except Exception:
return 0, None
if not pem_data: def summarize_server_certificate(self) -> Tuple[int, Optional[str], bool]:
return 0, None """Return (certificate_count, fingerprint_prefix, layered_default_trust)."""
try: pem_bytes, certs = self._load_server_certificates()
from cryptography import x509 # type: ignore if not pem_bytes:
except Exception: return 0, None, False
return 0, None
certs = [] fingerprint = None
for chunk in pem_data.split(b"-----END CERTIFICATE-----"): if certs:
if b"-----BEGIN CERTIFICATE-----" not in chunk:
continue
block = chunk + b"-----END CERTIFICATE-----\n"
try: try:
cert = x509.load_pem_x509_certificate(block) first_cert = certs[0]
fingerprint = hashlib.sha256(
first_cert.public_bytes(serialization.Encoding.DER)
).hexdigest()
except Exception: except Exception:
continue fingerprint = None
certs.append(cert) else:
try:
if not certs: pem_text = pem_bytes.decode("utf-8")
return 0, None der_bytes = ssl.PEM_cert_to_DER_cert(pem_text)
fingerprint = hashlib.sha256(der_bytes).hexdigest()
try: except Exception:
first_cert = certs[0] fingerprint = None
fingerprint = hashlib.sha256(
first_cert.public_bytes(serialization.Encoding.DER)
).hexdigest()
except Exception:
fingerprint = None
count = len(certs) if certs else 1
prefix = fingerprint[:12] if fingerprint else None prefix = fingerprint[:12] if fingerprint else None
return len(certs), prefix include_default = self._should_layer_default_trust(certs)
return count, prefix, include_default
def save_server_certificate(self, pem_text: str) -> None: def save_server_certificate(self, pem_text: str) -> None:
if not pem_text: if not pem_text:
@@ -439,14 +435,15 @@ class AgentKeyStore:
return None return None
def build_ssl_context(self) -> Optional[ssl.SSLContext]: def build_ssl_context(self) -> Optional[ssl.SSLContext]:
if not os.path.isfile(self._server_certificate_path): pem_bytes, certs = self._load_server_certificates()
if not pem_bytes:
return None return None
try: try:
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
except Exception: except Exception:
try: try:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
except Exception: except Exception:
return None return None
@@ -466,26 +463,42 @@ class AgentKeyStore:
except Exception: except Exception:
pass pass
loaded = False pem_text = None
try: try:
context.load_verify_locations(cafile=self._server_certificate_path) pem_text = pem_bytes.decode("utf-8")
loaded = True
except Exception: except Exception:
pass pass
loaded = False
if pem_text:
try:
context.load_verify_locations(cadata=pem_text)
loaded = True
except Exception:
loaded = False
if not loaded: if not loaded:
try: try:
with open(self._server_certificate_path, "r", encoding="utf-8") as fh: context.load_verify_locations(cafile=self._server_certificate_path)
pem_text = fh.read() loaded = True
if pem_text:
context.load_verify_locations(cadata=pem_text)
loaded = True
except Exception: except Exception:
loaded = False loaded = False
if not loaded: if not loaded:
return None return None
include_default = self._should_layer_default_trust(certs)
try:
setattr(context, "_borealis_layered_default", include_default)
except Exception:
pass
if include_default:
try:
context.load_default_certs()
except Exception:
pass
verify_flag = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", None) verify_flag = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", None)
if verify_flag is not None: if verify_flag is not None:
try: try:
@@ -493,13 +506,61 @@ class AgentKeyStore:
except Exception: except Exception:
pass pass
try:
context.load_default_certs()
except Exception:
pass
return context return context
# ------------------------------------------------------------------
# Server certificate helpers (internal)
# ------------------------------------------------------------------
def _load_server_certificates(self) -> Tuple[Optional[bytes], List["x509.Certificate"]]:
try:
if not os.path.isfile(self._server_certificate_path):
return None, []
with open(self._server_certificate_path, "rb") as fh:
pem_bytes = fh.read()
except Exception:
return None, []
if not pem_bytes.strip():
return None, []
if x509 is None:
return pem_bytes, []
terminator = b"-----END CERTIFICATE-----"
certs: List["x509.Certificate"] = []
for chunk in pem_bytes.split(terminator):
if b"-----BEGIN CERTIFICATE-----" not in chunk:
continue
block = chunk + terminator + b"\n"
try:
cert = x509.load_pem_x509_certificate(block)
except Exception:
continue
certs.append(cert)
return pem_bytes, certs
def _should_layer_default_trust(self, certs: List["x509.Certificate"]) -> bool:
if not certs:
return True
try:
first_cert = certs[0]
is_self_issued = first_cert.issuer == first_cert.subject
except Exception:
return True
if not is_self_issued:
return True
try:
basic = first_cert.extensions.get_extension_for_class(x509.BasicConstraints) # type: ignore[attr-defined]
is_ca = bool(basic.value.ca)
except Exception:
is_ca = False
return is_ca
def save_server_signing_key(self, value: str) -> None: def save_server_signing_key(self, value: str) -> None:
if not value: if not value:
return return