Coverage for mcpgateway / cache / auth_cache.py: 100%

441 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/cache/auth_cache.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5 

6Authentication Data Cache. 

7 

8This module implements a thread-safe two-tier cache for authentication data. 

9L1 (in-memory) is checked first for lowest latency, with L2 (Redis) as a 

10shared distributed cache. It caches user data, team memberships, and token 

11revocation status to reduce database queries during authentication. 

12 

13Performance Impact: 

14 - Before: 3-4 DB queries per authenticated request 

15 - After: 0-1 DB queries (cache hit) per TTL period 

16 

17Security Considerations: 

18 - Short TTLs for revocation data (30s default) to limit exposure window 

19 - Cache invalidation on token revocation, user update, team change 

20 - JWT payloads are NOT cached (security risk) 

21 - Graceful fallback to DB on cache failure 

22 

23Examples: 

24 >>> from mcpgateway.cache.auth_cache import auth_cache 

25 >>> # Cache is used automatically by get_current_user() 

26 >>> # Manual invalidation after user update: 

27 >>> import asyncio 

28 >>> # asyncio.run(auth_cache.invalidate_user("user@example.com")) 

29""" 

30 

31# Standard 

32import asyncio 

33from dataclasses import dataclass 

34import logging 

35import threading 

36import time 

37from typing import Any, Dict, List, Optional, Set 

38 

39logger = logging.getLogger(__name__) 

40 

41# Sentinel value to represent "user is not a member" in Redis cache 

42# This allows distinguishing between "not a member" (cached) and "cache miss" 

43_NOT_A_MEMBER_SENTINEL = "__NOT_A_MEMBER__" 

44 

45 

46@dataclass 

47class CachedAuthContext: 

48 """Cached authentication context from batched DB query. 

49 

50 This dataclass holds user data, team membership, and revocation status 

51 retrieved from a single database roundtrip. 

52 

53 Attributes: 

54 user: User data dict (email, is_admin, is_active, etc.) or None 

55 personal_team_id: User's personal team ID or None 

56 is_token_revoked: Whether the JWT is revoked 

57 

58 Examples: 

59 >>> ctx = CachedAuthContext( 

60 ... user={"email": "test@example.com", "is_admin": False}, 

61 ... personal_team_id="team-123", 

62 ... is_token_revoked=False 

63 ... ) 

64 >>> ctx.is_token_revoked 

65 False 

66 """ 

67 

68 user: Optional[Dict[str, Any]] = None 

69 personal_team_id: Optional[str] = None 

70 is_token_revoked: bool = False 

71 

72 

73@dataclass 

74class CacheEntry: 

75 """Cache entry with value and expiry timestamp. 

76 

77 Examples: 

78 >>> import time 

79 >>> entry = CacheEntry(value={"key": "value"}, expiry=time.time() + 60) 

80 >>> entry.is_expired() 

81 False 

82 """ 

83 

84 value: Any 

85 expiry: float 

86 

87 def is_expired(self) -> bool: 

88 """Check if this cache entry has expired. 

89 

90 Returns: 

91 bool: True if the entry has expired, False otherwise. 

92 """ 

93 return time.time() >= self.expiry 

94 

95 

96class AuthCache: 

97 """Thread-safe two-tier authentication cache (L1 in-memory + L2 Redis). 

98 

99 This cache reduces database load during authentication by caching: 

100 - User data (email, is_admin, is_active, etc.) 

101 - Personal team ID for the user 

102 - Token revocation status 

103 

104 Cache lookup checks L1 (in-memory) first for lowest latency, then L2 

105 (Redis) for distributed consistency. Redis hits are written through 

106 to L1 for subsequent requests. 

107 

108 Attributes: 

109 user_ttl: TTL in seconds for user data cache (default: 60) 

110 revocation_ttl: TTL in seconds for revocation cache (default: 30) 

111 team_ttl: TTL in seconds for team cache (default: 60) 

112 

113 Examples: 

114 >>> cache = AuthCache(user_ttl=60, revocation_ttl=30) 

115 >>> cache.stats()["hit_count"] 

116 0 

117 """ 

118 

119 _NOT_CACHED = object() 

120 

121 def __init__( 

122 self, 

123 user_ttl: Optional[int] = None, 

124 revocation_ttl: Optional[int] = None, 

125 team_ttl: Optional[int] = None, 

126 role_ttl: Optional[int] = None, 

127 enabled: Optional[bool] = None, 

128 ): 

129 """Initialize the auth cache. 

130 

131 Args: 

132 user_ttl: TTL for user data cache in seconds (default: from settings or 60) 

133 revocation_ttl: TTL for revocation cache in seconds (default: from settings or 30) 

134 team_ttl: TTL for team cache in seconds (default: from settings or 60) 

135 role_ttl: TTL for role cache in seconds (default: from settings or 60) 

136 enabled: Whether caching is enabled (default: from settings or True) 

137 

138 Examples: 

139 >>> cache = AuthCache(user_ttl=120, revocation_ttl=30) 

140 >>> cache._user_ttl 

141 120 

