Coverage for mcpgateway / cache / registry_cache.py: 100%

369 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/cache/registry_cache.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5 

6Registry Data Cache. 

7 

8This module implements a thread-safe cache for registry data (tools, prompts, 

9resources, agents, servers, gateways) with Redis as the primary store and 

10in-memory fallback. It reduces database queries for list endpoints. 

11 

12Performance Impact: 

13 - Before: 1-2 DB queries per list request 

14 - After: 0 DB queries (cache hit) per TTL period 

15 - Expected 95%+ cache hit rate under load 

16 

17Examples: 

18 >>> from mcpgateway.cache.registry_cache import registry_cache 

19 >>> # Cache is used automatically by list endpoints 

20 >>> # Manual invalidation after tool update: 

21 >>> import asyncio 

22 >>> # asyncio.run(registry_cache.invalidate_tools()) 

23""" 

24 

25# Standard 

26import asyncio 

27from dataclasses import dataclass 

28import hashlib 

29import logging 

30import threading 

31import time 

32from typing import Any, Callable, Dict, Optional 

33 

34logger = logging.getLogger(__name__) 

35 

36 

37def _get_cleanup_timeout() -> float: 

38 """Get cleanup timeout from config (lazy import to avoid circular deps). 

39 

40 Returns: 

41 Cleanup timeout in seconds (default: 5.0). 

42 """ 

43 try: 

44 # First-Party 

45 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel 

46 

47 return settings.mcp_session_pool_cleanup_timeout 

48 except Exception: 

49 return 5.0 

50 

51 

52@dataclass 

53class CacheEntry: 

54 """Cache entry with value and expiry timestamp. 

55 

56 Examples: 

57 >>> import time 

58 >>> entry = CacheEntry(value=["item1", "item2"], expiry=time.time() + 60) 

59 >>> entry.is_expired() 

60 False 

61 """ 

62 

63 value: Any 

64 expiry: float 

65 

66 def is_expired(self) -> bool: 

67 """Check if this cache entry has expired. 

68 

69 Returns: 

70 bool: True if the entry has expired, False otherwise. 

71 """ 

72 return time.time() >= self.expiry 

73 

74 

75@dataclass 

76class RegistryCacheConfig: 

77 """Configuration for registry cache TTLs. 

78 

79 Attributes: 

80 enabled: Whether caching is enabled 

81 tools_ttl: TTL in seconds for tools list cache 

82 prompts_ttl: TTL in seconds for prompts list cache 

83 resources_ttl: TTL in seconds for resources list cache 

84 agents_ttl: TTL in seconds for agents list cache 

85 servers_ttl: TTL in seconds for servers list cache 

86 gateways_ttl: TTL in seconds for gateways list cache 

87 catalog_ttl: TTL in seconds for catalog servers list cache 

88 

89 Examples: 

90 >>> config = RegistryCacheConfig() 

91 >>> config.tools_ttl 

92 20 

93 """ 

94 

95 enabled: bool = True 

96 tools_ttl: int = 20 

97 prompts_ttl: int = 15 

98 resources_ttl: int = 15 

99 agents_ttl: int = 20 

100 servers_ttl: int = 20 

101 gateways_ttl: int = 20 

102 catalog_ttl: int = 300 

103 

104 

105class RegistryCache: 

106 """Thread-safe registry cache with Redis and in-memory tiers. 

107 

108 This cache reduces database load for list endpoints by caching: 

109 - Tools list 

110 - Prompts list 

111 - Resources list 

112 - A2A Agents list 

113 - Servers list 

114 - Gateways list 

115 - Catalog servers list 

116 

117 The cache uses Redis as the primary store for distributed deployments 

118 and falls back to in-memory caching when Redis is unavailable. 

119 

120 Examples: 

121 >>> cache = RegistryCache() 

122 >>> cache.stats()["hit_count"] 

123 0 

124 """ 

125 

126 def __init__(self, config: Optional[RegistryCacheConfig] = None): 

127 """Initialize the registry cache. 

128 

129 Args: 

130 config: Cache configuration. If None, loads from settings. 

131 

132 Examples: 

133 >>> cache = RegistryCache() 

134 >>> cache._enabled 

135 True 

