Coverage for mcpgateway / cache / global_config_cache.py: 100%

58 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/cache/global_config_cache.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5 

6GlobalConfig In-Memory Cache. 

7 

8This module implements a thread-safe in-memory cache for GlobalConfig with TTL expiration. 

9GlobalConfig is a singleton configuration table that stores passthrough headers settings. 

10Since this data rarely changes (admin configuration), caching it in memory eliminates 

11thousands of redundant database queries under load. 

12 

13Performance Impact: 

14 - Before: 42,000+ DB queries per load test for 0-1 row table 

15 - After: 1 DB query per TTL period (default 60 seconds) 

16 

17Security Considerations: 

18 - Stale config window: Changes take up to TTL to propagate 

19 - Mitigation: Admin UI should call invalidate() after config changes 

20 - Cache poisoning: Not applicable (populated from DB only) 

21 - Information leakage: Not applicable (stores header names only, not values) 

22 

23Examples: 

24 >>> from unittest.mock import Mock, patch 

25 >>> from mcpgateway.cache.global_config_cache import GlobalConfigCache 

26 

27 >>> # Test cache miss and hit 

28 >>> cache = GlobalConfigCache(ttl_seconds=60) 

29 >>> mock_db = Mock() 

30 >>> mock_config = Mock() 

31 >>> mock_config.passthrough_headers = ["Authorization", "X-Request-ID"] 

32 >>> mock_db.query.return_value.first.return_value = mock_config 

33 

34 >>> # First call - cache miss, queries DB 

35 >>> result = cache.get(mock_db) 

36 >>> result.passthrough_headers 

37 ['Authorization', 'X-Request-ID'] 

38 >>> mock_db.query.return_value.first.call_count 

39 1 

40 

41 >>> # Second call - cache hit, no DB query 

42 >>> result = cache.get(mock_db) 

43 >>> mock_db.query.return_value.first.call_count 

44 1 

45 

46 >>> # After invalidation - queries DB again 

47 >>> cache.invalidate() 

48 >>> result = cache.get(mock_db) 

49 >>> mock_db.query.return_value.first.call_count 

50 2 

51""" 

52 

53# Standard 

54import logging 

55import threading 

56import time 

57 

58# Use standard logging to avoid circular imports with services 

59logger = logging.getLogger(__name__) 

60 

61 

62class GlobalConfigCache: 

63 """ 

64 Thread-safe in-memory cache for GlobalConfig with TTL. 

65 

66 This cache stores the GlobalConfig singleton to avoid repeated database queries. 

67 GlobalConfig contains passthrough headers configuration that rarely changes. 

68 

69 Attributes: 

70 ttl_seconds: Time-to-live in seconds before cache refresh 

71 _cache: Cached GlobalConfig object (or None) 

72 _expiry: Timestamp when cache expires 

73 _lock: Threading lock for thread-safe operations 

74 

75 Examples: 

76 >>> from unittest.mock import Mock 

77 >>> cache = GlobalConfigCache(ttl_seconds=60) 

78 >>> mock_db = Mock() 

79 >>> mock_db.query.return_value.first.return_value = None 

80 

81 >>> # Returns None when no GlobalConfig exists 

82 >>> cache.get(mock_db) is None 

83 True 

84 

85 >>> # get_passthrough_headers returns default when no config 

86 >>> cache.get_passthrough_headers(mock_db, ["Default-Header"]) 

87 ['Default-Header'] 

88 """ 

89 

90 # Sentinel value to distinguish "not cached" from "cached None" 

91 _NOT_CACHED = object() 

92 

93 def __init__(self, ttl_seconds: int = 60): 

94 """ 

95 Initialize the GlobalConfig cache. 

96 

97 Args: 

98 ttl_seconds: Time-to-live in seconds (default: 60). 

99 After this duration, the cache will refresh from DB. 

100 

101 Examples: 

102 >>> cache = GlobalConfigCache(ttl_seconds=30) 

103 >>> cache.ttl_seconds 

104 30 

105 """ 

106 self.ttl_seconds = ttl_seconds 

107 self._cache = self._NOT_CACHED # Use sentinel to distinguish from cached None 

108 self._expiry: float = 0 

109 self._lock = threading.Lock() 

110 self._hit_count = 0 

111 self._miss_count = 0 

112 

113 def get(self, db): 

114 """ 

115 Get GlobalConfig from cache or database. 

116 

117 Uses a double-checked locking pattern for thread safety with minimal 

118 lock contention on the hot path (cache hit). 

119 

120 Args: 

121 db: SQLAlchemy database session 

122 

123 Returns: 

124 GlobalConfig object or None if not configured 

125 

126 Examples: 

127 >>> from unittest.mock import Mock 

128 >>> cache = GlobalConfigCache(ttl_seconds=60) 

129 >>> mock_db = Mock() 

130 >>> mock_config = Mock() 

131 >>> mock_db.query.return_value.first.return_value = mock_config 

132 >>> cache.get(mock_db) is mock_config 

133 True 

134 """ 

135 now = time.time() 

136 

137 # Fast path: cache hit (no lock needed for read) 

138 # Use sentinel check to properly cache None (empty GlobalConfig table) 

139 if now < self._expiry and self._cache is not self._NOT_CACHED: 

140 self._hit_count += 1 

141 return self._cache 

142 

143 # Slow path: cache miss or expired - acquire lock 

144 with self._lock: 

145 # Double-check after acquiring lock (another thread may have refreshed) 

146 if now < self._expiry and self._cache is not self._NOT_CACHED: 

147 self._hit_count += 1 

148 return self._cache 

149 

150 # Import here to avoid circular imports 

151 # First-Party 

152 from mcpgateway.db import GlobalConfig # pylint: disable=import-outside-toplevel 

153 

154 # Refresh from database 

155 self._cache = db.query(GlobalConfig).first() 

156 self._expiry = now + self.ttl_seconds 

157 self._miss_count += 1 

158 

159 if self._cache: 

160 logger.debug(f"GlobalConfig cache refreshed (TTL: {self.ttl_seconds}s)") 

161 else: 

162 logger.debug("GlobalConfig not found in database, using defaults") 

163 

164 return self._cache 

165 

166 def get_passthrough_headers(self, db, default: list[str]) -> list[str]: 

167 """ 