142 """ 

143 # Import settings lazily to avoid circular imports 

144 try: 

145 # First-Party 

146 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel 

147 

148 self._user_ttl = user_ttl or getattr(settings, "auth_cache_user_ttl", 60) 

149 self._revocation_ttl = revocation_ttl or getattr(settings, "auth_cache_revocation_ttl", 30) 

150 self._team_ttl = team_ttl or getattr(settings, "auth_cache_team_ttl", 60) 

151 self._role_ttl = role_ttl or getattr(settings, "auth_cache_role_ttl", 60) 

152 self._teams_list_ttl = getattr(settings, "auth_cache_teams_ttl", 60) 

153 self._teams_list_enabled = getattr(settings, "auth_cache_teams_enabled", True) 

154 self._enabled = enabled if enabled is not None else getattr(settings, "auth_cache_enabled", True) 

155 self._cache_prefix = getattr(settings, "cache_prefix", "mcpgw:") 

156 except ImportError: 

157 self._user_ttl = user_ttl or 60 

158 self._revocation_ttl = revocation_ttl or 30 

159 self._team_ttl = team_ttl or 60 

160 self._role_ttl = role_ttl or 60 

161 self._teams_list_ttl = 60 

162 self._teams_list_enabled = True 

163 self._enabled = enabled if enabled is not None else True 

164 self._cache_prefix = "mcpgw:" 

165 

166 # In-memory cache (fallback when Redis unavailable) 

167 self._user_cache: Dict[str, CacheEntry] = {} 

168 self._team_cache: Dict[str, CacheEntry] = {} 

169 self._revocation_cache: Dict[str, CacheEntry] = {} 

170 self._context_cache: Dict[str, CacheEntry] = {} 

171 self._role_cache: Dict[str, CacheEntry] = {} 

172 self._teams_list_cache: Dict[str, CacheEntry] = {} 

173 

174 # Known revoked tokens (fast local lookup) 

175 self._revoked_jtis: Set[str] = set() 

176 

177 # Thread safety 

178 self._lock = threading.Lock() 

179 

180 # Redis availability (None = not checked yet) 

181 self._redis_checked = False 

182 self._redis_available = False 

183 

184 # Statistics 

185 self._hit_count = 0 

186 self._miss_count = 0 

187 self._redis_hit_count = 0 

188 self._redis_miss_count = 0 

189 

190 logger.info( 

191 f"AuthCache initialized: enabled={self._enabled}, " 

192 f"user_ttl={self._user_ttl}s, revocation_ttl={self._revocation_ttl}s, " 

193 f"team_ttl={self._team_ttl}s, role_ttl={self._role_ttl}s, " 

194 f"teams_list_enabled={self._teams_list_enabled}, teams_list_ttl={self._teams_list_ttl}s" 

195 ) 

196 

197 def _get_redis_key(self, key_type: str, identifier: str) -> str: 

198 """Generate Redis key with proper prefix. 

199 

200 Args: 

201 key_type: Type of cache entry (user, team, revoke, ctx) 

202 identifier: Unique identifier (email, jti, etc.) 

203 

204 Returns: 

205 Full Redis key with prefix 

206 

207 Examples: 

208 >>> cache = AuthCache() 

209 >>> cache._get_redis_key("user", "test@example.com") 

210 'mcpgw:auth:user:test@example.com' 

211 """ 

212 return f"{self._cache_prefix}auth:{key_type}:{identifier}" 

213 

214 async def _get_redis_client(self): 

215 """Get Redis client if available. 

216 

217 Returns: 

218 Redis client or None if unavailable 

219 """ 

220 try: 

221 # First-Party 

222 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

223 

224 client = await get_redis_client() 

225 if client and not self._redis_checked: 

226 self._redis_checked = True 

227 self._redis_available = True 

228 logger.debug("AuthCache: Redis client available") 

229 return client 

230 except Exception as e: 

231 if not self._redis_checked: 

232 self._redis_checked = True 

233 self._redis_available = False 

234 logger.debug(f"AuthCache: Redis unavailable, using in-memory cache: {e}") 

235 return None 

236 

237 async def get_auth_context( 

238 self, 

239 email: str, 

240 jti: Optional[str] = None, 

241 ) -> Optional[CachedAuthContext]: 

242 """Get cached authentication context. 

243 

244 Checks cache for user data, team membership, and revocation status. 

245 Returns None on cache miss. 

246 

247 Args: 

248 email: User email address 

249 jti: JWT ID for revocation check (optional) 

250 

251 Returns: 

252 CachedAuthContext if found in cache, None otherwise 

253 

254 Examples: 

255 >>> import asyncio 

256 >>> cache = AuthCache() 

257 >>> result = asyncio.run(cache.get_auth_context("test@example.com")) 

258 >>> result is None # Cache miss on fresh cache 

259 True 

260 """ 

261 if not self._enabled: 

262 return None 

263 

264 # Check for known revoked token first (fast local check) 

265 if jti and jti in self._revoked_jtis: 

266 self._hit_count += 1 

267 return CachedAuthContext(is_token_revoked=True) 

268 

269 cache_key = f"{email}:{jti or 'no-jti'}" 

270 

271 # Check L1 in-memory cache first (no network I/O) 

272 entry = self._context_cache.get(cache_key) 

273 if entry and not entry.is_expired(): 

274 self._hit_count += 1 

275 return entry.value 

276 

277 # Check L2 Redis cache 

278 redis = await self._get_redis_client() 

279 if redis: 

280 try: 

281 redis_key = self._get_redis_key("ctx", cache_key) 

282 data = await redis.get(redis_key) 

283 if data: 

284 # Third-Party 

285 import orjson # pylint: disable=import-outside-toplevel 

286 

287 cached = orjson.loads(data) 

288 result = CachedAuthContext( 

289 user=cached.get("user"), 

290 personal_team_id=cached.get("personal_team_id"), 

291 is_token_revoked=cached.get("is_token_revoked", False), 

292 ) 

293 self._hit_count += 1 

294 self._redis_hit_count += 1 

295 

296 # Write-through: populate L1 from Redis hit 

297 ttl = min(self._user_ttl, self._revocation_ttl, self._team_ttl) 

298 with self._lock: 

299 self._context_cache[cache_key] = CacheEntry( 

300 value=result, 

301 expiry=time.time() + ttl, 

302 ) 

303 

304 return result 

305 self._redis_miss_count += 1 

306 except Exception as e: 

307 logger.warning(f"AuthCache Redis get failed: {e}") 

308 

309 self._miss_count += 1 

310 return None 

311 

312 async def set_auth_context( 

313 self, 

314 email: str, 

315 jti: Optional[str], 

316 context: CachedAuthContext, 

317 ) -> None: 

318 """Store authentication context in cache. 

