Coverage for mcpgateway / cache / auth_cache.py: 100%

455 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/cache/auth_cache.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5 

6Authentication Data Cache. 

7 

8This module implements a thread-safe two-tier cache for authentication data. 

9L1 (in-memory) is checked first for lowest latency, with L2 (Redis) as a 

10shared distributed cache. It caches user data, team memberships, and token 

11revocation status to reduce database queries during authentication. 

12 

13Performance Impact: 

14 - Before: 3-4 DB queries per authenticated request 

15 - After: 0-1 DB queries (cache hit) per TTL period 

16 

17Security Considerations: 

18 - Short TTLs for revocation data (30s default) to limit exposure window 

19 - Cache invalidation on token revocation, user update, team change 

20 - JWT payloads are NOT cached (security risk) 

21 - Graceful fallback to DB on cache failure 

22 

23Examples: 

24 >>> from mcpgateway.cache.auth_cache import auth_cache 

25 >>> # Cache is used automatically by get_current_user() 

26 >>> # Manual invalidation after user update: 

27 >>> import asyncio 

28 >>> # asyncio.run(auth_cache.invalidate_user("user@example.com")) 

29""" 

30 

31# Standard 

32import asyncio 

33from dataclasses import dataclass 

34import logging 

35import threading 

36import time 

37from typing import Any, Dict, List, Optional, Set 

38 

39logger = logging.getLogger(__name__) 

40 

41# Sentinel value to represent "user is not a member" in Redis cache 

42# This allows distinguishing between "not a member" (cached) and "cache miss" 

43_NOT_A_MEMBER_SENTINEL = "__NOT_A_MEMBER__" 

44 

45 

46@dataclass 

47class CachedAuthContext: 

48 """Cached authentication context from batched DB query. 

49 

50 This dataclass holds user data, team membership, and revocation status 

51 retrieved from a single database roundtrip. 

52 

53 Attributes: 

54 user: User data dict (email, is_admin, is_active, etc.) or None 

55 personal_team_id: User's personal team ID or None 

56 is_token_revoked: Whether the JWT is revoked 

57 

58 Examples: 

59 >>> ctx = CachedAuthContext( 

60 ... user={"email": "test@example.com", "is_admin": False}, 

61 ... personal_team_id="team-123", 

62 ... is_token_revoked=False 

63 ... ) 

64 >>> ctx.is_token_revoked 

65 False 

66 """ 

67 

68 user: Optional[Dict[str, Any]] = None 

69 personal_team_id: Optional[str] = None 

70 is_token_revoked: bool = False 

71 

72 

73@dataclass 

74class CacheEntry: 

75 """Cache entry with value and expiry timestamp. 

76 

77 Examples: 

78 >>> import time 

79 >>> entry = CacheEntry(value={"key": "value"}, expiry=time.time() + 60) 

80 >>> entry.is_expired() 

81 False 

82 """ 

83 

84 value: Any 

85 expiry: float 

86 

87 def is_expired(self) -> bool: 

88 """Check if this cache entry has expired. 

89 

90 Returns: 

91 bool: True if the entry has expired, False otherwise. 

92 """ 

93 return time.time() >= self.expiry 

94 

95 

96class AuthCache: 

97 """Thread-safe two-tier authentication cache (L1 in-memory + L2 Redis). 

98 

99 This cache reduces database load during authentication by caching: 

100 - User data (email, is_admin, is_active, etc.) 

101 - Personal team ID for the user 

102 - Token revocation status 

103 

104 Cache lookup checks L1 (in-memory) first for lowest latency, then L2 

105 (Redis) for distributed consistency. Redis hits are written through 

106 to L1 for subsequent requests. 

107 

108 Attributes: 

109 user_ttl: TTL in seconds for user data cache (default: 60) 

110 revocation_ttl: TTL in seconds for revocation cache (default: 30) 

111 team_ttl: TTL in seconds for team cache (default: 60) 

112 

113 Examples: 

114 >>> cache = AuthCache(user_ttl=60, revocation_ttl=30) 

115 >>> cache.stats()["hit_count"] 

116 0 

117 """ 

118 

119 _NOT_CACHED = object() 

120 

121 def __init__( 

122 self, 

123 user_ttl: Optional[int] = None, 

124 revocation_ttl: Optional[int] = None, 

125 team_ttl: Optional[int] = None, 

126 role_ttl: Optional[int] = None, 

127 enabled: Optional[bool] = None, 

128 ): 

129 """Initialize the auth cache. 

130 

131 Args: 

132 user_ttl: TTL for user data cache in seconds (default: from settings or 60) 

133 revocation_ttl: TTL for revocation cache in seconds (default: from settings or 30) 

134 team_ttl: TTL for team cache in seconds (default: from settings or 60) 

135 role_ttl: TTL for role cache in seconds (default: from settings or 60) 

136 enabled: Whether caching is enabled (default: from settings or True) 

137 

138 Examples: 

139 >>> cache = AuthCache(user_ttl=120, revocation_ttl=30) 

140 >>> cache._user_ttl 

141 120 

