Coverage for mcpgateway / utils / ssl_context_cache.py: 100%

99 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""SSL context caching utilities for ContextForge services. 

3 

4This module provides caching for SSL contexts to avoid repeatedly creating 

5them for the same CA certificates, improving performance for services that 

6make many SSL connections. 

7""" 

8 

9# Standard 

10from datetime import datetime, timedelta 

11import hashlib 

12import logging 

13import os 

14import ssl 

15import tempfile 

16 

17logger = logging.getLogger(__name__) 

18 

19# Cache for SSL contexts keyed by SSL parameter hash 

20_ssl_context_cache: dict[str, ssl.SSLContext] = {} 

21_ssl_context_cache_timestamps: dict[str, datetime] = {} 

22 

23_SSL_CONTEXT_CACHE_MAX_SIZE = int(os.getenv("SSL_CONTEXT_CACHE_MAX_SIZE", "100")) 

24_SSL_CONTEXT_CACHE_TTL = os.getenv("SSL_CONTEXT_CACHE_TTL") 

25if _SSL_CONTEXT_CACHE_TTL is not None and _SSL_CONTEXT_CACHE_TTL.strip() != "": 

26 try: 

27 _SSL_CONTEXT_CACHE_TTL = int(_SSL_CONTEXT_CACHE_TTL) 

28 except ValueError: 

29 raise ValueError("SSL_CONTEXT_CACHE_TTL must be an integer number of seconds") 

30else: 

31 _SSL_CONTEXT_CACHE_TTL = None 

32 

33 

34def _is_expired(cache_key: str) -> bool: 

35 """Check if a cached SSL context entry has expired based on TTL. 

36 

37 Args: 

38 cache_key: The cache key to check for expiration. 

39 

40 Returns: 

41 True if the entry has expired and should be refreshed, False otherwise. 

42 """ 

43 if _SSL_CONTEXT_CACHE_TTL is None: 

44 return False 

45 created_at = _ssl_context_cache_timestamps.get(cache_key) 

46 if created_at is None: 

47 return False 

48 

49 return datetime.now() - created_at > timedelta(seconds=_SSL_CONTEXT_CACHE_TTL) 

50 

51 

52def _is_pem(value: str) -> bool: 

53 """Check if a string looks like inline PEM content rather than a file path. 

54 

55 Args: 

56 value: String to check (file path or PEM content). 

57 

58 Returns: 

59 True if the string starts with a PEM header. 

60 """ 

61 return value.lstrip().startswith("-----BEGIN ") 

62 

63 

64def _load_client_cert_chain(ctx: ssl.SSLContext, client_cert: str, client_key: str) -> None: 

65 """Load client cert/key into an SSL context, handling both file paths and PEM strings. 

66 

67 ``ssl.SSLContext.load_cert_chain`` only accepts file paths. When the 

68 values are inline PEM content (stored in the database), we write them 

69 to secure temporary files and load from there. 

70 

71 Args: 

72 ctx: SSL context to load the client certificate chain into. 

73 client_cert: Client certificate as a file path or inline PEM string. 

74 client_key: Client private key as a file path or inline PEM string. 

75 """ 

76 cert_is_pem = _is_pem(client_cert) 

77 key_is_pem = _is_pem(client_key) 

78 

79 if not cert_is_pem and not key_is_pem: 

80 # Both are file paths — use directly 

81 ctx.load_cert_chain(certfile=client_cert, keyfile=client_key) 

82 return 

83 

84 # At least one value is inline PEM — write temp files 

85 cert_tmp = key_tmp = None 

86 try: 

87 if cert_is_pem: 

88 cert_tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False, encoding="utf-8") 

89 cert_tmp.write(client_cert) 

90 cert_tmp.close() 

91 cert_path = cert_tmp.name 

92 else: 

93 cert_path = client_cert 

94 

95 if key_is_pem: 

96 key_tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False, encoding="utf-8") 

97 key_tmp.write(client_key) 

98 key_tmp.close() 

99 key_path = key_tmp.name 

100 else: 

101 key_path = client_key 

102 

103 ctx.load_cert_chain(certfile=cert_path, keyfile=key_path) 

104 finally: 

105 # Remove temp files immediately after loading 

106 if cert_tmp is not None: 

107 try: 

108 os.unlink(cert_tmp.name) 

109 except OSError: 

110 logger.debug("Failed to remove temp cert file %s", cert_tmp.name) 

111 if key_tmp is not None: 

112 try: 

113 os.unlink(key_tmp.name) 

114 except OSError: 

115 logger.debug("Failed to remove temp key file %s", key_tmp.name) 

116 

117 

118def get_cached_ssl_context( 

119 ca_certificate: str, 

120 client_cert: str | None = None, 

121 client_key: str | None = None, 

122) -> ssl.SSLContext: 

123 """Get or create cached SSL context for a CA certificate. 

124 

125 Args: 

126 ca_certificate: CA certificate in PEM format (str or bytes) 

127 client_cert: Optional client cert path or PEM for mTLS 

128 client_key: Optional client key path or PEM for mTLS 

129 

130 Returns: 