319 

320 Stores in both Redis (if available) and in-memory cache. 

321 

322 Args: 

323 email: User email address 

324 jti: JWT ID (optional) 

325 context: Authentication context to cache 

326 

327 Examples: 

328 >>> import asyncio 

329 >>> cache = AuthCache() 

330 >>> ctx = CachedAuthContext( 

331 ... user={"email": "test@example.com"}, 

332 ... personal_team_id="team-1", 

333 ... is_token_revoked=False 

334 ... ) 

335 >>> asyncio.run(cache.set_auth_context("test@example.com", "jti-123", ctx)) 

336 """ 

337 if not self._enabled: 

338 return 

339 

340 cache_key = f"{email}:{jti or 'no-jti'}" 

341 

342 # Use shortest TTL for combined context 

343 ttl = min(self._user_ttl, self._revocation_ttl, self._team_ttl) 

344 

345 # Prepare data for serialization 

346 data = { 

347 "user": context.user, 

348 "personal_team_id": context.personal_team_id, 

349 "is_token_revoked": context.is_token_revoked, 

350 } 

351 

352 # Store in Redis 

353 redis = await self._get_redis_client() 

354 if redis: 

355 try: 

356 # Third-Party 

357 import orjson # pylint: disable=import-outside-toplevel 

358 

359 redis_key = self._get_redis_key("ctx", cache_key) 

360 await redis.setex(redis_key, ttl, orjson.dumps(data)) 

361 except Exception as e: 

362 logger.warning(f"AuthCache Redis set failed: {e}") 

363 

364 # Store in in-memory cache 

365 with self._lock: 

366 self._context_cache[cache_key] = CacheEntry( 

367 value=context, 

368 expiry=time.time() + ttl, 

369 ) 

370 

371 async def invalidate_user(self, email: str) -> None: 

372 """Invalidate cached data for a user. 

373 

374 Call this when user data changes (password, profile, etc.). 

375 

376 Args: 

377 email: User email to invalidate 

378 

379 Examples: 

380 >>> import asyncio 

381 >>> cache = AuthCache() 

382 >>> asyncio.run(cache.invalidate_user("test@example.com")) 

383 """ 

384 logger.debug(f"AuthCache: Invalidating user cache for {email}") 

385 

386 # Clear in-memory caches 

387 with self._lock: 

388 # Clear any context cache entries for this user 

389 keys_to_remove = [k for k in self._context_cache if k.startswith(f"{email}:")] 

390 for key in keys_to_remove: 

391 self._context_cache.pop(key, None) 

392 

393 self._user_cache.pop(email, None) 

394 

395 # Clear team membership cache entries (keys are email:team_ids) 

396 team_keys_to_remove = [k for k in self._team_cache if k.startswith(f"{email}:")] 

397 for key in team_keys_to_remove: 

398 self._team_cache.pop(key, None) 

399 

400 # Clear Redis 

401 redis = await self._get_redis_client() 

402 if redis: 

403 try: 

404 # Delete user-specific keys 

405 await redis.delete( 

406 self._get_redis_key("user", email), 

407 self._get_redis_key("team", email), 

408 ) 

409 # Delete context keys (pattern match) 

410 pattern = self._get_redis_key("ctx", f"{email}:*") 

411 async for key in redis.scan_iter(match=pattern): 

412 await redis.delete(key) 

413 # Delete membership keys (pattern match) 

414 membership_pattern = self._get_redis_key("membership", f"{email}:*") 

415 async for key in redis.scan_iter(match=membership_pattern): 

416 await redis.delete(key) 

417 

418 # Publish invalidation for other workers 

419 await redis.publish("mcpgw:auth:invalidate", f"user:{email}") 

420 except Exception as e: 

421 logger.warning(f"AuthCache Redis invalidate_user failed: {e}") 

422 

423 async def invalidate_revocation(self, jti: str) -> None: 

424 """Invalidate cache for a revoked token. 

425 

426 Call this when a token is revoked. 

427 

428 Args: 

429 jti: JWT ID of revoked token 

430 

431 Examples: 

432 >>> import asyncio 

433 >>> cache = AuthCache() 

434 >>> asyncio.run(cache.invalidate_revocation("jti-123")) 

435 """ 

436 logger.debug(f"AuthCache: Invalidating revocation cache for jti={jti[:8]}...") 

437 

438 # Add to local revoked set for fast lookup 

439 with self._lock: 

440 self._revoked_jtis.add(jti) 

441 self._revocation_cache.pop(jti, None) 

442 

443 # Clear any context cache entries with this JTI 

444 keys_to_remove = [k for k in self._context_cache if k.endswith(f":{jti}")] 

445 for key in keys_to_remove: 

446 self._context_cache.pop(key, None) 

447 

448 # Update Redis 

449 redis = await self._get_redis_client() 

450 if redis: 

451 try: 

452 # Mark as revoked in Redis 

453 await redis.setex( 

454 self._get_redis_key("revoke", jti), 

455 86400, # 24 hour expiry for revocation markers 

456 "1", 

457 ) 

458 # Add to revoked tokens set 

459 await redis.sadd("mcpgw:auth:revoked_tokens", jti) 

460 

461 # Delete any cached contexts with this JTI 

462 pattern = self._get_redis_key("ctx", f"*:{jti}") 

463 async for key in redis.scan_iter(match=pattern): 

464 await redis.delete(key) 

465 

466 # Publish invalidation for other workers 

467 await redis.publish("mcpgw:auth:invalidate", f"revoke:{jti}") 

468 except Exception as e: 

469 logger.warning(f"AuthCache Redis invalidate_revocation failed: {e}") 

470 

471 async def invalidate_team(self, email: str) -> None: 

472 """Invalidate team cache for a user. 

