Coverage for mcpgateway / cache / auth_cache.py: 100%
441 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/cache/auth_cache.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
6Authentication Data Cache.
8This module implements a thread-safe two-tier cache for authentication data.
9L1 (in-memory) is checked first for lowest latency, with L2 (Redis) as a
10shared distributed cache. It caches user data, team memberships, and token
11revocation status to reduce database queries during authentication.
13Performance Impact:
14 - Before: 3-4 DB queries per authenticated request
15 - After: 0-1 DB queries (cache hit) per TTL period
17Security Considerations:
18 - Short TTLs for revocation data (30s default) to limit exposure window
19 - Cache invalidation on token revocation, user update, team change
20 - JWT payloads are NOT cached (security risk)
21 - Graceful fallback to DB on cache failure
23Examples:
24 >>> from mcpgateway.cache.auth_cache import auth_cache
25 >>> # Cache is used automatically by get_current_user()
26 >>> # Manual invalidation after user update:
27 >>> import asyncio
28 >>> # asyncio.run(auth_cache.invalidate_user("user@example.com"))
29"""
31# Standard
32import asyncio
33from dataclasses import dataclass
34import logging
35import threading
36import time
37from typing import Any, Dict, List, Optional, Set
39logger = logging.getLogger(__name__)
41# Sentinel value to represent "user is not a member" in Redis cache
42# This allows distinguishing between "not a member" (cached) and "cache miss"
43_NOT_A_MEMBER_SENTINEL = "__NOT_A_MEMBER__"
46@dataclass
47class CachedAuthContext:
48 """Cached authentication context from batched DB query.
50 This dataclass holds user data, team membership, and revocation status
51 retrieved from a single database roundtrip.
53 Attributes:
54 user: User data dict (email, is_admin, is_active, etc.) or None
55 personal_team_id: User's personal team ID or None
56 is_token_revoked: Whether the JWT is revoked
58 Examples:
59 >>> ctx = CachedAuthContext(
60 ... user={"email": "test@example.com", "is_admin": False},
61 ... personal_team_id="team-123",
62 ... is_token_revoked=False
63 ... )
64 >>> ctx.is_token_revoked
65 False
66 """
68 user: Optional[Dict[str, Any]] = None
69 personal_team_id: Optional[str] = None
70 is_token_revoked: bool = False
73@dataclass
74class CacheEntry:
75 """Cache entry with value and expiry timestamp.
77 Examples:
78 >>> import time
79 >>> entry = CacheEntry(value={"key": "value"}, expiry=time.time() + 60)
80 >>> entry.is_expired()
81 False
82 """
84 value: Any
85 expiry: float
87 def is_expired(self) -> bool:
88 """Check if this cache entry has expired.
90 Returns:
91 bool: True if the entry has expired, False otherwise.
92 """
93 return time.time() >= self.expiry
96class AuthCache:
97 """Thread-safe two-tier authentication cache (L1 in-memory + L2 Redis).
99 This cache reduces database load during authentication by caching:
100 - User data (email, is_admin, is_active, etc.)
101 - Personal team ID for the user
102 - Token revocation status
104 Cache lookup checks L1 (in-memory) first for lowest latency, then L2
105 (Redis) for distributed consistency. Redis hits are written through
106 to L1 for subsequent requests.
108 Attributes:
109 user_ttl: TTL in seconds for user data cache (default: 60)
110 revocation_ttl: TTL in seconds for revocation cache (default: 30)
111 team_ttl: TTL in seconds for team cache (default: 60)
113 Examples:
114 >>> cache = AuthCache(user_ttl=60, revocation_ttl=30)
115 >>> cache.stats()["hit_count"]
116 0
117 """
119 _NOT_CACHED = object()
121 def __init__(
122 self,
123 user_ttl: Optional[int] = None,
124 revocation_ttl: Optional[int] = None,
125 team_ttl: Optional[int] = None,
126 role_ttl: Optional[int] = None,
127 enabled: Optional[bool] = None,
128 ):
129 """Initialize the auth cache.
131 Args:
132 user_ttl: TTL for user data cache in seconds (default: from settings or 60)
133 revocation_ttl: TTL for revocation cache in seconds (default: from settings or 30)
134 team_ttl: TTL for team cache in seconds (default: from settings or 60)
135 role_ttl: TTL for role cache in seconds (default: from settings or 60)
136 enabled: Whether caching is enabled (default: from settings or True)
138 Examples:
139 >>> cache = AuthCache(user_ttl=120, revocation_ttl=30)
140 >>> cache._user_ttl
141 120
142 """
143 # Import settings lazily to avoid circular imports
144 try:
145 # First-Party
146 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
148 self._user_ttl = user_ttl or getattr(settings, "auth_cache_user_ttl", 60)
149 self._revocation_ttl = revocation_ttl or getattr(settings, "auth_cache_revocation_ttl", 30)
150 self._team_ttl = team_ttl or getattr(settings, "auth_cache_team_ttl", 60)
151 self._role_ttl = role_ttl or getattr(settings, "auth_cache_role_ttl", 60)
152 self._teams_list_ttl = getattr(settings, "auth_cache_teams_ttl", 60)
153 self._teams_list_enabled = getattr(settings, "auth_cache_teams_enabled", True)
154 self._enabled = enabled if enabled is not None else getattr(settings, "auth_cache_enabled", True)
155 self._cache_prefix = getattr(settings, "cache_prefix", "mcpgw:")
156 except ImportError:
157 self._user_ttl = user_ttl or 60
158 self._revocation_ttl = revocation_ttl or 30
159 self._team_ttl = team_ttl or 60
160 self._role_ttl = role_ttl or 60
161 self._teams_list_ttl = 60
162 self._teams_list_enabled = True
163 self._enabled = enabled if enabled is not None else True
164 self._cache_prefix = "mcpgw:"
166 # In-memory cache (fallback when Redis unavailable)
167 self._user_cache: Dict[str, CacheEntry] = {}
168 self._team_cache: Dict[str, CacheEntry] = {}
169 self._revocation_cache: Dict[str, CacheEntry] = {}
170 self._context_cache: Dict[str, CacheEntry] = {}
171 self._role_cache: Dict[str, CacheEntry] = {}
172 self._teams_list_cache: Dict[str, CacheEntry] = {}
174 # Known revoked tokens (fast local lookup)
175 self._revoked_jtis: Set[str] = set()
177 # Thread safety
178 self._lock = threading.Lock()
180 # Redis availability (None = not checked yet)
181 self._redis_checked = False
182 self._redis_available = False
184 # Statistics
185 self._hit_count = 0
186 self._miss_count = 0
187 self._redis_hit_count = 0
188 self._redis_miss_count = 0
190 logger.info(
191 f"AuthCache initialized: enabled={self._enabled}, "
192 f"user_ttl={self._user_ttl}s, revocation_ttl={self._revocation_ttl}s, "
193 f"team_ttl={self._team_ttl}s, role_ttl={self._role_ttl}s, "
194 f"teams_list_enabled={self._teams_list_enabled}, teams_list_ttl={self._teams_list_ttl}s"
195 )
197 def _get_redis_key(self, key_type: str, identifier: str) -> str:
198 """Generate Redis key with proper prefix.
200 Args:
201 key_type: Type of cache entry (user, team, revoke, ctx)
202 identifier: Unique identifier (email, jti, etc.)
204 Returns:
205 Full Redis key with prefix
207 Examples:
208 >>> cache = AuthCache()
209 >>> cache._get_redis_key("user", "test@example.com")
210 'mcpgw:auth:user:test@example.com'
211 """
212 return f"{self._cache_prefix}auth:{key_type}:{identifier}"
214 async def _get_redis_client(self):
215 """Get Redis client if available.
217 Returns:
218 Redis client or None if unavailable
219 """
220 try:
221 # First-Party
222 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
224 client = await get_redis_client()
225 if client and not self._redis_checked:
226 self._redis_checked = True
227 self._redis_available = True
228 logger.debug("AuthCache: Redis client available")
229 return client
230 except Exception as e:
231 if not self._redis_checked:
232 self._redis_checked = True
233 self._redis_available = False
234 logger.debug(f"AuthCache: Redis unavailable, using in-memory cache: {e}")
235 return None
237 async def get_auth_context(
238 self,
239 email: str,
240 jti: Optional[str] = None,
241 ) -> Optional[CachedAuthContext]:
242 """Get cached authentication context.
244 Checks cache for user data, team membership, and revocation status.
245 Returns None on cache miss.
247 Args:
248 email: User email address
249 jti: JWT ID for revocation check (optional)
251 Returns:
252 CachedAuthContext if found in cache, None otherwise
254 Examples:
255 >>> import asyncio
256 >>> cache = AuthCache()
257 >>> result = asyncio.run(cache.get_auth_context("test@example.com"))
258 >>> result is None # Cache miss on fresh cache
259 True
260 """
261 if not self._enabled:
262 return None
264 # Check for known revoked token first (fast local check)
265 if jti and jti in self._revoked_jtis:
266 self._hit_count += 1
267 return CachedAuthContext(is_token_revoked=True)
269 cache_key = f"{email}:{jti or 'no-jti'}"
271 # Check L1 in-memory cache first (no network I/O)
272 entry = self._context_cache.get(cache_key)
273 if entry and not entry.is_expired():
274 self._hit_count += 1
275 return entry.value
277 # Check L2 Redis cache
278 redis = await self._get_redis_client()
279 if redis:
280 try:
281 redis_key = self._get_redis_key("ctx", cache_key)
282 data = await redis.get(redis_key)
283 if data:
284 # Third-Party
285 import orjson # pylint: disable=import-outside-toplevel
287 cached = orjson.loads(data)
288 result = CachedAuthContext(
289 user=cached.get("user"),
290 personal_team_id=cached.get("personal_team_id"),
291 is_token_revoked=cached.get("is_token_revoked", False),
292 )
293 self._hit_count += 1
294 self._redis_hit_count += 1
296 # Write-through: populate L1 from Redis hit
297 ttl = min(self._user_ttl, self._revocation_ttl, self._team_ttl)
298 with self._lock:
299 self._context_cache[cache_key] = CacheEntry(
300 value=result,
301 expiry=time.time() + ttl,
302 )
304 return result
305 self._redis_miss_count += 1
306 except Exception as e:
307 logger.warning(f"AuthCache Redis get failed: {e}")
309 self._miss_count += 1
310 return None
312 async def set_auth_context(
313 self,
314 email: str,
315 jti: Optional[str],
316 context: CachedAuthContext,
317 ) -> None:
318 """Store authentication context in cache.
320 Stores in both Redis (if available) and in-memory cache.
322 Args:
323 email: User email address
324 jti: JWT ID (optional)
325 context: Authentication context to cache
327 Examples:
328 >>> import asyncio
329 >>> cache = AuthCache()
330 >>> ctx = CachedAuthContext(
331 ... user={"email": "test@example.com"},
332 ... personal_team_id="team-1",
333 ... is_token_revoked=False
334 ... )
335 >>> asyncio.run(cache.set_auth_context("test@example.com", "jti-123", ctx))
336 """
337 if not self._enabled:
338 return
340 cache_key = f"{email}:{jti or 'no-jti'}"
342 # Use shortest TTL for combined context
343 ttl = min(self._user_ttl, self._revocation_ttl, self._team_ttl)
345 # Prepare data for serialization
346 data = {
347 "user": context.user,
348 "personal_team_id": context.personal_team_id,
349 "is_token_revoked": context.is_token_revoked,
350 }
352 # Store in Redis
353 redis = await self._get_redis_client()
354 if redis:
355 try:
356 # Third-Party
357 import orjson # pylint: disable=import-outside-toplevel
359 redis_key = self._get_redis_key("ctx", cache_key)
360 await redis.setex(redis_key, ttl, orjson.dumps(data))
361 except Exception as e:
362 logger.warning(f"AuthCache Redis set failed: {e}")
364 # Store in in-memory cache
365 with self._lock:
366 self._context_cache[cache_key] = CacheEntry(
367 value=context,
368 expiry=time.time() + ttl,
369 )
371 async def invalidate_user(self, email: str) -> None:
372 """Invalidate cached data for a user.
374 Call this when user data changes (password, profile, etc.).
376 Args:
377 email: User email to invalidate
379 Examples:
380 >>> import asyncio
381 >>> cache = AuthCache()
382 >>> asyncio.run(cache.invalidate_user("test@example.com"))
383 """
384 logger.debug(f"AuthCache: Invalidating user cache for {email}")
386 # Clear in-memory caches
387 with self._lock:
388 # Clear any context cache entries for this user
389 keys_to_remove = [k for k in self._context_cache if k.startswith(f"{email}:")]
390 for key in keys_to_remove:
391 self._context_cache.pop(key, None)
393 self._user_cache.pop(email, None)
395 # Clear team membership cache entries (keys are email:team_ids)
396 team_keys_to_remove = [k for k in self._team_cache if k.startswith(f"{email}:")]
397 for key in team_keys_to_remove:
398 self._team_cache.pop(key, None)
400 # Clear Redis
401 redis = await self._get_redis_client()
402 if redis:
403 try:
404 # Delete user-specific keys
405 await redis.delete(
406 self._get_redis_key("user", email),
407 self._get_redis_key("team", email),
408 )
409 # Delete context keys (pattern match)
410 pattern = self._get_redis_key("ctx", f"{email}:*")
411 async for key in redis.scan_iter(match=pattern):
412 await redis.delete(key)
413 # Delete membership keys (pattern match)
414 membership_pattern = self._get_redis_key("membership", f"{email}:*")
415 async for key in redis.scan_iter(match=membership_pattern):
416 await redis.delete(key)
418 # Publish invalidation for other workers
419 await redis.publish("mcpgw:auth:invalidate", f"user:{email}")
420 except Exception as e:
421 logger.warning(f"AuthCache Redis invalidate_user failed: {e}")
423 async def invalidate_revocation(self, jti: str) -> None:
424 """Invalidate cache for a revoked token.
426 Call this when a token is revoked.
428 Args:
429 jti: JWT ID of revoked token
431 Examples:
432 >>> import asyncio
433 >>> cache = AuthCache()
434 >>> asyncio.run(cache.invalidate_revocation("jti-123"))
435 """
436 logger.debug(f"AuthCache: Invalidating revocation cache for jti={jti[:8]}...")
438 # Add to local revoked set for fast lookup
439 with self._lock:
440 self._revoked_jtis.add(jti)
441 self._revocation_cache.pop(jti, None)
443 # Clear any context cache entries with this JTI
444 keys_to_remove = [k for k in self._context_cache if k.endswith(f":{jti}")]
445 for key in keys_to_remove:
446 self._context_cache.pop(key, None)
448 # Update Redis
449 redis = await self._get_redis_client()
450 if redis:
451 try:
452 # Mark as revoked in Redis
453 await redis.setex(
454 self._get_redis_key("revoke", jti),
455 86400, # 24 hour expiry for revocation markers
456 "1",
457 )
458 # Add to revoked tokens set
459 await redis.sadd("mcpgw:auth:revoked_tokens", jti)
461 # Delete any cached contexts with this JTI
462 pattern = self._get_redis_key("ctx", f"*:{jti}")
463 async for key in redis.scan_iter(match=pattern):
464 await redis.delete(key)
466 # Publish invalidation for other workers
467 await redis.publish("mcpgw:auth:invalidate", f"revoke:{jti}")
468 except Exception as e:
469 logger.warning(f"AuthCache Redis invalidate_revocation failed: {e}")
471 async def invalidate_team(self, email: str) -> None:
472 """Invalidate team cache for a user.
474 Call this when team membership changes.
476 Args:
477 email: User email whose team changed
479 Examples:
480 >>> import asyncio
481 >>> cache = AuthCache()
482 >>> asyncio.run(cache.invalidate_team("test@example.com"))
483 """
484 logger.debug(f"AuthCache: Invalidating team cache for {email}")
486 # Clear in-memory caches
487 with self._lock:
488 self._team_cache.pop(email, None)
489 # Clear context cache entries for this user
490 keys_to_remove = [k for k in self._context_cache if k.startswith(f"{email}:")]
491 for key in keys_to_remove:
492 self._context_cache.pop(key, None)
494 # Clear Redis
495 redis = await self._get_redis_client()
496 if redis:
497 try:
498 await redis.delete(self._get_redis_key("team", email))
499 # Delete context keys
500 pattern = self._get_redis_key("ctx", f"{email}:*")
501 async for key in redis.scan_iter(match=pattern):
502 await redis.delete(key)
504 # Publish invalidation
505 await redis.publish("mcpgw:auth:invalidate", f"team:{email}")
506 except Exception as e:
507 logger.warning(f"AuthCache Redis invalidate_team failed: {e}")
509 async def get_user_role(self, email: str, team_id: str) -> Optional[str]:
510 """Get cached user role in a team.
512 Returns:
513 - None: Cache miss (caller should check DB)
514 - "": User is not a member of the team (cached negative result)
515 - Role string: User's role in the team (cached)
517 Args:
518 email: User email
519 team_id: Team ID
521 Examples:
522 >>> import asyncio
523 >>> cache = AuthCache()
524 >>> result = asyncio.run(cache.get_user_role("test@example.com", "team-123"))
525 >>> result is None # Cache miss
526 True
527 """
528 if not self._enabled:
529 return None
531 cache_key = f"{email}:{team_id}"
533 # Check L1 in-memory cache first (no network I/O)
534 entry = self._role_cache.get(cache_key)
535 if entry and not entry.is_expired():
536 self._hit_count += 1
537 # Return empty string for None (not a member) to distinguish from cache miss
538 return "" if entry.value is None else entry.value
540 # Check L2 Redis cache
541 redis = await self._get_redis_client()
542 if redis:
543 try:
544 redis_key = self._get_redis_key("role", cache_key)
545 data = await redis.get(redis_key)
546 if data is not None:
547 self._hit_count += 1
548 self._redis_hit_count += 1
549 # Role is stored as plain string, decode it
550 decoded = data.decode() if isinstance(data, bytes) else data
551 # Convert sentinel to empty string (user is not a member)
552 role_value = "" if decoded == _NOT_A_MEMBER_SENTINEL else decoded
554 # Write-through: populate L1 from Redis hit
555 # Store as None for not-a-member to match existing L1 storage format
556 with self._lock:
557 self._role_cache[cache_key] = CacheEntry(
558 value=None if role_value == "" else role_value,
559 expiry=time.time() + self._role_ttl,
560 )
562 return role_value
563 self._redis_miss_count += 1
564 except Exception as e:
565 logger.warning(f"AuthCache Redis get_user_role failed: {e}")
567 self._miss_count += 1
568 return None
570 async def set_user_role(self, email: str, team_id: str, role: Optional[str]) -> None:
571 """Store user role in cache.
573 Args:
574 email: User email
575 team_id: Team ID
576 role: User's role in the team (or None if not a member)
578 Examples:
579 >>> import asyncio
580 >>> cache = AuthCache()
581 >>> asyncio.run(cache.set_user_role("test@example.com", "team-123", "admin"))
582 """
583 if not self._enabled:
584 return
586 cache_key = f"{email}:{team_id}"
587 # Store None as sentinel value to distinguish "not a member" from cache miss
588 role_value = role if role is not None else _NOT_A_MEMBER_SENTINEL
590 # Store in Redis
591 redis = await self._get_redis_client()
592 if redis:
593 try:
594 redis_key = self._get_redis_key("role", cache_key)
595 await redis.setex(redis_key, self._role_ttl, role_value)
596 except Exception as e:
597 logger.warning(f"AuthCache Redis set_user_role failed: {e}")
599 # Store in in-memory cache
600 with self._lock:
601 self._role_cache[cache_key] = CacheEntry(
602 value=role,
603 expiry=time.time() + self._role_ttl,
604 )
606 async def invalidate_user_role(self, email: str, team_id: str) -> None:
607 """Invalidate cached role for a user in a team.
609 Call this when a user's role changes in a team.
611 Args:
612 email: User email
613 team_id: Team ID
615 Examples:
616 >>> import asyncio
617 >>> cache = AuthCache()
618 >>> asyncio.run(cache.invalidate_user_role("test@example.com", "team-123"))
619 """
620 logger.debug(f"AuthCache: Invalidating role cache for {email} in team {team_id}")
622 cache_key = f"{email}:{team_id}"
624 # Clear in-memory cache
625 with self._lock:
626 self._role_cache.pop(cache_key, None)
628 # Clear Redis
629 redis = await self._get_redis_client()
630 if redis:
631 try:
632 await redis.delete(self._get_redis_key("role", cache_key))
633 # Publish invalidation for other workers
634 await redis.publish("mcpgw:auth:invalidate", f"role:{email}:{team_id}")
635 except Exception as e:
636 logger.warning(f"AuthCache Redis invalidate_user_role failed: {e}")
638 async def invalidate_team_roles(self, team_id: str) -> None:
639 """Invalidate all cached roles for a team.
641 Call this when team membership changes significantly (e.g., team deletion).
643 Args:
644 team_id: Team ID
646 Examples:
647 >>> import asyncio
648 >>> cache = AuthCache()
649 >>> asyncio.run(cache.invalidate_team_roles("team-123"))
650 """
651 logger.debug(f"AuthCache: Invalidating all role caches for team {team_id}")
653 # Clear in-memory cache entries for this team
654 with self._lock:
655 keys_to_remove = [k for k in self._role_cache if k.endswith(f":{team_id}")]
656 for key in keys_to_remove:
657 self._role_cache.pop(key, None)
659 # Clear Redis
660 redis = await self._get_redis_client()
661 if redis:
662 try:
663 # Pattern match all role keys for this team
664 pattern = self._get_redis_key("role", f"*:{team_id}")
665 async for key in redis.scan_iter(match=pattern):
666 await redis.delete(key)
667 # Publish invalidation
668 await redis.publish("mcpgw:auth:invalidate", f"team_roles:{team_id}")
669 except Exception as e:
670 logger.warning(f"AuthCache Redis invalidate_team_roles failed: {e}")
672 async def get_user_teams(self, cache_key: str) -> Optional[List[str]]:
673 """Get cached team IDs for a user.
675 The cache_key should be in the format "email:include_personal" to
676 distinguish between calls with different include_personal flags.
678 Returns:
679 - None: Cache miss (caller should query DB)
680 - Empty list: User has no teams (cached result)
681 - List of team IDs: Cached team IDs
683 Args:
684 cache_key: Cache key in format "email:include_personal"
686 Examples:
687 >>> import asyncio
688 >>> cache = AuthCache()
689 >>> result = asyncio.run(cache.get_user_teams("test@example.com:True"))
690 >>> result is None # Cache miss
691 True
692 """
693 if not self._enabled or not self._teams_list_enabled:
694 return None
696 # Check L1 in-memory cache first (no network I/O)
697 entry = self._teams_list_cache.get(cache_key)
698 if entry and not entry.is_expired():
699 self._hit_count += 1
700 return entry.value
702 # Check L2 Redis cache
703 redis = await self._get_redis_client()
704 if redis:
705 try:
706 redis_key = self._get_redis_key("teams", cache_key)
707 data = await redis.get(redis_key)
708 if data is not None:
709 self._hit_count += 1
710 self._redis_hit_count += 1
711 # Third-Party
712 import orjson # pylint: disable=import-outside-toplevel
714 team_ids = orjson.loads(data)
716 # Write-through: populate L1 from Redis hit
717 with self._lock:
718 self._teams_list_cache[cache_key] = CacheEntry(
719 value=team_ids,
720 expiry=time.time() + self._teams_list_ttl,
721 )
723 return team_ids
724 self._redis_miss_count += 1
725 except Exception as e:
726 logger.warning(f"AuthCache Redis get_user_teams failed: {e}")
728 self._miss_count += 1
729 return None
731 async def set_user_teams(self, cache_key: str, team_ids: List[str]) -> None:
732 """Store team IDs for a user in cache.
734 Args:
735 cache_key: Cache key in format "email:include_personal"
736 team_ids: List of team IDs the user belongs to
738 Examples:
739 >>> import asyncio
740 >>> cache = AuthCache()
741 >>> asyncio.run(cache.set_user_teams("test@example.com:True", ["team-1", "team-2"]))
742 """
743 if not self._enabled or not self._teams_list_enabled:
744 return
746 # Store in Redis
747 redis = await self._get_redis_client()
748 if redis:
749 try:
750 # Third-Party
751 import orjson # pylint: disable=import-outside-toplevel
753 redis_key = self._get_redis_key("teams", cache_key)
754 await redis.setex(redis_key, self._teams_list_ttl, orjson.dumps(team_ids))
755 except Exception as e:
756 logger.warning(f"AuthCache Redis set_user_teams failed: {e}")
758 # Store in in-memory cache
759 with self._lock:
760 self._teams_list_cache[cache_key] = CacheEntry(
761 value=team_ids,
762 expiry=time.time() + self._teams_list_ttl,
763 )
765 async def invalidate_user_teams(self, email: str) -> None:
766 """Invalidate cached teams list for a user.
768 Call this when a user's team membership changes (add/remove member,
769 delete team, approve join request).
771 This invalidates both include_personal=True and include_personal=False
772 cache entries for the user.
774 Args:
775 email: User email whose teams cache should be invalidated
777 Examples:
778 >>> import asyncio
779 >>> cache = AuthCache()
780 >>> asyncio.run(cache.invalidate_user_teams("test@example.com"))
781 """
782 logger.debug(f"AuthCache: Invalidating teams list cache for {email}")
784 # Clear in-memory cache entries for this user (both True and False variants)
785 with self._lock:
786 keys_to_remove = [k for k in self._teams_list_cache if k.startswith(f"{email}:")]
787 for key in keys_to_remove:
788 self._teams_list_cache.pop(key, None)
790 # Clear Redis
791 redis = await self._get_redis_client()
792 if redis:
793 try:
794 # Delete both variants
795 await redis.delete(
796 self._get_redis_key("teams", f"{email}:True"),
797 self._get_redis_key("teams", f"{email}:False"),
798 )
799 # Publish invalidation for other workers
800 await redis.publish("mcpgw:auth:invalidate", f"teams:{email}")
801 except Exception as e:
802 logger.warning(f"AuthCache Redis invalidate_user_teams failed: {e}")
804 # =========================================================================
805 # Team Membership Validation Cache
806 # =========================================================================
807 # Used by TokenScopingMiddleware to cache email_team_members lookups.
808 # This prevents repeated DB queries for the same user+teams combination.
810 def get_team_membership_valid_sync(self, user_email: str, team_ids: List[str]) -> Optional[bool]:
811 """Get cached team membership validation result (synchronous).
813 This is the synchronous version used by token_scoping middleware.
814 Returns None on cache miss (caller should check DB).
816 Args:
817 user_email: User email
818 team_ids: List of team IDs to validate membership for
820 Returns:
821 - None: Cache miss (caller should query DB)
822 - True: User is valid member of all teams (cached)
823 - False: User is NOT a valid member of all teams (cached)
825 Examples:
826 >>> cache = AuthCache()
827 >>> result = cache.get_team_membership_valid_sync("test@example.com", ["team-1"])
828 >>> result is None # Cache miss
829 True
830 """
831 if not self._enabled or not team_ids:
832 return None
834 # Create cache key from user + sorted team IDs
835 sorted_teams = ":".join(sorted(team_ids))
836 cache_key = f"{user_email}:{sorted_teams}"
838 # Check in-memory cache only (sync version)
839 entry = self._team_cache.get(cache_key)
840 if entry and not entry.is_expired():
841 self._hit_count += 1
842 return entry.value
844 self._miss_count += 1
845 return None
847 def set_team_membership_valid_sync(self, user_email: str, team_ids: List[str], valid: bool) -> None:
848 """Store team membership validation result in cache (synchronous).
850 Args:
851 user_email: User email
852 team_ids: List of team IDs that were validated
853 valid: Whether user is a valid member of all teams
855 Examples:
856 >>> cache = AuthCache()
857 >>> cache.set_team_membership_valid_sync("test@example.com", ["team-1"], True)
858 """
859 if not self._enabled or not team_ids:
860 return
862 # Create cache key from user + sorted team IDs
863 sorted_teams = ":".join(sorted(team_ids))
864 cache_key = f"{user_email}:{sorted_teams}"
866 # Store in in-memory cache
867 with self._lock:
868 self._team_cache[cache_key] = CacheEntry(
869 value=valid,
870 expiry=time.time() + self._team_ttl,
871 )
873 async def get_team_membership_valid(self, user_email: str, team_ids: List[str]) -> Optional[bool]:
874 """Get cached team membership validation result (async).
876 Returns None on cache miss (caller should check DB).
878 Args:
879 user_email: User email
880 team_ids: List of team IDs to validate membership for
882 Returns:
883 - None: Cache miss (caller should query DB)
884 - True: User is valid member of all teams (cached)
885 - False: User is NOT a valid member of all teams (cached)
887 Examples:
888 >>> import asyncio
889 >>> cache = AuthCache()
890 >>> result = asyncio.run(cache.get_team_membership_valid("test@example.com", ["team-1"]))
891 >>> result is None # Cache miss
892 True
893 """
894 if not self._enabled or not team_ids:
895 return None
897 # Create cache key from user + sorted team IDs
898 sorted_teams = ":".join(sorted(team_ids))
899 cache_key = f"{user_email}:{sorted_teams}"
901 # Check L1 in-memory cache first (no network I/O)
902 entry = self._team_cache.get(cache_key)
903 if entry and not entry.is_expired():
904 self._hit_count += 1
905 return entry.value
907 # Check L2 Redis cache
908 redis = await self._get_redis_client()
909 if redis:
910 try:
911 redis_key = self._get_redis_key("membership", cache_key)
912 data = await redis.get(redis_key)
913 if data is not None:
914 self._hit_count += 1
915 self._redis_hit_count += 1
916 # Stored as "1" for True, "0" for False
917 decoded = data.decode() if isinstance(data, bytes) else data
918 result = decoded == "1"
920 # Write-through: populate L1 from Redis hit
921 with self._lock:
922 self._team_cache[cache_key] = CacheEntry(
923 value=result,
924 expiry=time.time() + self._team_ttl,
925 )
927 return result
928 self._redis_miss_count += 1
929 except Exception as e:
930 logger.warning(f"AuthCache Redis get_team_membership_valid failed: {e}")
932 self._miss_count += 1
933 return None
935 async def set_team_membership_valid(self, user_email: str, team_ids: List[str], valid: bool) -> None:
936 """Store team membership validation result in cache (async).
938 Args:
939 user_email: User email
940 team_ids: List of team IDs that were validated
941 valid: Whether user is a valid member of all teams
943 Examples:
944 >>> import asyncio
945 >>> cache = AuthCache()
946 >>> asyncio.run(cache.set_team_membership_valid("test@example.com", ["team-1"], True))
947 """
948 if not self._enabled or not team_ids:
949 return
951 # Create cache key from user + sorted team IDs
952 sorted_teams = ":".join(sorted(team_ids))
953 cache_key = f"{user_email}:{sorted_teams}"
955 # Store in Redis
956 redis = await self._get_redis_client()
957 if redis:
958 try:
959 redis_key = self._get_redis_key("membership", cache_key)
960 # Store as "1" for True, "0" for False
961 await redis.setex(redis_key, self._team_ttl, "1" if valid else "0")
962 except Exception as e:
963 logger.warning(f"AuthCache Redis set_team_membership_valid failed: {e}")
965 # Store in in-memory cache
966 with self._lock:
967 self._team_cache[cache_key] = CacheEntry(
968 value=valid,
969 expiry=time.time() + self._team_ttl,
970 )
972 async def invalidate_team_membership(self, user_email: str) -> None:
973 """Invalidate team membership cache for a user.
975 Call this when a user's team membership changes (add/remove member,
976 role change, deactivation).
978 Args:
979 user_email: User email whose membership cache should be invalidated
981 Examples:
982 >>> import asyncio
983 >>> cache = AuthCache()
984 >>> asyncio.run(cache.invalidate_team_membership("test@example.com"))
985 """
986 logger.debug(f"AuthCache: Invalidating team membership cache for {user_email}")
988 # Clear in-memory cache entries for this user
989 with self._lock:
990 keys_to_remove = [k for k in self._team_cache if k.startswith(f"{user_email}:")]
991 for key in keys_to_remove:
992 self._team_cache.pop(key, None)
994 # Clear Redis
995 redis = await self._get_redis_client()
996 if redis:
997 try:
998 # Pattern match all membership keys for this user
999 pattern = self._get_redis_key("membership", f"{user_email}:*")
1000 async for key in redis.scan_iter(match=pattern):
1001 await redis.delete(key)
1002 # Publish invalidation for other workers
1003 await redis.publish("mcpgw:auth:invalidate", f"membership:{user_email}")
1004 except Exception as e:
1005 logger.warning(f"AuthCache Redis invalidate_team_membership failed: {e}")
1007 async def is_token_revoked(self, jti: str) -> Optional[bool]:
1008 """Check if a token is revoked (cached check only).
1010 Returns None on cache miss (caller should check DB).
1012 Args:
1013 jti: JWT ID to check
1015 Returns:
1016 True if revoked, False if not revoked, None if unknown
1018 Examples:
1019 >>> import asyncio
1020 >>> cache = AuthCache()
1021 >>> cache._revoked_jtis.add("revoked-jti")
1022 >>> asyncio.run(cache.is_token_revoked("revoked-jti"))
1023 True
1024 """
1025 if not self._enabled:
1026 return None
1028 # Fast local check
1029 if jti in self._revoked_jtis:
1030 return True
1032 # Check L1 in-memory revocation cache
1033 entry = self._revocation_cache.get(jti)
1034 if entry and not entry.is_expired():
1035 return entry.value
1037 # Check L2 Redis
1038 redis = await self._get_redis_client()
1039 if redis:
1040 try:
1041 # Check revoked tokens set
1042 if await redis.sismember("mcpgw:auth:revoked_tokens", jti):
1043 # Add to local set for faster future lookups
1044 with self._lock:
1045 self._revoked_jtis.add(jti)
1046 # Write-through: populate L1 revocation cache
1047 self._revocation_cache[jti] = CacheEntry(
1048 value=True,
1049 expiry=time.time() + self._revocation_ttl,
1050 )
1051 return True
1053 # Check individual revocation key
1054 if await redis.exists(self._get_redis_key("revoke", jti)):
1055 with self._lock:
1056 self._revoked_jtis.add(jti)
1057 # Write-through: populate L1 revocation cache
1058 self._revocation_cache[jti] = CacheEntry(
1059 value=True,
1060 expiry=time.time() + self._revocation_ttl,
1061 )
1062 return True
1063 except Exception as e:
1064 logger.warning(f"AuthCache Redis is_token_revoked failed: {e}")
1066 return None
1068 async def sync_revoked_tokens(self) -> None:
1069 """Sync revoked tokens from database to cache on startup.
1071 Should be called during application startup to populate the
1072 revocation cache.
1074 Examples:
1075 >>> import asyncio
1076 >>> cache = AuthCache()
1077 >>> asyncio.run(cache.sync_revoked_tokens())
1078 """
1079 if not self._enabled:
1080 return
1082 try:
1083 # Third-Party
1084 from sqlalchemy import select # pylint: disable=import-outside-toplevel
1086 # First-Party
1087 from mcpgateway.db import fresh_db_session, TokenRevocation # pylint: disable=import-outside-toplevel
1089 def _load_revoked_jtis() -> Set[str]:
1090 """Load all revoked JTIs from database.
1092 Returns:
1093 Set of revoked JTI strings.
1094 """
1095 with fresh_db_session() as db:
1096 result = db.execute(select(TokenRevocation.jti))
1097 return {row[0] for row in result}
1099 jtis = await asyncio.to_thread(_load_revoked_jtis)
1101 with self._lock:
1102 self._revoked_jtis.update(jtis)
1104 # Also sync to Redis
1105 redis = await self._get_redis_client()
1106 if redis and jtis:
1107 try:
1108 await redis.sadd("mcpgw:auth:revoked_tokens", *jtis)
1109 except Exception as e:
1110 logger.warning(f"AuthCache Redis sync_revoked_tokens failed: {e}")
1112 logger.info(f"AuthCache: Synced {len(jtis)} revoked tokens to cache")
1114 except Exception as e:
1115 logger.warning(f"AuthCache sync_revoked_tokens failed: {e}")
1117 def invalidate_all(self) -> None:
1118 """Invalidate all cached data.
1120 Call during testing or when major configuration changes.
1122 Examples:
1123 >>> cache = AuthCache()
1124 >>> cache.invalidate_all()
1125 """
1126 with self._lock:
1127 self._user_cache.clear()
1128 self._team_cache.clear()
1129 self._revocation_cache.clear()
1130 self._context_cache.clear()
1131 self._role_cache.clear()
1132 self._teams_list_cache.clear()
1133 # Don't clear _revoked_jtis as those are confirmed revocations
1135 logger.info("AuthCache: All caches invalidated")
1137 def stats(self) -> Dict[str, Any]:
1138 """Get cache statistics.
1140 Returns:
1141 Dictionary with hit/miss counts and hit rate
1143 Examples:
1144 >>> cache = AuthCache()
1145 >>> stats = cache.stats()
1146 >>> "hit_count" in stats
1147 True
1148 """
1149 total = self._hit_count + self._miss_count
1150 redis_total = self._redis_hit_count + self._redis_miss_count
1152 return {
1153 "enabled": self._enabled,
1154 "hit_count": self._hit_count,
1155 "miss_count": self._miss_count,
1156 "hit_rate": self._hit_count / total if total > 0 else 0.0,
1157 "redis_hit_count": self._redis_hit_count,
1158 "redis_miss_count": self._redis_miss_count,
1159 "redis_hit_rate": self._redis_hit_count / redis_total if redis_total > 0 else 0.0,
1160 "redis_available": self._redis_available,
1161 "revoked_tokens_cached": len(self._revoked_jtis),
1162 "context_cache_size": len(self._context_cache),
1163 "role_cache_size": len(self._role_cache),
1164 "teams_list_cache_size": len(self._teams_list_cache),
1165 "team_membership_cache_size": len(self._team_cache),
1166 "user_ttl": self._user_ttl,
1167 "revocation_ttl": self._revocation_ttl,
1168 "team_ttl": self._team_ttl,
1169 "role_ttl": self._role_ttl,
1170 "teams_list_enabled": self._teams_list_enabled,
1171 "teams_list_ttl": self._teams_list_ttl,
1172 }
1174 def reset_stats(self) -> None:
1175 """Reset hit/miss counters.
1177 Examples:
1178 >>> cache = AuthCache()
1179 >>> cache._hit_count = 100
1180 >>> cache.reset_stats()
1181 >>> cache._hit_count
1182 0
1183 """
1184 self._hit_count = 0
1185 self._miss_count = 0
1186 self._redis_hit_count = 0
1187 self._redis_miss_count = 0
1190# Global singleton instance
1191_auth_cache: Optional[AuthCache] = None
1194def get_auth_cache() -> AuthCache:
1195 """Get or create the singleton AuthCache instance.
1197 Returns:
1198 AuthCache: The singleton auth cache instance
1200 Examples:
1201 >>> cache = get_auth_cache()
1202 >>> isinstance(cache, AuthCache)
1203 True
1204 """
1205 global _auth_cache # pylint: disable=global-statement
1206 if _auth_cache is None:
1207 _auth_cache = AuthCache()
1208 return _auth_cache
1211# Convenience alias for direct import
1212auth_cache = get_auth_cache()