142 """ 

143 # Import settings lazily to avoid circular imports 

144 try: 

145 # First-Party 

146 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel 

147 

148 self._user_ttl = user_ttl or getattr(settings, "auth_cache_user_ttl", 60) 

149 self._revocation_ttl = revocation_ttl or getattr(settings, "auth_cache_revocation_ttl", 30) 

150 self._team_ttl = team_ttl or getattr(settings, "auth_cache_team_ttl", 60) 

151 self._role_ttl = role_ttl or getattr(settings, "auth_cache_role_ttl", 60) 

152 self._teams_list_ttl = getattr(settings, "auth_cache_teams_ttl", 60) 

153 self._teams_list_enabled = getattr(settings, "auth_cache_teams_enabled", True) 

154 self._enabled = enabled if enabled is not None else getattr(settings, "auth_cache_enabled", True) 

155 self._cache_prefix = getattr(settings, "cache_prefix", "mcpgw:") 

156 except ImportError: 

157 self._user_ttl = user_ttl or 60 

158 self._revocation_ttl = revocation_ttl or 30 

159 self._team_ttl = team_ttl or 60 

160 self._role_ttl = role_ttl or 60 

161 self._teams_list_ttl = 60 

162 self._teams_list_enabled = True 

163 self._enabled = enabled if enabled is not None else True 

164 self._cache_prefix = "mcpgw:" 

165 

166 # In-memory cache (fallback when Redis unavailable) 

167 self._user_cache: Dict[str, CacheEntry] = {} 

168 self._team_cache: Dict[str, CacheEntry] = {} 

169 self._revocation_cache: Dict[str, CacheEntry] = {} 

170 self._context_cache: Dict[str, CacheEntry] = {} 

171 self._role_cache: Dict[str, CacheEntry] = {} 

172 self._teams_list_cache: Dict[str, CacheEntry] = {} 

173 

174 # Known revoked tokens (fast local lookup) 

175 self._revoked_jtis: Set[str] = set() 

176 

177 # Thread safety 

178 self._lock = threading.Lock() 

179 

180 # Redis availability (None = not checked yet) 

181 self._redis_checked = False 

182 self._redis_available = False 

183 

184 # Statistics 

185 self._hit_count = 0 

186 self._miss_count = 0 

187 self._redis_hit_count = 0 

188 self._redis_miss_count = 0 

189 

190 logger.info( 

191 f"AuthCache initialized: enabled={self._enabled}, " 

192 f"user_ttl={self._user_ttl}s, revocation_ttl={self._revocation_ttl}s, " 

193 f"team_ttl={self._team_ttl}s, role_ttl={self._role_ttl}s, " 

194 f"teams_list_enabled={self._teams_list_enabled}, teams_list_ttl={self._teams_list_ttl}s" 

195 ) 

196 

197 def _get_redis_key(self, key_type: str, identifier: str) -> str: 

198 """Generate Redis key with proper prefix. 

199 

200 Args: 

201 key_type: Type of cache entry (user, team, revoke, ctx) 

202 identifier: Unique identifier (email, jti, etc.) 

203 

204 Returns: 

205 Full Redis key with prefix 

206 

207 Examples: 

208 >>> cache = AuthCache() 

209 >>> cache._get_redis_key("user", "test@example.com") 

210 'mcpgw:auth:user:test@example.com' 

211 """ 

212 return f"{self._cache_prefix}auth:{key_type}:{identifier}" 

213 

214 async def _get_redis_client(self): 

215 """Get Redis client if available. 

216 

217 Returns: 

218 Redis client or None if unavailable 

219 """ 

220 try: 

221 # First-Party 

222 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

223 

224 client = await get_redis_client() 

225 if client and not self._redis_checked: 

226 self._redis_checked = True 

227 self._redis_available = True 

228 logger.debug("AuthCache: Redis client available") 

229 return client 

230 except Exception as e: 

231 if not self._redis_checked: 

232 self._redis_checked = True 

233 self._redis_available = False 

234 logger.debug(f"AuthCache: Redis unavailable, using in-memory cache: {e}") 

235 return None 

236 

237 async def get_auth_context( 

238 self, 

239 email: str, 

240 jti: Optional[str] = None, 

241 ) -> Optional[CachedAuthContext]: 

242 """Get cached authentication context. 

243 

244 Checks cache for user data, team membership, and revocation status. 

245 Returns None on cache miss. 

246 

247 Args: 

248 email: User email address 

249 jti: JWT ID for revocation check (optional) 

250 

251 Returns: 

252 CachedAuthContext if found in cache, None otherwise 

253 

254 Examples: 

255 >>> import asyncio 

256 >>> cache = AuthCache() 

257 >>> result = asyncio.run(cache.get_auth_context("test@example.com")) 

258 >>> result is None # Cache miss on fresh cache 

259 True 

260 """ 

261 if not self._enabled: 

262 return None 

263 

264 # Check for known revoked token first (fast local check) 

265 if jti and jti in self._revoked_jtis: 

266 self._hit_count += 1 

267 return CachedAuthContext(is_token_revoked=True) 

268 

269 # Cross-worker revocation check: when a token is revoked on another worker, 

270 # the revoking worker writes a Redis revocation marker. Check it BEFORE the 

271 # L1 in-memory cache so that stale L1 entries cannot bypass revocation. 

272 if jti: 

273 redis = await self._get_redis_client() 

274 if redis: 

275 try: 

276 revoke_key = self._get_redis_key("revoke", jti) 

277 if await redis.exists(revoke_key): 

278 # Promote to local set so subsequent requests skip the Redis call 

279 with self._lock: 

280 self._revoked_jtis.add(jti) 

281 # Evict any stale L1 context entries for this JTI 

282 for k in [k for k in self._context_cache if k.endswith(f":{jti}")]: 

283 self._context_cache.pop(k, None) 

284 self._hit_count += 1 

285 return CachedAuthContext(is_token_revoked=True) 

286 except Exception as exc: 

287 logger.debug(f"AuthCache: Redis revocation check failed for {jti[:8]}: {exc}") 

288 

289 cache_key = f"{email}:{jti or 'no-jti'}" 

290 

291 # Check L1 in-memory cache first (no network I/O) 

292 entry = self._context_cache.get(cache_key) 

293 if entry and not entry.is_expired(): 

294 self._hit_count += 1 

295 return entry.value 

296 

297 # Check L2 Redis cache 

298 redis = await self._get_redis_client() 

299 if redis: 

300 try: 

301 redis_key = self._get_redis_key("ctx", cache_key) 

302 data = await redis.get(redis_key) 

303 if data: 

304 # Third-Party 

305 import orjson # pylint: disable=import-outside-toplevel 

306 

307 cached = orjson.loads(data) 

308 result = CachedAuthContext( 

309 user=cached.get("user"), 

310 personal_team_id=cached.get("personal_team_id"), 

311 is_token_revoked=cached.get("is_token_revoked", False), 

312 ) 

313 self._hit_count += 1 

314 self._redis_hit_count += 1 

315 

316 # Write-through: populate L1 from Redis hit 

317 ttl = min(self._user_ttl, self._revocation_ttl, self._team_ttl) 

318 with self._lock: 

319 self._context_cache[cache_key] = CacheEntry( 

320 value=result, 

321 expiry=time.time() + ttl, 

322 ) 

323 

324 return result 

325 self._redis_miss_count += 1 

326 except Exception as e: 

327 logger.warning(f"AuthCache Redis get failed: {e}") 

328 

329 self._miss_count += 1 

330 return None 

331 

332 async def set_auth_context( 

333 self, 

334 email: str, 

335 jti: Optional[str], 

336 context: CachedAuthContext, 

337 ) -> None: 

338 """Store authentication context in cache. 