473 

474 Call this when team membership changes. 

475 

476 Args: 

477 email: User email whose team changed 

478 

479 Examples: 

480 >>> import asyncio 

481 >>> cache = AuthCache() 

482 >>> asyncio.run(cache.invalidate_team("test@example.com")) 

483 """ 

484 logger.debug(f"AuthCache: Invalidating team cache for {email}") 

485 

486 # Clear in-memory caches 

487 with self._lock: 

488 self._team_cache.pop(email, None) 

489 # Clear context cache entries for this user 

490 keys_to_remove = [k for k in self._context_cache if k.startswith(f"{email}:")] 

491 for key in keys_to_remove: 

492 self._context_cache.pop(key, None) 

493 

494 # Clear Redis 

495 redis = await self._get_redis_client() 

496 if redis: 

497 try: 

498 await redis.delete(self._get_redis_key("team", email)) 

499 # Delete context keys 

500 pattern = self._get_redis_key("ctx", f"{email}:*") 

501 async for key in redis.scan_iter(match=pattern): 

502 await redis.delete(key) 

503 

504 # Publish invalidation 

505 await redis.publish("mcpgw:auth:invalidate", f"team:{email}") 

506 except Exception as e: 

507 logger.warning(f"AuthCache Redis invalidate_team failed: {e}") 

508 

509 async def get_user_role(self, email: str, team_id: str) -> Optional[str]: 

510 """Get cached user role in a team. 

511 

512 Returns: 

513 - None: Cache miss (caller should check DB) 

514 - "": User is not a member of the team (cached negative result) 

515 - Role string: User's role in the team (cached) 

516 

517 Args: 

518 email: User email 

519 team_id: Team ID 

520 

521 Examples: 

522 >>> import asyncio 

523 >>> cache = AuthCache() 

524 >>> result = asyncio.run(cache.get_user_role("test@example.com", "team-123")) 

525 >>> result is None # Cache miss 

526 True 

527 """ 

528 if not self._enabled: 

529 return None 

530 

531 cache_key = f"{email}:{team_id}" 

532 

533 # Check L1 in-memory cache first (no network I/O) 

534 entry = self._role_cache.get(cache_key) 

535 if entry and not entry.is_expired(): 

536 self._hit_count += 1 

537 # Return empty string for None (not a member) to distinguish from cache miss 

538 return "" if entry.value is None else entry.value 

539 

540 # Check L2 Redis cache 

541 redis = await self._get_redis_client() 

542 if redis: 

543 try: 

544 redis_key = self._get_redis_key("role", cache_key) 

545 data = await redis.get(redis_key) 

546 if data is not None: 

547 self._hit_count += 1 

548 self._redis_hit_count += 1 

549 # Role is stored as plain string, decode it 

550 decoded = data.decode() if isinstance(data, bytes) else data 

551 # Convert sentinel to empty string (user is not a member) 

552 role_value = "" if decoded == _NOT_A_MEMBER_SENTINEL else decoded 

553 

554 # Write-through: populate L1 from Redis hit 

555 # Store as None for not-a-member to match existing L1 storage format 

556 with self._lock: 

557 self._role_cache[cache_key] = CacheEntry( 

558 value=None if role_value == "" else role_value, 

559 expiry=time.time() + self._role_ttl, 

560 ) 

561 

562 return role_value 

563 self._redis_miss_count += 1 

564 except Exception as e: 

565 logger.warning(f"AuthCache Redis get_user_role failed: {e}") 

566 

567 self._miss_count += 1 

568 return None 

569 

570 async def set_user_role(self, email: str, team_id: str, role: Optional[str]) -> None: 

571 """Store user role in cache. 

572 

573 Args: 

574 email: User email 

575 team_id: Team ID 

576 role: User's role in the team (or None if not a member) 

577 

578 Examples: 

579 >>> import asyncio 

580 >>> cache = AuthCache() 

581 >>> asyncio.run(cache.set_user_role("test@example.com", "team-123", "admin")) 

582 """ 

583 if not self._enabled: 

584 return 

585 

586 cache_key = f"{email}:{team_id}" 

587 # Store None as sentinel value to distinguish "not a member" from cache miss 

588 role_value = role if role is not None else _NOT_A_MEMBER_SENTINEL 

589 

590 # Store in Redis 

591 redis = await self._get_redis_client() 

592 if redis: 

593 try: 

594 redis_key = self._get_redis_key("role", cache_key) 

595 await redis.setex(redis_key, self._role_ttl, role_value) 

596 except Exception as e: 

597 logger.warning(f"AuthCache Redis set_user_role failed: {e}") 

598 

599 # Store in in-memory cache 

600 with self._lock: 

601 self._role_cache[cache_key] = CacheEntry( 

602 value=role, 

603 expiry=time.time() + self._role_ttl, 

604 ) 

605 

606 async def invalidate_user_role(self, email: str, team_id: str) -> None: 

607 """Invalidate cached role for a user in a team. 

608 

609 Call this when a user's role changes in a team. 

