Coverage for mcpgateway / cache / tool_lookup_cache.py: 96%
168 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Tool lookup cache (tool name -> tool config) with L1 memory + L2 Redis.
4This cache targets the hot-path tool lookup in ToolService.invoke_tool by
5avoiding a DB query per tool invocation. It uses a per-worker in-memory
6cache with TTL and optional Redis backing for distributed deployments.
7"""
9# Future
10from __future__ import annotations
12# Standard
13from collections import OrderedDict
14from dataclasses import dataclass
15import logging
16import threading
17import time
18from typing import Any, Dict, Optional
20# Third-Party
21import orjson
23logger = logging.getLogger(__name__)
26@dataclass
27class CacheEntry:
28 """Cache entry with value and expiry timestamp."""
30 value: Dict[str, Any]
31 expiry: float
33 def is_expired(self) -> bool:
34 """Return True if the cache entry has expired.
36 Returns:
37 True if expired, otherwise False.
39 Examples:
40 >>> from unittest.mock import patch
41 >>> from mcpgateway.cache.tool_lookup_cache import CacheEntry
42 >>> with patch("time.time", return_value=1000):
43 ... CacheEntry(value={"k": "v"}, expiry=999).is_expired()
44 True
45 >>> with patch("time.time", return_value=1000):
46 ... CacheEntry(value={"k": "v"}, expiry=1001).is_expired()
47 False
48 """
49 return time.time() >= self.expiry
52class ToolLookupCache:
53 """Two-tier cache for tool lookups by name.
55 L1: in-memory LRU/TTL per worker.
56 L2: Redis (optional, shared across workers).
57 """
59 def __init__(self) -> None:
60 """Initialize cache settings and in-memory structures.
62 Examples:
63 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache
64 >>> cache = ToolLookupCache()
65 >>> isinstance(cache.enabled, bool)
66 True
67 """
68 try:
69 # First-Party
70 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
72 self._enabled = getattr(settings, "tool_lookup_cache_enabled", True)
73 self._ttl_seconds = getattr(settings, "tool_lookup_cache_ttl_seconds", 60)
74 self._negative_ttl_seconds = getattr(settings, "tool_lookup_cache_negative_ttl_seconds", 10)
75 self._l1_maxsize = getattr(settings, "tool_lookup_cache_l1_maxsize", 10000)
76 self._l2_enabled = getattr(settings, "tool_lookup_cache_l2_enabled", True) and settings.cache_type == "redis"
77 self._cache_prefix = getattr(settings, "cache_prefix", "mcpgw:")
78 except ImportError:
79 self._enabled = True
80 self._ttl_seconds = 60
81 self._negative_ttl_seconds = 10
82 self._l1_maxsize = 10000
83 self._l2_enabled = False
84 self._cache_prefix = "mcpgw:"
86 self._cache: "OrderedDict[str, CacheEntry]" = OrderedDict()
87 self._lock = threading.Lock()
89 self._redis_checked = False
90 self._redis_available = False
92 self._l1_hit_count = 0
93 self._l1_miss_count = 0
94 self._l2_hit_count = 0
95 self._l2_miss_count = 0
97 logger.info(
98 "ToolLookupCache initialized: enabled=%s l1_max=%s ttl=%ss l2_enabled=%s",
99 self._enabled,
100 self._l1_maxsize,
101 self._ttl_seconds,
102 self._l2_enabled,
103 )
105 @property
106 def enabled(self) -> bool:
107 """Return True if the cache is enabled.
109 Returns:
110 True if enabled, otherwise False.
112 Examples:
113 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache
114 >>> ToolLookupCache().enabled in (True, False)
115 True
116 """
117 return self._enabled
119 def _redis_key(self, name: str) -> str:
120 """Build the Redis key for a tool name.
122 Args:
123 name: Tool name.
125 Returns:
126 Redis key for the tool lookup entry.
127 """
128 return f"{self._cache_prefix}tool_lookup:{name}"
130 def _gateway_set_key(self, gateway_id: str) -> str:
131 """Build the Redis set key for tools in a gateway.
133 Args:
134 gateway_id: Gateway ID.
136 Returns:
137 Redis set key for gateway tool names.
138 """
139 return f"{self._cache_prefix}tool_lookup:gateway:{gateway_id}"
141 async def _get_redis_client(self):
142 """Return a Redis client if L2 is enabled and available.
144 Returns:
145 Redis client instance or None.
146 """
147 if not self._l2_enabled:
148 return None
149 try:
150 # First-Party
151 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
153 client = await get_redis_client()
154 if client and not self._redis_checked: 154 ↛ 157line 154 didn't jump to line 157 because the condition on line 154 was always true
155 self._redis_checked = True
156 self._redis_available = True
157 return client
158 except Exception:
159 if not self._redis_checked: 159 ↛ 162line 159 didn't jump to line 162 because the condition on line 159 was always true
160 self._redis_checked = True
161 self._redis_available = False
162 return None
164 def _get_l1(self, name: str) -> Optional[Dict[str, Any]]:
165 """Fetch a cached payload from L1 if present and not expired.
167 Args:
168 name: Tool name.
170 Returns:
171 Cached payload dict or None.
172 """
173 with self._lock:
174 entry = self._cache.get(name)
175 if entry and not entry.is_expired():
176 # LRU: move to end on hit
177 self._cache.move_to_end(name)
178 self._l1_hit_count += 1
179 return entry.value
180 if entry:
181 self._cache.pop(name, None)
182 self._l1_miss_count += 1
183 return None
185 def _set_l1(self, name: str, value: Dict[str, Any], ttl: int) -> None:
186 """Store a payload in the L1 cache with TTL.
188 Args:
189 name: Tool name.
190 value: Payload to cache.
191 ttl: Time to live in seconds.
192 """
193 with self._lock:
194 if name in self._cache:
195 self._cache.pop(name, None)
196 elif len(self._cache) >= self._l1_maxsize:
197 self._cache.popitem(last=False)
198 self._cache[name] = CacheEntry(value=value, expiry=time.time() + ttl)
200 async def get(self, name: str) -> Optional[Dict[str, Any]]:
201 """Get cached payload for a tool name, checking L1 then L2.
203 Args:
204 name: Tool name.
206 Returns:
207 Cached payload dict or None.
209 Examples:
210 >>> import asyncio
211 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache
212 >>> cache = ToolLookupCache()
213 >>> cache._enabled = True
214 >>> cache._l2_enabled = False
215 >>> asyncio.run(cache.get("missing")) is None
216 True
217 >>> asyncio.run(cache.set("t1", {"tool": {"name": "t1"}}, ttl=60))
218 >>> asyncio.run(cache.get("t1"))["tool"]["name"]
219 't1'
220 """
221 if not self._enabled:
222 return None
224 cached = self._get_l1(name)
225 if cached is not None:
226 return cached
228 redis = await self._get_redis_client()
229 if not redis:
230 return None
232 try:
233 data = await redis.get(self._redis_key(name))
234 if data:
235 self._l2_hit_count += 1
236 payload = orjson.loads(data)
237 self._set_l1(name, payload, self._ttl_seconds)
238 return payload
239 self._l2_miss_count += 1
240 except Exception as exc:
241 logger.debug("ToolLookupCache Redis get failed: %s", exc)
242 return None
244 async def set(self, name: str, payload: Dict[str, Any], ttl: Optional[int] = None, gateway_id: Optional[str] = None) -> None:
245 """Store a payload in cache and update gateway index if provided.
247 Args:
248 name: Tool name.
249 payload: Payload to cache.
250 ttl: Time to live in seconds (defaults to configured TTL).
251 gateway_id: Gateway ID for invalidation set tracking.
253 Examples:
254 >>> import asyncio
255 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache
256 >>> cache = ToolLookupCache()
257 >>> cache._enabled = True
258 >>> cache._l2_enabled = False
259 >>> asyncio.run(cache.set("t1", {"status": "ok"}, ttl=60))
260 >>> asyncio.run(cache.get("t1"))
261 {'status': 'ok'}
262 """
263 if not self._enabled:
264 return
266 effective_ttl = ttl if ttl is not None else self._ttl_seconds
267 self._set_l1(name, payload, effective_ttl)
269 redis = await self._get_redis_client()
270 if not redis:
271 return
273 try:
274 await redis.setex(self._redis_key(name), effective_ttl, orjson.dumps(payload))
275 if gateway_id: 275 ↛ exitline 275 didn't return from function 'set' because the condition on line 275 was always true
276 set_key = self._gateway_set_key(gateway_id)
277 await redis.sadd(set_key, name)
278 await redis.expire(set_key, max(effective_ttl, self._ttl_seconds))
279 except Exception as exc:
280 logger.debug("ToolLookupCache Redis set failed: %s", exc)
282 async def set_negative(self, name: str, status: str) -> None:
283 """Store a negative cache entry for a tool name.
285 Args:
286 name: Tool name.
287 status: Negative status (missing, inactive, offline).
289 Examples:
290 >>> import asyncio
291 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache
292 >>> cache = ToolLookupCache()
293 >>> cache._enabled = True
294 >>> cache._l2_enabled = False
295 >>> asyncio.run(cache.set_negative("t1", "missing"))
296 >>> asyncio.run(cache.get("t1"))
297 {'status': 'missing'}
298 """
299 payload = {"status": status}
300 await self.set(name=name, payload=payload, ttl=self._negative_ttl_seconds)
302 async def invalidate(self, name: str, gateway_id: Optional[str] = None) -> None:
303 """Invalidate a tool cache entry by name.
305 Args:
306 name: Tool name.
307 gateway_id: Gateway ID for invalidation set tracking.
309 Examples:
310 >>> import asyncio
311 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache
312 >>> cache = ToolLookupCache()
313 >>> cache._enabled = True
314 >>> cache._l2_enabled = False
315 >>> asyncio.run(cache.set("t1", {"status": "ok"}, ttl=60))
316 >>> asyncio.run(cache.invalidate("t1"))
317 >>> asyncio.run(cache.get("t1")) is None
318 True
319 """
320 if not self._enabled:
321 return
323 with self._lock:
324 self._cache.pop(name, None)
326 redis = await self._get_redis_client()
327 if not redis:
328 return
330 try:
331 await redis.delete(self._redis_key(name))
332 if gateway_id: 332 ↛ 334line 332 didn't jump to line 334 because the condition on line 332 was always true
333 await redis.srem(self._gateway_set_key(gateway_id), name)
334 await redis.publish("mcpgw:cache:invalidate", f"tool_lookup:{name}")
335 except Exception as exc:
336 logger.debug("ToolLookupCache Redis invalidate failed: %s", exc)
338 async def invalidate_gateway(self, gateway_id: str) -> None:
339 """Invalidate all cached tools for a gateway.
341 Args:
342 gateway_id: Gateway ID.
344 Examples:
345 >>> import asyncio
346 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache
347 >>> cache = ToolLookupCache()
348 >>> cache._enabled = True
349 >>> cache._l2_enabled = False
350 >>> asyncio.run(cache.set("t1", {"tool": {"gateway_id": "g1"}}, ttl=60, gateway_id="g1"))
351 >>> asyncio.run(cache.set("t2", {"tool": {"gateway_id": "g2"}}, ttl=60, gateway_id="g2"))
352 >>> asyncio.run(cache.invalidate_gateway("g1"))
353 >>> (asyncio.run(cache.get("t1")) is None, asyncio.run(cache.get("t2")) is None)
354 (True, False)
355 """
356 if not self._enabled:
357 return
359 # L1 invalidation by gateway_id
360 with self._lock:
361 to_remove = [name for name, entry in self._cache.items() if entry.value.get("tool", {}).get("gateway_id") == gateway_id]
362 for name in to_remove:
363 self._cache.pop(name, None)
365 redis = await self._get_redis_client()
366 if not redis:
367 return
369 set_key = self._gateway_set_key(gateway_id)
370 try:
371 tool_names = await redis.smembers(set_key)
372 if tool_names: 372 ↛ 375line 372 didn't jump to line 375 because the condition on line 372 was always true
373 keys = [self._redis_key(name.decode() if isinstance(name, bytes) else name) for name in tool_names]
374 await redis.delete(*keys)
375 await redis.delete(set_key)
376 await redis.publish("mcpgw:cache:invalidate", f"tool_lookup:gateway:{gateway_id}")
377 except Exception as exc:
378 logger.debug("ToolLookupCache Redis invalidate_gateway failed: %s", exc)
380 def invalidate_all_local(self) -> None:
381 """Clear all L1 cache entries."""
382 # Note: L2 is intentionally not cleared here; this is L1 only.
383 with self._lock:
384 self._cache.clear()
386 def stats(self) -> Dict[str, Any]:
387 """Return cache hit/miss statistics and configuration.
389 Returns:
390 Cache stats and settings.
392 Examples:
393 >>> import asyncio
394 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache
395 >>> cache = ToolLookupCache()
396 >>> cache._enabled = True
397 >>> cache._l2_enabled = False
398 >>> asyncio.run(cache.get("missing")) is None
399 True
400 >>> s = cache.stats()
401 >>> (s["l1_miss_count"] >= 1, "l1_hit_rate" in s)
402 (True, True)
403 """
404 total_l1 = self._l1_hit_count + self._l1_miss_count
405 total_l2 = self._l2_hit_count + self._l2_miss_count
406 return {
407 "enabled": self._enabled,
408 "l1_hit_count": self._l1_hit_count,
409 "l1_miss_count": self._l1_miss_count,
410 "l1_hit_rate": self._l1_hit_count / total_l1 if total_l1 > 0 else 0.0,
411 "l2_hit_count": self._l2_hit_count,
412 "l2_miss_count": self._l2_miss_count,
413 "l2_hit_rate": self._l2_hit_count / total_l2 if total_l2 > 0 else 0.0,
414 "l1_size": len(self._cache),
415 "l1_maxsize": self._l1_maxsize,
416 "ttl_seconds": self._ttl_seconds,
417 "negative_ttl_seconds": self._negative_ttl_seconds,
418 "l2_enabled": self._l2_enabled,
419 "redis_available": self._redis_available,
420 }
422 def reset_stats(self) -> None:
423 """Reset hit/miss counters.
425 Examples:
426 >>> import asyncio
427 >>> from mcpgateway.cache.tool_lookup_cache import ToolLookupCache
428 >>> cache = ToolLookupCache()
429 >>> cache._enabled = True
430 >>> cache._l2_enabled = False
431 >>> asyncio.run(cache.get("missing")) is None
432 True
433 >>> cache.reset_stats()
434 >>> cache.stats()["l1_miss_count"]
435 0
436 """
437 self._l1_hit_count = 0
438 self._l1_miss_count = 0
439 self._l2_hit_count = 0
440 self._l2_miss_count = 0
443tool_lookup_cache = ToolLookupCache()