339 

340 Stores in both Redis (if available) and in-memory cache. 

341 

342 Args: 

343 email: User email address 

344 jti: JWT ID (optional) 

345 context: Authentication context to cache 

346 

347 Examples: 

348 >>> import asyncio 

349 >>> cache = AuthCache() 

350 >>> ctx = CachedAuthContext( 

351 ... user={"email": "test@example.com"}, 

352 ... personal_team_id="team-1", 

353 ... is_token_revoked=False 

354 ... ) 

355 >>> asyncio.run(cache.set_auth_context("test@example.com", "jti-123", ctx)) 

356 """ 

357 if not self._enabled: 

358 return 

359 

360 cache_key = f"{email}:{jti or 'no-jti'}" 

361 

362 # Use shortest TTL for combined context 

363 ttl = min(self._user_ttl, self._revocation_ttl, self._team_ttl) 

364 

365 # Prepare data for serialization 

366 data = { 

367 "user": context.user, 

368 "personal_team_id": context.personal_team_id, 

369 "is_token_revoked": context.is_token_revoked, 

370 } 

371 

372 # Store in Redis 

373 redis = await self._get_redis_client() 

374 if redis: 

375 try: 

376 # Third-Party 

377 import orjson # pylint: disable=import-outside-toplevel 

378 

379 redis_key = self._get_redis_key("ctx", cache_key) 

380 await redis.setex(redis_key, ttl, orjson.dumps(data)) 

381 except Exception as e: 

382 logger.warning(f"AuthCache Redis set failed: {e}") 

383 

384 # Store in in-memory cache 

385 with self._lock: 

386 self._context_cache[cache_key] = CacheEntry( 

387 value=context, 

388 expiry=time.time() + ttl, 

389 ) 

390 

391 async def invalidate_user(self, email: str) -> None: 

392 """Invalidate cached data for a user. 

393 

394 Call this when user data changes (password, profile, etc.). 

395 

396 Args: 

397 email: User email to invalidate 

398 

399 Examples: 

400 >>> import asyncio 

401 >>> cache = AuthCache() 

402 >>> asyncio.run(cache.invalidate_user("test@example.com")) 

403 """ 

404 logger.debug(f"AuthCache: Invalidating user cache for {email}") 

405 

406 # Clear in-memory caches 

407 with self._lock: 

408 # Clear any context cache entries for this user 

409 keys_to_remove = [k for k in self._context_cache if k.startswith(f"{email}:")] 

410 for key in keys_to_remove: 

411 self._context_cache.pop(key, None) 

412 

413 self._user_cache.pop(email, None) 

414 

415 # Clear team membership cache entries (keys are email:team_ids) 

416 team_keys_to_remove = [k for k in self._team_cache if k.startswith(f"{email}:")] 

417 for key in team_keys_to_remove: 

418 self._team_cache.pop(key, None) 

419 

420 # Clear Redis 

421 redis = await self._get_redis_client() 

422 if redis: 

423 try: 

424 # Delete user-specific keys 

425 await redis.delete( 

426 self._get_redis_key("user", email), 

427 self._get_redis_key("team", email), 

428 ) 

429 # Delete context keys (pattern match) 

430 pattern = self._get_redis_key("ctx", f"{email}:*") 

431 async for key in redis.scan_iter(match=pattern): 

432 await redis.delete(key) 

433 # Delete membership keys (pattern match) 

434 membership_pattern = self._get_redis_key("membership", f"{email}:*") 

435 async for key in redis.scan_iter(match=membership_pattern): 

436 await redis.delete(key) 

437 

438 # Publish invalidation for other workers 

439 await redis.publish("mcpgw:auth:invalidate", f"user:{email}") 

440 except Exception as e: 

441 logger.warning(f"AuthCache Redis invalidate_user failed: {e}") 

442 

443 async def invalidate_revocation(self, jti: str) -> None: 

444 """Invalidate cache for a revoked token. 

445 

446 Call this when a token is revoked. 

447 

448 Args: 

449 jti: JWT ID of revoked token 

450 

451 Examples: 

452 >>> import asyncio 

453 >>> cache = AuthCache() 

454 >>> asyncio.run(cache.invalidate_revocation("jti-123")) 

