Coverage for mcpgateway / cache / tool_lookup_cache.py: 96%

168 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Tool lookup cache (tool name -> tool config) with L1 memory + L2 Redis. 

3 

4This cache targets the hot-path tool lookup in ToolService.invoke_tool by 

5avoiding a DB query per tool invocation. It uses a per-worker in-memory 

6cache with TTL and optional Redis backing for distributed deployments. 

7""" 

8 

9# Future 

10from __future__ import annotations 

11 

12# Standard 

13from collections import OrderedDict 

14from dataclasses import dataclass 

15import logging 

16import threading 

17import time 

18from typing import Any, Dict, Optional 

19 

20# Third-Party 

21import orjson 

22 

23logger = logging.getLogger(__name__) 

24 

25 

26@dataclass 

27class CacheEntry: 

28 """Cache entry with value and expiry timestamp.""" 

29 

30 value: Dict[str, Any] 

31 expiry: float 

32 

33 def is_expired(self) -> bool: 

34 """Return True if the cache entry has expired. 

35 

36 Returns: 

37 True if expired, otherwise False. 

38 

39 Examples: 

40 >>> from unittest.mock import patch 

41 >>> from mcpgateway.cache.tool_lookup_cache import CacheEntry 

42 >>> with patch("time.time", return_value=1000): 

43 ... CacheEntry(value={"k": "v"}, expiry=999).is_expired() 

44 True 

45 >>> with patch("time.time", return_value=1000): 

46 ... CacheEntry(value={"k": "v"}, expiry=1001).is_expired() 

47 False 

48 """ 

49 return time.time() >= self.expiry 

50 

51 

52class ToolLookupCache: 

53 """Two-tier cache for tool lookups by name. 

54 

55 L1: in-memory LRU/TTL per worker. 

56 L2: Redis (optional, shared across workers). 

57 """ 

58 

59 def __init__(self) -> None: 

60 """Initialize cache settings and in-memory structures. 

61 

62 Examples: 

63 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache 

64 >>> cache = ToolLookupCache() 

65 >>> isinstance(cache.enabled, bool) 

66 True 

67 """ 

68 try: 

69 # First-Party 

70 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel 

71 

72 self._enabled = getattr(settings, "tool_lookup_cache_enabled", True) 

73 self._ttl_seconds = getattr(settings, "tool_lookup_cache_ttl_seconds", 60) 

74 self._negative_ttl_seconds = getattr(settings, "tool_lookup_cache_negative_ttl_seconds", 10) 

75 self._l1_maxsize = getattr(settings, "tool_lookup_cache_l1_maxsize", 10000) 

76 self._l2_enabled = getattr(settings, "tool_lookup_cache_l2_enabled", True) and settings.cache_type == "redis" 

77 self._cache_prefix = getattr(settings, "cache_prefix", "mcpgw:") 

78 except ImportError: 

79 self._enabled = True 

80 self._ttl_seconds = 60 

81 self._negative_ttl_seconds = 10 

82 self._l1_maxsize = 10000 

83 self._l2_enabled = False 

84 self._cache_prefix = "mcpgw:" 

85 

86 self._cache: "OrderedDict[str, CacheEntry]" = OrderedDict() 

87 self._lock = threading.Lock() 

88 

89 self._redis_checked = False 

90 self._redis_available = False 

91 

92 self._l1_hit_count = 0 

93 self._l1_miss_count = 0 

94 self._l2_hit_count = 0 

95 self._l2_miss_count = 0 

96 

97 logger.info( 

98 "ToolLookupCache initialized: enabled=%s l1_max=%s ttl=%ss l2_enabled=%s", 

99 self._enabled, 

100 self._l1_maxsize, 

101 self._ttl_seconds, 

102 self._l2_enabled, 

103 ) 

104 

105 @property 

106 def enabled(self) -> bool: 

107 """Return True if the cache is enabled. 

108 

109 Returns: 

110 True if enabled, otherwise False. 

111 

112 Examples: 

113 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache 

114 >>> ToolLookupCache().enabled in (True, False) 

115 True 

116 """ 

117 return self._enabled 

118 

119 def _redis_key(self, name: str) -> str: 

120 """Build the Redis key for a tool name. 