610 

611 Args: 

612 email: User email 

613 team_id: Team ID 

614 

615 Examples: 

616 >>> import asyncio 

617 >>> cache = AuthCache() 

618 >>> asyncio.run(cache.invalidate_user_role("test@example.com", "team-123")) 

619 """ 

620 logger.debug(f"AuthCache: Invalidating role cache for {email} in team {team_id}") 

621 

622 cache_key = f"{email}:{team_id}" 

623 

624 # Clear in-memory cache 

625 with self._lock: 

626 self._role_cache.pop(cache_key, None) 

627 

628 # Clear Redis 

629 redis = await self._get_redis_client() 

630 if redis: 

631 try: 

632 await redis.delete(self._get_redis_key("role", cache_key)) 

633 # Publish invalidation for other workers 

634 await redis.publish("mcpgw:auth:invalidate", f"role:{email}:{team_id}") 

635 except Exception as e: 

636 logger.warning(f"AuthCache Redis invalidate_user_role failed: {e}") 

637 

638 async def invalidate_team_roles(self, team_id: str) -> None: 

639 """Invalidate all cached roles for a team. 

640 

641 Call this when team membership changes significantly (e.g., team deletion). 

642 

643 Args: 

644 team_id: Team ID 

645 

646 Examples: 

647 >>> import asyncio 

648 >>> cache = AuthCache() 

649 >>> asyncio.run(cache.invalidate_team_roles("team-123")) 

650 """ 

651 logger.debug(f"AuthCache: Invalidating all role caches for team {team_id}") 

652 

653 # Clear in-memory cache entries for this team 

654 with self._lock: 

655 keys_to_remove = [k for k in self._role_cache if k.endswith(f":{team_id}")] 

656 for key in keys_to_remove: 

657 self._role_cache.pop(key, None) 

658 

659 # Clear Redis 

660 redis = await self._get_redis_client() 

661 if redis: 

662 try: 

663 # Pattern match all role keys for this team 

664 pattern = self._get_redis_key("role", f"*:{team_id}") 

665 async for key in redis.scan_iter(match=pattern): 

666 await redis.delete(key) 

667 # Publish invalidation 

668 await redis.publish("mcpgw:auth:invalidate", f"team_roles:{team_id}") 

669 except Exception as e: 

670 logger.warning(f"AuthCache Redis invalidate_team_roles failed: {e}") 

671 

672 async def get_user_teams(self, cache_key: str) -> Optional[List[str]]: 

673 """Get cached team IDs for a user. 

674 

675 The cache_key should be in the format "email:include_personal" to 

676 distinguish between calls with different include_personal flags. 

677 

678 Returns: 

679 - None: Cache miss (caller should query DB) 

680 - Empty list: User has no teams (cached result) 

681 - List of team IDs: Cached team IDs 

682 

683 Args: 

684 cache_key: Cache key in format "email:include_personal" 

685 

686 Examples: 

687 >>> import asyncio 

688 >>> cache = AuthCache() 

689 >>> result = asyncio.run(cache.get_user_teams("test@example.com:True")) 

690 >>> result is None # Cache miss 

691 True 

692 """ 

693 if not self._enabled or not self._teams_list_enabled: 

694 return None 

695 

696 # Check L1 in-memory cache first (no network I/O) 

697 entry = self._teams_list_cache.get(cache_key) 

698 if entry and not entry.is_expired(): 

699 self._hit_count += 1 

700 return entry.value 

701 

702 # Check L2 Redis cache 

703 redis = await self._get_redis_client() 

704 if redis: 

705 try: 

706 redis_key = self._get_redis_key("teams", cache_key) 

707 data = await redis.get(redis_key) 

708 if data is not None: 

709 self._hit_count += 1 

710 self._redis_hit_count += 1 

711 # Third-Party 

712 import orjson # pylint: disable=import-outside-toplevel 

713 

714 team_ids = orjson.loads(data) 

715 

716 # Write-through: populate L1 from Redis hit 

717 with self._lock: 

718 self._teams_list_cache[cache_key] = CacheEntry( 

719 value=team_ids, 

720 expiry=time.time() + self._teams_list_ttl, 

721 ) 

722 

723 return team_ids 

724 self._redis_miss_count += 1 

725 except Exception as e: 

726 logger.warning(f"AuthCache Redis get_user_teams failed: {e}") 

727 

728 self._miss_count += 1 

729 return None 

730 

731 async def set_user_teams(self, cache_key: str, team_ids: List[str]) -> None: 

732 """Store team IDs for a user in cache. 

733 

734 Args: 

735 cache_key: Cache key in format "email:include_personal" 

736 team_ids: List of team IDs the user belongs to 

737 

738 Examples: 

739 >>> import asyncio 

740 >>> cache = AuthCache() 

741 >>> asyncio.run(cache.set_user_teams("test@example.com:True", ["team-1", "team-2"])) 

742 """ 

743 if not self._enabled or not self._teams_list_enabled: 

744 return 

745 

746 # Store in Redis 

747 redis = await self._get_redis_client() 

748 if redis: 

749 try: 

750 # Third-Party 

751 import orjson # pylint: disable=import-outside-toplevel 

752 

753 redis_key = self._get_redis_key("teams", cache_key) 

754 await redis.setex(redis_key, self._teams_list_ttl, orjson.dumps(team_ids)) 

755 except Exception as e: 

756 logger.warning(f"AuthCache Redis set_user_teams failed: {e}") 

757 

758 # Store in in-memory cache 

759 with self._lock: 

760 self._teams_list_cache[cache_key] = CacheEntry( 

761 value=team_ids, 

762 expiry=time.time() + self._teams_list_ttl, 

763 ) 

764 

765 async def invalidate_user_teams(self, email: str) -> None: 

766 """Invalidate cached teams list for a user. 