455 """ 

456 logger.debug(f"AuthCache: Invalidating revocation cache for jti={jti[:8]}...") 

457 

458 # Add to local revoked set for fast lookup 

459 with self._lock: 

460 self._revoked_jtis.add(jti) 

461 self._revocation_cache.pop(jti, None) 

462 

463 # Clear any context cache entries with this JTI 

464 keys_to_remove = [k for k in self._context_cache if k.endswith(f":{jti}")] 

465 for key in keys_to_remove: 

466 self._context_cache.pop(key, None) 

467 

468 # Update Redis 

469 redis = await self._get_redis_client() 

470 if redis: 

471 try: 

472 # Mark as revoked in Redis 

473 await redis.setex( 

474 self._get_redis_key("revoke", jti), 

475 86400, # 24 hour expiry for revocation markers 

476 "1", 

477 ) 

478 # Add to revoked tokens set 

479 await redis.sadd("mcpgw:auth:revoked_tokens", jti) 

480 

481 # Delete any cached contexts with this JTI 

482 pattern = self._get_redis_key("ctx", f"*:{jti}") 

483 async for key in redis.scan_iter(match=pattern): 

484 await redis.delete(key) 

485 

486 # Publish invalidation for other workers 

487 await redis.publish("mcpgw:auth:invalidate", f"revoke:{jti}") 

488 except Exception as e: 

489 logger.warning(f"AuthCache Redis invalidate_revocation failed: {e}") 

490 

491 async def invalidate_team(self, email: str) -> None: 

492 """Invalidate team cache for a user. 

493 

494 Call this when team membership changes. 

495 

496 Args: 

497 email: User email whose team changed 

498 

499 Examples: 

500 >>> import asyncio 

501 >>> cache = AuthCache() 

502 >>> asyncio.run(cache.invalidate_team("test@example.com")) 

503 """ 

504 logger.debug(f"AuthCache: Invalidating team cache for {email}") 

505 

506 # Clear in-memory caches 

507 with self._lock: 

508 self._team_cache.pop(email, None) 

509 # Clear context cache entries for this user 

510 keys_to_remove = [k for k in self._context_cache if k.startswith(f"{email}:")] 

511 for key in keys_to_remove: 

512 self._context_cache.pop(key, None) 

513 

514 # Clear Redis 

515 redis = await self._get_redis_client() 

516 if redis: 

517 try: 

518 await redis.delete(self._get_redis_key("team", email)) 

519 # Delete context keys 

520 pattern = self._get_redis_key("ctx", f"{email}:*") 

521 async for key in redis.scan_iter(match=pattern): 

522 await redis.delete(key) 

523 

524 # Publish invalidation 

525 await redis.publish("mcpgw:auth:invalidate", f"team:{email}") 

526 except Exception as e: 

527 logger.warning(f"AuthCache Redis invalidate_team failed: {e}") 

528 

529 async def get_user_role(self, email: str, team_id: str) -> Optional[str]: 

530 """Get cached user role in a team. 

531 

532 Returns: 

533 - None: Cache miss (caller should check DB) 

534 - "": User is not a member of the team (cached negative result) 

535 - Role string: User's role in the team (cached) 

536 

537 Args: 

538 email: User email 

539 team_id: Team ID 

540 

541 Examples: 

542 >>> import asyncio 

543 >>> cache = AuthCache() 

544 >>> result = asyncio.run(cache.get_user_role("test@example.com", "team-123")) 

545 >>> result is None # Cache miss 

546 True 

547 """ 

548 if not self._enabled: 

549 return None 

550 

551 cache_key = f"{email}:{team_id}" 

552 

553 # Check L1 in-memory cache first (no network I/O) 

554 entry = self._role_cache.get(cache_key) 

555 if entry and not entry.is_expired(): 

556 self._hit_count += 1 

557 # Return empty string for None (not a member) to distinguish from cache miss 

558 return "" if entry.value is None else entry.value 

559 

560 # Check L2 Redis cache 

561 redis = await self._get_redis_client() 

562 if redis: 

563 try: 

564 redis_key = self._get_redis_key("role", cache_key) 

565 data = await redis.get(redis_key) 

566 if data is not None: 

567 self._hit_count += 1 

568 self._redis_hit_count += 1 

569 # Role is stored as plain string, decode it 

570 decoded = data.decode() if isinstance(data, bytes) else data 

571 # Convert sentinel to empty string (user is not a member) 

572 role_value = "" if decoded == _NOT_A_MEMBER_SENTINEL else decoded 

573 

574 # Write-through: populate L1 from Redis hit 

575 # Store as None for not-a-member to match existing L1 storage format 

576 with self._lock: 

577 self._role_cache[cache_key] = CacheEntry( 

578 value=None if role_value == "" else role_value, 

579 expiry=time.time() + self._role_ttl, 

580 ) 

581 

582 return role_value 

583 self._redis_miss_count += 1 

584 except Exception as e: 

585 logger.warning(f"AuthCache Redis get_user_role failed: {e}") 

586 

587 self._miss_count += 1 

588 return None 

589 

590 async def set_user_role(self, email: str, team_id: str, role: Optional[str]) -> None: 

591 """Store user role in cache. 

592 

593 Args: 

594 email: User email 

595 team_id: Team ID 

596 role: User's role in the team (or None if not a member) 

597 

598 Examples: 

599 >>> import asyncio 

600 >>> cache = AuthCache() 

601 >>> asyncio.run(cache.set_user_role("test@example.com", "team-123", "admin")) 

602 """ 

603 if not self._enabled: 

604 return 

605 

606 cache_key = f"{email}:{team_id}" 

607 # Store None as sentinel value to distinguish "not a member" from cache miss 

608 role_value = role if role is not None else _NOT_A_MEMBER_SENTINEL 

609 

610 # Store in Redis 

611 redis = await self._get_redis_client() 

612 if redis: 

613 try: 

614 redis_key = self._get_redis_key("role", cache_key) 

615 await redis.setex(redis_key, self._role_ttl, role_value) 

616 except Exception as e: 

617 logger.warning(f"AuthCache Redis set_user_role failed: {e}") 

618 

619 # Store in in-memory cache 

620 with self._lock: 

621 self._role_cache[cache_key] = CacheEntry( 

622 value=role, 

623 expiry=time.time() + self._role_ttl, 

624 ) 

625 

626 async def invalidate_user_role(self, email: str, team_id: str) -> None: 

627 """Invalidate cached role for a user in a team. 

