Coverage for mcpgateway / cache / registry_cache.py: 100%
370 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/cache/registry_cache.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
6Registry Data Cache.
8This module implements a thread-safe cache for registry data (tools, prompts,
9resources, agents, servers, gateways) with Redis as the primary store and
10in-memory fallback. It reduces database queries for list endpoints.
12Performance Impact:
13 - Before: 1-2 DB queries per list request
14 - After: 0 DB queries (cache hit) per TTL period
15 - Expected 95%+ cache hit rate under load
17Examples:
18 >>> from mcpgateway.cache.registry_cache import registry_cache
19 >>> # Cache is used automatically by list endpoints
20 >>> # Manual invalidation after tool update:
21 >>> import asyncio
22 >>> # asyncio.run(registry_cache.invalidate_tools())
23"""
25# Standard
26import asyncio
27from dataclasses import dataclass
28import hashlib
29import logging
30import threading
31import time
32from typing import Any, Callable, Dict, Optional
34logger = logging.getLogger(__name__)
37def _get_cleanup_timeout() -> float:
38 """Get cleanup timeout from config (lazy import to avoid circular deps).
40 Returns:
41 Cleanup timeout in seconds (default: 5.0).
42 """
43 try:
44 # First-Party
45 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
47 return settings.mcp_session_pool_cleanup_timeout
48 except Exception:
49 return 5.0
52@dataclass
53class CacheEntry:
54 """Cache entry with value and expiry timestamp.
56 Examples:
57 >>> import time
58 >>> entry = CacheEntry(value=["item1", "item2"], expiry=time.time() + 60)
59 >>> entry.is_expired()
60 False
61 """
63 value: Any
64 expiry: float
66 def is_expired(self) -> bool:
67 """Check if this cache entry has expired.
69 Returns:
70 bool: True if the entry has expired, False otherwise.
71 """
72 return time.time() >= self.expiry
75@dataclass
76class RegistryCacheConfig:
77 """Configuration for registry cache TTLs.
79 Attributes:
80 enabled: Whether caching is enabled
81 tools_ttl: TTL in seconds for tools list cache
82 prompts_ttl: TTL in seconds for prompts list cache
83 resources_ttl: TTL in seconds for resources list cache
84 agents_ttl: TTL in seconds for agents list cache
85 servers_ttl: TTL in seconds for servers list cache
86 gateways_ttl: TTL in seconds for gateways list cache
87 catalog_ttl: TTL in seconds for catalog servers list cache
89 Examples:
90 >>> config = RegistryCacheConfig()
91 >>> config.tools_ttl
92 20
93 """
95 enabled: bool = True
96 tools_ttl: int = 20
97 prompts_ttl: int = 15
98 resources_ttl: int = 15
99 agents_ttl: int = 20
100 servers_ttl: int = 20
101 gateways_ttl: int = 20
102 catalog_ttl: int = 300
105class RegistryCache:
106 """Thread-safe registry cache with Redis and in-memory tiers.
108 This cache reduces database load for list endpoints by caching:
109 - Tools list
110 - Prompts list
111 - Resources list
112 - A2A Agents list
113 - Servers list
114 - Gateways list
115 - Catalog servers list
117 The cache uses Redis as the primary store for distributed deployments
118 and falls back to in-memory caching when Redis is unavailable.
120 Examples:
121 >>> cache = RegistryCache()
122 >>> cache.stats()["hit_count"]
123 0
124 """
126 def __init__(self, config: Optional[RegistryCacheConfig] = None):
127 """Initialize the registry cache.
129 Args:
130 config: Cache configuration. If None, loads from settings.
132 Examples:
133 >>> cache = RegistryCache()
134 >>> cache._enabled
135 True
136 """
137 # Import settings lazily to avoid circular imports
138 try:
139 # First-Party
140 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
142 self._enabled = getattr(settings, "registry_cache_enabled", True)
143 self._tools_ttl = getattr(settings, "registry_cache_tools_ttl", 20)
144 self._prompts_ttl = getattr(settings, "registry_cache_prompts_ttl", 15)
145 self._resources_ttl = getattr(settings, "registry_cache_resources_ttl", 15)
146 self._agents_ttl = getattr(settings, "registry_cache_agents_ttl", 20)
147 self._servers_ttl = getattr(settings, "registry_cache_servers_ttl", 20)
148 self._gateways_ttl = getattr(settings, "registry_cache_gateways_ttl", 20)
149 self._catalog_ttl = getattr(settings, "registry_cache_catalog_ttl", 300)
150 self._cache_prefix = getattr(settings, "cache_prefix", "mcpgw:")
151 except ImportError:
152 cfg = config or RegistryCacheConfig()
153 self._enabled = cfg.enabled
154 self._tools_ttl = cfg.tools_ttl
155 self._prompts_ttl = cfg.prompts_ttl
156 self._resources_ttl = cfg.resources_ttl
157 self._agents_ttl = cfg.agents_ttl
158 self._servers_ttl = cfg.servers_ttl
159 self._gateways_ttl = cfg.gateways_ttl
160 self._catalog_ttl = cfg.catalog_ttl
161 self._cache_prefix = "mcpgw:"
163 # In-memory cache (fallback when Redis unavailable)
164 self._cache: Dict[str, CacheEntry] = {}
166 # Thread safety
167 self._lock = threading.Lock()
169 # Redis availability (None = not checked yet)
170 self._redis_checked = False
171 self._redis_available = False
173 # Statistics
174 self._hit_count = 0
175 self._miss_count = 0
176 self._redis_hit_count = 0
177 self._redis_miss_count = 0
179 logger.info(
180 f"RegistryCache initialized: enabled={self._enabled}, "
181 f"tools_ttl={self._tools_ttl}s, prompts_ttl={self._prompts_ttl}s, "
182 f"resources_ttl={self._resources_ttl}s, agents_ttl={self._agents_ttl}s, "
183 f"catalog_ttl={self._catalog_ttl}s"
184 )
186 def _get_redis_key(self, cache_type: str, filters_hash: str = "") -> str:
187 """Generate Redis key with proper prefix.
189 Args:
190 cache_type: Type of cache entry (tools, prompts, etc.)
191 filters_hash: Hash of filter parameters
193 Returns:
194 Full Redis key with prefix
196 Examples:
197 >>> cache = RegistryCache()
198 >>> cache._get_redis_key("tools", "abc123")
199 'mcpgw:registry:tools:abc123'
200 """
201 if filters_hash:
202 return f"{self._cache_prefix}registry:{cache_type}:{filters_hash}"
203 return f"{self._cache_prefix}registry:{cache_type}"
205 def hash_filters(self, **kwargs) -> str:
206 """Generate a hash from filter parameters.
208 Args:
209 **kwargs: Filter parameters to hash
211 Returns:
212 MD5 hash of the filter parameters
214 Examples:
215 >>> cache = RegistryCache()
216 >>> h = cache.hash_filters(include_inactive=False, tags=["api"])
217 >>> len(h)
218 32
219 """
220 # Sort keys for consistent hashing
221 sorted_items = sorted(kwargs.items())
222 filter_str = str(sorted_items)
223 return hashlib.md5(filter_str.encode()).hexdigest() # nosec B324 # noqa: DUO130
225 async def _get_redis_client(self):
226 """Get Redis client if available.
228 Returns:
229 Redis client or None if unavailable.
230 """
231 try:
232 # First-Party
233 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
235 client = await get_redis_client()
236 if client and not self._redis_checked:
237 self._redis_checked = True
238 self._redis_available = True
239 logger.debug("RegistryCache: Redis client available")
240 return client
241 except Exception as e:
242 if not self._redis_checked:
243 self._redis_checked = True
244 self._redis_available = False
245 logger.debug(f"RegistryCache: Redis unavailable, using in-memory cache: {e}")
246 return None
248 async def get(self, cache_type: str, filters_hash: str = "") -> Optional[Any]:
249 """Get cached data.
251 Args:
252 cache_type: Type of cache (tools, prompts, resources, agents, servers, gateways)
253 filters_hash: Hash of filter parameters
255 Returns:
256 Cached data if found, None otherwise
258 Examples:
259 >>> import asyncio
260 >>> cache = RegistryCache()
261 >>> result = asyncio.run(cache.get("tools", "abc123"))
262 >>> result is None # Cache miss on fresh cache
263 True
264 """
265 if not self._enabled:
266 return None
268 cache_key = self._get_redis_key(cache_type, filters_hash)
270 # Try Redis first
271 redis = await self._get_redis_client()
272 if redis:
273 try:
274 data = await redis.get(cache_key)
275 if data:
276 # Third-Party
277 import orjson # pylint: disable=import-outside-toplevel
279 self._hit_count += 1
280 self._redis_hit_count += 1
281 return orjson.loads(data)
282 self._redis_miss_count += 1
283 except Exception as e:
284 logger.warning(f"RegistryCache Redis get failed: {e}")
286 # Fall back to in-memory cache
287 with self._lock:
288 entry = self._cache.get(cache_key)
289 if entry and not entry.is_expired():
290 self._hit_count += 1
291 return entry.value
293 self._miss_count += 1
294 return None
296 async def set(self, cache_type: str, data: Any, filters_hash: str = "", ttl: Optional[int] = None) -> None:
297 """Store data in cache.
299 Args:
300 cache_type: Type of cache (tools, prompts, resources, agents, servers, gateways)
301 data: Data to cache (must be JSON-serializable)
302 filters_hash: Hash of filter parameters
303 ttl: TTL in seconds (uses default for cache_type if not specified)
305 Examples:
306 >>> import asyncio
307 >>> cache = RegistryCache()
308 >>> asyncio.run(cache.set("tools", [{"id": "1", "name": "tool1"}], "abc123"))
309 """
310 if not self._enabled:
311 return
313 # Determine TTL
314 if ttl is None:
315 ttl_map = {
316 "tools": self._tools_ttl,
317 "prompts": self._prompts_ttl,
318 "resources": self._resources_ttl,
319 "agents": self._agents_ttl,
320 "servers": self._servers_ttl,
321 "gateways": self._gateways_ttl,
322 "catalog": self._catalog_ttl,
323 }
324 ttl = ttl_map.get(cache_type, 20)
326 cache_key = self._get_redis_key(cache_type, filters_hash)
328 # Store in Redis
329 redis = await self._get_redis_client()
330 if redis:
331 try:
332 # Third-Party
333 import orjson # pylint: disable=import-outside-toplevel
335 await redis.setex(cache_key, ttl, orjson.dumps(data))
336 except Exception as e:
337 logger.warning(f"RegistryCache Redis set failed: {e}")
339 # Store in in-memory cache
340 with self._lock:
341 self._cache[cache_key] = CacheEntry(value=data, expiry=time.time() + ttl)
343 async def invalidate(self, cache_type: str) -> None:
344 """Invalidate all cached data for a cache type.
346 Args:
347 cache_type: Type of cache to invalidate (tools, prompts, etc.)
349 Examples:
350 >>> import asyncio
351 >>> cache = RegistryCache()
352 >>> asyncio.run(cache.invalidate("tools"))
353 """
354 logger.debug(f"RegistryCache: Invalidating {cache_type} cache")
355 prefix = self._get_redis_key(cache_type)
357 # Clear in-memory cache
358 with self._lock:
359 keys_to_remove = [k for k in self._cache if k.startswith(prefix)]
360 for key in keys_to_remove:
361 self._cache.pop(key, None)
363 # Clear Redis
364 redis = await self._get_redis_client()
365 if redis:
366 try:
367 pattern = f"{prefix}*"
368 async for key in redis.scan_iter(match=pattern):
369 await redis.delete(key)
371 # Publish invalidation for other workers
372 await redis.publish("mcpgw:cache:invalidate", f"registry:{cache_type}")
373 except Exception as e:
374 logger.warning(f"RegistryCache Redis invalidate failed: {e}")
376 async def invalidate_tools(self) -> None:
377 """Invalidate tools cache.
379 Examples:
380 >>> import asyncio
381 >>> cache = RegistryCache()
382 >>> asyncio.run(cache.invalidate_tools())
383 """
384 await self.invalidate("tools")
386 async def invalidate_prompts(self) -> None:
387 """Invalidate prompts cache.
389 Examples:
390 >>> import asyncio
391 >>> cache = RegistryCache()
392 >>> asyncio.run(cache.invalidate_prompts())
393 """
394 await self.invalidate("prompts")
396 async def invalidate_resources(self) -> None:
397 """Invalidate resources cache.
399 Examples:
400 >>> import asyncio
401 >>> cache = RegistryCache()
402 >>> asyncio.run(cache.invalidate_resources())
403 """
404 await self.invalidate("resources")
406 async def invalidate_agents(self) -> None:
407 """Invalidate agents cache.
409 Examples:
410 >>> import asyncio
411 >>> cache = RegistryCache()
412 >>> asyncio.run(cache.invalidate_agents())
413 """
414 await self.invalidate("agents")
416 async def invalidate_servers(self) -> None:
417 """Invalidate servers cache.
419 Examples:
420 >>> import asyncio
421 >>> cache = RegistryCache()
422 >>> asyncio.run(cache.invalidate_servers())
423 """
424 await self.invalidate("servers")
426 async def invalidate_gateways(self) -> None:
427 """Invalidate gateways cache.
429 Examples:
430 >>> import asyncio
431 >>> cache = RegistryCache()
432 >>> asyncio.run(cache.invalidate_gateways())
433 """
434 await self.invalidate("gateways")
436 async def invalidate_catalog(self) -> None:
437 """Invalidate catalog servers cache.
439 Examples:
440 >>> import asyncio
441 >>> cache = RegistryCache()
442 >>> asyncio.run(cache.invalidate_catalog())
443 """
444 await self.invalidate("catalog")
446 def invalidate_all(self) -> None:
447 """Invalidate all cached data synchronously.
449 Examples:
450 >>> cache = RegistryCache()
451 >>> cache.invalidate_all()
452 """
453 with self._lock:
454 self._cache.clear()
455 logger.info("RegistryCache: All caches invalidated")
457 def stats(self) -> Dict[str, Any]:
458 """Get cache statistics.
460 Returns:
461 Dictionary with hit/miss counts and hit rate
463 Examples:
464 >>> cache = RegistryCache()
465 >>> stats = cache.stats()
466 >>> "hit_count" in stats
467 True
468 """
469 total = self._hit_count + self._miss_count
470 redis_total = self._redis_hit_count + self._redis_miss_count
472 return {
473 "enabled": self._enabled,
474 "hit_count": self._hit_count,
475 "miss_count": self._miss_count,
476 "hit_rate": self._hit_count / total if total > 0 else 0.0,
477 "redis_hit_count": self._redis_hit_count,
478 "redis_miss_count": self._redis_miss_count,
479 "redis_hit_rate": self._redis_hit_count / redis_total if redis_total > 0 else 0.0,
480 "redis_available": self._redis_available,
481 "cache_size": len(self._cache),
482 "ttls": {
483 "tools": self._tools_ttl,
484 "prompts": self._prompts_ttl,
485 "resources": self._resources_ttl,
486 "agents": self._agents_ttl,
487 "servers": self._servers_ttl,
488 "gateways": self._gateways_ttl,
489 "catalog": self._catalog_ttl,
490 },
491 }
493 def reset_stats(self) -> None:
494 """Reset hit/miss counters.
496 Examples:
497 >>> cache = RegistryCache()
498 >>> cache._hit_count = 100
499 >>> cache.reset_stats()
500 >>> cache._hit_count
501 0
502 """
503 self._hit_count = 0
504 self._miss_count = 0
505 self._redis_hit_count = 0
506 self._redis_miss_count = 0
509# Global singleton instance
510_registry_cache: Optional[RegistryCache] = None
513def get_registry_cache() -> RegistryCache:
514 """Get or create the singleton RegistryCache instance.
516 Returns:
517 RegistryCache: The singleton registry cache instance
519 Examples:
520 >>> cache = get_registry_cache()
521 >>> isinstance(cache, RegistryCache)
522 True
523 """
524 global _registry_cache # pylint: disable=global-statement
525 if _registry_cache is None:
526 _registry_cache = RegistryCache()
527 return _registry_cache
530# Convenience alias for direct import
531registry_cache = get_registry_cache()
534_MAX_REVOKED_JTIS = 100_000
535"""Upper bound on the in-memory revoked-JTI set.
537Prevents unbounded memory growth if a compromised Redis channel floods
538``revoke:`` messages. When the cap is reached new JTIs are still
539processed (cache eviction) but are not added to the local set;
540subsequent ``is_token_revoked()`` / ``get_auth_context()`` calls will
541fall through to the Redis check on L1 cache miss, so revocation is
542still enforced.
543"""
546class CacheInvalidationSubscriber:
547 """Redis pubsub subscriber for cross-worker cache invalidation.
549 This class subscribes to both 'mcpgw:cache:invalidate' and
550 'mcpgw:auth:invalidate' Redis channels and processes invalidation
551 messages from other workers, ensuring local in-memory caches stay
552 synchronized in multi-worker deployments.
554 Message formats handled:
555 - registry:{cache_type} - Invalidate registry cache (tools, prompts, etc.)
556 - tool_lookup:{name} - Invalidate specific tool lookup
557 - tool_lookup:gateway:{gateway_id} - Invalidate all tools for a gateway
558 - admin:{prefix} - Invalidate admin stats cache
559 - user:{email} - Invalidate auth user cache
560 - revoke:{jti} - Invalidate auth revocation cache
561 - team:{email} - Invalidate auth team cache
562 - role:{email}:{team_id} - Invalidate auth role cache
563 - team_roles:{team_id} - Invalidate all roles for a team
564 - teams:{email} - Invalidate auth teams list cache
565 - membership:{email} - Invalidate auth team membership cache
567 Examples:
568 >>> subscriber = CacheInvalidationSubscriber()
569 >>> # Start listening in background task:
570 >>> # await subscriber.start()
571 >>> # Stop when shutting down:
572 >>> # await subscriber.stop()
573 """
575 def __init__(self) -> None:
576 """Initialize the cache invalidation subscriber."""
577 self._task: Optional[asyncio.Task[None]] = None
578 self._stop_event: Optional[asyncio.Event] = None
579 self._pubsub: Optional[Any] = None
580 self._channels = ["mcpgw:cache:invalidate", "mcpgw:auth:invalidate"]
581 self._started = False
583 async def start(self) -> None:
584 """Start listening for cache invalidation messages.
586 This creates a background task that subscribes to the Redis
587 channel and processes invalidation messages.
589 Examples:
590 >>> import asyncio
591 >>> subscriber = CacheInvalidationSubscriber()
592 >>> # asyncio.run(subscriber.start())
593 """
594 if self._started:
595 logger.debug("CacheInvalidationSubscriber already started")
596 return
598 try:
599 # First-Party
600 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
602 redis = await get_redis_client()
603 if not redis:
604 logger.info("CacheInvalidationSubscriber: Redis unavailable, skipping cross-worker invalidation")
605 return
607 self._stop_event = asyncio.Event()
608 self._pubsub = redis.pubsub()
609 await self._pubsub.subscribe(*self._channels) # pyright: ignore[reportOptionalMemberAccess]
611 self._task = asyncio.create_task(self._listen_loop())
612 self._started = True
613 logger.info("CacheInvalidationSubscriber started on channels %s", self._channels)
615 except Exception as e:
616 logger.warning("CacheInvalidationSubscriber failed to start: %s", e)
617 # Clean up partially created pubsub to prevent leaks
618 # Use timeout to prevent blocking if pubsub doesn't close cleanly
619 cleanup_timeout = _get_cleanup_timeout()
620 if self._pubsub is not None:
621 try:
622 try:
623 await asyncio.wait_for(self._pubsub.aclose(), timeout=cleanup_timeout)
624 except AttributeError:
625 await asyncio.wait_for(self._pubsub.close(), timeout=cleanup_timeout)
626 except asyncio.TimeoutError:
627 logger.debug("Pubsub cleanup timed out - proceeding anyway")
628 except Exception as cleanup_err:
629 logger.debug("Error during pubsub cleanup: %s", cleanup_err)
630 self._pubsub = None
632 async def stop(self) -> None:
633 """Stop listening for cache invalidation messages.
635 This cancels the background task and cleans up resources.
637 Examples:
638 >>> import asyncio
639 >>> subscriber = CacheInvalidationSubscriber()
640 >>> # asyncio.run(subscriber.stop())
641 """
642 if not self._started:
643 return
645 self._started = False
647 if self._stop_event:
648 self._stop_event.set()
650 if self._task:
651 self._task.cancel()
652 try:
653 await asyncio.wait_for(self._task, timeout=2.0)
654 except (asyncio.CancelledError, asyncio.TimeoutError):
655 pass
656 self._task = None
658 if self._pubsub:
659 cleanup_timeout = _get_cleanup_timeout()
660 try:
661 await asyncio.wait_for(self._pubsub.unsubscribe(*self._channels), timeout=cleanup_timeout)
662 except asyncio.TimeoutError:
663 logger.debug("Pubsub unsubscribe timed out - proceeding anyway")
664 except Exception as e:
665 logger.debug("Error unsubscribing from pubsub: %s", e)
666 try:
667 try:
668 await asyncio.wait_for(self._pubsub.aclose(), timeout=cleanup_timeout)
669 except AttributeError:
670 await asyncio.wait_for(self._pubsub.close(), timeout=cleanup_timeout)
671 except asyncio.TimeoutError:
672 logger.debug("Pubsub close timed out - proceeding anyway")
673 except Exception as e:
674 logger.debug("Error closing pubsub: %s", e)
675 self._pubsub = None
677 logger.info("CacheInvalidationSubscriber stopped")
679 async def _listen_loop(self) -> None:
680 """Background loop that listens for and processes invalidation messages.
682 Raises:
683 asyncio.CancelledError: If the task is cancelled during shutdown.
684 """
685 logger.debug("CacheInvalidationSubscriber listen loop started")
686 try:
687 while self._started and not (self._stop_event and self._stop_event.is_set()):
688 if self._pubsub is None:
689 break
690 try:
691 message = await asyncio.wait_for(
692 self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0),
693 timeout=2.0,
694 )
695 if message and message.get("type") == "message":
696 data = message.get("data")
697 if isinstance(data, bytes):
698 data = data.decode("utf-8")
699 channel = message.get("channel", "")
700 if isinstance(channel, bytes):
701 channel = channel.decode("utf-8")
702 if data:
703 await self._process_invalidation(data, channel=channel)
704 except asyncio.TimeoutError:
705 continue
706 except Exception as e: # pylint: disable=broad-exception-caught
707 logger.debug("CacheInvalidationSubscriber message error: %s", e)
708 await asyncio.sleep(0.1)
709 except asyncio.CancelledError:
710 logger.debug("CacheInvalidationSubscriber listen loop cancelled")
711 raise
712 finally:
713 logger.debug("CacheInvalidationSubscriber listen loop exited")
715 _AUTH_PREFIXES = ("user:", "revoke:", "team_roles:", "teams:", "team:", "role:", "membership:")
716 """Message prefixes that belong exclusively to the auth invalidation channel."""
718 async def _process_invalidation(self, message: str, *, channel: str = "") -> None: # pylint: disable=too-many-branches
719 """Process a cache invalidation message.
721 Args:
722 message: The invalidation message in format 'type:identifier'
723 channel: The Redis pubsub channel the message arrived on.
724 Used to enforce that auth-prefixed messages are only
725 accepted from ``mcpgw:auth:invalidate``.
726 """
727 logger.debug("CacheInvalidationSubscriber received on %s: %s", channel, message)
729 # pylint: disable=protected-access
730 # pyright: ignore[reportPrivateUsage]
731 # We intentionally access protected members to clear local in-memory caches
732 # without triggering another round of Redis pubsub invalidation messages
733 try:
734 if message.startswith("registry:"):
735 # Handle registry cache invalidation (tools, prompts, resources, etc.)
736 cache_type = message[len("registry:") :]
737 cache = get_registry_cache()
738 # Only clear local in-memory cache to avoid infinite loops
739 prefix = cache._get_redis_key(cache_type) # pyright: ignore[reportPrivateUsage]
740 with cache._lock: # pyright: ignore[reportPrivateUsage]
741 keys_to_remove = [k for k in cache._cache if k.startswith(prefix)] # pyright: ignore[reportPrivateUsage]
742 for key in keys_to_remove:
743 cache._cache.pop(key, None) # pyright: ignore[reportPrivateUsage]
744 logger.debug("CacheInvalidationSubscriber: Cleared local registry:%s cache (%d keys)", cache_type, len(keys_to_remove))
746 elif message.startswith("tool_lookup:gateway:"):
747 # Handle gateway-wide tool lookup invalidation
748 gateway_id = message[len("tool_lookup:gateway:") :]
749 # First-Party
750 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
752 # Only clear local L1 cache
753 with tool_lookup_cache._lock: # pyright: ignore[reportPrivateUsage]
754 to_remove = [name for name, entry in tool_lookup_cache._cache.items() if entry.value.get("tool", {}).get("gateway_id") == gateway_id] # pyright: ignore[reportPrivateUsage]
755 for name in to_remove:
756 tool_lookup_cache._cache.pop(name, None) # pyright: ignore[reportPrivateUsage]
757 logger.debug("CacheInvalidationSubscriber: Cleared local tool_lookup for gateway %s (%d keys)", gateway_id, len(to_remove))
759 elif message.startswith("tool_lookup:"):
760 # Handle specific tool lookup invalidation
761 tool_name = message[len("tool_lookup:") :]
762 # First-Party
763 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
765 # Only clear local L1 cache
766 with tool_lookup_cache._lock: # pyright: ignore[reportPrivateUsage]
767 tool_lookup_cache._cache.pop(tool_name, None) # pyright: ignore[reportPrivateUsage]
768 logger.debug("CacheInvalidationSubscriber: Cleared local tool_lookup:%s", tool_name)
770 elif message.startswith("admin:"):
771 # Handle admin stats cache invalidation
772 prefix = message[len("admin:") :]
773 # First-Party
774 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
776 # Only clear local in-memory cache
777 full_prefix = admin_stats_cache._get_redis_key(prefix) # pyright: ignore[reportPrivateUsage]
778 with admin_stats_cache._lock: # pyright: ignore[reportPrivateUsage]
779 keys_to_remove = [k for k in admin_stats_cache._cache if k.startswith(full_prefix)] # pyright: ignore[reportPrivateUsage]
780 for key in keys_to_remove:
781 admin_stats_cache._cache.pop(key, None) # pyright: ignore[reportPrivateUsage]
782 logger.debug("CacheInvalidationSubscriber: Cleared local admin:%s cache (%d keys)", prefix, len(keys_to_remove))
784 elif message.startswith(self._AUTH_PREFIXES):
785 if channel != "mcpgw:auth:invalidate":
786 logger.warning("CacheInvalidationSubscriber: Ignoring auth message on wrong channel %s: %s", channel, message)
787 else:
788 self._process_auth_invalidation(message)
790 else:
791 logger.debug("CacheInvalidationSubscriber: Unknown message format: %s", message)
793 except Exception as e: # pylint: disable=broad-exception-caught
794 logger.warning("CacheInvalidationSubscriber: Error processing '%s': %s", message, e)
796 @staticmethod
797 def _evict_keys(cache_dict: dict, predicate: "Callable[[str], bool]") -> int:
798 """Remove all keys from *cache_dict* that satisfy *predicate*.
800 Must be called while holding the owning cache's ``_lock``.
802 Args:
803 cache_dict: The dictionary to evict keys from.
804 predicate: A callable that returns True for keys to remove.
806 Returns:
807 Number of keys removed.
808 """
809 keys = [k for k in cache_dict if predicate(k)]
810 for k in keys:
811 cache_dict.pop(k, None)
812 return len(keys)
814 def _process_auth_invalidation(self, message: str) -> None: # pylint: disable=too-many-branches
815 """Dispatch an auth-channel invalidation message to the local auth cache.
817 Called from :meth:`_process_invalidation` for messages received on
818 ``mcpgw:auth:invalidate``.
820 Args:
821 message: The invalidation message (e.g. ``user:alice@test.com``).
822 """
823 # pylint: disable=protected-access
824 # First-Party
825 from mcpgateway.cache.auth_cache import auth_cache # pylint: disable=import-outside-toplevel
827 # Dispatch auth message to the correct handler
828 if message.startswith("user:"):
829 email = message[len("user:") :]
830 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
831 self._evict_keys(auth_cache._context_cache, lambda k: k.startswith(f"{email}:")) # pyright: ignore[reportPrivateUsage]
832 auth_cache._user_cache.pop(email, None) # pyright: ignore[reportPrivateUsage]
833 self._evict_keys(auth_cache._team_cache, lambda k: k.startswith(f"{email}:")) # pyright: ignore[reportPrivateUsage]
834 logger.debug("CacheInvalidationSubscriber: Cleared local auth user cache for %s", email)
836 elif message.startswith("revoke:"):
837 jti = message[len("revoke:") :]
838 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
839 if len(auth_cache._revoked_jtis) < _MAX_REVOKED_JTIS: # pyright: ignore[reportPrivateUsage]
840 auth_cache._revoked_jtis.add(jti) # pyright: ignore[reportPrivateUsage]
841 else:
842 logger.warning("CacheInvalidationSubscriber: _revoked_jtis at cap (%d), skipping add for jti=%s", _MAX_REVOKED_JTIS, jti[:8])
843 auth_cache._revocation_cache.pop(jti, None) # pyright: ignore[reportPrivateUsage]
844 self._evict_keys(auth_cache._context_cache, lambda k: k.endswith(f":{jti}")) # pyright: ignore[reportPrivateUsage]
845 logger.debug("CacheInvalidationSubscriber: Cleared local auth revocation cache for jti=%s", jti[:8])
847 elif message.startswith("team_roles:"):
848 team_id = message[len("team_roles:") :]
849 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
850 self._evict_keys(auth_cache._role_cache, lambda k: k.endswith(f":{team_id}")) # pyright: ignore[reportPrivateUsage]
851 logger.debug("CacheInvalidationSubscriber: Cleared local auth team_roles cache for team %s", team_id)
853 elif message.startswith("teams:"):
854 email = message[len("teams:") :]
855 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
856 self._evict_keys(auth_cache._teams_list_cache, lambda k: k.startswith(f"{email}:")) # pyright: ignore[reportPrivateUsage]
857 logger.debug("CacheInvalidationSubscriber: Cleared local auth teams list cache for %s", email)
859 elif message.startswith("team:"):
860 email = message[len("team:") :]
861 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
862 auth_cache._team_cache.pop(email, None) # pyright: ignore[reportPrivateUsage]
863 self._evict_keys(auth_cache._context_cache, lambda k: k.startswith(f"{email}:")) # pyright: ignore[reportPrivateUsage]
864 logger.debug("CacheInvalidationSubscriber: Cleared local auth team cache for %s", email)
866 elif message.startswith("role:"):
867 cache_key = message[len("role:") :]
868 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
869 auth_cache._role_cache.pop(cache_key, None) # pyright: ignore[reportPrivateUsage]
870 logger.debug("CacheInvalidationSubscriber: Cleared local auth role cache for %s", cache_key)
872 elif message.startswith("membership:"):
873 user_email = message[len("membership:") :]
874 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
875 self._evict_keys(auth_cache._team_cache, lambda k: k.startswith(f"{user_email}:")) # pyright: ignore[reportPrivateUsage]
876 logger.debug("CacheInvalidationSubscriber: Cleared local auth membership cache for %s", user_email)
879# Global singleton for cache invalidation subscriber
880_cache_invalidation_subscriber: Optional[CacheInvalidationSubscriber] = None
883def get_cache_invalidation_subscriber() -> CacheInvalidationSubscriber:
884 """Get or create the singleton CacheInvalidationSubscriber instance.
886 Returns:
887 CacheInvalidationSubscriber: The singleton instance
889 Examples:
890 >>> subscriber = get_cache_invalidation_subscriber()
891 >>> isinstance(subscriber, CacheInvalidationSubscriber)
892 True
893 """
894 global _cache_invalidation_subscriber # pylint: disable=global-statement
895 if _cache_invalidation_subscriber is None:
896 _cache_invalidation_subscriber = CacheInvalidationSubscriber()
897 return _cache_invalidation_subscriber