136 """ 

137 # Import settings lazily to avoid circular imports 

138 try: 

139 # First-Party 

140 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel 

141 

142 self._enabled = getattr(settings, "registry_cache_enabled", True) 

143 self._tools_ttl = getattr(settings, "registry_cache_tools_ttl", 20) 

144 self._prompts_ttl = getattr(settings, "registry_cache_prompts_ttl", 15) 

145 self._resources_ttl = getattr(settings, "registry_cache_resources_ttl", 15) 

146 self._agents_ttl = getattr(settings, "registry_cache_agents_ttl", 20) 

147 self._servers_ttl = getattr(settings, "registry_cache_servers_ttl", 20) 

148 self._gateways_ttl = getattr(settings, "registry_cache_gateways_ttl", 20) 

149 self._catalog_ttl = getattr(settings, "registry_cache_catalog_ttl", 300) 

150 self._cache_prefix = getattr(settings, "cache_prefix", "mcpgw:") 

151 except ImportError: 

152 cfg = config or RegistryCacheConfig() 

153 self._enabled = cfg.enabled 

154 self._tools_ttl = cfg.tools_ttl 

155 self._prompts_ttl = cfg.prompts_ttl 

156 self._resources_ttl = cfg.resources_ttl 

157 self._agents_ttl = cfg.agents_ttl 

158 self._servers_ttl = cfg.servers_ttl 

159 self._gateways_ttl = cfg.gateways_ttl 

160 self._catalog_ttl = cfg.catalog_ttl 

161 self._cache_prefix = "mcpgw:" 

162 

163 # In-memory cache (fallback when Redis unavailable) 

164 self._cache: Dict[str, CacheEntry] = {} 

165 

166 # Thread safety 

167 self._lock = threading.Lock() 

168 

169 # Redis availability (None = not checked yet) 

170 self._redis_checked = False 

171 self._redis_available = False 

172 

173 # Statistics 

174 self._hit_count = 0 

175 self._miss_count = 0 

176 self._redis_hit_count = 0 

177 self._redis_miss_count = 0 

178 

179 logger.info( 

180 f"RegistryCache initialized: enabled={self._enabled}, " 

181 f"tools_ttl={self._tools_ttl}s, prompts_ttl={self._prompts_ttl}s, " 

182 f"resources_ttl={self._resources_ttl}s, agents_ttl={self._agents_ttl}s, " 

183 f"catalog_ttl={self._catalog_ttl}s" 

184 ) 

185 

186 def _get_redis_key(self, cache_type: str, filters_hash: str = "") -> str: 

187 """Generate Redis key with proper prefix. 

188 

189 Args: 

190 cache_type: Type of cache entry (tools, prompts, etc.) 

191 filters_hash: Hash of filter parameters 

192 

193 Returns: 

194 Full Redis key with prefix 

195 

196 Examples: 

197 >>> cache = RegistryCache() 

198 >>> cache._get_redis_key("tools", "abc123") 

199 'mcpgw:registry:tools:abc123' 

200 """ 

201 if filters_hash: 

202 return f"{self._cache_prefix}registry:{cache_type}:{filters_hash}" 

203 return f"{self._cache_prefix}registry:{cache_type}" 

204 

205 def hash_filters(self, **kwargs) -> str: 

206 """Generate a hash from filter parameters. 

207 

208 Args: 

209 **kwargs: Filter parameters to hash 

210 

211 Returns: 

212 MD5 hash of the filter parameters 

213 

214 Examples: 

215 >>> cache = RegistryCache() 

216 >>> h = cache.hash_filters(include_inactive=False, tags=["api"]) 

217 >>> len(h) 

218 32 

219 """ 

220 # Sort keys for consistent hashing 

221 sorted_items = sorted(kwargs.items()) 

222 filter_str = str(sorted_items) 

223 return hashlib.md5(filter_str.encode()).hexdigest() # nosec B324 # noqa: DUO130 

224 

225 async def _get_redis_client(self): 

226 """Get Redis client if available. 

227 

228 Returns: 

229 Redis client or None if unavailable. 

230 """ 

231 try: 

232 # First-Party 

233 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

234 

235 client = await get_redis_client() 

236 if client and not self._redis_checked: 

237 self._redis_checked = True 

238 self._redis_available = True 

239 logger.debug("RegistryCache: Redis client available") 

240 return client 

241 except Exception as e: 

242 if not self._redis_checked: 

243 self._redis_checked = True 

244 self._redis_available = False 

245 logger.debug(f"RegistryCache: Redis unavailable, using in-memory cache: {e}") 

246 return None 

247 

248 async def get(self, cache_type: str, filters_hash: str = "") -> Optional[Any]: 

249 """Get cached data. 

250 

251 Args: 

252 cache_type: Type of cache (tools, prompts, resources, agents, servers, gateways) 

253 filters_hash: Hash of filter parameters 

254 

255 Returns: 

256 Cached data if found, None otherwise 

257 

258 Examples: 

259 >>> import asyncio 

260 >>> cache = RegistryCache() 