628 

629 Call this when a user's role changes in a team. 

630 

631 Args: 

632 email: User email 

633 team_id: Team ID 

634 

635 Examples: 

636 >>> import asyncio 

637 >>> cache = AuthCache() 

638 >>> asyncio.run(cache.invalidate_user_role("test@example.com", "team-123")) 

639 """ 

640 logger.debug(f"AuthCache: Invalidating role cache for {email} in team {team_id}") 

641 

642 cache_key = f"{email}:{team_id}" 

643 

644 # Clear in-memory cache 

645 with self._lock: 

646 self._role_cache.pop(cache_key, None) 

647 

648 # Clear Redis 

649 redis = await self._get_redis_client() 

650 if redis: 

651 try: 

652 await redis.delete(self._get_redis_key("role", cache_key)) 

653 # Publish invalidation for other workers 

654 await redis.publish("mcpgw:auth:invalidate", f"role:{email}:{team_id}") 

655 except Exception as e: 

656 logger.warning(f"AuthCache Redis invalidate_user_role failed: {e}") 

657 

658 async def invalidate_team_roles(self, team_id: str) -> None: 

659 """Invalidate all cached roles for a team. 

660 

661 Call this when team membership changes significantly (e.g., team deletion). 

662 

663 Args: 

664 team_id: Team ID 

665 

666 Examples: 

667 >>> import asyncio 

668 >>> cache = AuthCache() 

669 >>> asyncio.run(cache.invalidate_team_roles("team-123")) 

670 """ 

671 logger.debug(f"AuthCache: Invalidating all role caches for team {team_id}") 

672 

673 # Clear in-memory cache entries for this team 

674 with self._lock: 

675 keys_to_remove = [k for k in self._role_cache if k.endswith(f":{team_id}")] 

676 for key in keys_to_remove: 

677 self._role_cache.pop(key, None) 

678 

679 # Clear Redis 

680 redis = await self._get_redis_client() 

681 if redis: 

682 try: 

683 # Pattern match all role keys for this team 

684 pattern = self._get_redis_key("role", f"*:{team_id}") 

685 async for key in redis.scan_iter(match=pattern): 

686 await redis.delete(key) 

687 # Publish invalidation 

688 await redis.publish("mcpgw:auth:invalidate", f"team_roles:{team_id}") 

689 except Exception as e: 

690 logger.warning(f"AuthCache Redis invalidate_team_roles failed: {e}") 

691 

692 async def get_user_teams(self, cache_key: str) -> Optional[List[str]]: 

693 """Get cached team IDs for a user. 

694 

695 The cache_key should be in the format "email:include_personal" to 

696 distinguish between calls with different include_personal flags. 

697 

698 Returns: 

699 - None: Cache miss (caller should query DB) 

700 - Empty list: User has no teams (cached result) 

701 - List of team IDs: Cached team IDs 

702 

703 Args: 

704 cache_key: Cache key in format "email:include_personal" 

705 

706 Examples: 

707 >>> import asyncio 

708 >>> cache = AuthCache() 

709 >>> result = asyncio.run(cache.get_user_teams("test@example.com:True")) 

710 >>> result is None # Cache miss 

711 True 

712 """ 

713 if not self._enabled or not self._teams_list_enabled: 

714 return None 

715 

716 # Check L1 in-memory cache first (no network I/O) 

717 entry = self._teams_list_cache.get(cache_key) 

718 if entry and not entry.is_expired(): 

719 self._hit_count += 1 

720 return entry.value 

721 

722 # Check L2 Redis cache 

723 redis = await self._get_redis_client() 

724 if redis: 

725 try: 

726 redis_key = self._get_redis_key("teams", cache_key) 

727 data = await redis.get(redis_key) 

728 if data is not None: 

729 self._hit_count += 1 

730 self._redis_hit_count += 1 

731 # Third-Party 

732 import orjson # pylint: disable=import-outside-toplevel 

733 

734 team_ids = orjson.loads(data) 

735 

736 # Write-through: populate L1 from Redis hit 

737 with self._lock: 

738 self._teams_list_cache[cache_key] = CacheEntry( 

739 value=team_ids, 

740 expiry=time.time() + self._teams_list_ttl, 

741 ) 

742 

743 return team_ids 

744 self._redis_miss_count += 1 

745 except Exception as e: 

746 logger.warning(f"AuthCache Redis get_user_teams failed: {e}") 

747 

748 self._miss_count += 1 

749 return None 

750 

751 async def set_user_teams(self, cache_key: str, team_ids: List[str]) -> None: 

752 """Store team IDs for a user in cache. 

753 

754 Args: 

755 cache_key: Cache key in format "email:include_personal" 

756 team_ids: List of team IDs the user belongs to 

757 

758 Examples: 

759 >>> import asyncio 

760 >>> cache = AuthCache() 

761 >>> asyncio.run(cache.set_user_teams("test@example.com:True", ["team-1", "team-2"])) 

