Coverage for mcpgateway / cache / auth_cache.py: 100%
455 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/cache/auth_cache.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
6Authentication Data Cache.
8This module implements a thread-safe two-tier cache for authentication data.
9L1 (in-memory) is checked first for lowest latency, with L2 (Redis) as a
10shared distributed cache. It caches user data, team memberships, and token
11revocation status to reduce database queries during authentication.
13Performance Impact:
14 - Before: 3-4 DB queries per authenticated request
15 - After: 0-1 DB queries (cache hit) per TTL period
17Security Considerations:
18 - Short TTLs for revocation data (30s default) to limit exposure window
19 - Cache invalidation on token revocation, user update, team change
20 - JWT payloads are NOT cached (security risk)
21 - Graceful fallback to DB on cache failure
23Examples:
24 >>> from mcpgateway.cache.auth_cache import auth_cache
25 >>> # Cache is used automatically by get_current_user()
26 >>> # Manual invalidation after user update:
27 >>> import asyncio
28 >>> # asyncio.run(auth_cache.invalidate_user("user@example.com"))
29"""
31# Standard
32import asyncio
33from dataclasses import dataclass
34import logging
35import threading
36import time
37from typing import Any, Dict, List, Optional, Set
39logger = logging.getLogger(__name__)
41# Sentinel value to represent "user is not a member" in Redis cache
42# This allows distinguishing between "not a member" (cached) and "cache miss"
43_NOT_A_MEMBER_SENTINEL = "__NOT_A_MEMBER__"
46@dataclass
47class CachedAuthContext:
48 """Cached authentication context from batched DB query.
50 This dataclass holds user data, team membership, and revocation status
51 retrieved from a single database roundtrip.
53 Attributes:
54 user: User data dict (email, is_admin, is_active, etc.) or None
55 personal_team_id: User's personal team ID or None
56 is_token_revoked: Whether the JWT is revoked
58 Examples:
59 >>> ctx = CachedAuthContext(
60 ... user={"email": "test@example.com", "is_admin": False},
61 ... personal_team_id="team-123",
62 ... is_token_revoked=False
63 ... )
64 >>> ctx.is_token_revoked
65 False
66 """
68 user: Optional[Dict[str, Any]] = None
69 personal_team_id: Optional[str] = None
70 is_token_revoked: bool = False
73@dataclass
74class CacheEntry:
75 """Cache entry with value and expiry timestamp.
77 Examples:
78 >>> import time
79 >>> entry = CacheEntry(value={"key": "value"}, expiry=time.time() + 60)
80 >>> entry.is_expired()
81 False
82 """
84 value: Any
85 expiry: float
87 def is_expired(self) -> bool:
88 """Check if this cache entry has expired.
90 Returns:
91 bool: True if the entry has expired, False otherwise.
92 """
93 return time.time() >= self.expiry
96class AuthCache:
97 """Thread-safe two-tier authentication cache (L1 in-memory + L2 Redis).
99 This cache reduces database load during authentication by caching:
100 - User data (email, is_admin, is_active, etc.)
101 - Personal team ID for the user
102 - Token revocation status
104 Cache lookup checks L1 (in-memory) first for lowest latency, then L2
105 (Redis) for distributed consistency. Redis hits are written through
106 to L1 for subsequent requests.
108 Attributes:
109 user_ttl: TTL in seconds for user data cache (default: 60)
110 revocation_ttl: TTL in seconds for revocation cache (default: 30)
111 team_ttl: TTL in seconds for team cache (default: 60)
113 Examples:
114 >>> cache = AuthCache(user_ttl=60, revocation_ttl=30)
115 >>> cache.stats()["hit_count"]
116 0
117 """
119 _NOT_CACHED = object()
121 def __init__(
122 self,
123 user_ttl: Optional[int] = None,
124 revocation_ttl: Optional[int] = None,
125 team_ttl: Optional[int] = None,
126 role_ttl: Optional[int] = None,
127 enabled: Optional[bool] = None,
128 ):
129 """Initialize the auth cache.
131 Args:
132 user_ttl: TTL for user data cache in seconds (default: from settings or 60)
133 revocation_ttl: TTL for revocation cache in seconds (default: from settings or 30)
134 team_ttl: TTL for team cache in seconds (default: from settings or 60)
135 role_ttl: TTL for role cache in seconds (default: from settings or 60)
136 enabled: Whether caching is enabled (default: from settings or True)
138 Examples:
139 >>> cache = AuthCache(user_ttl=120, revocation_ttl=30)
140 >>> cache._user_ttl
141 120
142 """
143 # Import settings lazily to avoid circular imports
144 try:
145 # First-Party
146 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
148 self._user_ttl = user_ttl or getattr(settings, "auth_cache_user_ttl", 60)
149 self._revocation_ttl = revocation_ttl or getattr(settings, "auth_cache_revocation_ttl", 30)
150 self._team_ttl = team_ttl or getattr(settings, "auth_cache_team_ttl", 60)
151 self._role_ttl = role_ttl or getattr(settings, "auth_cache_role_ttl", 60)
152 self._teams_list_ttl = getattr(settings, "auth_cache_teams_ttl", 60)
153 self._teams_list_enabled = getattr(settings, "auth_cache_teams_enabled", True)
154 self._enabled = enabled if enabled is not None else getattr(settings, "auth_cache_enabled", True)
155 self._cache_prefix = getattr(settings, "cache_prefix", "mcpgw:")
156 except ImportError:
157 self._user_ttl = user_ttl or 60
158 self._revocation_ttl = revocation_ttl or 30
159 self._team_ttl = team_ttl or 60
160 self._role_ttl = role_ttl or 60
161 self._teams_list_ttl = 60
162 self._teams_list_enabled = True
163 self._enabled = enabled if enabled is not None else True
164 self._cache_prefix = "mcpgw:"
166 # In-memory cache (fallback when Redis unavailable)
167 self._user_cache: Dict[str, CacheEntry] = {}
168 self._team_cache: Dict[str, CacheEntry] = {}
169 self._revocation_cache: Dict[str, CacheEntry] = {}
170 self._context_cache: Dict[str, CacheEntry] = {}
171 self._role_cache: Dict[str, CacheEntry] = {}
172 self._teams_list_cache: Dict[str, CacheEntry] = {}
174 # Known revoked tokens (fast local lookup)
175 self._revoked_jtis: Set[str] = set()
177 # Thread safety
178 self._lock = threading.Lock()
180 # Redis availability (None = not checked yet)
181 self._redis_checked = False
182 self._redis_available = False
184 # Statistics
185 self._hit_count = 0
186 self._miss_count = 0
187 self._redis_hit_count = 0
188 self._redis_miss_count = 0
190 logger.info(
191 f"AuthCache initialized: enabled={self._enabled}, "
192 f"user_ttl={self._user_ttl}s, revocation_ttl={self._revocation_ttl}s, "
193 f"team_ttl={self._team_ttl}s, role_ttl={self._role_ttl}s, "
194 f"teams_list_enabled={self._teams_list_enabled}, teams_list_ttl={self._teams_list_ttl}s"
195 )
197 def _get_redis_key(self, key_type: str, identifier: str) -> str:
198 """Generate Redis key with proper prefix.
200 Args:
201 key_type: Type of cache entry (user, team, revoke, ctx)
202 identifier: Unique identifier (email, jti, etc.)
204 Returns:
205 Full Redis key with prefix
207 Examples:
208 >>> cache = AuthCache()
209 >>> cache._get_redis_key("user", "test@example.com")
210 'mcpgw:auth:user:test@example.com'
211 """
212 return f"{self._cache_prefix}auth:{key_type}:{identifier}"
214 async def _get_redis_client(self):
215 """Get Redis client if available.
217 Returns:
218 Redis client or None if unavailable
219 """
220 try:
221 # First-Party
222 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
224 client = await get_redis_client()
225 if client and not self._redis_checked:
226 self._redis_checked = True
227 self._redis_available = True
228 logger.debug("AuthCache: Redis client available")
229 return client
230 except Exception as e:
231 if not self._redis_checked:
232 self._redis_checked = True
233 self._redis_available = False
234 logger.debug(f"AuthCache: Redis unavailable, using in-memory cache: {e}")
235 return None
237 async def get_auth_context(
238 self,
239 email: str,
240 jti: Optional[str] = None,
241 ) -> Optional[CachedAuthContext]:
242 """Get cached authentication context.
244 Checks cache for user data, team membership, and revocation status.
245 Returns None on cache miss.
247 Args:
248 email: User email address
249 jti: JWT ID for revocation check (optional)
251 Returns:
252 CachedAuthContext if found in cache, None otherwise
254 Examples:
255 >>> import asyncio
256 >>> cache = AuthCache()
257 >>> result = asyncio.run(cache.get_auth_context("test@example.com"))
258 >>> result is None # Cache miss on fresh cache
259 True
260 """
261 if not self._enabled:
262 return None
264 # Check for known revoked token first (fast local check)
265 if jti and jti in self._revoked_jtis:
266 self._hit_count += 1
267 return CachedAuthContext(is_token_revoked=True)
269 # Cross-worker revocation check: when a token is revoked on another worker,
270 # the revoking worker writes a Redis revocation marker. Check it BEFORE the
271 # L1 in-memory cache so that stale L1 entries cannot bypass revocation.
272 if jti:
273 redis = await self._get_redis_client()
274 if redis:
275 try:
276 revoke_key = self._get_redis_key("revoke", jti)
277 if await redis.exists(revoke_key):
278 # Promote to local set so subsequent requests skip the Redis call
279 with self._lock:
280 self._revoked_jtis.add(jti)
281 # Evict any stale L1 context entries for this JTI
282 for k in [k for k in self._context_cache if k.endswith(f":{jti}")]:
283 self._context_cache.pop(k, None)
284 self._hit_count += 1
285 return CachedAuthContext(is_token_revoked=True)
286 except Exception as exc:
287 logger.debug(f"AuthCache: Redis revocation check failed for {jti[:8]}: {exc}")
289 cache_key = f"{email}:{jti or 'no-jti'}"
291 # Check L1 in-memory cache first (no network I/O)
292 entry = self._context_cache.get(cache_key)
293 if entry and not entry.is_expired():
294 self._hit_count += 1
295 return entry.value
297 # Check L2 Redis cache
298 redis = await self._get_redis_client()
299 if redis:
300 try:
301 redis_key = self._get_redis_key("ctx", cache_key)
302 data = await redis.get(redis_key)
303 if data:
304 # Third-Party
305 import orjson # pylint: disable=import-outside-toplevel
307 cached = orjson.loads(data)
308 result = CachedAuthContext(
309 user=cached.get("user"),
310 personal_team_id=cached.get("personal_team_id"),
311 is_token_revoked=cached.get("is_token_revoked", False),
312 )
313 self._hit_count += 1
314 self._redis_hit_count += 1
316 # Write-through: populate L1 from Redis hit
317 ttl = min(self._user_ttl, self._revocation_ttl, self._team_ttl)
318 with self._lock:
319 self._context_cache[cache_key] = CacheEntry(
320 value=result,
321 expiry=time.time() + ttl,
322 )
324 return result
325 self._redis_miss_count += 1
326 except Exception as e:
327 logger.warning(f"AuthCache Redis get failed: {e}")
329 self._miss_count += 1
330 return None
332 async def set_auth_context(
333 self,
334 email: str,
335 jti: Optional[str],
336 context: CachedAuthContext,
337 ) -> None:
338 """Store authentication context in cache.
340 Stores in both Redis (if available) and in-memory cache.
342 Args:
343 email: User email address
344 jti: JWT ID (optional)
345 context: Authentication context to cache
347 Examples:
348 >>> import asyncio
349 >>> cache = AuthCache()
350 >>> ctx = CachedAuthContext(
351 ... user={"email": "test@example.com"},
352 ... personal_team_id="team-1",
353 ... is_token_revoked=False
354 ... )
355 >>> asyncio.run(cache.set_auth_context("test@example.com", "jti-123", ctx))
356 """
357 if not self._enabled:
358 return
360 cache_key = f"{email}:{jti or 'no-jti'}"
362 # Use shortest TTL for combined context
363 ttl = min(self._user_ttl, self._revocation_ttl, self._team_ttl)
365 # Prepare data for serialization
366 data = {
367 "user": context.user,
368 "personal_team_id": context.personal_team_id,
369 "is_token_revoked": context.is_token_revoked,
370 }
372 # Store in Redis
373 redis = await self._get_redis_client()
374 if redis:
375 try:
376 # Third-Party
377 import orjson # pylint: disable=import-outside-toplevel
379 redis_key = self._get_redis_key("ctx", cache_key)
380 await redis.setex(redis_key, ttl, orjson.dumps(data))
381 except Exception as e:
382 logger.warning(f"AuthCache Redis set failed: {e}")
384 # Store in in-memory cache
385 with self._lock:
386 self._context_cache[cache_key] = CacheEntry(
387 value=context,
388 expiry=time.time() + ttl,
389 )
391 async def invalidate_user(self, email: str) -> None:
392 """Invalidate cached data for a user.
394 Call this when user data changes (password, profile, etc.).
396 Args:
397 email: User email to invalidate
399 Examples:
400 >>> import asyncio
401 >>> cache = AuthCache()
402 >>> asyncio.run(cache.invalidate_user("test@example.com"))
403 """
404 logger.debug(f"AuthCache: Invalidating user cache for {email}")
406 # Clear in-memory caches
407 with self._lock:
408 # Clear any context cache entries for this user
409 keys_to_remove = [k for k in self._context_cache if k.startswith(f"{email}:")]
410 for key in keys_to_remove:
411 self._context_cache.pop(key, None)
413 self._user_cache.pop(email, None)
415 # Clear team membership cache entries (keys are email:team_ids)
416 team_keys_to_remove = [k for k in self._team_cache if k.startswith(f"{email}:")]
417 for key in team_keys_to_remove:
418 self._team_cache.pop(key, None)
420 # Clear Redis
421 redis = await self._get_redis_client()
422 if redis:
423 try:
424 # Delete user-specific keys
425 await redis.delete(
426 self._get_redis_key("user", email),
427 self._get_redis_key("team", email),
428 )
429 # Delete context keys (pattern match)
430 pattern = self._get_redis_key("ctx", f"{email}:*")
431 async for key in redis.scan_iter(match=pattern):
432 await redis.delete(key)
433 # Delete membership keys (pattern match)
434 membership_pattern = self._get_redis_key("membership", f"{email}:*")
435 async for key in redis.scan_iter(match=membership_pattern):
436 await redis.delete(key)
438 # Publish invalidation for other workers
439 await redis.publish("mcpgw:auth:invalidate", f"user:{email}")
440 except Exception as e:
441 logger.warning(f"AuthCache Redis invalidate_user failed: {e}")
443 async def invalidate_revocation(self, jti: str) -> None:
444 """Invalidate cache for a revoked token.
446 Call this when a token is revoked.
448 Args:
449 jti: JWT ID of revoked token
451 Examples:
452 >>> import asyncio
453 >>> cache = AuthCache()
454 >>> asyncio.run(cache.invalidate_revocation("jti-123"))
455 """
456 logger.debug(f"AuthCache: Invalidating revocation cache for jti={jti[:8]}...")
458 # Add to local revoked set for fast lookup
459 with self._lock:
460 self._revoked_jtis.add(jti)
461 self._revocation_cache.pop(jti, None)
463 # Clear any context cache entries with this JTI
464 keys_to_remove = [k for k in self._context_cache if k.endswith(f":{jti}")]
465 for key in keys_to_remove:
466 self._context_cache.pop(key, None)
468 # Update Redis
469 redis = await self._get_redis_client()
470 if redis:
471 try:
472 # Mark as revoked in Redis
473 await redis.setex(
474 self._get_redis_key("revoke", jti),
475 86400, # 24 hour expiry for revocation markers
476 "1",
477 )
478 # Add to revoked tokens set
479 await redis.sadd("mcpgw:auth:revoked_tokens", jti)
481 # Delete any cached contexts with this JTI
482 pattern = self._get_redis_key("ctx", f"*:{jti}")
483 async for key in redis.scan_iter(match=pattern):
484 await redis.delete(key)
486 # Publish invalidation for other workers
487 await redis.publish("mcpgw:auth:invalidate", f"revoke:{jti}")
488 except Exception as e:
489 logger.warning(f"AuthCache Redis invalidate_revocation failed: {e}")
491 async def invalidate_team(self, email: str) -> None:
492 """Invalidate team cache for a user.
494 Call this when team membership changes.
496 Args:
497 email: User email whose team changed
499 Examples:
500 >>> import asyncio
501 >>> cache = AuthCache()
502 >>> asyncio.run(cache.invalidate_team("test@example.com"))
503 """
504 logger.debug(f"AuthCache: Invalidating team cache for {email}")
506 # Clear in-memory caches
507 with self._lock:
508 self._team_cache.pop(email, None)
509 # Clear context cache entries for this user
510 keys_to_remove = [k for k in self._context_cache if k.startswith(f"{email}:")]
511 for key in keys_to_remove:
512 self._context_cache.pop(key, None)
514 # Clear Redis
515 redis = await self._get_redis_client()
516 if redis:
517 try:
518 await redis.delete(self._get_redis_key("team", email))
519 # Delete context keys
520 pattern = self._get_redis_key("ctx", f"{email}:*")
521 async for key in redis.scan_iter(match=pattern):
522 await redis.delete(key)
524 # Publish invalidation
525 await redis.publish("mcpgw:auth:invalidate", f"team:{email}")
526 except Exception as e:
527 logger.warning(f"AuthCache Redis invalidate_team failed: {e}")
529 async def get_user_role(self, email: str, team_id: str) -> Optional[str]:
530 """Get cached user role in a team.
532 Returns:
533 - None: Cache miss (caller should check DB)
534 - "": User is not a member of the team (cached negative result)
535 - Role string: User's role in the team (cached)
537 Args:
538 email: User email
539 team_id: Team ID
541 Examples:
542 >>> import asyncio
543 >>> cache = AuthCache()
544 >>> result = asyncio.run(cache.get_user_role("test@example.com", "team-123"))
545 >>> result is None # Cache miss
546 True
547 """
548 if not self._enabled:
549 return None
551 cache_key = f"{email}:{team_id}"
553 # Check L1 in-memory cache first (no network I/O)
554 entry = self._role_cache.get(cache_key)
555 if entry and not entry.is_expired():
556 self._hit_count += 1
557 # Return empty string for None (not a member) to distinguish from cache miss
558 return "" if entry.value is None else entry.value
560 # Check L2 Redis cache
561 redis = await self._get_redis_client()
562 if redis:
563 try:
564 redis_key = self._get_redis_key("role", cache_key)
565 data = await redis.get(redis_key)
566 if data is not None:
567 self._hit_count += 1
568 self._redis_hit_count += 1
569 # Role is stored as plain string, decode it
570 decoded = data.decode() if isinstance(data, bytes) else data
571 # Convert sentinel to empty string (user is not a member)
572 role_value = "" if decoded == _NOT_A_MEMBER_SENTINEL else decoded
574 # Write-through: populate L1 from Redis hit
575 # Store as None for not-a-member to match existing L1 storage format
576 with self._lock:
577 self._role_cache[cache_key] = CacheEntry(
578 value=None if role_value == "" else role_value,
579 expiry=time.time() + self._role_ttl,
580 )
582 return role_value
583 self._redis_miss_count += 1
584 except Exception as e:
585 logger.warning(f"AuthCache Redis get_user_role failed: {e}")
587 self._miss_count += 1
588 return None
590 async def set_user_role(self, email: str, team_id: str, role: Optional[str]) -> None:
591 """Store user role in cache.
593 Args:
594 email: User email
595 team_id: Team ID
596 role: User's role in the team (or None if not a member)
598 Examples:
599 >>> import asyncio
600 >>> cache = AuthCache()
601 >>> asyncio.run(cache.set_user_role("test@example.com", "team-123", "admin"))
602 """
603 if not self._enabled:
604 return
606 cache_key = f"{email}:{team_id}"
607 # Store None as sentinel value to distinguish "not a member" from cache miss
608 role_value = role if role is not None else _NOT_A_MEMBER_SENTINEL
610 # Store in Redis
611 redis = await self._get_redis_client()
612 if redis:
613 try:
614 redis_key = self._get_redis_key("role", cache_key)
615 await redis.setex(redis_key, self._role_ttl, role_value)
616 except Exception as e:
617 logger.warning(f"AuthCache Redis set_user_role failed: {e}")
619 # Store in in-memory cache
620 with self._lock:
621 self._role_cache[cache_key] = CacheEntry(
622 value=role,
623 expiry=time.time() + self._role_ttl,
624 )
626 async def invalidate_user_role(self, email: str, team_id: str) -> None:
627 """Invalidate cached role for a user in a team.
629 Call this when a user's role changes in a team.
631 Args:
632 email: User email
633 team_id: Team ID
635 Examples:
636 >>> import asyncio
637 >>> cache = AuthCache()
638 >>> asyncio.run(cache.invalidate_user_role("test@example.com", "team-123"))
639 """
640 logger.debug(f"AuthCache: Invalidating role cache for {email} in team {team_id}")
642 cache_key = f"{email}:{team_id}"
644 # Clear in-memory cache
645 with self._lock:
646 self._role_cache.pop(cache_key, None)
648 # Clear Redis
649 redis = await self._get_redis_client()
650 if redis:
651 try:
652 await redis.delete(self._get_redis_key("role", cache_key))
653 # Publish invalidation for other workers
654 await redis.publish("mcpgw:auth:invalidate", f"role:{email}:{team_id}")
655 except Exception as e:
656 logger.warning(f"AuthCache Redis invalidate_user_role failed: {e}")
658 async def invalidate_team_roles(self, team_id: str) -> None:
659 """Invalidate all cached roles for a team.
661 Call this when team membership changes significantly (e.g., team deletion).
663 Args:
664 team_id: Team ID
666 Examples:
667 >>> import asyncio
668 >>> cache = AuthCache()
669 >>> asyncio.run(cache.invalidate_team_roles("team-123"))
670 """
671 logger.debug(f"AuthCache: Invalidating all role caches for team {team_id}")
673 # Clear in-memory cache entries for this team
674 with self._lock:
675 keys_to_remove = [k for k in self._role_cache if k.endswith(f":{team_id}")]
676 for key in keys_to_remove:
677 self._role_cache.pop(key, None)
679 # Clear Redis
680 redis = await self._get_redis_client()
681 if redis:
682 try:
683 # Pattern match all role keys for this team
684 pattern = self._get_redis_key("role", f"*:{team_id}")
685 async for key in redis.scan_iter(match=pattern):
686 await redis.delete(key)
687 # Publish invalidation
688 await redis.publish("mcpgw:auth:invalidate", f"team_roles:{team_id}")
689 except Exception as e:
690 logger.warning(f"AuthCache Redis invalidate_team_roles failed: {e}")
692 async def get_user_teams(self, cache_key: str) -> Optional[List[str]]:
693 """Get cached team IDs for a user.
695 The cache_key should be in the format "email:include_personal" to
696 distinguish between calls with different include_personal flags.
698 Returns:
699 - None: Cache miss (caller should query DB)
700 - Empty list: User has no teams (cached result)
701 - List of team IDs: Cached team IDs
703 Args:
704 cache_key: Cache key in format "email:include_personal"
706 Examples:
707 >>> import asyncio
708 >>> cache = AuthCache()
709 >>> result = asyncio.run(cache.get_user_teams("test@example.com:True"))
710 >>> result is None # Cache miss
711 True
712 """
713 if not self._enabled or not self._teams_list_enabled:
714 return None
716 # Check L1 in-memory cache first (no network I/O)
717 entry = self._teams_list_cache.get(cache_key)
718 if entry and not entry.is_expired():
719 self._hit_count += 1
720 return entry.value
722 # Check L2 Redis cache
723 redis = await self._get_redis_client()
724 if redis:
725 try:
726 redis_key = self._get_redis_key("teams", cache_key)
727 data = await redis.get(redis_key)
728 if data is not None:
729 self._hit_count += 1
730 self._redis_hit_count += 1
731 # Third-Party
732 import orjson # pylint: disable=import-outside-toplevel
734 team_ids = orjson.loads(data)
736 # Write-through: populate L1 from Redis hit
737 with self._lock:
738 self._teams_list_cache[cache_key] = CacheEntry(
739 value=team_ids,
740 expiry=time.time() + self._teams_list_ttl,
741 )
743 return team_ids
744 self._redis_miss_count += 1
745 except Exception as e:
746 logger.warning(f"AuthCache Redis get_user_teams failed: {e}")
748 self._miss_count += 1
749 return None
751 async def set_user_teams(self, cache_key: str, team_ids: List[str]) -> None:
752 """Store team IDs for a user in cache.
754 Args:
755 cache_key: Cache key in format "email:include_personal"
756 team_ids: List of team IDs the user belongs to
758 Examples:
759 >>> import asyncio
760 >>> cache = AuthCache()
761 >>> asyncio.run(cache.set_user_teams("test@example.com:True", ["team-1", "team-2"]))
762 """
763 if not self._enabled or not self._teams_list_enabled:
764 return
766 # Store in Redis
767 redis = await self._get_redis_client()
768 if redis:
769 try:
770 # Third-Party
771 import orjson # pylint: disable=import-outside-toplevel
773 redis_key = self._get_redis_key("teams", cache_key)
774 await redis.setex(redis_key, self._teams_list_ttl, orjson.dumps(team_ids))
775 except Exception as e:
776 logger.warning(f"AuthCache Redis set_user_teams failed: {e}")
778 # Store in in-memory cache
779 with self._lock:
780 self._teams_list_cache[cache_key] = CacheEntry(
781 value=team_ids,
782 expiry=time.time() + self._teams_list_ttl,
783 )
785 async def invalidate_user_teams(self, email: str) -> None:
786 """Invalidate cached teams list for a user.
788 Call this when a user's team membership changes (add/remove member,
789 delete team, approve join request).
791 This invalidates both include_personal=True and include_personal=False
792 cache entries for the user.
794 Args:
795 email: User email whose teams cache should be invalidated
797 Examples:
798 >>> import asyncio
799 >>> cache = AuthCache()
800 >>> asyncio.run(cache.invalidate_user_teams("test@example.com"))
801 """
802 logger.debug(f"AuthCache: Invalidating teams list cache for {email}")
804 # Clear in-memory cache entries for this user (both True and False variants)
805 with self._lock:
806 keys_to_remove = [k for k in self._teams_list_cache if k.startswith(f"{email}:")]
807 for key in keys_to_remove:
808 self._teams_list_cache.pop(key, None)
810 # Clear Redis
811 redis = await self._get_redis_client()
812 if redis:
813 try:
814 # Delete both variants
815 await redis.delete(
816 self._get_redis_key("teams", f"{email}:True"),
817 self._get_redis_key("teams", f"{email}:False"),
818 )
819 # Publish invalidation for other workers
820 await redis.publish("mcpgw:auth:invalidate", f"teams:{email}")
821 except Exception as e:
822 logger.warning(f"AuthCache Redis invalidate_user_teams failed: {e}")
824 # =========================================================================
825 # Team Membership Validation Cache
826 # =========================================================================
827 # Used by TokenScopingMiddleware to cache email_team_members lookups.
828 # This prevents repeated DB queries for the same user+teams combination.
830 def get_team_membership_valid_sync(self, user_email: str, team_ids: List[str]) -> Optional[bool]:
831 """Get cached team membership validation result (synchronous).
833 This is the synchronous version used by token_scoping middleware.
834 Returns None on cache miss (caller should check DB).
836 Args:
837 user_email: User email
838 team_ids: List of team IDs to validate membership for
840 Returns:
841 - None: Cache miss (caller should query DB)
842 - True: User is valid member of all teams (cached)
843 - False: User is NOT a valid member of all teams (cached)
845 Examples:
846 >>> cache = AuthCache()
847 >>> result = cache.get_team_membership_valid_sync("test@example.com", ["team-1"])
848 >>> result is None # Cache miss
849 True
850 """
851 if not self._enabled or not team_ids:
852 return None
854 # Create cache key from user + sorted team IDs
855 sorted_teams = ":".join(sorted(team_ids))
856 cache_key = f"{user_email}:{sorted_teams}"
858 # Check in-memory cache only (sync version)
859 entry = self._team_cache.get(cache_key)
860 if entry and not entry.is_expired():
861 self._hit_count += 1
862 return entry.value
864 self._miss_count += 1
865 return None
867 def set_team_membership_valid_sync(self, user_email: str, team_ids: List[str], valid: bool) -> None:
868 """Store team membership validation result in cache (synchronous).
870 Args:
871 user_email: User email
872 team_ids: List of team IDs that were validated
873 valid: Whether user is a valid member of all teams
875 Examples:
876 >>> cache = AuthCache()
877 >>> cache.set_team_membership_valid_sync("test@example.com", ["team-1"], True)
878 """
879 if not self._enabled or not team_ids:
880 return
882 # Create cache key from user + sorted team IDs
883 sorted_teams = ":".join(sorted(team_ids))
884 cache_key = f"{user_email}:{sorted_teams}"
886 # Store in in-memory cache
887 with self._lock:
888 self._team_cache[cache_key] = CacheEntry(
889 value=valid,
890 expiry=time.time() + self._team_ttl,
891 )
893 async def get_team_membership_valid(self, user_email: str, team_ids: List[str]) -> Optional[bool]:
894 """Get cached team membership validation result (async).
896 Returns None on cache miss (caller should check DB).
898 Args:
899 user_email: User email
900 team_ids: List of team IDs to validate membership for
902 Returns:
903 - None: Cache miss (caller should query DB)
904 - True: User is valid member of all teams (cached)
905 - False: User is NOT a valid member of all teams (cached)
907 Examples:
908 >>> import asyncio
909 >>> cache = AuthCache()
910 >>> result = asyncio.run(cache.get_team_membership_valid("test@example.com", ["team-1"]))
911 >>> result is None # Cache miss
912 True
913 """
914 if not self._enabled or not team_ids:
915 return None
917 # Create cache key from user + sorted team IDs
918 sorted_teams = ":".join(sorted(team_ids))
919 cache_key = f"{user_email}:{sorted_teams}"
921 # Check L1 in-memory cache first (no network I/O)
922 entry = self._team_cache.get(cache_key)
923 if entry and not entry.is_expired():
924 self._hit_count += 1
925 return entry.value
927 # Check L2 Redis cache
928 redis = await self._get_redis_client()
929 if redis:
930 try:
931 redis_key = self._get_redis_key("membership", cache_key)
932 data = await redis.get(redis_key)
933 if data is not None:
934 self._hit_count += 1
935 self._redis_hit_count += 1
936 # Stored as "1" for True, "0" for False
937 decoded = data.decode() if isinstance(data, bytes) else data
938 result = decoded == "1"
940 # Write-through: populate L1 from Redis hit
941 with self._lock:
942 self._team_cache[cache_key] = CacheEntry(
943 value=result,
944 expiry=time.time() + self._team_ttl,
945 )
947 return result
948 self._redis_miss_count += 1
949 except Exception as e:
950 logger.warning(f"AuthCache Redis get_team_membership_valid failed: {e}")
952 self._miss_count += 1
953 return None
955 async def set_team_membership_valid(self, user_email: str, team_ids: List[str], valid: bool) -> None:
956 """Store team membership validation result in cache (async).
958 Args:
959 user_email: User email
960 team_ids: List of team IDs that were validated
961 valid: Whether user is a valid member of all teams
963 Examples:
964 >>> import asyncio
965 >>> cache = AuthCache()
966 >>> asyncio.run(cache.set_team_membership_valid("test@example.com", ["team-1"], True))
967 """
968 if not self._enabled or not team_ids:
969 return
971 # Create cache key from user + sorted team IDs
972 sorted_teams = ":".join(sorted(team_ids))
973 cache_key = f"{user_email}:{sorted_teams}"
975 # Store in Redis
976 redis = await self._get_redis_client()
977 if redis:
978 try:
979 redis_key = self._get_redis_key("membership", cache_key)
980 # Store as "1" for True, "0" for False
981 await redis.setex(redis_key, self._team_ttl, "1" if valid else "0")
982 except Exception as e:
983 logger.warning(f"AuthCache Redis set_team_membership_valid failed: {e}")
985 # Store in in-memory cache
986 with self._lock:
987 self._team_cache[cache_key] = CacheEntry(
988 value=valid,
989 expiry=time.time() + self._team_ttl,
990 )
992 async def invalidate_team_membership(self, user_email: str) -> None:
993 """Invalidate team membership cache for a user.
995 Call this when a user's team membership changes (add/remove member,
996 role change, deactivation).
998 Args:
999 user_email: User email whose membership cache should be invalidated
1001 Examples:
1002 >>> import asyncio
1003 >>> cache = AuthCache()
1004 >>> asyncio.run(cache.invalidate_team_membership("test@example.com"))
1005 """
1006 logger.debug(f"AuthCache: Invalidating team membership cache for {user_email}")
1008 # Clear in-memory cache entries for this user
1009 with self._lock:
1010 keys_to_remove = [k for k in self._team_cache if k.startswith(f"{user_email}:")]
1011 for key in keys_to_remove:
1012 self._team_cache.pop(key, None)
1014 # Clear Redis
1015 redis = await self._get_redis_client()
1016 if redis:
1017 try:
1018 # Pattern match all membership keys for this user
1019 pattern = self._get_redis_key("membership", f"{user_email}:*")
1020 async for key in redis.scan_iter(match=pattern):
1021 await redis.delete(key)
1022 # Publish invalidation for other workers
1023 await redis.publish("mcpgw:auth:invalidate", f"membership:{user_email}")
1024 except Exception as e:
1025 logger.warning(f"AuthCache Redis invalidate_team_membership failed: {e}")
1027 async def is_token_revoked(self, jti: str) -> Optional[bool]:
1028 """Check if a token is revoked (cached check only).
1030 Returns None on cache miss (caller should check DB).
1032 Args:
1033 jti: JWT ID to check
1035 Returns:
1036 True if revoked, False if not revoked, None if unknown
1038 Examples:
1039 >>> import asyncio
1040 >>> cache = AuthCache()
1041 >>> cache._revoked_jtis.add("revoked-jti")
1042 >>> asyncio.run(cache.is_token_revoked("revoked-jti"))
1043 True
1044 """
1045 if not self._enabled:
1046 return None
1048 # Fast local check
1049 if jti in self._revoked_jtis:
1050 return True
1052 # Check L1 in-memory revocation cache
1053 entry = self._revocation_cache.get(jti)
1054 if entry and not entry.is_expired():
1055 return entry.value
1057 # Check L2 Redis
1058 redis = await self._get_redis_client()
1059 if redis:
1060 try:
1061 # Check revoked tokens set
1062 if await redis.sismember("mcpgw:auth:revoked_tokens", jti):
1063 # Add to local set for faster future lookups
1064 with self._lock:
1065 self._revoked_jtis.add(jti)
1066 # Write-through: populate L1 revocation cache
1067 self._revocation_cache[jti] = CacheEntry(
1068 value=True,
1069 expiry=time.time() + self._revocation_ttl,
1070 )
1071 return True
1073 # Check individual revocation key
1074 if await redis.exists(self._get_redis_key("revoke", jti)):
1075 with self._lock:
1076 self._revoked_jtis.add(jti)
1077 # Write-through: populate L1 revocation cache
1078 self._revocation_cache[jti] = CacheEntry(
1079 value=True,
1080 expiry=time.time() + self._revocation_ttl,
1081 )
1082 return True
1083 except Exception as e:
1084 logger.warning(f"AuthCache Redis is_token_revoked failed: {e}")
1086 return None
1088 async def sync_revoked_tokens(self) -> None:
1089 """Sync revoked tokens from database to cache on startup.
1091 Should be called during application startup to populate the
1092 revocation cache.
1094 Examples:
1095 >>> import asyncio
1096 >>> cache = AuthCache()
1097 >>> asyncio.run(cache.sync_revoked_tokens())
1098 """
1099 if not self._enabled:
1100 return
1102 try:
1103 # Third-Party
1104 from sqlalchemy import select # pylint: disable=import-outside-toplevel
1106 # First-Party
1107 from mcpgateway.db import fresh_db_session, TokenRevocation # pylint: disable=import-outside-toplevel
1109 def _load_revoked_jtis() -> Set[str]:
1110 """Load all revoked JTIs from database.
1112 Returns:
1113 Set of revoked JTI strings.
1114 """
1115 with fresh_db_session() as db:
1116 result = db.execute(select(TokenRevocation.jti))
1117 return {row[0] for row in result}
1119 jtis = await asyncio.to_thread(_load_revoked_jtis)
1121 with self._lock:
1122 self._revoked_jtis.update(jtis)
1124 # Also sync to Redis
1125 redis = await self._get_redis_client()
1126 if redis and jtis:
1127 try:
1128 await redis.sadd("mcpgw:auth:revoked_tokens", *jtis)
1129 except Exception as e:
1130 logger.warning(f"AuthCache Redis sync_revoked_tokens failed: {e}")
1132 logger.info(f"AuthCache: Synced {len(jtis)} revoked tokens to cache")
1134 except Exception as e:
1135 logger.warning(f"AuthCache sync_revoked_tokens failed: {e}")
1137 def invalidate_all(self) -> None:
1138 """Invalidate all cached data.
1140 Call during testing or when major configuration changes.
1142 Examples:
1143 >>> cache = AuthCache()
1144 >>> cache.invalidate_all()
1145 """
1146 with self._lock:
1147 self._user_cache.clear()
1148 self._team_cache.clear()
1149 self._revocation_cache.clear()
1150 self._context_cache.clear()
1151 self._role_cache.clear()
1152 self._teams_list_cache.clear()
1153 # Don't clear _revoked_jtis as those are confirmed revocations
1155 logger.info("AuthCache: All caches invalidated")
1157 def stats(self) -> Dict[str, Any]:
1158 """Get cache statistics.
1160 Returns:
1161 Dictionary with hit/miss counts and hit rate
1163 Examples:
1164 >>> cache = AuthCache()
1165 >>> stats = cache.stats()
1166 >>> "hit_count" in stats
1167 True
1168 """
1169 total = self._hit_count + self._miss_count
1170 redis_total = self._redis_hit_count + self._redis_miss_count
1172 return {
1173 "enabled": self._enabled,
1174 "hit_count": self._hit_count,
1175 "miss_count": self._miss_count,
1176 "hit_rate": self._hit_count / total if total > 0 else 0.0,
1177 "redis_hit_count": self._redis_hit_count,
1178 "redis_miss_count": self._redis_miss_count,
1179 "redis_hit_rate": self._redis_hit_count / redis_total if redis_total > 0 else 0.0,
1180 "redis_available": self._redis_available,
1181 "revoked_tokens_cached": len(self._revoked_jtis),
1182 "context_cache_size": len(self._context_cache),
1183 "role_cache_size": len(self._role_cache),
1184 "teams_list_cache_size": len(self._teams_list_cache),
1185 "team_membership_cache_size": len(self._team_cache),
1186 "user_ttl": self._user_ttl,
1187 "revocation_ttl": self._revocation_ttl,
1188 "team_ttl": self._team_ttl,
1189 "role_ttl": self._role_ttl,
1190 "teams_list_enabled": self._teams_list_enabled,
1191 "teams_list_ttl": self._teams_list_ttl,
1192 }
1194 def reset_stats(self) -> None:
1195 """Reset hit/miss counters.
1197 Examples:
1198 >>> cache = AuthCache()
1199 >>> cache._hit_count = 100
1200 >>> cache.reset_stats()
1201 >>> cache._hit_count
1202 0
1203 """
1204 self._hit_count = 0
1205 self._miss_count = 0
1206 self._redis_hit_count = 0
1207 self._redis_miss_count = 0
1210# Global singleton instance
1211_auth_cache: Optional[AuthCache] = None
1214def get_auth_cache() -> AuthCache:
1215 """Get or create the singleton AuthCache instance.
1217 Returns:
1218 AuthCache: The singleton auth cache instance
1220 Examples:
1221 >>> cache = get_auth_cache()
1222 >>> isinstance(cache, AuthCache)
1223 True
1224 """
1225 global _auth_cache # pylint: disable=global-statement
1226 if _auth_cache is None:
1227 _auth_cache = AuthCache()
1228 return _auth_cache
1231# Convenience alias for direct import
1232auth_cache = get_auth_cache()