261 >>> result = asyncio.run(cache.get("tools", "abc123")) 

262 >>> result is None # Cache miss on fresh cache 

263 True 

264 """ 

265 if not self._enabled: 

266 return None 

267 

268 cache_key = self._get_redis_key(cache_type, filters_hash) 

269 

270 # Try Redis first 

271 redis = await self._get_redis_client() 

272 if redis: 

273 try: 

274 data = await redis.get(cache_key) 

275 if data: 

276 # Third-Party 

277 import orjson # pylint: disable=import-outside-toplevel 

278 

279 self._hit_count += 1 

280 self._redis_hit_count += 1 

281 return orjson.loads(data) 

282 self._redis_miss_count += 1 

283 except Exception as e: 

284 logger.warning(f"RegistryCache Redis get failed: {e}") 

285 

286 # Fall back to in-memory cache 

287 with self._lock: 

288 entry = self._cache.get(cache_key) 

289 if entry and not entry.is_expired(): 

290 self._hit_count += 1 

291 return entry.value 

292 

293 self._miss_count += 1 

294 return None 

295 

296 async def set(self, cache_type: str, data: Any, filters_hash: str = "", ttl: Optional[int] = None) -> None: 

297 """Store data in cache. 

298 

299 Args: 

300 cache_type: Type of cache (tools, prompts, resources, agents, servers, gateways) 

301 data: Data to cache (must be JSON-serializable) 

302 filters_hash: Hash of filter parameters 

303 ttl: TTL in seconds (uses default for cache_type if not specified) 

304 

305 Examples: 

306 >>> import asyncio 

307 >>> cache = RegistryCache() 

308 >>> asyncio.run(cache.set("tools", [{"id": "1", "name": "tool1"}], "abc123")) 

309 """ 

310 if not self._enabled: 

311 return 

312 

313 # Determine TTL 

314 if ttl is None: 

315 ttl_map = { 

316 "tools": self._tools_ttl, 

317 "prompts": self._prompts_ttl, 

318 "resources": self._resources_ttl, 

319 "agents": self._agents_ttl, 

320 "servers": self._servers_ttl, 

321 "gateways": self._gateways_ttl, 

322 "catalog": self._catalog_ttl, 

323 } 

324 ttl = ttl_map.get(cache_type, 20) 

325 

326 cache_key = self._get_redis_key(cache_type, filters_hash) 

327 

328 # Store in Redis 

329 redis = await self._get_redis_client() 

330 if redis: 

331 try: 

332 # Third-Party 

333 import orjson # pylint: disable=import-outside-toplevel 

334 

335 await redis.setex(cache_key, ttl, orjson.dumps(data)) 

336 except Exception as e: 

337 logger.warning(f"RegistryCache Redis set failed: {e}") 

338 

339 # Store in in-memory cache 

340 with self._lock: 

341 self._cache[cache_key] = CacheEntry(value=data, expiry=time.time() + ttl) 

342 

343 async def invalidate(self, cache_type: str) -> None: 

344 """Invalidate all cached data for a cache type. 

345 

346 Args: 

347 cache_type: Type of cache to invalidate (tools, prompts, etc.) 

348 

349 Examples: 

350 >>> import asyncio 

351 >>> cache = RegistryCache() 

352 >>> asyncio.run(cache.invalidate("tools")) 

353 """ 

354 logger.debug(f"RegistryCache: Invalidating {cache_type} cache") 

355 prefix = self._get_redis_key(cache_type) 

356 

357 # Clear in-memory cache 

358 with self._lock: 

359 keys_to_remove = [k for k in self._cache if k.startswith(prefix)] 

360 for key in keys_to_remove: 

361 self._cache.pop(key, None) 

362 

363 # Clear Redis 

364 redis = await self._get_redis_client() 

365 if redis: 

366 try: 

367 pattern = f"{prefix}*" 

368 async for key in redis.scan_iter(match=pattern): 

369 await redis.delete(key) 

370 

371 # Publish invalidation for other workers 

372 await redis.publish("mcpgw:cache:invalidate", f"registry:{cache_type}") 

373 except Exception as e: 

374 logger.warning(f"RegistryCache Redis invalidate failed: {e}") 

375 

376 async def invalidate_tools(self) -> None: 

377 """Invalidate tools cache. 

378 

379 Examples: 

380 >>> import asyncio 

381 >>> cache = RegistryCache() 

382 >>> asyncio.run(cache.invalidate_tools()) 

383 """ 

384 await self.invalidate("tools") 

385 

386 async def invalidate_prompts(self) -> None: 

387 """Invalidate prompts cache. 

388 

389 Examples: 

390 >>> import asyncio 