762 """ 

763 if not self._enabled or not self._teams_list_enabled: 

764 return 

765 

766 # Store in Redis 

767 redis = await self._get_redis_client() 

768 if redis: 

769 try: 

770 # Third-Party 

771 import orjson # pylint: disable=import-outside-toplevel 

772 

773 redis_key = self._get_redis_key("teams", cache_key) 

774 await redis.setex(redis_key, self._teams_list_ttl, orjson.dumps(team_ids)) 

775 except Exception as e: 

776 logger.warning(f"AuthCache Redis set_user_teams failed: {e}") 

777 

778 # Store in in-memory cache 

779 with self._lock: 

780 self._teams_list_cache[cache_key] = CacheEntry( 

781 value=team_ids, 

782 expiry=time.time() + self._teams_list_ttl, 

783 ) 

784 

785 async def invalidate_user_teams(self, email: str) -> None: 

786 """Invalidate cached teams list for a user. 

787 

788 Call this when a user's team membership changes (add/remove member, 

789 delete team, approve join request). 

790 

791 This invalidates both include_personal=True and include_personal=False 

792 cache entries for the user. 

793 

794 Args: 

795 email: User email whose teams cache should be invalidated 

796 

797 Examples: 

798 >>> import asyncio 

799 >>> cache = AuthCache() 

800 >>> asyncio.run(cache.invalidate_user_teams("test@example.com")) 

801 """ 

802 logger.debug(f"AuthCache: Invalidating teams list cache for {email}") 

803 

804 # Clear in-memory cache entries for this user (both True and False variants) 

805 with self._lock: 

806 keys_to_remove = [k for k in self._teams_list_cache if k.startswith(f"{email}:")] 

807 for key in keys_to_remove: 

808 self._teams_list_cache.pop(key, None) 

809 

810 # Clear Redis 

811 redis = await self._get_redis_client() 

812 if redis: 

813 try: 

814 # Delete both variants 

815 await redis.delete( 

816 self._get_redis_key("teams", f"{email}:True"), 

817 self._get_redis_key("teams", f"{email}:False"), 

818 ) 

819 # Publish invalidation for other workers 

820 await redis.publish("mcpgw:auth:invalidate", f"teams:{email}") 

821 except Exception as e: 

822 logger.warning(f"AuthCache Redis invalidate_user_teams failed: {e}") 

823 

824 # ========================================================================= 

825 # Team Membership Validation Cache 

826 # ========================================================================= 

827 # Used by TokenScopingMiddleware to cache email_team_members lookups. 

828 # This prevents repeated DB queries for the same user+teams combination. 

829 

830 def get_team_membership_valid_sync(self, user_email: str, team_ids: List[str]) -> Optional[bool]: 

831 """Get cached team membership validation result (synchronous). 

832 

833 This is the synchronous version used by token_scoping middleware. 

834 Returns None on cache miss (caller should check DB). 

835 

836 Args: 

837 user_email: User email 

838 team_ids: List of team IDs to validate membership for 

839 

840 Returns: 

841 - None: Cache miss (caller should query DB) 

842 - True: User is valid member of all teams (cached) 

843 - False: User is NOT a valid member of all teams (cached) 

844 

845 Examples: 

846 >>> cache = AuthCache() 

847 >>> result = cache.get_team_membership_valid_sync("test@example.com", ["team-1"]) 

848 >>> result is None # Cache miss 

849 True 

850 """ 

851 if not self._enabled or not team_ids: 

852 return None 

853 

854 # Create cache key from user + sorted team IDs 

855 sorted_teams = ":".join(sorted(team_ids)) 

856 cache_key = f"{user_email}:{sorted_teams}" 

857 

858 # Check in-memory cache only (sync version) 

859 entry = self._team_cache.get(cache_key) 

860 if entry and not entry.is_expired(): 

861 self._hit_count += 1 

862 return entry.value 

863 

864 self._miss_count += 1 

865 return None 

866 

867 def set_team_membership_valid_sync(self, user_email: str, team_ids: List[str], valid: bool) -> None: 

868 """Store team membership validation result in cache (synchronous). 

869 

870 Args: 

871 user_email: User email 

872 team_ids: List of team IDs that were validated 

873 valid: Whether user is a valid member of all teams 

874 

875 Examples: 

876 >>> cache = AuthCache() 

877 >>> cache.set_team_membership_valid_sync("test@example.com", ["team-1"], True) 

878 """ 

879 if not self._enabled or not team_ids: 

880 return 

881 

882 # Create cache key from user + sorted team IDs 

883 sorted_teams = ":".join(sorted(team_ids)) 

884 cache_key = f"{user_email}:{sorted_teams}" 

885 

886 # Store in in-memory cache 

887 with self._lock: 

888 self._team_cache[cache_key] = CacheEntry( 

889 value=valid, 

890 expiry=time.time() + self._team_ttl, 

891 ) 

892 

893 async def get_team_membership_valid(self, user_email: str, team_ids: List[str]) -> Optional[bool]: 

894 """Get cached team membership validation result (async). 

895 

896 Returns None on cache miss (caller should check DB). 

897 

898 Args: 

899 user_email: User email 

900 team_ids: List of team IDs to validate membership for 

901 

902 Returns: 

903 - None: Cache miss (caller should query DB) 

904 - True: User is valid member of all teams (cached) 

905 - False: User is NOT a valid member of all teams (cached) 

906 

907 Examples: 

908 >>> import asyncio 

909 >>> cache = AuthCache() 

910 >>> result = asyncio.run(cache.get_team_membership_valid("test@example.com", ["team-1"])) 

911 >>> result is None # Cache miss 

912 True 