131 ssl.SSLContext: Configured SSL context 

132 

133 Raises: 

134 ValueError: If only one of client_cert/client_key is provided. 

135 

136 Examples: 

137 The actual `ssl.SSLContext.load_verify_locations()` call requires valid PEM 

138 data; in doctests we mock it to focus on caching behavior. 

139 

140 >>> from unittest.mock import Mock, patch 

141 >>> from mcpgateway.utils.ssl_context_cache import clear_ssl_context_cache, get_cached_ssl_context 

142 >>> clear_ssl_context_cache() 

143 >>> with patch("mcpgateway.utils.ssl_context_cache.ssl.create_default_context") as mock_create: 

144 ... ctx = Mock() 

145 ... mock_create.return_value = ctx 

146 ... a = get_cached_ssl_context("CERTDATA") 

147 ... b = get_cached_ssl_context(b"CERTDATA") # same bytes => same cache entry 

148 ... (a is ctx, b is ctx, mock_create.call_count) 

149 (True, True, 1) 

150 

151 Note: 

152 The function handles bytes, str, and other types (for test mocks). 

153 SSL contexts are cached by the SHA256 hash of the certificate to 

154 avoid repeated expensive SSL setup operations. 

155 """ 

156 # Ensure CA certificate is normalized to bytes for hash calculation 

157 if isinstance(ca_certificate, bytes): 

158 ca_cert_bytes = ca_certificate 

159 elif isinstance(ca_certificate, str): 

160 ca_cert_bytes = ca_certificate.encode() 

161 else: 

162 ca_cert_bytes = str(ca_certificate).encode() 

163 

164 # Client cert/key may be either path-like content or inlined PEM string. 

165 client_cert_value = client_cert or "" 

166 client_key_value = client_key or "" 

167 

168 # Build stable cache key incrementally (avoids delimiter collisions). 

169 key_hash = hashlib.sha256() 

170 key_hash.update(b"ca_cert:") 

171 key_hash.update(ca_cert_bytes) 

172 key_hash.update(b"|client_cert:") 

173 key_hash.update(client_cert_value.encode()) 

174 key_hash.update(b"|client_key:") 

175 key_hash.update(client_key_value.encode()) 

176 

177 cache_key = key_hash.hexdigest() 

178 

179 if cache_key in _ssl_context_cache and not _is_expired(cache_key): 

180 return _ssl_context_cache[cache_key] 

181 

182 # If expired, clear this entry so it is refreshed below 

183 if cache_key in _ssl_context_cache: 

184 _ssl_context_cache.pop(cache_key, None) 

185 _ssl_context_cache_timestamps.pop(cache_key, None) 

186 

187 # Create new SSL context and configure CA cert 

188 ctx = ssl.create_default_context() 

189 ctx.load_verify_locations(cadata=ca_certificate) 

190 

191 # Validate mTLS: require both or neither 

192 if bool(client_cert) != bool(client_key): 

193 raise ValueError("mTLS requires both client_cert and client_key; got only one") 

194 

195 # Load client certificates for mTLS when provided 

196 if client_cert and client_key: 

197 _load_client_cert_chain(ctx, client_cert, client_key) 

198 

199 # Cache entry creation timestamp if TTL is enabled 

200 _ssl_context_cache[cache_key] = ctx 

201 if _SSL_CONTEXT_CACHE_TTL is not None: 

202 _ssl_context_cache_timestamps[cache_key] = datetime.now() 

203 

204 # Evict all cache if size limit exceeded; keep this newly inserted item. 

205 # This avoids growing indefinitely without requiring LRU tracking. 

206 if len(_ssl_context_cache) > _SSL_CONTEXT_CACHE_MAX_SIZE: 

207 current_ctx = _ssl_context_cache.pop(cache_key) 

208 current_ts = _ssl_context_cache_timestamps.pop(cache_key, None) 

209 

210 _ssl_context_cache.clear() 

211 _ssl_context_cache_timestamps.clear() 

212 

213 _ssl_context_cache[cache_key] = current_ctx 

214 if current_ts is not None: 

215 _ssl_context_cache_timestamps[cache_key] = current_ts 

216 

217 return ctx 

218 

219 

220def clear_ssl_context_cache() -> None: 

221 """Clear the SSL context cache. 

222 

223 Call this function: 

224 - In test fixtures to ensure test isolation 

225 - After CA certificate rotation 

226 - When memory pressure requires cache cleanup 

227 

228 Examples: 

229 >>> from unittest.mock import Mock, patch 

230 >>> from mcpgateway.utils.ssl_context_cache import clear_ssl_context_cache, get_cached_ssl_context 

231 >>> with patch("mcpgateway.utils.ssl_context_cache.ssl.create_default_context") as mock_create: 

232 ... mock_create.return_value = Mock() 

233 ... _ = get_cached_ssl_context("CERTDATA") 

234 ... clear_ssl_context_cache() 

235 ... _ = get_cached_ssl_context("CERTDATA") 

236 ... mock_create.call_count 

237 2 

238 """ 

239 _ssl_context_cache.clear() 

240 _ssl_context_cache_timestamps.clear()