391 >>> cache = RegistryCache() 

392 >>> asyncio.run(cache.invalidate_prompts()) 

393 """ 

394 await self.invalidate("prompts") 

395 

396 async def invalidate_resources(self) -> None: 

397 """Invalidate resources cache. 

398 

399 Examples: 

400 >>> import asyncio 

401 >>> cache = RegistryCache() 

402 >>> asyncio.run(cache.invalidate_resources()) 

403 """ 

404 await self.invalidate("resources") 

405 

406 async def invalidate_agents(self) -> None: 

407 """Invalidate agents cache. 

408 

409 Examples: 

410 >>> import asyncio 

411 >>> cache = RegistryCache() 

412 >>> asyncio.run(cache.invalidate_agents()) 

413 """ 

414 await self.invalidate("agents") 

415 

416 async def invalidate_servers(self) -> None: 

417 """Invalidate servers cache. 

418 

419 Examples: 

420 >>> import asyncio 

421 >>> cache = RegistryCache() 

422 >>> asyncio.run(cache.invalidate_servers()) 

423 """ 

424 await self.invalidate("servers") 

425 

426 async def invalidate_gateways(self) -> None: 

427 """Invalidate gateways cache. 

428 

429 Examples: 

430 >>> import asyncio 

431 >>> cache = RegistryCache() 

432 >>> asyncio.run(cache.invalidate_gateways()) 

433 """ 

434 await self.invalidate("gateways") 

435 

436 async def invalidate_catalog(self) -> None: 

437 """Invalidate catalog servers cache. 

438 

439 Examples: 

440 >>> import asyncio 

441 >>> cache = RegistryCache() 

442 >>> asyncio.run(cache.invalidate_catalog()) 

443 """ 

444 await self.invalidate("catalog") 

445 

446 def invalidate_all(self) -> None: 

447 """Invalidate all cached data synchronously. 

448 

449 Examples: 

450 >>> cache = RegistryCache() 

451 >>> cache.invalidate_all() 

452 """ 

453 with self._lock: 

454 self._cache.clear() 

455 logger.info("RegistryCache: All caches invalidated") 

456 

457 def stats(self) -> Dict[str, Any]: 

458 """Get cache statistics. 

459 

460 Returns: 

461 Dictionary with hit/miss counts and hit rate 

462 

463 Examples: 

464 >>> cache = RegistryCache() 

465 >>> stats = cache.stats() 

466 >>> "hit_count" in stats 

467 True 

468 """ 

469 total = self._hit_count + self._miss_count 

470 redis_total = self._redis_hit_count + self._redis_miss_count 

471 

472 return { 

473 "enabled": self._enabled, 

474 "hit_count": self._hit_count, 

475 "miss_count": self._miss_count, 

476 "hit_rate": self._hit_count / total if total > 0 else 0.0, 

477 "redis_hit_count": self._redis_hit_count, 

478 "redis_miss_count": self._redis_miss_count, 

479 "redis_hit_rate": self._redis_hit_count / redis_total if redis_total > 0 else 0.0, 

480 "redis_available": self._redis_available, 

481 "cache_size": len(self._cache), 

482 "ttls": { 

483 "tools": self._tools_ttl, 

484 "prompts": self._prompts_ttl, 

485 "resources": self._resources_ttl, 

486 "agents": self._agents_ttl, 

487 "servers": self._servers_ttl, 

488 "gateways": self._gateways_ttl, 

489 "catalog": self._catalog_ttl, 

490 }, 

491 } 

492 

493 def reset_stats(self) -> None: 

494 """Reset hit/miss counters. 

495 

496 Examples: 

497 >>> cache = RegistryCache() 

498 >>> cache._hit_count = 100 

499 >>> cache.reset_stats() 

500 >>> cache._hit_count 

501 0 

502 """ 

503 self._hit_count = 0 

504 self._miss_count = 0 

505 self._redis_hit_count = 0 

506 self._redis_miss_count = 0 

507 

508 

509# Global singleton instance 

510_registry_cache: Optional[RegistryCache] = None 

511 

512 

513def get_registry_cache() -> RegistryCache: 

514 """Get or create the singleton RegistryCache instance. 

515 

516 Returns: 

517 RegistryCache: The singleton registry cache instance 

518 

519 Examples: 

520 >>> cache = get_registry_cache() 

521 >>> isinstance(cache, RegistryCache) 

522 True 