767 

768 Call this when a user's team membership changes (add/remove member, 

769 delete team, approve join request). 

770 

771 This invalidates both include_personal=True and include_personal=False 

772 cache entries for the user. 

773 

774 Args: 

775 email: User email whose teams cache should be invalidated 

776 

777 Examples: 

778 >>> import asyncio 

779 >>> cache = AuthCache() 

780 >>> asyncio.run(cache.invalidate_user_teams("test@example.com")) 

781 """ 

782 logger.debug(f"AuthCache: Invalidating teams list cache for {email}") 

783 

784 # Clear in-memory cache entries for this user (both True and False variants) 

785 with self._lock: 

786 keys_to_remove = [k for k in self._teams_list_cache if k.startswith(f"{email}:")] 

787 for key in keys_to_remove: 

788 self._teams_list_cache.pop(key, None) 

789 

790 # Clear Redis 

791 redis = await self._get_redis_client() 

792 if redis: 

793 try: 

794 # Delete both variants 

795 await redis.delete( 

796 self._get_redis_key("teams", f"{email}:True"), 

797 self._get_redis_key("teams", f"{email}:False"), 

798 ) 

799 # Publish invalidation for other workers 

800 await redis.publish("mcpgw:auth:invalidate", f"teams:{email}") 

801 except Exception as e: 

802 logger.warning(f"AuthCache Redis invalidate_user_teams failed: {e}") 

803 

804 # ========================================================================= 

805 # Team Membership Validation Cache 

806 # ========================================================================= 

807 # Used by TokenScopingMiddleware to cache email_team_members lookups. 

808 # This prevents repeated DB queries for the same user+teams combination. 

809 

810 def get_team_membership_valid_sync(self, user_email: str, team_ids: List[str]) -> Optional[bool]: 

811 """Get cached team membership validation result (synchronous). 

812 

813 This is the synchronous version used by token_scoping middleware. 

814 Returns None on cache miss (caller should check DB). 

815 

816 Args: 

817 user_email: User email 

818 team_ids: List of team IDs to validate membership for 

819 

820 Returns: 

821 - None: Cache miss (caller should query DB) 

822 - True: User is valid member of all teams (cached) 

823 - False: User is NOT a valid member of all teams (cached) 

824 

825 Examples: 

826 >>> cache = AuthCache() 

827 >>> result = cache.get_team_membership_valid_sync("test@example.com", ["team-1"]) 

828 >>> result is None # Cache miss 

829 True 

830 """ 

831 if not self._enabled or not team_ids: 

832 return None 

833 

834 # Create cache key from user + sorted team IDs 

835 sorted_teams = ":".join(sorted(team_ids)) 

836 cache_key = f"{user_email}:{sorted_teams}" 

837 

838 # Check in-memory cache only (sync version) 

839 entry = self._team_cache.get(cache_key) 

840 if entry and not entry.is_expired(): 

841 self._hit_count += 1 

842 return entry.value 

843 

844 self._miss_count += 1 

845 return None 

846 

847 def set_team_membership_valid_sync(self, user_email: str, team_ids: List[str], valid: bool) -> None: 

848 """Store team membership validation result in cache (synchronous). 

849 

850 Args: 

851 user_email: User email 

852 team_ids: List of team IDs that were validated 

853 valid: Whether user is a valid member of all teams 

854 

855 Examples: 

856 >>> cache = AuthCache() 

857 >>> cache.set_team_membership_valid_sync("test@example.com", ["team-1"], True) 

858 """ 

859 if not self._enabled or not team_ids: 

860 return 

861 

862 # Create cache key from user + sorted team IDs 

863 sorted_teams = ":".join(sorted(team_ids)) 

864 cache_key = f"{user_email}:{sorted_teams}" 

865 

866 # Store in in-memory cache 

867 with self._lock: 

868 self._team_cache[cache_key] = CacheEntry( 

869 value=valid, 

870 expiry=time.time() + self._team_ttl, 

871 ) 

872 

873 async def get_team_membership_valid(self, user_email: str, team_ids: List[str]) -> Optional[bool]: 

874 """Get cached team membership validation result (async). 

875 

876 Returns None on cache miss (caller should check DB). 

877 

878 Args: 

879 user_email: User email 

880 team_ids: List of team IDs to validate membership for 

881 

882 Returns: 

883 - None: Cache miss (caller should query DB) 

884 - True: User is valid member of all teams (cached) 

885 - False: User is NOT a valid member of all teams (cached) 

886 

887 Examples: 

888 >>> import asyncio 

889 >>> cache = AuthCache() 

890 >>> result = asyncio.run(cache.get_team_membership_valid("test@example.com", ["team-1"])) 

891 >>> result is None # Cache miss 

892 True 