121 

122 Args: 

123 name: Tool name. 

124 

125 Returns: 

126 Redis key for the tool lookup entry. 

127 """ 

128 return f"{self._cache_prefix}tool_lookup:{name}" 

129 

130 def _gateway_set_key(self, gateway_id: str) -> str: 

131 """Build the Redis set key for tools in a gateway. 

132 

133 Args: 

134 gateway_id: Gateway ID. 

135 

136 Returns: 

137 Redis set key for gateway tool names. 

138 """ 

139 return f"{self._cache_prefix}tool_lookup:gateway:{gateway_id}" 

140 

141 async def _get_redis_client(self): 

142 """Return a Redis client if L2 is enabled and available. 

143 

144 Returns: 

145 Redis client instance or None. 

146 """ 

147 if not self._l2_enabled: 

148 return None 

149 try: 

150 # First-Party 

151 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

152 

153 client = await get_redis_client() 

154 if client and not self._redis_checked: 154 ↛ 157line 154 didn't jump to line 157 because the condition on line 154 was always true

155 self._redis_checked = True 

156 self._redis_available = True 

157 return client 

158 except Exception: 

159 if not self._redis_checked: 159 ↛ 162line 159 didn't jump to line 162 because the condition on line 159 was always true

160 self._redis_checked = True 

161 self._redis_available = False 

162 return None 

163 

164 def _get_l1(self, name: str) -> Optional[Dict[str, Any]]: 

165 """Fetch a cached payload from L1 if present and not expired. 

166 

167 Args: 

168 name: Tool name. 

169 

170 Returns: 

171 Cached payload dict or None. 

172 """ 

173 with self._lock: 

174 entry = self._cache.get(name) 

175 if entry and not entry.is_expired(): 

176 # LRU: move to end on hit 

177 self._cache.move_to_end(name) 

178 self._l1_hit_count += 1 

179 return entry.value 

180 if entry: 

181 self._cache.pop(name, None) 

182 self._l1_miss_count += 1 

183 return None 

184 

185 def _set_l1(self, name: str, value: Dict[str, Any], ttl: int) -> None: 

186 """Store a payload in the L1 cache with TTL. 

187 

188 Args: 

189 name: Tool name. 

190 value: Payload to cache. 

191 ttl: Time to live in seconds. 

192 """ 

193 with self._lock: 

194 if name in self._cache: 

195 self._cache.pop(name, None) 

196 elif len(self._cache) >= self._l1_maxsize: 

197 self._cache.popitem(last=False) 

198 self._cache[name] = CacheEntry(value=value, expiry=time.time() + ttl) 

199 

200 async def get(self, name: str) -> Optional[Dict[str, Any]]: 

201 """Get cached payload for a tool name, checking L1 then L2. 

202 

203 Args: 

204 name: Tool name. 

205 

206 Returns: 

207 Cached payload dict or None. 

208 

209 Examples: 

210 >>> import asyncio 

211 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache 

212 >>> cache = ToolLookupCache() 

213 >>> cache._enabled = True 

214 >>> cache._l2_enabled = False 

215 >>> asyncio.run(cache.get("missing")) is None 

216 True 

217 >>> asyncio.run(cache.set("t1", {"tool": {"name": "t1"}}, ttl=60)) 

218 >>> asyncio.run(cache.get("t1"))["tool"]["name"] 

219 't1' 

220 """ 

221 if not self._enabled: 

222 return None 

223 

224 cached = self._get_l1(name) 

225 if cached is not None: 

226 return cached 

227 

228 redis = await self._get_redis_client() 

229 if not redis: 

230 return None 

231 

232 try: 

233 data = await redis.get(self._redis_key(name)) 

234 if data: 

235 self._l2_hit_count += 1 

236 payload = orjson.loads(data) 

237 self._set_l1(name, payload, self._ttl_seconds) 

238 return payload 

239 self._l2_miss_count += 1 

240 except Exception as exc: 

241 logger.debug("ToolLookupCache Redis get failed: %s", exc) 

242 return None 

243 

244 async def set(self, name: str, payload: Dict[str, Any], ttl: Optional[int] = None, gateway_id: Optional[str] = None) -> None: 

245 """Store a payload in cache and update gateway index if provided. 