523 """ 

524 global _registry_cache # pylint: disable=global-statement 

525 if _registry_cache is None: 

526 _registry_cache = RegistryCache() 

527 return _registry_cache 

528 

529 

530# Convenience alias for direct import 

531registry_cache = get_registry_cache() 

532 

533 

534# Upper bound on the in-memory revoked-JTI set. 

535# 

536# Prevents unbounded memory growth if a compromised Redis channel floods 

537# ``revoke:`` messages. When the cap is reached new JTIs are still 

538# processed (cache eviction) but are not added to the local set; 

539# subsequent ``is_token_revoked()`` / ``get_auth_context()`` calls will 

540# fall through to the Redis check on L1 cache miss, so revocation is 

541# still enforced. 

542_MAX_REVOKED_JTIS = 100_000 

543 

544 

545class CacheInvalidationSubscriber: 

546 """Redis pubsub subscriber for cross-worker cache invalidation. 

547 

548 This class subscribes to both 'mcpgw:cache:invalidate' and 

549 'mcpgw:auth:invalidate' Redis channels and processes invalidation 

550 messages from other workers, ensuring local in-memory caches stay 

551 synchronized in multi-worker deployments. 

552 

553 Message formats handled: 

554 - registry:{cache_type} - Invalidate registry cache (tools, prompts, etc.) 

555 - tool_lookup:{name} - Invalidate specific tool lookup 

556 - tool_lookup:gateway:{gateway_id} - Invalidate all tools for a gateway 

557 - admin:{prefix} - Invalidate admin stats cache 

558 - user:{email} - Invalidate auth user cache 

559 - revoke:{jti} - Invalidate auth revocation cache 

560 - team:{email} - Invalidate auth team cache 

561 - role:{email}:{team_id} - Invalidate auth role cache 

562 - team_roles:{team_id} - Invalidate all roles for a team 

563 - teams:{email} - Invalidate auth teams list cache 

564 - membership:{email} - Invalidate auth team membership cache 

565 

566 Examples: 

567 >>> subscriber = CacheInvalidationSubscriber() 

568 >>> # Start listening in background task: 

569 >>> # await subscriber.start() 

570 >>> # Stop when shutting down: 

571 >>> # await subscriber.stop() 

572 """ 

573 

574 def __init__(self) -> None: 

575 """Initialize the cache invalidation subscriber.""" 

576 self._task: Optional[asyncio.Task[None]] = None 

577 self._stop_event: Optional[asyncio.Event] = None 

578 self._pubsub: Optional[Any] = None 

579 self._channels = ["mcpgw:cache:invalidate", "mcpgw:auth:invalidate"] 

580 self._started = False 

581 

582 async def start(self) -> None: 

583 """Start listening for cache invalidation messages. 

584 

585 This creates a background task that subscribes to the Redis 

586 channel and processes invalidation messages. 

587 

588 Examples: 

589 >>> import asyncio 

590 >>> subscriber = CacheInvalidationSubscriber() 

591 >>> # asyncio.run(subscriber.start()) 

592 """ 

593 if self._started: 

594 logger.debug("CacheInvalidationSubscriber already started") 

595 return 

596 

597 try: 

598 # First-Party 

599 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

600 

601 redis = await get_redis_client() 

602 if not redis: 

603 logger.info("CacheInvalidationSubscriber: Redis unavailable, skipping cross-worker invalidation") 

604 return 

605 

606 self._stop_event = asyncio.Event() 

607 self._pubsub = redis.pubsub() 

608 await self._pubsub.subscribe(*self._channels) # pyright: ignore[reportOptionalMemberAccess] 

609 

610 self._task = asyncio.create_task(self._listen_loop()) 

611 self._started = True 

612 logger.info("CacheInvalidationSubscriber started on channels %s", self._channels) 

613 

614 except Exception as e: 

615 logger.warning("CacheInvalidationSubscriber failed to start: %s", e) 

616 # Clean up partially created pubsub to prevent leaks 

617 # Use timeout to prevent blocking if pubsub doesn't close cleanly 

618 cleanup_timeout = _get_cleanup_timeout() 

619 if self._pubsub is not None: 

620 try: 

621 try: 

622 await asyncio.wait_for(self._pubsub.aclose(), timeout=cleanup_timeout) 

623 except AttributeError: 

624 await asyncio.wait_for(self._pubsub.close(), timeout=cleanup_timeout) 

625 except asyncio.TimeoutError: 

626 logger.debug("Pubsub cleanup timed out - proceeding anyway") 

627 except Exception as cleanup_err: 

628 logger.debug("Error during pubsub cleanup: %s", cleanup_err) 

629 self._pubsub = None 

630 

631 async def stop(self) -> None: 

632 """Stop listening for cache invalidation messages. 

633 

634 This cancels the background task and cleans up resources. 

635 

636 Examples: 