893 """ 

894 if not self._enabled or not team_ids: 

895 return None 

896 

897 # Create cache key from user + sorted team IDs 

898 sorted_teams = ":".join(sorted(team_ids)) 

899 cache_key = f"{user_email}:{sorted_teams}" 

900 

901 # Check L1 in-memory cache first (no network I/O) 

902 entry = self._team_cache.get(cache_key) 

903 if entry and not entry.is_expired(): 

904 self._hit_count += 1 

905 return entry.value 

906 

907 # Check L2 Redis cache 

908 redis = await self._get_redis_client() 

909 if redis: 

910 try: 

911 redis_key = self._get_redis_key("membership", cache_key) 

912 data = await redis.get(redis_key) 

913 if data is not None: 

914 self._hit_count += 1 

915 self._redis_hit_count += 1 

916 # Stored as "1" for True, "0" for False 

917 decoded = data.decode() if isinstance(data, bytes) else data 

918 result = decoded == "1" 

919 

920 # Write-through: populate L1 from Redis hit 

921 with self._lock: 

922 self._team_cache[cache_key] = CacheEntry( 

923 value=result, 

924 expiry=time.time() + self._team_ttl, 

925 ) 

926 

927 return result 

928 self._redis_miss_count += 1 

929 except Exception as e: 

930 logger.warning(f"AuthCache Redis get_team_membership_valid failed: {e}") 

931 

932 self._miss_count += 1 

933 return None 

934 

935 async def set_team_membership_valid(self, user_email: str, team_ids: List[str], valid: bool) -> None: 

936 """Store team membership validation result in cache (async). 

937 

938 Args: 

939 user_email: User email 

940 team_ids: List of team IDs that were validated 

941 valid: Whether user is a valid member of all teams 

942 

943 Examples: 

944 >>> import asyncio 

945 >>> cache = AuthCache() 

946 >>> asyncio.run(cache.set_team_membership_valid("test@example.com", ["team-1"], True)) 

947 """ 

948 if not self._enabled or not team_ids: 

949 return 

950 

951 # Create cache key from user + sorted team IDs 

952 sorted_teams = ":".join(sorted(team_ids)) 

953 cache_key = f"{user_email}:{sorted_teams}" 

954 

955 # Store in Redis 

956 redis = await self._get_redis_client() 

957 if redis: 

958 try: 

959 redis_key = self._get_redis_key("membership", cache_key) 

960 # Store as "1" for True, "0" for False 

961 await redis.setex(redis_key, self._team_ttl, "1" if valid else "0") 

962 except Exception as e: 

963 logger.warning(f"AuthCache Redis set_team_membership_valid failed: {e}") 

964 

965 # Store in in-memory cache 

966 with self._lock: 

967 self._team_cache[cache_key] = CacheEntry( 

968 value=valid, 

969 expiry=time.time() + self._team_ttl, 

970 ) 

971 

972 async def invalidate_team_membership(self, user_email: str) -> None: 

973 """Invalidate team membership cache for a user. 

974 

975 Call this when a user's team membership changes (add/remove member, 

976 role change, deactivation). 

977 

978 Args: 

979 user_email: User email whose membership cache should be invalidated 

980 

981 Examples: 

982 >>> import asyncio 

983 >>> cache = AuthCache() 

984 >>> asyncio.run(cache.invalidate_team_membership("test@example.com")) 

985 """ 

986 logger.debug(f"AuthCache: Invalidating team membership cache for {user_email}") 

987 

988 # Clear in-memory cache entries for this user 

989 with self._lock: 

990 keys_to_remove = [k for k in self._team_cache if k.startswith(f"{user_email}:")] 

991 for key in keys_to_remove: 

992 self._team_cache.pop(key, None) 

993 

994 # Clear Redis 

995 redis = await self._get_redis_client() 

996 if redis: 

997 try: 

998 # Pattern match all membership keys for this user 

999 pattern = self._get_redis_key("membership", f"{user_email}:*") 

1000 async for key in redis.scan_iter(match=pattern): 

1001 await redis.delete(key) 

1002 # Publish invalidation for other workers 

1003 await redis.publish("mcpgw:auth:invalidate", f"membership:{user_email}") 

1004 except Exception as e: 

1005 logger.warning(f"AuthCache Redis invalidate_team_membership failed: {e}") 

1006 

1007 async def is_token_revoked(self, jti: str) -> Optional[bool]: 

1008 """Check if a token is revoked (cached check only). 

1009 

1010 Returns None on cache miss (caller should check DB). 

1011 

1012 Args: 

1013 jti: JWT ID to check 

1014 

1015 Returns: 

1016 True if revoked, False if not revoked, None if unknown 

1017 

1018 Examples: 

1019 >>> import asyncio 

1020 >>> cache = AuthCache() 

1021 >>> cache._revoked_jtis.add("revoked-jti") 

1022 >>> asyncio.run(cache.is_token_revoked("revoked-jti")) 

1023 True 

1024 """ 

1025 if not self._enabled: 

1026 return None 

1027 

1028 # Fast local check 

1029 if jti in self._revoked_jtis: 

1030 return True 

1031 

1032 # Check L1 in-memory revocation cache 

1033 entry = self._revocation_cache.get(jti) 

1034 if entry and not entry.is_expired(): 

1035 return entry.value 

1036 

1037 # Check L2 Redis 

1038 redis = await self._get_redis_client() 

1039 if redis: 

1040 try: 

1041 # Check revoked tokens set 

1042 if await redis.sismember("mcpgw:auth:revoked_tokens", jti): 

1043 # Add to local set for faster future lookups 

1044 with self._lock: 

1045 self._revoked_jtis.add(jti) 

1046 # Write-through: populate L1 revocation cache 

1047 self._revocation_cache[jti] = CacheEntry( 

1048 value=True, 

1049 expiry=time.time() + self._revocation_ttl, 

1050 ) 

1051 return True 

1052 

1053 # Check individual revocation key 

1054 if await redis.exists(self._get_redis_key("revoke", jti)): 

1055 with self._lock: 

1056 self._revoked_jtis.add(jti) 

1057 # Write-through: populate L1 revocation cache 

1058 self._revocation_cache[jti] = CacheEntry( 

1059 value=True, 

1060 expiry=time.time() + self._revocation_ttl, 

1061 ) 

1062 return True 

1063 except Exception as e: 

1064 logger.warning(f"AuthCache Redis is_token_revoked failed: {e}") 

1065 

1066 return None 

1067 

1068 async def sync_revoked_tokens(self) -> None: 

1069 """Sync revoked tokens from database to cache on startup. 

