Coverage for mcpgateway / cache / registry_cache.py: 100%
369 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/cache/registry_cache.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
6Registry Data Cache.
8This module implements a thread-safe cache for registry data (tools, prompts,
9resources, agents, servers, gateways) with Redis as the primary store and
10in-memory fallback. It reduces database queries for list endpoints.
12Performance Impact:
13 - Before: 1-2 DB queries per list request
14 - After: 0 DB queries (cache hit) per TTL period
15 - Expected 95%+ cache hit rate under load
17Examples:
18 >>> from mcpgateway.cache.registry_cache import registry_cache
19 >>> # Cache is used automatically by list endpoints
20 >>> # Manual invalidation after tool update:
21 >>> import asyncio
22 >>> # asyncio.run(registry_cache.invalidate_tools())
23"""
25# Standard
26import asyncio
27from dataclasses import dataclass
28import hashlib
29import logging
30import threading
31import time
32from typing import Any, Callable, Dict, Optional
34logger = logging.getLogger(__name__)
37def _get_cleanup_timeout() -> float:
38 """Get cleanup timeout from config (lazy import to avoid circular deps).
40 Returns:
41 Cleanup timeout in seconds (default: 5.0).
42 """
43 try:
44 # First-Party
45 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
47 return settings.mcp_session_pool_cleanup_timeout
48 except Exception:
49 return 5.0
52@dataclass
53class CacheEntry:
54 """Cache entry with value and expiry timestamp.
56 Examples:
57 >>> import time
58 >>> entry = CacheEntry(value=["item1", "item2"], expiry=time.time() + 60)
59 >>> entry.is_expired()
60 False
61 """
63 value: Any
64 expiry: float
66 def is_expired(self) -> bool:
67 """Check if this cache entry has expired.
69 Returns:
70 bool: True if the entry has expired, False otherwise.
71 """
72 return time.time() >= self.expiry
75@dataclass
76class RegistryCacheConfig:
77 """Configuration for registry cache TTLs.
79 Attributes:
80 enabled: Whether caching is enabled
81 tools_ttl: TTL in seconds for tools list cache
82 prompts_ttl: TTL in seconds for prompts list cache
83 resources_ttl: TTL in seconds for resources list cache
84 agents_ttl: TTL in seconds for agents list cache
85 servers_ttl: TTL in seconds for servers list cache
86 gateways_ttl: TTL in seconds for gateways list cache
87 catalog_ttl: TTL in seconds for catalog servers list cache
89 Examples:
90 >>> config = RegistryCacheConfig()
91 >>> config.tools_ttl
92 20
93 """
95 enabled: bool = True
96 tools_ttl: int = 20
97 prompts_ttl: int = 15
98 resources_ttl: int = 15
99 agents_ttl: int = 20
100 servers_ttl: int = 20
101 gateways_ttl: int = 20
102 catalog_ttl: int = 300
105class RegistryCache:
106 """Thread-safe registry cache with Redis and in-memory tiers.
108 This cache reduces database load for list endpoints by caching:
109 - Tools list
110 - Prompts list
111 - Resources list
112 - A2A Agents list
113 - Servers list
114 - Gateways list
115 - Catalog servers list
117 The cache uses Redis as the primary store for distributed deployments
118 and falls back to in-memory caching when Redis is unavailable.
120 Examples:
121 >>> cache = RegistryCache()
122 >>> cache.stats()["hit_count"]
123 0
124 """
126 def __init__(self, config: Optional[RegistryCacheConfig] = None):
127 """Initialize the registry cache.
129 Args:
130 config: Cache configuration. If None, loads from settings.
132 Examples:
133 >>> cache = RegistryCache()
134 >>> cache._enabled
135 True
136 """
137 # Import settings lazily to avoid circular imports
138 try:
139 # First-Party
140 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
142 self._enabled = getattr(settings, "registry_cache_enabled", True)
143 self._tools_ttl = getattr(settings, "registry_cache_tools_ttl", 20)
144 self._prompts_ttl = getattr(settings, "registry_cache_prompts_ttl", 15)
145 self._resources_ttl = getattr(settings, "registry_cache_resources_ttl", 15)
146 self._agents_ttl = getattr(settings, "registry_cache_agents_ttl", 20)
147 self._servers_ttl = getattr(settings, "registry_cache_servers_ttl", 20)
148 self._gateways_ttl = getattr(settings, "registry_cache_gateways_ttl", 20)
149 self._catalog_ttl = getattr(settings, "registry_cache_catalog_ttl", 300)
150 self._cache_prefix = getattr(settings, "cache_prefix", "mcpgw:")
151 except ImportError:
152 cfg = config or RegistryCacheConfig()
153 self._enabled = cfg.enabled
154 self._tools_ttl = cfg.tools_ttl
155 self._prompts_ttl = cfg.prompts_ttl
156 self._resources_ttl = cfg.resources_ttl
157 self._agents_ttl = cfg.agents_ttl
158 self._servers_ttl = cfg.servers_ttl
159 self._gateways_ttl = cfg.gateways_ttl
160 self._catalog_ttl = cfg.catalog_ttl
161 self._cache_prefix = "mcpgw:"
163 # In-memory cache (fallback when Redis unavailable)
164 self._cache: Dict[str, CacheEntry] = {}
166 # Thread safety
167 self._lock = threading.Lock()
169 # Redis availability (None = not checked yet)
170 self._redis_checked = False
171 self._redis_available = False
173 # Statistics
174 self._hit_count = 0
175 self._miss_count = 0
176 self._redis_hit_count = 0
177 self._redis_miss_count = 0
179 logger.info(
180 f"RegistryCache initialized: enabled={self._enabled}, "
181 f"tools_ttl={self._tools_ttl}s, prompts_ttl={self._prompts_ttl}s, "
182 f"resources_ttl={self._resources_ttl}s, agents_ttl={self._agents_ttl}s, "
183 f"catalog_ttl={self._catalog_ttl}s"
184 )
186 def _get_redis_key(self, cache_type: str, filters_hash: str = "") -> str:
187 """Generate Redis key with proper prefix.
189 Args:
190 cache_type: Type of cache entry (tools, prompts, etc.)
191 filters_hash: Hash of filter parameters
193 Returns:
194 Full Redis key with prefix
196 Examples:
197 >>> cache = RegistryCache()
198 >>> cache._get_redis_key("tools", "abc123")
199 'mcpgw:registry:tools:abc123'
200 """
201 if filters_hash:
202 return f"{self._cache_prefix}registry:{cache_type}:{filters_hash}"
203 return f"{self._cache_prefix}registry:{cache_type}"
205 def hash_filters(self, **kwargs) -> str:
206 """Generate a hash from filter parameters.
208 Args:
209 **kwargs: Filter parameters to hash
211 Returns:
212 MD5 hash of the filter parameters
214 Examples:
215 >>> cache = RegistryCache()
216 >>> h = cache.hash_filters(include_inactive=False, tags=["api"])
217 >>> len(h)
218 32
219 """
220 # Sort keys for consistent hashing
221 sorted_items = sorted(kwargs.items())
222 filter_str = str(sorted_items)
223 return hashlib.md5(filter_str.encode()).hexdigest() # nosec B324 # noqa: DUO130
225 async def _get_redis_client(self):
226 """Get Redis client if available.
228 Returns:
229 Redis client or None if unavailable.
230 """
231 try:
232 # First-Party
233 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
235 client = await get_redis_client()
236 if client and not self._redis_checked:
237 self._redis_checked = True
238 self._redis_available = True
239 logger.debug("RegistryCache: Redis client available")
240 return client
241 except Exception as e:
242 if not self._redis_checked:
243 self._redis_checked = True
244 self._redis_available = False
245 logger.debug(f"RegistryCache: Redis unavailable, using in-memory cache: {e}")
246 return None
248 async def get(self, cache_type: str, filters_hash: str = "") -> Optional[Any]:
249 """Get cached data.
251 Args:
252 cache_type: Type of cache (tools, prompts, resources, agents, servers, gateways)
253 filters_hash: Hash of filter parameters
255 Returns:
256 Cached data if found, None otherwise
258 Examples:
259 >>> import asyncio
260 >>> cache = RegistryCache()
261 >>> result = asyncio.run(cache.get("tools", "abc123"))
262 >>> result is None # Cache miss on fresh cache
263 True
264 """
265 if not self._enabled:
266 return None
268 cache_key = self._get_redis_key(cache_type, filters_hash)
270 # Try Redis first
271 redis = await self._get_redis_client()
272 if redis:
273 try:
274 data = await redis.get(cache_key)
275 if data:
276 # Third-Party
277 import orjson # pylint: disable=import-outside-toplevel
279 self._hit_count += 1
280 self._redis_hit_count += 1
281 return orjson.loads(data)
282 self._redis_miss_count += 1
283 except Exception as e:
284 logger.warning(f"RegistryCache Redis get failed: {e}")
286 # Fall back to in-memory cache
287 with self._lock:
288 entry = self._cache.get(cache_key)
289 if entry and not entry.is_expired():
290 self._hit_count += 1
291 return entry.value
293 self._miss_count += 1
294 return None
296 async def set(self, cache_type: str, data: Any, filters_hash: str = "", ttl: Optional[int] = None) -> None:
297 """Store data in cache.
299 Args:
300 cache_type: Type of cache (tools, prompts, resources, agents, servers, gateways)
301 data: Data to cache (must be JSON-serializable)
302 filters_hash: Hash of filter parameters
303 ttl: TTL in seconds (uses default for cache_type if not specified)
305 Examples:
306 >>> import asyncio
307 >>> cache = RegistryCache()
308 >>> asyncio.run(cache.set("tools", [{"id": "1", "name": "tool1"}], "abc123"))
309 """
310 if not self._enabled:
311 return
313 # Determine TTL
314 if ttl is None:
315 ttl_map = {
316 "tools": self._tools_ttl,
317 "prompts": self._prompts_ttl,
318 "resources": self._resources_ttl,
319 "agents": self._agents_ttl,
320 "servers": self._servers_ttl,
321 "gateways": self._gateways_ttl,
322 "catalog": self._catalog_ttl,
323 }
324 ttl = ttl_map.get(cache_type, 20)
326 cache_key = self._get_redis_key(cache_type, filters_hash)
328 # Store in Redis
329 redis = await self._get_redis_client()
330 if redis:
331 try:
332 # Third-Party
333 import orjson # pylint: disable=import-outside-toplevel
335 await redis.setex(cache_key, ttl, orjson.dumps(data))
336 except Exception as e:
337 logger.warning(f"RegistryCache Redis set failed: {e}")
339 # Store in in-memory cache
340 with self._lock:
341 self._cache[cache_key] = CacheEntry(value=data, expiry=time.time() + ttl)
343 async def invalidate(self, cache_type: str) -> None:
344 """Invalidate all cached data for a cache type.
346 Args:
347 cache_type: Type of cache to invalidate (tools, prompts, etc.)
349 Examples:
350 >>> import asyncio
351 >>> cache = RegistryCache()
352 >>> asyncio.run(cache.invalidate("tools"))
353 """
354 logger.debug(f"RegistryCache: Invalidating {cache_type} cache")
355 prefix = self._get_redis_key(cache_type)
357 # Clear in-memory cache
358 with self._lock:
359 keys_to_remove = [k for k in self._cache if k.startswith(prefix)]
360 for key in keys_to_remove:
361 self._cache.pop(key, None)
363 # Clear Redis
364 redis = await self._get_redis_client()
365 if redis:
366 try:
367 pattern = f"{prefix}*"
368 async for key in redis.scan_iter(match=pattern):
369 await redis.delete(key)
371 # Publish invalidation for other workers
372 await redis.publish("mcpgw:cache:invalidate", f"registry:{cache_type}")
373 except Exception as e:
374 logger.warning(f"RegistryCache Redis invalidate failed: {e}")
376 async def invalidate_tools(self) -> None:
377 """Invalidate tools cache.
379 Examples:
380 >>> import asyncio
381 >>> cache = RegistryCache()
382 >>> asyncio.run(cache.invalidate_tools())
383 """
384 await self.invalidate("tools")
386 async def invalidate_prompts(self) -> None:
387 """Invalidate prompts cache.
389 Examples:
390 >>> import asyncio
391 >>> cache = RegistryCache()
392 >>> asyncio.run(cache.invalidate_prompts())
393 """
394 await self.invalidate("prompts")
396 async def invalidate_resources(self) -> None:
397 """Invalidate resources cache.
399 Examples:
400 >>> import asyncio
401 >>> cache = RegistryCache()
402 >>> asyncio.run(cache.invalidate_resources())
403 """
404 await self.invalidate("resources")
406 async def invalidate_agents(self) -> None:
407 """Invalidate agents cache.
409 Examples:
410 >>> import asyncio
411 >>> cache = RegistryCache()
412 >>> asyncio.run(cache.invalidate_agents())
413 """
414 await self.invalidate("agents")
416 async def invalidate_servers(self) -> None:
417 """Invalidate servers cache.
419 Examples:
420 >>> import asyncio
421 >>> cache = RegistryCache()
422 >>> asyncio.run(cache.invalidate_servers())
423 """
424 await self.invalidate("servers")
426 async def invalidate_gateways(self) -> None:
427 """Invalidate gateways cache.
429 Examples:
430 >>> import asyncio
431 >>> cache = RegistryCache()
432 >>> asyncio.run(cache.invalidate_gateways())
433 """
434 await self.invalidate("gateways")
436 async def invalidate_catalog(self) -> None:
437 """Invalidate catalog servers cache.
439 Examples:
440 >>> import asyncio
441 >>> cache = RegistryCache()
442 >>> asyncio.run(cache.invalidate_catalog())
443 """
444 await self.invalidate("catalog")
446 def invalidate_all(self) -> None:
447 """Invalidate all cached data synchronously.
449 Examples:
450 >>> cache = RegistryCache()
451 >>> cache.invalidate_all()
452 """
453 with self._lock:
454 self._cache.clear()
455 logger.info("RegistryCache: All caches invalidated")
457 def stats(self) -> Dict[str, Any]:
458 """Get cache statistics.
460 Returns:
461 Dictionary with hit/miss counts and hit rate
463 Examples:
464 >>> cache = RegistryCache()
465 >>> stats = cache.stats()
466 >>> "hit_count" in stats
467 True
468 """
469 total = self._hit_count + self._miss_count
470 redis_total = self._redis_hit_count + self._redis_miss_count
472 return {
473 "enabled": self._enabled,
474 "hit_count": self._hit_count,
475 "miss_count": self._miss_count,
476 "hit_rate": self._hit_count / total if total > 0 else 0.0,
477 "redis_hit_count": self._redis_hit_count,
478 "redis_miss_count": self._redis_miss_count,
479 "redis_hit_rate": self._redis_hit_count / redis_total if redis_total > 0 else 0.0,
480 "redis_available": self._redis_available,
481 "cache_size": len(self._cache),
482 "ttls": {
483 "tools": self._tools_ttl,
484 "prompts": self._prompts_ttl,
485 "resources": self._resources_ttl,
486 "agents": self._agents_ttl,
487 "servers": self._servers_ttl,
488 "gateways": self._gateways_ttl,
489 "catalog": self._catalog_ttl,
490 },
491 }
493 def reset_stats(self) -> None:
494 """Reset hit/miss counters.
496 Examples:
497 >>> cache = RegistryCache()
498 >>> cache._hit_count = 100
499 >>> cache.reset_stats()
500 >>> cache._hit_count
501 0
502 """
503 self._hit_count = 0
504 self._miss_count = 0
505 self._redis_hit_count = 0
506 self._redis_miss_count = 0
509# Global singleton instance
510_registry_cache: Optional[RegistryCache] = None
513def get_registry_cache() -> RegistryCache:
514 """Get or create the singleton RegistryCache instance.
516 Returns:
517 RegistryCache: The singleton registry cache instance
519 Examples:
520 >>> cache = get_registry_cache()
521 >>> isinstance(cache, RegistryCache)
522 True
523 """
524 global _registry_cache # pylint: disable=global-statement
525 if _registry_cache is None:
526 _registry_cache = RegistryCache()
527 return _registry_cache
530# Convenience alias for direct import
531registry_cache = get_registry_cache()
534# Upper bound on the in-memory revoked-JTI set.
535#
536# Prevents unbounded memory growth if a compromised Redis channel floods
537# ``revoke:`` messages. When the cap is reached new JTIs are still
538# processed (cache eviction) but are not added to the local set;
539# subsequent ``is_token_revoked()`` / ``get_auth_context()`` calls will
540# fall through to the Redis check on L1 cache miss, so revocation is
541# still enforced.
542_MAX_REVOKED_JTIS = 100_000
545class CacheInvalidationSubscriber:
546 """Redis pubsub subscriber for cross-worker cache invalidation.
548 This class subscribes to both 'mcpgw:cache:invalidate' and
549 'mcpgw:auth:invalidate' Redis channels and processes invalidation
550 messages from other workers, ensuring local in-memory caches stay
551 synchronized in multi-worker deployments.
553 Message formats handled:
554 - registry:{cache_type} - Invalidate registry cache (tools, prompts, etc.)
555 - tool_lookup:{name} - Invalidate specific tool lookup
556 - tool_lookup:gateway:{gateway_id} - Invalidate all tools for a gateway
557 - admin:{prefix} - Invalidate admin stats cache
558 - user:{email} - Invalidate auth user cache
559 - revoke:{jti} - Invalidate auth revocation cache
560 - team:{email} - Invalidate auth team cache
561 - role:{email}:{team_id} - Invalidate auth role cache
562 - team_roles:{team_id} - Invalidate all roles for a team
563 - teams:{email} - Invalidate auth teams list cache
564 - membership:{email} - Invalidate auth team membership cache
566 Examples:
567 >>> subscriber = CacheInvalidationSubscriber()
568 >>> # Start listening in background task:
569 >>> # await subscriber.start()
570 >>> # Stop when shutting down:
571 >>> # await subscriber.stop()
572 """
574 def __init__(self) -> None:
575 """Initialize the cache invalidation subscriber."""
576 self._task: Optional[asyncio.Task[None]] = None
577 self._stop_event: Optional[asyncio.Event] = None
578 self._pubsub: Optional[Any] = None
579 self._channels = ["mcpgw:cache:invalidate", "mcpgw:auth:invalidate"]
580 self._started = False
582 async def start(self) -> None:
583 """Start listening for cache invalidation messages.
585 This creates a background task that subscribes to the Redis
586 channel and processes invalidation messages.
588 Examples:
589 >>> import asyncio
590 >>> subscriber = CacheInvalidationSubscriber()
591 >>> # asyncio.run(subscriber.start())
592 """
593 if self._started:
594 logger.debug("CacheInvalidationSubscriber already started")
595 return
597 try:
598 # First-Party
599 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
601 redis = await get_redis_client()
602 if not redis:
603 logger.info("CacheInvalidationSubscriber: Redis unavailable, skipping cross-worker invalidation")
604 return
606 self._stop_event = asyncio.Event()
607 self._pubsub = redis.pubsub()
608 await self._pubsub.subscribe(*self._channels) # pyright: ignore[reportOptionalMemberAccess]
610 self._task = asyncio.create_task(self._listen_loop())
611 self._started = True
612 logger.info("CacheInvalidationSubscriber started on channels %s", self._channels)
614 except Exception as e:
615 logger.warning("CacheInvalidationSubscriber failed to start: %s", e)
616 # Clean up partially created pubsub to prevent leaks
617 # Use timeout to prevent blocking if pubsub doesn't close cleanly
618 cleanup_timeout = _get_cleanup_timeout()
619 if self._pubsub is not None:
620 try:
621 try:
622 await asyncio.wait_for(self._pubsub.aclose(), timeout=cleanup_timeout)
623 except AttributeError:
624 await asyncio.wait_for(self._pubsub.close(), timeout=cleanup_timeout)
625 except asyncio.TimeoutError:
626 logger.debug("Pubsub cleanup timed out - proceeding anyway")
627 except Exception as cleanup_err:
628 logger.debug("Error during pubsub cleanup: %s", cleanup_err)
629 self._pubsub = None
631 async def stop(self) -> None:
632 """Stop listening for cache invalidation messages.
634 This cancels the background task and cleans up resources.
636 Examples:
637 >>> import asyncio
638 >>> subscriber = CacheInvalidationSubscriber()
639 >>> # asyncio.run(subscriber.stop())
640 """
641 if not self._started:
642 return
644 self._started = False
646 if self._stop_event:
647 self._stop_event.set()
649 if self._task:
650 self._task.cancel()
651 try:
652 await asyncio.wait_for(self._task, timeout=2.0)
653 except (asyncio.CancelledError, asyncio.TimeoutError):
654 pass
655 self._task = None
657 if self._pubsub:
658 cleanup_timeout = _get_cleanup_timeout()
659 try:
660 await asyncio.wait_for(self._pubsub.unsubscribe(*self._channels), timeout=cleanup_timeout)
661 except asyncio.TimeoutError:
662 logger.debug("Pubsub unsubscribe timed out - proceeding anyway")
663 except Exception as e:
664 logger.debug("Error unsubscribing from pubsub: %s", e)
665 try:
666 try:
667 await asyncio.wait_for(self._pubsub.aclose(), timeout=cleanup_timeout)
668 except AttributeError:
669 await asyncio.wait_for(self._pubsub.close(), timeout=cleanup_timeout)
670 except asyncio.TimeoutError:
671 logger.debug("Pubsub close timed out - proceeding anyway")
672 except Exception as e:
673 logger.debug("Error closing pubsub: %s", e)
674 self._pubsub = None
676 logger.info("CacheInvalidationSubscriber stopped")
678 async def _listen_loop(self) -> None:
679 """Background loop that listens for and processes invalidation messages.
681 Raises:
682 asyncio.CancelledError: If the task is cancelled during shutdown.
683 """
684 logger.debug("CacheInvalidationSubscriber listen loop started")
685 try:
686 while self._started and not (self._stop_event and self._stop_event.is_set()):
687 if self._pubsub is None:
688 break
689 try:
690 message = await asyncio.wait_for(
691 self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0),
692 timeout=2.0,
693 )
694 if message and message.get("type") == "message":
695 data = message.get("data")
696 if isinstance(data, bytes):
697 data = data.decode("utf-8")
698 channel = message.get("channel", "")
699 if isinstance(channel, bytes):
700 channel = channel.decode("utf-8")
701 if data:
702 await self._process_invalidation(data, channel=channel)
703 except asyncio.TimeoutError:
704 continue
705 except Exception as e: # pylint: disable=broad-exception-caught
706 logger.debug("CacheInvalidationSubscriber message error: %s", e)
707 await asyncio.sleep(0.1)
708 except asyncio.CancelledError:
709 logger.debug("CacheInvalidationSubscriber listen loop cancelled")
710 raise
711 finally:
712 logger.debug("CacheInvalidationSubscriber listen loop exited")
714 _AUTH_PREFIXES = ("user:", "revoke:", "team_roles:", "teams:", "team:", "role:", "membership:")
715 """Message prefixes that belong exclusively to the auth invalidation channel."""
717 async def _process_invalidation(self, message: str, *, channel: str = "") -> None: # pylint: disable=too-many-branches
718 """Process a cache invalidation message.
720 Args:
721 message: The invalidation message in format 'type:identifier'
722 channel: The Redis pubsub channel the message arrived on.
723 Used to enforce that auth-prefixed messages are only
724 accepted from ``mcpgw:auth:invalidate``.
725 """
726 logger.debug("CacheInvalidationSubscriber received on %s: %s", channel, message)
728 # pylint: disable=protected-access
729 # pyright: ignore[reportPrivateUsage]
730 # We intentionally access protected members to clear local in-memory caches
731 # without triggering another round of Redis pubsub invalidation messages
732 try:
733 if message.startswith("registry:"):
734 # Handle registry cache invalidation (tools, prompts, resources, etc.)
735 cache_type = message[len("registry:") :]
736 cache = get_registry_cache()
737 # Only clear local in-memory cache to avoid infinite loops
738 prefix = cache._get_redis_key(cache_type) # pyright: ignore[reportPrivateUsage]
739 with cache._lock: # pyright: ignore[reportPrivateUsage]
740 keys_to_remove = [k for k in cache._cache if k.startswith(prefix)] # pyright: ignore[reportPrivateUsage]
741 for key in keys_to_remove:
742 cache._cache.pop(key, None) # pyright: ignore[reportPrivateUsage]
743 logger.debug("CacheInvalidationSubscriber: Cleared local registry:%s cache (%d keys)", cache_type, len(keys_to_remove))
745 elif message.startswith("tool_lookup:gateway:"):
746 # Handle gateway-wide tool lookup invalidation
747 gateway_id = message[len("tool_lookup:gateway:") :]
748 # First-Party
749 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
751 # Only clear local L1 cache
752 with tool_lookup_cache._lock: # pyright: ignore[reportPrivateUsage]
753 to_remove = [name for name, entry in tool_lookup_cache._cache.items() if entry.value.get("tool", {}).get("gateway_id") == gateway_id] # pyright: ignore[reportPrivateUsage]
754 for name in to_remove:
755 tool_lookup_cache._cache.pop(name, None) # pyright: ignore[reportPrivateUsage]
756 logger.debug("CacheInvalidationSubscriber: Cleared local tool_lookup for gateway %s (%d keys)", gateway_id, len(to_remove))
758 elif message.startswith("tool_lookup:"):
759 # Handle specific tool lookup invalidation
760 tool_name = message[len("tool_lookup:") :]
761 # First-Party
762 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
764 # Only clear local L1 cache
765 with tool_lookup_cache._lock: # pyright: ignore[reportPrivateUsage]
766 tool_lookup_cache._cache.pop(tool_name, None) # pyright: ignore[reportPrivateUsage]
767 logger.debug("CacheInvalidationSubscriber: Cleared local tool_lookup:%s", tool_name)
769 elif message.startswith("admin:"):
770 # Handle admin stats cache invalidation
771 prefix = message[len("admin:") :]
772 # First-Party
773 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
775 # Only clear local in-memory cache
776 full_prefix = admin_stats_cache._get_redis_key(prefix) # pyright: ignore[reportPrivateUsage]
777 with admin_stats_cache._lock: # pyright: ignore[reportPrivateUsage]
778 keys_to_remove = [k for k in admin_stats_cache._cache if k.startswith(full_prefix)] # pyright: ignore[reportPrivateUsage]
779 for key in keys_to_remove:
780 admin_stats_cache._cache.pop(key, None) # pyright: ignore[reportPrivateUsage]
781 logger.debug("CacheInvalidationSubscriber: Cleared local admin:%s cache (%d keys)", prefix, len(keys_to_remove))
783 elif message.startswith(self._AUTH_PREFIXES):
784 if channel != "mcpgw:auth:invalidate":
785 logger.warning("CacheInvalidationSubscriber: Ignoring auth message on wrong channel %s: %s", channel, message)
786 else:
787 self._process_auth_invalidation(message)
789 else:
790 logger.debug("CacheInvalidationSubscriber: Unknown message format: %s", message)
792 except Exception as e: # pylint: disable=broad-exception-caught
793 logger.warning("CacheInvalidationSubscriber: Error processing '%s': %s", message, e)
795 @staticmethod
796 def _evict_keys(cache_dict: dict, predicate: "Callable[[str], bool]") -> int:
797 """Remove all keys from *cache_dict* that satisfy *predicate*.
799 Must be called while holding the owning cache's ``_lock``.
801 Args:
802 cache_dict: The dictionary to evict keys from.
803 predicate: A callable that returns True for keys to remove.
805 Returns:
806 Number of keys removed.
807 """
808 keys = [k for k in cache_dict if predicate(k)]
809 for k in keys:
810 cache_dict.pop(k, None)
811 return len(keys)
813 def _process_auth_invalidation(self, message: str) -> None: # pylint: disable=too-many-branches
814 """Dispatch an auth-channel invalidation message to the local auth cache.
816 Called from :meth:`_process_invalidation` for messages received on
817 ``mcpgw:auth:invalidate``.
819 Args:
820 message: The invalidation message (e.g. ``user:alice@test.com``).
821 """
822 # pylint: disable=protected-access
823 # First-Party
824 from mcpgateway.cache.auth_cache import auth_cache # pylint: disable=import-outside-toplevel
826 # Dispatch auth message to the correct handler
827 if message.startswith("user:"):
828 email = message[len("user:") :]
829 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
830 self._evict_keys(auth_cache._context_cache, lambda k: k.startswith(f"{email}:")) # pyright: ignore[reportPrivateUsage]
831 auth_cache._user_cache.pop(email, None) # pyright: ignore[reportPrivateUsage]
832 self._evict_keys(auth_cache._team_cache, lambda k: k.startswith(f"{email}:")) # pyright: ignore[reportPrivateUsage]
833 logger.debug("CacheInvalidationSubscriber: Cleared local auth user cache for %s", email)
835 elif message.startswith("revoke:"):
836 jti = message[len("revoke:") :]
837 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
838 if len(auth_cache._revoked_jtis) < _MAX_REVOKED_JTIS: # pyright: ignore[reportPrivateUsage]
839 auth_cache._revoked_jtis.add(jti) # pyright: ignore[reportPrivateUsage]
840 else:
841 logger.warning("CacheInvalidationSubscriber: _revoked_jtis at cap (%d), skipping add for jti=%s", _MAX_REVOKED_JTIS, jti[:8])
842 auth_cache._revocation_cache.pop(jti, None) # pyright: ignore[reportPrivateUsage]
843 self._evict_keys(auth_cache._context_cache, lambda k: k.endswith(f":{jti}")) # pyright: ignore[reportPrivateUsage]
844 logger.debug("CacheInvalidationSubscriber: Cleared local auth revocation cache for jti=%s", jti[:8])
846 elif message.startswith("team_roles:"):
847 team_id = message[len("team_roles:") :]
848 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
849 self._evict_keys(auth_cache._role_cache, lambda k: k.endswith(f":{team_id}")) # pyright: ignore[reportPrivateUsage]
850 logger.debug("CacheInvalidationSubscriber: Cleared local auth team_roles cache for team %s", team_id)
852 elif message.startswith("teams:"):
853 email = message[len("teams:") :]
854 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
855 self._evict_keys(auth_cache._teams_list_cache, lambda k: k.startswith(f"{email}:")) # pyright: ignore[reportPrivateUsage]
856 logger.debug("CacheInvalidationSubscriber: Cleared local auth teams list cache for %s", email)
858 elif message.startswith("team:"):
859 email = message[len("team:") :]
860 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
861 auth_cache._team_cache.pop(email, None) # pyright: ignore[reportPrivateUsage]
862 self._evict_keys(auth_cache._context_cache, lambda k: k.startswith(f"{email}:")) # pyright: ignore[reportPrivateUsage]
863 logger.debug("CacheInvalidationSubscriber: Cleared local auth team cache for %s", email)
865 elif message.startswith("role:"):
866 cache_key = message[len("role:") :]
867 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
868 auth_cache._role_cache.pop(cache_key, None) # pyright: ignore[reportPrivateUsage]
869 logger.debug("CacheInvalidationSubscriber: Cleared local auth role cache for %s", cache_key)
871 elif message.startswith("membership:"):
872 user_email = message[len("membership:") :]
873 with auth_cache._lock: # pyright: ignore[reportPrivateUsage]
874 self._evict_keys(auth_cache._team_cache, lambda k: k.startswith(f"{user_email}:")) # pyright: ignore[reportPrivateUsage]
875 logger.debug("CacheInvalidationSubscriber: Cleared local auth membership cache for %s", user_email)
878# Global singleton for cache invalidation subscriber
879_cache_invalidation_subscriber: Optional[CacheInvalidationSubscriber] = None
882def get_cache_invalidation_subscriber() -> CacheInvalidationSubscriber:
883 """Get or create the singleton CacheInvalidationSubscriber instance.
885 Returns:
886 CacheInvalidationSubscriber: The singleton instance
888 Examples:
889 >>> subscriber = get_cache_invalidation_subscriber()
890 >>> isinstance(subscriber, CacheInvalidationSubscriber)
891 True
892 """
893 global _cache_invalidation_subscriber # pylint: disable=global-statement
894 if _cache_invalidation_subscriber is None:
895 _cache_invalidation_subscriber = CacheInvalidationSubscriber()
896 return _cache_invalidation_subscriber