637 >>> import asyncio 

638 >>> subscriber = CacheInvalidationSubscriber() 

639 >>> # asyncio.run(subscriber.stop()) 

640 """ 

641 if not self._started: 

642 return 

643 

644 self._started = False 

645 

646 if self._stop_event: 

647 self._stop_event.set() 

648 

649 if self._task: 

650 self._task.cancel() 

651 try: 

652 await asyncio.wait_for(self._task, timeout=2.0) 

653 except (asyncio.CancelledError, asyncio.TimeoutError): 

654 pass 

655 self._task = None 

656 

657 if self._pubsub: 

658 cleanup_timeout = _get_cleanup_timeout() 

659 try: 

660 await asyncio.wait_for(self._pubsub.unsubscribe(*self._channels), timeout=cleanup_timeout) 

661 except asyncio.TimeoutError: 

662 logger.debug("Pubsub unsubscribe timed out - proceeding anyway") 

663 except Exception as e: 

664 logger.debug("Error unsubscribing from pubsub: %s", e) 

665 try: 

666 try: 

667 await asyncio.wait_for(self._pubsub.aclose(), timeout=cleanup_timeout) 

668 except AttributeError: 

669 await asyncio.wait_for(self._pubsub.close(), timeout=cleanup_timeout) 

670 except asyncio.TimeoutError: 

671 logger.debug("Pubsub close timed out - proceeding anyway") 

672 except Exception as e: 

673 logger.debug("Error closing pubsub: %s", e) 

674 self._pubsub = None 

675 

676 logger.info("CacheInvalidationSubscriber stopped") 

677 

678 async def _listen_loop(self) -> None: 

679 """Background loop that listens for and processes invalidation messages. 

680 

681 Raises: 

682 asyncio.CancelledError: If the task is cancelled during shutdown. 

683 """ 

684 logger.debug("CacheInvalidationSubscriber listen loop started") 

685 try: 

686 while self._started and not (self._stop_event and self._stop_event.is_set()): 

687 if self._pubsub is None: 

688 break 

689 try: 

690 message = await asyncio.wait_for( 

691 self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0), 

692 timeout=2.0, 

693 ) 

694 if message and message.get("type") == "message": 

695 data = message.get("data") 

696 if isinstance(data, bytes): 

697 data = data.decode("utf-8") 

698 channel = message.get("channel", "") 

699 if isinstance(channel, bytes): 

700 channel = channel.decode("utf-8") 

701 if data: 

702 await self._process_invalidation(data, channel=channel) 

703 except asyncio.TimeoutError: 

704 continue 

705 except Exception as e: # pylint: disable=broad-exception-caught 

706 logger.debug("CacheInvalidationSubscriber message error: %s", e) 

707 await asyncio.sleep(0.1) 

708 except asyncio.CancelledError: 

709 logger.debug("CacheInvalidationSubscriber listen loop cancelled") 

710 raise 

711 finally: 

712 logger.debug("CacheInvalidationSubscriber listen loop exited") 

713 

714 _AUTH_PREFIXES = ("user:", "revoke:", "team_roles:", "teams:", "team:", "role:", "membership:") 

715 """Message prefixes that belong exclusively to the auth invalidation channel.""" 

716 

717 async def _process_invalidation(self, message: str, *, channel: str = "") -> None: # pylint: disable=too-many-branches 

718 """Process a cache invalidation message. 

719 

720 Args: 

721 message: The invalidation message in format 'type:identifier' 

722 channel: The Redis pubsub channel the message arrived on. 

723 Used to enforce that auth-prefixed messages are only 

724 accepted from ``mcpgw:auth:invalidate``. 