1070 

1071 Should be called during application startup to populate the 

1072 revocation cache. 

1073 

1074 Examples: 

1075 >>> import asyncio 

1076 >>> cache = AuthCache() 

1077 >>> asyncio.run(cache.sync_revoked_tokens()) 

1078 """ 

1079 if not self._enabled: 

1080 return 

1081 

1082 try: 

1083 # Third-Party 

1084 from sqlalchemy import select # pylint: disable=import-outside-toplevel 

1085 

1086 # First-Party 

1087 from mcpgateway.db import fresh_db_session, TokenRevocation # pylint: disable=import-outside-toplevel 

1088 

1089 def _load_revoked_jtis() -> Set[str]: 

1090 """Load all revoked JTIs from database. 

1091 

1092 Returns: 

1093 Set of revoked JTI strings. 

1094 """ 

1095 with fresh_db_session() as db: 

1096 result = db.execute(select(TokenRevocation.jti)) 

1097 return {row[0] for row in result} 

1098 

1099 jtis = await asyncio.to_thread(_load_revoked_jtis) 

1100 

1101 with self._lock: 

1102 self._revoked_jtis.update(jtis) 

1103 

1104 # Also sync to Redis 

1105 redis = await self._get_redis_client() 

1106 if redis and jtis: 

1107 try: 

1108 await redis.sadd("mcpgw:auth:revoked_tokens", *jtis) 

1109 except Exception as e: 

1110 logger.warning(f"AuthCache Redis sync_revoked_tokens failed: {e}") 

1111 

1112 logger.info(f"AuthCache: Synced {len(jtis)} revoked tokens to cache") 

1113 

1114 except Exception as e: 

1115 logger.warning(f"AuthCache sync_revoked_tokens failed: {e}") 

1116 

1117 def invalidate_all(self) -> None: 

1118 """Invalidate all cached data. 

1119 

1120 Call during testing or when major configuration changes. 

1121 

1122 Examples: 

1123 >>> cache = AuthCache() 

1124 >>> cache.invalidate_all() 

1125 """ 

1126 with self._lock: 

1127 self._user_cache.clear() 

1128 self._team_cache.clear() 

1129 self._revocation_cache.clear() 

1130 self._context_cache.clear() 

1131 self._role_cache.clear() 

1132 self._teams_list_cache.clear() 

1133 # Don't clear _revoked_jtis as those are confirmed revocations 

1134 

1135 logger.info("AuthCache: All caches invalidated") 

1136 

1137 def stats(self) -> Dict[str, Any]: 

1138 """Get cache statistics. 

1139 

1140 Returns: 

1141 Dictionary with hit/miss counts and hit rate 

1142 

1143 Examples: 

1144 >>> cache = AuthCache() 

1145 >>> stats = cache.stats() 

1146 >>> "hit_count" in stats 

1147 True 

1148 """ 

1149 total = self._hit_count + self._miss_count 

1150 redis_total = self._redis_hit_count + self._redis_miss_count 

1151 

1152 return { 

1153 "enabled": self._enabled, 

1154 "hit_count": self._hit_count, 

1155 "miss_count": self._miss_count, 

1156 "hit_rate": self._hit_count / total if total > 0 else 0.0, 

1157 "redis_hit_count": self._redis_hit_count, 

1158 "redis_miss_count": self._redis_miss_count, 

1159 "redis_hit_rate": self._redis_hit_count / redis_total if redis_total > 0 else 0.0, 

1160 "redis_available": self._redis_available, 

1161 "revoked_tokens_cached": len(self._revoked_jtis), 

1162 "context_cache_size": len(self._context_cache), 

1163 "role_cache_size": len(self._role_cache), 

1164 "teams_list_cache_size": len(self._teams_list_cache), 

1165 "team_membership_cache_size": len(self._team_cache), 

1166 "user_ttl": self._user_ttl, 

1167 "revocation_ttl": self._revocation_ttl, 

1168 "team_ttl": self._team_ttl, 

1169 "role_ttl": self._role_ttl, 

1170 "teams_list_enabled": self._teams_list_enabled, 

1171 "teams_list_ttl": self._teams_list_ttl, 

1172 } 

1173 

1174 def reset_stats(self) -> None: 

1175 """Reset hit/miss counters. 

1176 

1177 Examples: 

1178 >>> cache = AuthCache() 

1179 >>> cache._hit_count = 100 

1180 >>> cache.reset_stats() 

1181 >>> cache._hit_count 

1182 0 

1183 """ 

1184 self._hit_count = 0 

1185 self._miss_count = 0 

1186 self._redis_hit_count = 0 

1187 self._redis_miss_count = 0 

1188 

1189 

1190# Global singleton instance 

1191_auth_cache: Optional[AuthCache] = None 

1192 

1193 

1194def get_auth_cache() -> AuthCache: 

1195 """Get or create the singleton AuthCache instance. 

1196 

1197 Returns: 

1198 AuthCache: The singleton auth cache instance 

1199 

1200 Examples: 

1201 >>> cache = get_auth_cache() 

1202 >>> isinstance(cache, AuthCache) 

1203 True 

1204 """ 

1205 global _auth_cache # pylint: disable=global-statement 

1206 if _auth_cache is None: 

1207 _auth_cache = AuthCache() 

1208 return _auth_cache 

1209 

1210 

1211# Convenience alias for direct import 

1212auth_cache = get_auth_cache()