913 """ 

914 if not self._enabled or not team_ids: 

915 return None 

916 

917 # Create cache key from user + sorted team IDs 

918 sorted_teams = ":".join(sorted(team_ids)) 

919 cache_key = f"{user_email}:{sorted_teams}" 

920 

921 # Check L1 in-memory cache first (no network I/O) 

922 entry = self._team_cache.get(cache_key) 

923 if entry and not entry.is_expired(): 

924 self._hit_count += 1 

925 return entry.value 

926 

927 # Check L2 Redis cache 

928 redis = await self._get_redis_client() 

929 if redis: 

930 try: 

931 redis_key = self._get_redis_key("membership", cache_key) 

932 data = await redis.get(redis_key) 

933 if data is not None: 

934 self._hit_count += 1 

935 self._redis_hit_count += 1 

936 # Stored as "1" for True, "0" for False 

937 decoded = data.decode() if isinstance(data, bytes) else data 

938 result = decoded == "1" 

939 

940 # Write-through: populate L1 from Redis hit 

941 with self._lock: 

942 self._team_cache[cache_key] = CacheEntry( 

943 value=result, 

944 expiry=time.time() + self._team_ttl, 

945 ) 

946 

947 return result 

948 self._redis_miss_count += 1 

949 except Exception as e: 

950 logger.warning(f"AuthCache Redis get_team_membership_valid failed: {e}") 

951 

952 self._miss_count += 1 

953 return None 

954 

955 async def set_team_membership_valid(self, user_email: str, team_ids: List[str], valid: bool) -> None: 

956 """Store team membership validation result in cache (async). 

957 

958 Args: 

959 user_email: User email 

960 team_ids: List of team IDs that were validated 

961 valid: Whether user is a valid member of all teams 

962 

963 Examples: 

964 >>> import asyncio 

965 >>> cache = AuthCache() 

966 >>> asyncio.run(cache.set_team_membership_valid("test@example.com", ["team-1"], True)) 

967 """ 

968 if not self._enabled or not team_ids: 

969 return 

970 

971 # Create cache key from user + sorted team IDs 

972 sorted_teams = ":".join(sorted(team_ids)) 

973 cache_key = f"{user_email}:{sorted_teams}" 

974 

975 # Store in Redis 

976 redis = await self._get_redis_client() 

977 if redis: 

978 try: 

979 redis_key = self._get_redis_key("membership", cache_key) 

980 # Store as "1" for True, "0" for False 

981 await redis.setex(redis_key, self._team_ttl, "1" if valid else "0") 

982 except Exception as e: 

983 logger.warning(f"AuthCache Redis set_team_membership_valid failed: {e}") 

984 

985 # Store in in-memory cache 

986 with self._lock: 

987 self._team_cache[cache_key] = CacheEntry( 

988 value=valid, 

989 expiry=time.time() + self._team_ttl, 

990 ) 

991 

992 async def invalidate_team_membership(self, user_email: str) -> None: 

993 """Invalidate team membership cache for a user. 

994 

995 Call this when a user's team membership changes (add/remove member, 

996 role change, deactivation). 

997 

998 Args: 

999 user_email: User email whose membership cache should be invalidated 

1000 

1001 Examples: 

1002 >>> import asyncio 

1003 >>> cache = AuthCache() 

1004 >>> asyncio.run(cache.invalidate_team_membership("test@example.com")) 

1005 """ 

1006 logger.debug(f"AuthCache: Invalidating team membership cache for {user_email}") 

1007 

1008 # Clear in-memory cache entries for this user 

1009 with self._lock: 

1010 keys_to_remove = [k for k in self._team_cache if k.startswith(f"{user_email}:")] 

1011 for key in keys_to_remove: 

1012 self._team_cache.pop(key, None) 

1013 

1014 # Clear Redis 

1015 redis = await self._get_redis_client() 

1016 if redis: 

1017 try: 

1018 # Pattern match all membership keys for this user 

1019 pattern = self._get_redis_key("membership", f"{user_email}:*") 

1020 async for key in redis.scan_iter(match=pattern): 

1021 await redis.delete(key) 

1022 # Publish invalidation for other workers 

1023 await redis.publish("mcpgw:auth:invalidate", f"membership:{user_email}") 

1024 except Exception as e: 

1025 logger.warning(f"AuthCache Redis invalidate_team_membership failed: {e}") 

1026 

1027 async def is_token_revoked(self, jti: str) -> Optional[bool]: 

1028 """Check if a token is revoked (cached check only). 

1029 

1030 Returns None on cache miss (caller should check DB). 

1031 

1032 Args: 

1033 jti: JWT ID to check 

1034 

1035 Returns: 

1036 True if revoked, False if not revoked, None if unknown 

1037 

1038 Examples: 

1039 >>> import asyncio 

1040 >>> cache = AuthCache() 

1041 >>> cache._revoked_jtis.add("revoked-jti") 

1042 >>> asyncio.run(cache.is_token_revoked("revoked-jti")) 

1043 True 

1044 """ 

1045 if not self._enabled: 

1046 return None 

1047 

1048 # Fast local check 

1049 if jti in self._revoked_jtis: 

1050 return True 

1051 

1052 # Check L1 in-memory revocation cache 

1053 entry = self._revocation_cache.get(jti) 

1054 if entry and not entry.is_expired(): 

1055 return entry.value 

1056 

1057 # Check L2 Redis 

1058 redis = await self._get_redis_client() 

1059 if redis: 

1060 try: 

1061 # Check revoked tokens set 

1062 if await redis.sismember("mcpgw:auth:revoked_tokens", jti): 

1063 # Add to local set for faster future lookups 

1064 with self._lock: 

1065 self._revoked_jtis.add(jti) 

1066 # Write-through: populate L1 revocation cache 

1067 self._revocation_cache[jti] = CacheEntry( 

1068 value=True, 

1069 expiry=time.time() + self._revocation_ttl, 

1070 ) 

1071 return True 

1072 

1073 # Check individual revocation key 

1074 if await redis.exists(self._get_redis_key("revoke", jti)): 

1075 with self._lock: 

1076 self._revoked_jtis.add(jti) 

1077 # Write-through: populate L1 revocation cache 

1078 self._revocation_cache[jti] = CacheEntry( 

1079 value=True, 

1080 expiry=time.time() + self._revocation_ttl, 

1081 ) 

1082 return True 

1083 except Exception as e: 

1084 logger.warning(f"AuthCache Redis is_token_revoked failed: {e}") 

1085 

1086 return None 

1087 

1088 async def sync_revoked_tokens(self) -> None: 

1089 """Sync revoked tokens from database to cache on startup. 