725 """ 

726 logger.debug("CacheInvalidationSubscriber received on %s: %s", channel, message) 

727 

728 # pylint: disable=protected-access 

729 # pyright: ignore[reportPrivateUsage] 

730 # We intentionally access protected members to clear local in-memory caches 

731 # without triggering another round of Redis pubsub invalidation messages 

732 try: 

733 if message.startswith("registry:"): 

734 # Handle registry cache invalidation (tools, prompts, resources, etc.) 

735 cache_type = message[len("registry:") :] 

736 cache = get_registry_cache() 

737 # Only clear local in-memory cache to avoid infinite loops 

738 prefix = cache._get_redis_key(cache_type) # pyright: ignore[reportPrivateUsage] 

739 with cache._lock: # pyright: ignore[reportPrivateUsage] 

740 keys_to_remove = [k for k in cache._cache if k.startswith(prefix)] # pyright: ignore[reportPrivateUsage] 

741 for key in keys_to_remove: 

742 cache._cache.pop(key, None) # pyright: ignore[reportPrivateUsage] 

743 logger.debug("CacheInvalidationSubscriber: Cleared local registry:%s cache (%d keys)", cache_type, len(keys_to_remove)) 

744 

745 elif message.startswith("tool_lookup:gateway:"): 

746 # Handle gateway-wide tool lookup invalidation 

747 gateway_id = message[len("tool_lookup:gateway:") :] 

748 # First-Party 

749 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel 

750 

751 # Only clear local L1 cache 

752 with tool_lookup_cache._lock: # pyright: ignore[reportPrivateUsage] 

753 to_remove = [name for name, entry in tool_lookup_cache._cache.items() if entry.value.get("tool", {}).get("gateway_id") == gateway_id] # pyright: ignore[reportPrivateUsage] 

754 for name in to_remove: 

755 tool_lookup_cache._cache.pop(name, None) # pyright: ignore[reportPrivateUsage] 

756 logger.debug("CacheInvalidationSubscriber: Cleared local tool_lookup for gateway %s (%d keys)", gateway_id, len(to_remove)) 

757 

758 elif message.startswith("tool_lookup:"): 

759 # Handle specific tool lookup invalidation 

760 tool_name = message[len("tool_lookup:") :] 

761 # First-Party 

762 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel 

763 

764 # Only clear local L1 cache 

765 with tool_lookup_cache._lock: # pyright: ignore[reportPrivateUsage] 

766 tool_lookup_cache._cache.pop(tool_name, None) # pyright: ignore[reportPrivateUsage] 

767 logger.debug("CacheInvalidationSubscriber: Cleared local tool_lookup:%s", tool_name) 

768 

769 elif message.startswith("admin:"): 

770 # Handle admin stats cache invalidation 

771 prefix = message[len("admin:") :] 

772 # First-Party 

773 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

774 

775 # Only clear local in-memory cache 

776 full_prefix = admin_stats_cache._get_redis_key(prefix) # pyright: ignore[reportPrivateUsage] 

777 with admin_stats_cache._lock: # pyright: ignore[reportPrivateUsage] 

778 keys_to_remove = [k for k in admin_stats_cache._cache if k.startswith(full_prefix)] # pyright: ignore[reportPrivateUsage] 

779 for key in keys_to_remove: 

780 admin_stats_cache._cache.pop(key, None) # pyright: ignore[reportPrivateUsage] 

781 logger.debug("CacheInvalidationSubscriber: Cleared local admin:%s cache (%d keys)", prefix, len(keys_to_remove)) 

782 

783 elif message.startswith(self._AUTH_PREFIXES): 

784 if channel != "mcpgw:auth:invalidate": 

785 logger.warning("CacheInvalidationSubscriber: Ignoring auth message on wrong channel %s: %s", channel, message) 

786 else: 

787 self._process_auth_invalidation(message) 

788 

789 else: 

790 logger.debug("CacheInvalidationSubscriber: Unknown message format: %s", message) 

791 

792 except Exception as e: # pylint: disable=broad-exception-caught 

793 logger.warning("CacheInvalidationSubscriber: Error processing '%s': %s", message, e) 

794 

795 @staticmethod 

796 def _evict_keys(cache_dict: dict, predicate: "Callable[[str], bool]") -> int: 

797 """Remove all keys from *cache_dict* that satisfy *predicate*. 

798 

799 Must be called while holding the owning cache's ``_lock``. 

800 

801 Args: 

802 cache_dict: The dictionary to evict keys from. 

803 predicate: A callable that returns True for keys to remove. 

804 

805 Returns: 

806 Number of keys removed. 

807 """ 

808 keys = [k for k in cache_dict if predicate(k)] 

809 for k in keys: 

810 cache_dict.pop(k, None) 

811 return len(keys) 

812 

813 def _process_auth_invalidation(self, message: str) -> None: # pylint: disable=too-many-branches 

814 """Dispatch an auth-channel invalidation message to the local auth cache. 

815 

816 Called from :meth:`_process_invalidation` for messages received on 

817 ``mcpgw:auth:invalidate``. 

818 

819 Args: 

820 message: The invalidation message (e.g. ``user:alice@test.com``). 