168 Get passthrough headers based on PASSTHROUGH_HEADERS_SOURCE setting. 

169 

170 Supports three modes: 

171 - "env": Environment variable always wins (ignore database) 

172 - "db": Database wins if configured, fallback to env (default, backward compatible) 

173 - "merge": Union of env and database headers (DB overrides for duplicates) 

174 

175 Args: 

176 db: SQLAlchemy database session 

177 default: Default headers from environment variable (settings.default_passthrough_headers) 

178 

179 Returns: 

180 List of allowed passthrough header names 

181 

182 Examples: 

183 >>> from unittest.mock import Mock, patch 

184 >>> cache = GlobalConfigCache(ttl_seconds=60) 

185 >>> mock_db = Mock() 

186 

187 >>> # "db" mode (default): When no config exists, returns default 

188 >>> mock_db.query.return_value.first.return_value = None 

189 >>> cache.invalidate() # Clear any cached value 

190 >>> with patch("mcpgateway.config.settings") as mock_settings: 

191 ... mock_settings.passthrough_headers_source = "db" 

192 ... cache.get_passthrough_headers(mock_db, ["X-Default"]) 

193 ['X-Default'] 

194 

195 >>> # "env" mode: Always returns default, ignores database 

196 >>> mock_config = Mock() 

197 >>> mock_config.passthrough_headers = ["Authorization"] 

198 >>> mock_db.query.return_value.first.return_value = mock_config 

199 >>> cache.invalidate() 

200 >>> with patch("mcpgateway.config.settings") as mock_settings: 

201 ... mock_settings.passthrough_headers_source = "env" 

202 ... cache.get_passthrough_headers(mock_db, ["X-Default"]) 

203 ['X-Default'] 

204 

205 >>> # "merge" mode: Combines both sources 

206 >>> cache.invalidate() 

207 >>> with patch("mcpgateway.config.settings") as mock_settings: 

208 ... mock_settings.passthrough_headers_source = "merge" 

209 ... result = cache.get_passthrough_headers(mock_db, ["X-Default"]) 

210 ... "X-Default" in result and "Authorization" in result 

211 True 

212 """ 

213 # Import here to avoid circular imports 

214 # First-Party 

215 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel 

216 

217 source = settings.passthrough_headers_source 

218 

219 if source == "env": 

220 # Environment always wins - don't query database at all 

221 logger.debug("Passthrough headers source=env: using environment variable only") 

222 return default if default else [] 

223 

224 config = self.get(db) 

225 

226 if source == "merge": 

227 # Union of both sources, preserving original casing 

228 # Use lowercase keys for deduplication, original casing for values 

229 env_headers = {h.lower(): h for h in (default or [])} 

230 db_headers = {h.lower(): h for h in (config.passthrough_headers or [])} if config else {} 

231 # DB values override env for same header (handles case differences) 

232 merged = {**env_headers, **db_headers} 

233 result = list(merged.values()) 

234 logger.debug(f"Passthrough headers source=merge: combined {len(result)} headers from env and db") 

235 return result 

236 

237 # Default "db" mode - current behavior for backward compatibility 

238 if config and config.passthrough_headers: 

239 logger.debug("Passthrough headers source=db: using database configuration") 

240 return config.passthrough_headers 

241 logger.debug("Passthrough headers source=db: no database config, using environment default") 

242 return default 

243 

244 def invalidate(self) -> None: 

245 """ 

246 Invalidate the cache, forcing a refresh on next access. 

247 

248 Call this method after updating GlobalConfig in the database 

249 to ensure changes propagate immediately. 

250 

251 Examples: 

252 >>> cache = GlobalConfigCache(ttl_seconds=60) 

253 >>> cache._expiry = time.time() + 1000 # Set future expiry 

254 >>> cache.invalidate() 

255 >>> cache._expiry 

256 0 

257 """ 

258 with self._lock: 

259 self._cache = self._NOT_CACHED 

260 self._expiry = 0 

261 logger.info("GlobalConfig cache invalidated") 

262 

263 def stats(self) -> dict: 

264 """ 

265 Get cache statistics. 

266 

267 Returns: 

268 Dictionary with hit_count, miss_count, and hit_rate 

269 

270 Examples: 

271 >>> cache = GlobalConfigCache(ttl_seconds=60) 

272 >>> cache._hit_count = 90 

273 >>> cache._miss_count = 10 

274 >>> stats = cache.stats() 

275 >>> stats["hit_rate"] 

276 0.9 

277 """ 

278 total = self._hit_count + self._miss_count 

279 return { 

280 "hit_count": self._hit_count, 

281 "miss_count": self._miss_count, 

282 "hit_rate": self._hit_count / total if total > 0 else 0.0, 

283 "ttl_seconds": self.ttl_seconds, 

284 "is_cached": self._cache is not self._NOT_CACHED and time.time() < self._expiry, 

285 } 

286 

287 

288# Global singleton instance 

289# This is the primary interface for accessing cached GlobalConfig 

290global_config_cache = GlobalConfigCache()