Coverage for mcpgateway / utils / ssl_context_cache.py: 100%
99 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""SSL context caching utilities for ContextForge services.
4This module provides caching for SSL contexts to avoid repeatedly creating
5them for the same CA certificates, improving performance for services that
6make many SSL connections.
7"""
9# Standard
10from datetime import datetime, timedelta
11import hashlib
12import logging
13import os
14import ssl
15import tempfile
17logger = logging.getLogger(__name__)
19# Cache for SSL contexts keyed by SSL parameter hash
20_ssl_context_cache: dict[str, ssl.SSLContext] = {}
21_ssl_context_cache_timestamps: dict[str, datetime] = {}
23_SSL_CONTEXT_CACHE_MAX_SIZE = int(os.getenv("SSL_CONTEXT_CACHE_MAX_SIZE", "100"))
24_SSL_CONTEXT_CACHE_TTL = os.getenv("SSL_CONTEXT_CACHE_TTL")
25if _SSL_CONTEXT_CACHE_TTL is not None and _SSL_CONTEXT_CACHE_TTL.strip() != "":
26 try:
27 _SSL_CONTEXT_CACHE_TTL = int(_SSL_CONTEXT_CACHE_TTL)
28 except ValueError:
29 raise ValueError("SSL_CONTEXT_CACHE_TTL must be an integer number of seconds")
30else:
31 _SSL_CONTEXT_CACHE_TTL = None
34def _is_expired(cache_key: str) -> bool:
35 """Check if a cached SSL context entry has expired based on TTL.
37 Args:
38 cache_key: The cache key to check for expiration.
40 Returns:
41 True if the entry has expired and should be refreshed, False otherwise.
42 """
43 if _SSL_CONTEXT_CACHE_TTL is None:
44 return False
45 created_at = _ssl_context_cache_timestamps.get(cache_key)
46 if created_at is None:
47 return False
49 return datetime.now() - created_at > timedelta(seconds=_SSL_CONTEXT_CACHE_TTL)
52def _is_pem(value: str) -> bool:
53 """Check if a string looks like inline PEM content rather than a file path.
55 Args:
56 value: String to check (file path or PEM content).
58 Returns:
59 True if the string starts with a PEM header.
60 """
61 return value.lstrip().startswith("-----BEGIN ")
64def _load_client_cert_chain(ctx: ssl.SSLContext, client_cert: str, client_key: str) -> None:
65 """Load client cert/key into an SSL context, handling both file paths and PEM strings.
67 ``ssl.SSLContext.load_cert_chain`` only accepts file paths. When the
68 values are inline PEM content (stored in the database), we write them
69 to secure temporary files and load from there.
71 Args:
72 ctx: SSL context to load the client certificate chain into.
73 client_cert: Client certificate as a file path or inline PEM string.
74 client_key: Client private key as a file path or inline PEM string.
75 """
76 cert_is_pem = _is_pem(client_cert)
77 key_is_pem = _is_pem(client_key)
79 if not cert_is_pem and not key_is_pem:
80 # Both are file paths — use directly
81 ctx.load_cert_chain(certfile=client_cert, keyfile=client_key)
82 return
84 # At least one value is inline PEM — write temp files
85 cert_tmp = key_tmp = None
86 try:
87 if cert_is_pem:
88 cert_tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False, encoding="utf-8")
89 cert_tmp.write(client_cert)
90 cert_tmp.close()
91 cert_path = cert_tmp.name
92 else:
93 cert_path = client_cert
95 if key_is_pem:
96 key_tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False, encoding="utf-8")
97 key_tmp.write(client_key)
98 key_tmp.close()
99 key_path = key_tmp.name
100 else:
101 key_path = client_key
103 ctx.load_cert_chain(certfile=cert_path, keyfile=key_path)
104 finally:
105 # Remove temp files immediately after loading
106 if cert_tmp is not None:
107 try:
108 os.unlink(cert_tmp.name)
109 except OSError:
110 logger.debug("Failed to remove temp cert file %s", cert_tmp.name)
111 if key_tmp is not None:
112 try:
113 os.unlink(key_tmp.name)
114 except OSError:
115 logger.debug("Failed to remove temp key file %s", key_tmp.name)
118def get_cached_ssl_context(
119 ca_certificate: str,
120 client_cert: str | None = None,
121 client_key: str | None = None,
122) -> ssl.SSLContext:
123 """Get or create cached SSL context for a CA certificate.
125 Args:
126 ca_certificate: CA certificate in PEM format (str or bytes)
127 client_cert: Optional client cert path or PEM for mTLS
128 client_key: Optional client key path or PEM for mTLS
130 Returns:
131 ssl.SSLContext: Configured SSL context
133 Raises:
134 ValueError: If only one of client_cert/client_key is provided.
136 Examples:
137 The actual `ssl.SSLContext.load_verify_locations()` call requires valid PEM
138 data; in doctests we mock it to focus on caching behavior.
140 >>> from unittest.mock import Mock, patch
141 >>> from mcpgateway.utils.ssl_context_cache import clear_ssl_context_cache, get_cached_ssl_context
142 >>> clear_ssl_context_cache()
143 >>> with patch("mcpgateway.utils.ssl_context_cache.ssl.create_default_context") as mock_create:
144 ... ctx = Mock()
145 ... mock_create.return_value = ctx
146 ... a = get_cached_ssl_context("CERTDATA")
147 ... b = get_cached_ssl_context(b"CERTDATA") # same bytes => same cache entry
148 ... (a is ctx, b is ctx, mock_create.call_count)
149 (True, True, 1)
151 Note:
152 The function handles bytes, str, and other types (for test mocks).
153 SSL contexts are cached by the SHA256 hash of the certificate to
154 avoid repeated expensive SSL setup operations.
155 """
156 # Ensure CA certificate is normalized to bytes for hash calculation
157 if isinstance(ca_certificate, bytes):
158 ca_cert_bytes = ca_certificate
159 elif isinstance(ca_certificate, str):
160 ca_cert_bytes = ca_certificate.encode()
161 else:
162 ca_cert_bytes = str(ca_certificate).encode()
164 # Client cert/key may be either path-like content or inlined PEM string.
165 client_cert_value = client_cert or ""
166 client_key_value = client_key or ""
168 # Build stable cache key incrementally (avoids delimiter collisions).
169 key_hash = hashlib.sha256()
170 key_hash.update(b"ca_cert:")
171 key_hash.update(ca_cert_bytes)
172 key_hash.update(b"|client_cert:")
173 key_hash.update(client_cert_value.encode())
174 key_hash.update(b"|client_key:")
175 key_hash.update(client_key_value.encode())
177 cache_key = key_hash.hexdigest()
179 if cache_key in _ssl_context_cache and not _is_expired(cache_key):
180 return _ssl_context_cache[cache_key]
182 # If expired, clear this entry so it is refreshed below
183 if cache_key in _ssl_context_cache:
184 _ssl_context_cache.pop(cache_key, None)
185 _ssl_context_cache_timestamps.pop(cache_key, None)
187 # Create new SSL context and configure CA cert
188 ctx = ssl.create_default_context()
189 ctx.load_verify_locations(cadata=ca_certificate)
191 # Validate mTLS: require both or neither
192 if bool(client_cert) != bool(client_key):
193 raise ValueError("mTLS requires both client_cert and client_key; got only one")
195 # Load client certificates for mTLS when provided
196 if client_cert and client_key:
197 _load_client_cert_chain(ctx, client_cert, client_key)
199 # Cache entry creation timestamp if TTL is enabled
200 _ssl_context_cache[cache_key] = ctx
201 if _SSL_CONTEXT_CACHE_TTL is not None:
202 _ssl_context_cache_timestamps[cache_key] = datetime.now()
204 # Evict all cache if size limit exceeded; keep this newly inserted item.
205 # This avoids growing indefinitely without requiring LRU tracking.
206 if len(_ssl_context_cache) > _SSL_CONTEXT_CACHE_MAX_SIZE:
207 current_ctx = _ssl_context_cache.pop(cache_key)
208 current_ts = _ssl_context_cache_timestamps.pop(cache_key, None)
210 _ssl_context_cache.clear()
211 _ssl_context_cache_timestamps.clear()
213 _ssl_context_cache[cache_key] = current_ctx
214 if current_ts is not None:
215 _ssl_context_cache_timestamps[cache_key] = current_ts
217 return ctx
220def clear_ssl_context_cache() -> None:
221 """Clear the SSL context cache.
223 Call this function:
224 - In test fixtures to ensure test isolation
225 - After CA certificate rotation
226 - When memory pressure requires cache cleanup
228 Examples:
229 >>> from unittest.mock import Mock, patch
230 >>> from mcpgateway.utils.ssl_context_cache import clear_ssl_context_cache, get_cached_ssl_context
231 >>> with patch("mcpgateway.utils.ssl_context_cache.ssl.create_default_context") as mock_create:
232 ... mock_create.return_value = Mock()
233 ... _ = get_cached_ssl_context("CERTDATA")
234 ... clear_ssl_context_cache()
235 ... _ = get_cached_ssl_context("CERTDATA")
236 ... mock_create.call_count
237 2
238 """
239 _ssl_context_cache.clear()
240 _ssl_context_cache_timestamps.clear()