821 """ 

822 # pylint: disable=protected-access 

823 # First-Party 

824 from mcpgateway.cache.auth_cache import auth_cache # pylint: disable=import-outside-toplevel 

825 

826 # Dispatch auth message to the correct handler 

827 if message.startswith("user:"): 

828 email = message[len("user:") :] 

829 with auth_cache._lock: # pyright: ignore[reportPrivateUsage] 

830 self._evict_keys(auth_cache._context_cache, lambda k: k.startswith(f"{email}:")) # pyright: ignore[reportPrivateUsage] 

831 auth_cache._user_cache.pop(email, None) # pyright: ignore[reportPrivateUsage] 

832 self._evict_keys(auth_cache._team_cache, lambda k: k.startswith(f"{email}:")) # pyright: ignore[reportPrivateUsage] 

833 logger.debug("CacheInvalidationSubscriber: Cleared local auth user cache for %s", email) 

834 

835 elif message.startswith("revoke:"): 

836 jti = message[len("revoke:") :] 

837 with auth_cache._lock: # pyright: ignore[reportPrivateUsage] 

838 if len(auth_cache._revoked_jtis) < _MAX_REVOKED_JTIS: # pyright: ignore[reportPrivateUsage] 

839 auth_cache._revoked_jtis.add(jti) # pyright: ignore[reportPrivateUsage] 

840 else: 

841 logger.warning("CacheInvalidationSubscriber: _revoked_jtis at cap (%d), skipping add for jti=%s", _MAX_REVOKED_JTIS, jti[:8]) 

842 auth_cache._revocation_cache.pop(jti, None) # pyright: ignore[reportPrivateUsage] 

843 self._evict_keys(auth_cache._context_cache, lambda k: k.endswith(f":{jti}")) # pyright: ignore[reportPrivateUsage] 

844 logger.debug("CacheInvalidationSubscriber: Cleared local auth revocation cache for jti=%s", jti[:8]) 

845 

846 elif message.startswith("team_roles:"): 

847 team_id = message[len("team_roles:") :] 

848 with auth_cache._lock: # pyright: ignore[reportPrivateUsage] 

849 self._evict_keys(auth_cache._role_cache, lambda k: k.endswith(f":{team_id}")) # pyright: ignore[reportPrivateUsage] 

850 logger.debug("CacheInvalidationSubscriber: Cleared local auth team_roles cache for team %s", team_id) 

851 

852 elif message.startswith("teams:"): 

853 email = message[len("teams:") :] 

854 with auth_cache._lock: # pyright: ignore[reportPrivateUsage] 

855 self._evict_keys(auth_cache._teams_list_cache, lambda k: k.startswith(f"{email}:")) # pyright: ignore[reportPrivateUsage] 

856 logger.debug("CacheInvalidationSubscriber: Cleared local auth teams list cache for %s", email) 

857 

858 elif message.startswith("team:"): 

859 email = message[len("team:") :] 

860 with auth_cache._lock: # pyright: ignore[reportPrivateUsage] 

861 auth_cache._team_cache.pop(email, None) # pyright: ignore[reportPrivateUsage] 

862 self._evict_keys(auth_cache._context_cache, lambda k: k.startswith(f"{email}:")) # pyright: ignore[reportPrivateUsage] 

863 logger.debug("CacheInvalidationSubscriber: Cleared local auth team cache for %s", email) 

864 

865 elif message.startswith("role:"): 

866 cache_key = message[len("role:") :] 

867 with auth_cache._lock: # pyright: ignore[reportPrivateUsage] 

868 auth_cache._role_cache.pop(cache_key, None) # pyright: ignore[reportPrivateUsage] 

869 logger.debug("CacheInvalidationSubscriber: Cleared local auth role cache for %s", cache_key) 

870 

871 elif message.startswith("membership:"): 

872 user_email = message[len("membership:") :] 

873 with auth_cache._lock: # pyright: ignore[reportPrivateUsage] 

874 self._evict_keys(auth_cache._team_cache, lambda k: k.startswith(f"{user_email}:")) # pyright: ignore[reportPrivateUsage] 

875 logger.debug("CacheInvalidationSubscriber: Cleared local auth membership cache for %s", user_email) 

876 

877 

878# Global singleton for cache invalidation subscriber 

879_cache_invalidation_subscriber: Optional[CacheInvalidationSubscriber] = None 

880 

881 

882def get_cache_invalidation_subscriber() -> CacheInvalidationSubscriber: 

883 """Get or create the singleton CacheInvalidationSubscriber instance. 

884 

885 Returns: 

886 CacheInvalidationSubscriber: The singleton instance 

887 

888 Examples: 

889 >>> subscriber = get_cache_invalidation_subscriber() 

890 >>> isinstance(subscriber, CacheInvalidationSubscriber) 

891 True 

892 """ 

893 global _cache_invalidation_subscriber # pylint: disable=global-statement 

894 if _cache_invalidation_subscriber is None: 

895 _cache_invalidation_subscriber = CacheInvalidationSubscriber() 

896 return _cache_invalidation_subscriber