246 

247 Args: 

248 name: Tool name. 

249 payload: Payload to cache. 

250 ttl: Time to live in seconds (defaults to configured TTL). 

251 gateway_id: Gateway ID for invalidation set tracking. 

252 

253 Examples: 

254 >>> import asyncio 

255 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache 

256 >>> cache = ToolLookupCache() 

257 >>> cache._enabled = True 

258 >>> cache._l2_enabled = False 

259 >>> asyncio.run(cache.set("t1", {"status": "ok"}, ttl=60)) 

260 >>> asyncio.run(cache.get("t1")) 

261 {'status': 'ok'} 

262 """ 

263 if not self._enabled: 

264 return 

265 

266 effective_ttl = ttl if ttl is not None else self._ttl_seconds 

267 self._set_l1(name, payload, effective_ttl) 

268 

269 redis = await self._get_redis_client() 

270 if not redis: 

271 return 

272 

273 try: 

274 await redis.setex(self._redis_key(name), effective_ttl, orjson.dumps(payload)) 

275 if gateway_id: 275 ↛ exitline 275 didn't return from function 'set' because the condition on line 275 was always true

276 set_key = self._gateway_set_key(gateway_id) 

277 await redis.sadd(set_key, name) 

278 await redis.expire(set_key, max(effective_ttl, self._ttl_seconds)) 

279 except Exception as exc: 

280 logger.debug("ToolLookupCache Redis set failed: %s", exc) 

281 

282 async def set_negative(self, name: str, status: str) -> None: 

283 """Store a negative cache entry for a tool name. 

284 

285 Args: 

286 name: Tool name. 

287 status: Negative status (missing, inactive, offline). 

288 

289 Examples: 

290 >>> import asyncio 

291 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache 

292 >>> cache = ToolLookupCache() 

293 >>> cache._enabled = True 

294 >>> cache._l2_enabled = False 

295 >>> asyncio.run(cache.set_negative("t1", "missing")) 

296 >>> asyncio.run(cache.get("t1")) 

297 {'status': 'missing'} 

298 """ 

299 payload = {"status": status} 

300 await self.set(name=name, payload=payload, ttl=self._negative_ttl_seconds) 

301 

302 async def invalidate(self, name: str, gateway_id: Optional[str] = None) -> None: 

303 """Invalidate a tool cache entry by name. 

304 

305 Args: 

306 name: Tool name. 

307 gateway_id: Gateway ID for invalidation set tracking. 

308 

309 Examples: 

310 >>> import asyncio 

311 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache 

312 >>> cache = ToolLookupCache() 

313 >>> cache._enabled = True 

314 >>> cache._l2_enabled = False 

315 >>> asyncio.run(cache.set("t1", {"status": "ok"}, ttl=60)) 

316 >>> asyncio.run(cache.invalidate("t1")) 

317 >>> asyncio.run(cache.get("t1")) is None 

318 True 

319 """ 

320 if not self._enabled: 

321 return 

322 

323 with self._lock: 

324 self._cache.pop(name, None) 

325 

326 redis = await self._get_redis_client() 

327 if not redis: 

328 return 

329 

330 try: 

331 await redis.delete(self._redis_key(name)) 

332 if gateway_id: 332 ↛ 334line 332 didn't jump to line 334 because the condition on line 332 was always true

333 await redis.srem(self._gateway_set_key(gateway_id), name) 

334 await redis.publish("mcpgw:cache:invalidate", f"tool_lookup:{name}") 

335 except Exception as exc: 

336 logger.debug("ToolLookupCache Redis invalidate failed: %s", exc) 

337 

338 async def invalidate_gateway(self, gateway_id: str) -> None: 

339 """Invalidate all cached tools for a gateway. 

340 

341 Args: 

342 gateway_id: Gateway ID. 

343 

344 Examples: 

345 >>> import asyncio 

346 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache 

347 >>> cache = ToolLookupCache() 

348 >>> cache._enabled = True 

349 >>> cache._l2_enabled = False 