1090 

1091 Should be called during application startup to populate the 

1092 revocation cache. 

1093 

1094 Examples: 

1095 >>> import asyncio 

1096 >>> cache = AuthCache() 

1097 >>> asyncio.run(cache.sync_revoked_tokens()) 

1098 """ 

1099 if not self._enabled: 

1100 return 

1101 

1102 try: 

1103 # Third-Party 

1104 from sqlalchemy import select # pylint: disable=import-outside-toplevel 

1105 

1106 # First-Party 

1107 from mcpgateway.db import fresh_db_session, TokenRevocation # pylint: disable=import-outside-toplevel 

1108 

1109 def _load_revoked_jtis() -> Set[str]: 

1110 """Load all revoked JTIs from database. 

1111 

1112 Returns: 

1113 Set of revoked JTI strings. 

1114 """ 

1115 with fresh_db_session() as db: 

1116 result = db.execute(select(TokenRevocation.jti)) 

1117 return {row[0] for row in result} 

1118 

1119 jtis = await asyncio.to_thread(_load_revoked_jtis) 

1120 

1121 with self._lock: 

1122 self._revoked_jtis.update(jtis) 

1123 

1124 # Also sync to Redis 

1125 redis = await self._get_redis_client() 

1126 if redis and jtis: 

1127 try: 

1128 await redis.sadd("mcpgw:auth:revoked_tokens", *jtis) 

1129 except Exception as e: 

1130 logger.warning(f"AuthCache Redis sync_revoked_tokens failed: {e}") 

1131 

1132 logger.info(f"AuthCache: Synced {len(jtis)} revoked tokens to cache") 

1133 

1134 except Exception as e: 

1135 logger.warning(f"AuthCache sync_revoked_tokens failed: {e}") 

1136 

1137 def invalidate_all(self) -> None: 

1138 """Invalidate all cached data. 

1139 

1140 Call during testing or when major configuration changes. 

1141 

1142 Examples: 

1143 >>> cache = AuthCache() 

1144 >>> cache.invalidate_all() 

1145 """ 

1146 with self._lock: 

1147 self._user_cache.clear() 

1148 self._team_cache.clear() 

1149 self._revocation_cache.clear() 

1150 self._context_cache.clear() 

1151 self._role_cache.clear() 

1152 self._teams_list_cache.clear() 

1153 # Don't clear _revoked_jtis as those are confirmed revocations 

1154 

1155 logger.info("AuthCache: All caches invalidated") 

1156 

1157 def stats(self) -> Dict[str, Any]: 

1158 """Get cache statistics. 

1159 

1160 Returns: 

1161 Dictionary with hit/miss counts and hit rate 

1162 

1163 Examples: 

1164 >>> cache = AuthCache() 

1165 >>> stats = cache.stats() 

1166 >>> "hit_count" in stats 

1167 True 

1168 """ 

1169 total = self._hit_count + self._miss_count 

1170 redis_total = self._redis_hit_count + self._redis_miss_count 

1171 

1172 return { 

1173 "enabled": self._enabled, 

1174 "hit_count": self._hit_count, 

1175 "miss_count": self._miss_count, 

1176 "hit_rate": self._hit_count / total if total > 0 else 0.0, 

1177 "redis_hit_count": self._redis_hit_count, 

1178 "redis_miss_count": self._redis_miss_count, 

1179 "redis_hit_rate": self._redis_hit_count / redis_total if redis_total > 0 else 0.0, 

1180 "redis_available": self._redis_available, 

1181 "revoked_tokens_cached": len(self._revoked_jtis), 

1182 "context_cache_size": len(self._context_cache), 

1183 "role_cache_size": len(self._role_cache), 

1184 "teams_list_cache_size": len(self._teams_list_cache), 

1185 "team_membership_cache_size": len(self._team_cache), 

1186 "user_ttl": self._user_ttl, 

1187 "revocation_ttl": self._revocation_ttl, 

1188 "team_ttl": self._team_ttl, 

1189 "role_ttl": self._role_ttl, 

1190 "teams_list_enabled": self._teams_list_enabled, 

1191 "teams_list_ttl": self._teams_list_ttl, 

1192 } 

1193 

1194 def reset_stats(self) -> None: 

1195 """Reset hit/miss counters. 

1196 

1197 Examples: 

1198 >>> cache = AuthCache() 

1199 >>> cache._hit_count = 100 

1200 >>> cache.reset_stats() 

1201 >>> cache._hit_count 

1202 0 

1203 """ 

1204 self._hit_count = 0 

1205 self._miss_count = 0 

1206 self._redis_hit_count = 0 

1207 self._redis_miss_count = 0 

1208 

1209 

1210# Global singleton instance 

1211_auth_cache: Optional[AuthCache] = None 

1212 

1213 

1214def get_auth_cache() -> AuthCache: 

1215 """Get or create the singleton AuthCache instance. 

1216 

1217 Returns: 

1218 AuthCache: The singleton auth cache instance 

1219 

1220 Examples: 

1221 >>> cache = get_auth_cache() 

1222 >>> isinstance(cache, AuthCache) 

1223 True 

1224 """ 

1225 global _auth_cache # pylint: disable=global-statement 

1226 if _auth_cache is None: 

1227 _auth_cache = AuthCache() 

1228 return _auth_cache 

1229 

1230 

1231# Convenience alias for direct import 

1232auth_cache = get_auth_cache()