350 >>> asyncio.run(cache.set("t1", {"tool": {"gateway_id": "g1"}}, ttl=60, gateway_id="g1")) 

351 >>> asyncio.run(cache.set("t2", {"tool": {"gateway_id": "g2"}}, ttl=60, gateway_id="g2")) 

352 >>> asyncio.run(cache.invalidate_gateway("g1")) 

353 >>> (asyncio.run(cache.get("t1")) is None, asyncio.run(cache.get("t2")) is None) 

354 (True, False) 

355 """ 

356 if not self._enabled: 

357 return 

358 

359 # L1 invalidation by gateway_id 

360 with self._lock: 

361 to_remove = [name for name, entry in self._cache.items() if entry.value.get("tool", {}).get("gateway_id") == gateway_id] 

362 for name in to_remove: 

363 self._cache.pop(name, None) 

364 

365 redis = await self._get_redis_client() 

366 if not redis: 

367 return 

368 

369 set_key = self._gateway_set_key(gateway_id) 

370 try: 

371 tool_names = await redis.smembers(set_key) 

372 if tool_names: 372 ↛ 375line 372 didn't jump to line 375 because the condition on line 372 was always true

373 keys = [self._redis_key(name.decode() if isinstance(name, bytes) else name) for name in tool_names] 

374 await redis.delete(*keys) 

375 await redis.delete(set_key) 

376 await redis.publish("mcpgw:cache:invalidate", f"tool_lookup:gateway:{gateway_id}") 

377 except Exception as exc: 

378 logger.debug("ToolLookupCache Redis invalidate_gateway failed: %s", exc) 

379 

380 def invalidate_all_local(self) -> None: 

381 """Clear all L1 cache entries.""" 

382 # Note: L2 is intentionally not cleared here; this is L1 only. 

383 with self._lock: 

384 self._cache.clear() 

385 

386 def stats(self) -> Dict[str, Any]: 

387 """Return cache hit/miss statistics and configuration. 

388 

389 Returns: 

390 Cache stats and settings. 

391 

392 Examples: 

393 >>> import asyncio 

394 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache 

395 >>> cache = ToolLookupCache() 

396 >>> cache._enabled = True 

397 >>> cache._l2_enabled = False 

398 >>> asyncio.run(cache.get("missing")) is None 

399 True 

400 >>> s = cache.stats() 

401 >>> (s["l1_miss_count"] >= 1, "l1_hit_rate" in s) 

402 (True, True) 

403 """ 

404 total_l1 = self._l1_hit_count + self._l1_miss_count 

405 total_l2 = self._l2_hit_count + self._l2_miss_count 

406 return { 

407 "enabled": self._enabled, 

408 "l1_hit_count": self._l1_hit_count, 

409 "l1_miss_count": self._l1_miss_count, 

410 "l1_hit_rate": self._l1_hit_count / total_l1 if total_l1 > 0 else 0.0, 

411 "l2_hit_count": self._l2_hit_count, 

412 "l2_miss_count": self._l2_miss_count, 

413 "l2_hit_rate": self._l2_hit_count / total_l2 if total_l2 > 0 else 0.0, 

414 "l1_size": len(self._cache), 

415 "l1_maxsize": self._l1_maxsize, 

416 "ttl_seconds": self._ttl_seconds, 

417 "negative_ttl_seconds": self._negative_ttl_seconds, 

418 "l2_enabled": self._l2_enabled, 

419 "redis_available": self._redis_available, 

420 } 

421 

422 def reset_stats(self) -> None: 

423 """Reset hit/miss counters. 

424 

425 Examples: 

426 >>> import asyncio 

427 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache 

428 >>> cache = ToolLookupCache() 

429 >>> cache._enabled = True 

430 >>> cache._l2_enabled = False 

431 >>> asyncio.run(cache.get("missing")) is None 

432 True 

433 >>> cache.reset_stats() 

434 >>> cache.stats()["l1_miss_count"] 

435 0 

436 """ 

437 self._l1_hit_count = 0 

438 self._l1_miss_count = 0 

439 self._l2_hit_count = 0 

440 self._l2_miss_count = 0 

441 

442 

443tool_lookup_cache = ToolLookupCache()