Coverage for mcpgateway / auth.py: 100%

448 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/auth.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Shared authentication utilities. 

8 

9This module provides common authentication functions that can be shared 

10across different parts of the application without creating circular imports. 

11""" 

12 

13# Standard 

14import asyncio 

15from datetime import datetime, timezone 

16import hashlib 

17import logging 

18from typing import Any, Dict, Generator, List, Never, Optional 

19import uuid 

20 

21# Third-Party 

22from fastapi import Depends, HTTPException, status 

23from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer 

24from sqlalchemy.orm import Session 

25 

26# First-Party 

27from mcpgateway.config import settings 

28from mcpgateway.db import EmailUser, fresh_db_session, SessionLocal 

29from mcpgateway.plugins.framework import get_plugin_manager, GlobalContext, HttpAuthResolveUserPayload, HttpHeaderPayload, HttpHookType, PluginViolationError 

30from mcpgateway.utils.correlation_id import get_correlation_id 

31from mcpgateway.utils.verify_credentials import verify_jwt_token_cached 

32 

33# Security scheme 

34security = HTTPBearer(auto_error=False) 

35 

36 

37def _log_auth_event( 

38 logger: logging.Logger, 

39 message: str, 

40 level: int = logging.INFO, 

41 user_id: Optional[str] = None, 

42 auth_method: Optional[str] = None, 

43 auth_success: bool = False, 

44 security_event: Optional[str] = None, 

45 security_severity: str = "low", 

46 **extra_context, 

47) -> None: 

48 """Log authentication event with structured context and request_id. 

49 

50 This helper creates structured log records that include request_id from the 

51 correlation ID context, enabling end-to-end tracing of authentication flows. 

52 

53 Args: 

54 logger: Logger instance to use 

55 message: Log message 

56 level: Log level (default: INFO) 

57 user_id: User identifier 

58 auth_method: Authentication method used (jwt, api_token, etc.) 

59 auth_success: Whether authentication succeeded 

60 security_event: Type of security event (authentication, authorization, etc.) 

61 security_severity: Severity level (low, medium, high, critical) 

62 **extra_context: Additional context fields 

63 """ 

64 # Get request_id from correlation ID context 

65 request_id = get_correlation_id() 

66 

67 # Build structured log record 

68 extra = { 

69 "request_id": request_id, 

70 "entity_type": "auth", 

71 "auth_success": auth_success, 

72 "security_event": security_event or "authentication", 

73 "security_severity": security_severity, 

74 } 

75 

76 if user_id: 

77 extra["user_id"] = user_id 

78 if auth_method: 

79 extra["auth_method"] = auth_method 

80 

81 # Add any additional context 

82 extra.update(extra_context) 

83 

84 # Log with structured context 

85 logger.log(level, message, extra=extra) 

86 

87 

88def get_db() -> Generator[Session, Never, None]: 

89 """Database dependency. 

90 

91 Commits the transaction on successful completion to avoid implicit rollbacks 

92 for read-only operations. Rolls back explicitly on exception. 

93 

94 Yields: 

95 Session: SQLAlchemy database session 

96 

97 Raises: 

98 Exception: Re-raises any exception after rolling back the transaction. 

99 

100 Examples: 

101 >>> db_gen = get_db() 

102 >>> db = next(db_gen) 

103 >>> hasattr(db, 'query') 

104 True 

105 >>> hasattr(db, 'close') 

106 True 

107 """ 

108 db = SessionLocal() 

109 try: 

110 yield db 

111 db.commit() 

112 except Exception: 

113 try: 

114 db.rollback() 

115 except Exception: 

116 try: 

117 db.invalidate() 

118 except Exception: 

119 pass # nosec B110 - Best effort cleanup on connection failure 

120 raise 

121 finally: 

122 db.close() 

123 

124 

125def _get_personal_team_sync(user_email: str) -> Optional[str]: 

126 """Synchronous helper to get user's personal team using a fresh DB session. 

127 

128 This runs in a thread pool to avoid blocking the event loop. 

129 

130 Args: 

131 user_email: The user's email address. 

132 

133 Returns: 

134 The personal team ID, or None if not found. 

135 """ 

136 with fresh_db_session() as db: 

137 # Third-Party 

138 from sqlalchemy import select # pylint: disable=import-outside-toplevel 

139 

140 # First-Party 

141 from mcpgateway.db import EmailTeam, EmailTeamMember # pylint: disable=import-outside-toplevel 

142 

143 result = db.execute(select(EmailTeam).join(EmailTeamMember).where(EmailTeamMember.user_email == user_email, EmailTeam.is_personal.is_(True))) 

144 personal_team = result.scalar_one_or_none() 

145 return personal_team.id if personal_team else None 

146 

147 

148def _get_user_team_ids_sync(email: str) -> List[str]: 

149 """Query all active team IDs for a user (including personal teams). 

150 

151 Uses a fresh DB session so this can be called from thread pool. 

152 Matches the behavior of user.get_teams() which returns all active memberships. 

153 

154 Args: 

155 email: User email address 

156 

157 Returns: 

158 List of team ID strings 

159 """ 

160 with fresh_db_session() as db: 

161 # Third-Party 

162 from sqlalchemy import select # pylint: disable=import-outside-toplevel 

163 

164 # First-Party 

165 from mcpgateway.db import EmailTeamMember # pylint: disable=import-outside-toplevel 

166 

167 result = db.execute( 

168 select(EmailTeamMember.team_id).where( 

169 EmailTeamMember.user_email == email, 

170 EmailTeamMember.is_active.is_(True), 

171 ) 

172 ) 

173 return [row[0] for row in result.all()] 

174 

175 

176def get_user_team_roles(db, user_email: str) -> Dict[str, str]: 

177 """Return a {team_id: role} mapping for a user's active team memberships. 

178 

179 Args: 

180 db: SQLAlchemy database session. 

181 user_email: Email address of the user to query memberships for. 

182 

183 Returns: 

184 Dict mapping team_id to the user's role in that team. 

185 Returns empty dict on DB errors (safe default — headers stay masked). 

186 """ 

187 try: 

188 # First-Party 

189 from mcpgateway.db import EmailTeamMember # pylint: disable=import-outside-toplevel 

190 

191 rows = db.query(EmailTeamMember.team_id, EmailTeamMember.role).filter(EmailTeamMember.user_email == user_email, EmailTeamMember.is_active.is_(True)).all() 

192 return {r.team_id: r.role for r in rows} 

193 except Exception: 

194 return {} 

195 

196 

197def _resolve_teams_from_db_sync(email: str, is_admin: bool) -> Optional[List[str]]: 

198 """Resolve teams synchronously with L1 cache support. 

199 

200 Used by StreamableHTTP transport which runs in a sync context. 

201 Checks the in-memory L1 cache before falling back to DB. 

202 

203 Args: 

204 email: User email address 

205 is_admin: Whether the user is an admin 

206 

207 Returns: 

208 None (admin bypass), [] (no teams), or list of team ID strings 

209 """ 

210 if is_admin: 

211 return None # Admin bypass 

212 

213 cache_key = f"{email}:True" 

214 

215 # Check L1 in-memory cache (sync-safe, no network I/O) 

216 try: 

217 # First-Party 

218 from mcpgateway.cache.auth_cache import auth_cache # pylint: disable=import-outside-toplevel 

219 

220 entry = auth_cache._teams_list_cache.get(cache_key) # pylint: disable=protected-access 

221 if entry and not entry.is_expired(): 

222 auth_cache._hit_count += 1 # pylint: disable=protected-access 

223 return entry.value 

224 except Exception: # nosec B110 - Cache unavailable is non-fatal 

225 pass 

226 

227 # Cache miss: query DB 

228 team_ids = _get_user_team_ids_sync(email) 

229 

230 # Populate L1 cache for subsequent requests 

231 try: 

232 # Standard 

233 import time # pylint: disable=import-outside-toplevel 

234 

235 # First-Party 

236 from mcpgateway.cache.auth_cache import auth_cache, CacheEntry # pylint: disable=import-outside-toplevel 

237 

238 with auth_cache._lock: # pylint: disable=protected-access 

239 auth_cache._teams_list_cache[cache_key] = CacheEntry( # pylint: disable=protected-access 

240 value=team_ids, 

241 expiry=time.time() + auth_cache._teams_list_ttl, # pylint: disable=protected-access 

242 ) 

243 except Exception: # nosec B110 - Cache write failure is non-fatal 

244 pass 

245 

246 return team_ids 

247 

248 

249async def _resolve_teams_from_db(email: str, user_info) -> Optional[List[str]]: 

250 """Resolve teams for session tokens from DB/cache. 

251 

252 For admin users, returns None (admin bypass). 

253 For non-admin users, returns the full list of team IDs from DB/cache. 

254 

255 Args: 

256 email: User email address 

257 user_info: User dict or EmailUser instance 

258 

259 Returns: 

260 None (admin bypass), [] (no teams), or list of team ID strings 

261 """ 

262 is_admin = user_info.get("is_admin", False) if isinstance(user_info, dict) else getattr(user_info, "is_admin", False) 

263 if is_admin: 

264 return None # Admin bypass 

265 

266 # Try auth cache first 

267 try: 

268 # First-Party 

269 from mcpgateway.cache.auth_cache import auth_cache # pylint: disable=import-outside-toplevel 

270 

271 cached_teams = await auth_cache.get_user_teams(f"{email}:True") 

272 if cached_teams is not None: 

273 return cached_teams 

274 except Exception: # nosec B110 - Cache unavailable is non-fatal, fall through to DB 

275 pass 

276 

277 # Cache miss: query DB 

278 team_ids = await asyncio.to_thread(_get_user_team_ids_sync, email) 

279 

280 # Cache the result 

281 try: 

282 # First-Party 

283 from mcpgateway.cache.auth_cache import auth_cache # pylint: disable=import-outside-toplevel 

284 

285 await auth_cache.set_user_teams(f"{email}:True", team_ids) 

286 except Exception: # nosec B110 - Cache write failure is non-fatal 

287 pass 

288 

289 return team_ids 

290 

291 

292def normalize_token_teams(payload: Dict[str, Any]) -> Optional[List[str]]: 

293 """ 

294 Normalize token teams to a canonical form for consistent security checks. 

295 

296 SECURITY: This is the single source of truth for token team normalization. 

297 All code paths that read token teams should use this function. 

298 

299 Rules: 

300 - "teams" key missing → [] (public-only, secure default) 

301 - "teams" is null + is_admin=true → None (admin bypass, sees all) 

302 - "teams" is null + is_admin=false → [] (public-only, no bypass for non-admins) 

303 - "teams" is [] → [] (explicit public-only) 

304 - "teams" is [...] → normalized list of string IDs 

305 

306 Args: 

307 payload: The JWT payload dict 

308 

309 Returns: 

310 None for admin bypass, [] for public-only, or list of normalized team ID strings 

311 """ 

312 # Check if "teams" key exists (distinguishes missing from explicit null) 

313 if "teams" not in payload: 

314 # Missing teams key → public-only (secure default) 

315 return [] 

316 

317 teams = payload.get("teams") 

318 

319 if teams is None: 

320 # Explicit null - only allow admin bypass if is_admin is true 

321 # Check BOTH top-level is_admin AND nested user.is_admin 

322 is_admin = payload.get("is_admin", False) 

323 if not is_admin: 

324 user_info = payload.get("user", {}) 

325 is_admin = user_info.get("is_admin", False) if isinstance(user_info, dict) else False 

326 if is_admin: 

327 # Admin with explicit null teams → admin bypass (sees all) 

328 return None 

329 # Non-admin with null teams → public-only (no bypass) 

330 return [] 

331 

332 # teams is a list - normalize to string IDs 

333 # Handle both dict format [{"id": "team1"}] and string format ["team1"] 

334 normalized: List[str] = [] 

335 for team in teams: 

336 if isinstance(team, dict): 

337 team_id = team.get("id") 

338 if team_id: 

339 normalized.append(str(team_id)) 

340 elif isinstance(team, str): 

341 normalized.append(team) 

342 return normalized 

343 

344 

345async def get_team_from_token(payload: Dict[str, Any]) -> Optional[str]: 

346 """ 

347 Extract the team ID from an authentication token payload. 

348 

349 SECURITY: This function uses secure-first defaults: 

350 - Missing teams key = public-only (no personal team fallback) 

351 - Empty teams list = public-only (no team access) 

352 - Teams with values = use first team ID 

353 

354 This prevents privilege escalation where missing claims could grant 

355 unintended team access. 

356 

357 Args: 

358 payload (Dict[str, Any]): 

359 The token payload. Expected fields: 

360 - "sub" (str): The user's unique identifier (email). 

361 - "teams" (List[str], optional): List containing team ID. 

362 

363 Returns: 

364 Optional[str]: 

365 The resolved team ID. Returns `None` if teams is missing or empty. 

366 

367 Examples: 

368 >>> import asyncio 

369 >>> # --- Case 1: Token has team --- 

370 >>> payload = {"sub": "user@example.com", "teams": ["team_456"]} 

371 >>> asyncio.run(get_team_from_token(payload)) 

372 'team_456' 

373 

374 >>> # --- Case 2: Token has explicit empty teams (public-only) --- 

375 >>> payload = {"sub": "user@example.com", "teams": []} 

376 >>> asyncio.run(get_team_from_token(payload)) # Returns None 

377 >>> # None 

378 

379 >>> # --- Case 3: Token has no teams key (secure default) --- 

380 >>> payload = {"sub": "user@example.com"} 

381 >>> asyncio.run(get_team_from_token(payload)) # Returns None 

382 >>> # None 

383 """ 

384 teams = payload.get("teams") 

385 

386 # SECURITY: Treat missing teams as public-only (secure default) 

387 # - teams is None (missing key): Public-only (secure default, no legacy fallback) 

388 # - teams == [] (explicit empty list): Public-only, no team access 

389 # - teams == [...] (has teams): Use first team 

390 # Admin bypass is handled separately via is_admin flag in token, not via missing teams 

391 if teams is None or len(teams) == 0: 

392 # Missing teams or explicit empty = public-only, no fallback to personal team 

393 return None 

394 

395 # Has teams - use the first one 

396 team_id = teams[0] 

397 if isinstance(team_id, dict): 

398 team_id = team_id.get("id") 

399 return team_id 

400 

401 

402def _check_token_revoked_sync(jti: str) -> bool: 

403 """Synchronous helper to check if a token is revoked. 

404 

405 This runs in a thread pool to avoid blocking the event loop. 

406 

407 Args: 

408 jti: The JWT ID to check. 

409 

410 Returns: 

411 True if the token is revoked, False otherwise. 

412 """ 

413 with fresh_db_session() as db: 

414 # Third-Party 

415 from sqlalchemy import select # pylint: disable=import-outside-toplevel 

416 

417 # First-Party 

418 from mcpgateway.db import TokenRevocation # pylint: disable=import-outside-toplevel 

419 

420 result = db.execute(select(TokenRevocation).where(TokenRevocation.jti == jti)) 

421 return result.scalar_one_or_none() is not None 

422 

423 

424def _lookup_api_token_sync(token_hash: str) -> Optional[Dict[str, Any]]: 

425 """Synchronous helper to look up an API token by hash. 

426 

427 This runs in a thread pool to avoid blocking the event loop. 

428 

429 Args: 

430 token_hash: SHA256 hash of the API token. 

431 

432 Returns: 

433 Dict with token info if found and active, None otherwise. 

434 """ 

435 with fresh_db_session() as db: 

436 # Third-Party 

437 from sqlalchemy import select # pylint: disable=import-outside-toplevel 

438 

439 # First-Party 

440 from mcpgateway.db import EmailApiToken, utc_now # pylint: disable=import-outside-toplevel 

441 

442 result = db.execute(select(EmailApiToken).where(EmailApiToken.token_hash == token_hash, EmailApiToken.is_active.is_(True))) 

443 api_token = result.scalar_one_or_none() 

444 

445 if api_token: 

446 # Check expiration 

447 if api_token.expires_at and api_token.expires_at < datetime.now(timezone.utc): 

448 return {"expired": True} 

449 

450 # Check revocation 

451 # First-Party 

452 from mcpgateway.db import TokenRevocation # pylint: disable=import-outside-toplevel 

453 

454 revoke_result = db.execute(select(TokenRevocation).where(TokenRevocation.jti == api_token.jti)) 

455 if revoke_result.scalar_one_or_none() is not None: 

456 return {"revoked": True} 

457 

458 # Update last_used timestamp 

459 api_token.last_used = utc_now() 

460 db.commit() 

461 

462 return { 

463 "user_email": api_token.user_email, 

464 "jti": api_token.jti, 

465 } 

466 return None 

467 

468 

469def _is_api_token_jti_sync(jti: str) -> bool: 

470 """Check if JTI belongs to an API token (legacy fallback) - SYNC version. 

471 

472 Used for tokens created before auth_provider was added to the payload. 

473 Called via asyncio.to_thread() to avoid blocking the event loop. 

474 

475 SECURITY: Fail-closed on DB errors. If we can't verify the token isn't 

476 an API token, treat it as one to preserve the hard-block policy. 

477 

478 Args: 

479 jti: JWT ID to check 

480 

481 Returns: 

482 bool: True if JTI exists in email_api_tokens table OR if lookup fails 

483 """ 

484 # Third-Party 

485 from sqlalchemy import select # pylint: disable=import-outside-toplevel 

486 

487 # First-Party 

488 from mcpgateway.db import EmailApiToken # pylint: disable=import-outside-toplevel 

489 

490 try: 

491 with fresh_db_session() as db: 

492 result = db.execute(select(EmailApiToken.id).where(EmailApiToken.jti == jti).limit(1)) 

493 return result.scalar_one_or_none() is not None 

494 except Exception as e: 

495 logging.getLogger(__name__).warning(f"Legacy API token check failed, failing closed: {e}") 

496 return True # FAIL-CLOSED: treat as API token to preserve hard-block 

497 

498 

499def _get_user_by_email_sync(email: str) -> Optional[EmailUser]: 

500 """Synchronous helper to get user by email. 

501 

502 This runs in a thread pool to avoid blocking the event loop. 

503 

504 Args: 

505 email: The user's email address. 

506 

507 Returns: 

508 EmailUser if found, None otherwise. 

509 """ 

510 with fresh_db_session() as db: 

511 # Third-Party 

512 from sqlalchemy import select # pylint: disable=import-outside-toplevel 

513 

514 result = db.execute(select(EmailUser).where(EmailUser.email == email)) 

515 user = result.scalar_one_or_none() 

516 if user: 

517 # Detach from session and return a copy of attributes 

518 # since the session will be closed 

519 return EmailUser( 

520 email=user.email, 

521 password_hash=user.password_hash, 

522 full_name=user.full_name, 

523 is_admin=user.is_admin, 

524 is_active=user.is_active, 

525 auth_provider=user.auth_provider, 

526 password_change_required=user.password_change_required, 

527 email_verified_at=user.email_verified_at, 

528 created_at=user.created_at, 

529 updated_at=user.updated_at, 

530 ) 

531 return None 

532 

533 

534def _get_auth_context_batched_sync(email: str, jti: Optional[str] = None) -> Dict[str, Any]: 

535 """Batched auth context lookup in a single DB session. 

536 

537 Combines what were 3 separate asyncio.to_thread calls into 1: 

538 1. _get_user_by_email_sync - user data 

539 2. _get_personal_team_sync - personal team ID 

540 3. _check_token_revoked_sync - token revocation status 

541 4. _get_user_team_ids - all active team memberships (for session tokens) 

542 

543 This reduces thread pool contention and DB connection overhead. 

544 

545 Args: 

546 email: User email address 

547 jti: JWT ID for revocation check (optional) 

548 

549 Returns: 

550 Dict with keys: user (dict or None), personal_team_id (str or None), 

551 is_token_revoked (bool), team_ids (list of str) 

552 

553 Examples: 

554 >>> # This function runs in a thread pool 

555 >>> # result = _get_auth_context_batched_sync("test@example.com", "jti-123") 

556 >>> # result["is_token_revoked"] # False if not revoked 

557 """ 

558 with fresh_db_session() as db: 

559 # Third-Party 

560 from sqlalchemy import select # pylint: disable=import-outside-toplevel 

561 

562 # First-Party 

563 from mcpgateway.db import EmailTeam, EmailTeamMember, TokenRevocation # pylint: disable=import-outside-toplevel 

564 

565 result = { 

566 "user": None, 

567 "personal_team_id": None, 

568 "is_token_revoked": False, # nosec B105 - boolean flag, not a password 

569 "team_ids": [], 

570 } 

571 

572 # Query 1: Get user data 

573 user_result = db.execute(select(EmailUser).where(EmailUser.email == email)) 

574 user = user_result.scalar_one_or_none() 

575 

576 if user: 

577 # Detach user data as dict (session will close) 

578 result["user"] = { 

579 "email": user.email, 

580 "password_hash": user.password_hash, 

581 "full_name": user.full_name, 

582 "is_admin": user.is_admin, 

583 "is_active": user.is_active, 

584 "auth_provider": user.auth_provider, 

585 "password_change_required": user.password_change_required, 

586 "email_verified_at": user.email_verified_at, 

587 "created_at": user.created_at, 

588 "updated_at": user.updated_at, 

589 } 

590 

591 # Query 2: Get personal team (only if user exists) 

592 team_result = db.execute( 

593 select(EmailTeam) 

594 .join(EmailTeamMember) 

595 .where( 

596 EmailTeamMember.user_email == email, 

597 EmailTeam.is_personal.is_(True), 

598 ) 

599 ) 

600 personal_team = team_result.scalar_one_or_none() 

601 if personal_team: 

602 result["personal_team_id"] = personal_team.id 

603 

604 # Query 4: Get all active team memberships (for session token team resolution) 

605 team_ids_result = db.execute( 

606 select(EmailTeamMember.team_id).where( 

607 EmailTeamMember.user_email == email, 

608 EmailTeamMember.is_active.is_(True), 

609 ) 

610 ) 

611 result["team_ids"] = [row[0] for row in team_ids_result.all()] 

612 

613 # Query 3: Check token revocation (if JTI provided) 

614 if jti: 

615 revoke_result = db.execute(select(TokenRevocation).where(TokenRevocation.jti == jti)) 

616 result["is_token_revoked"] = revoke_result.scalar_one_or_none() is not None 

617 

618 return result 

619 

620 

621def _user_from_cached_dict(user_dict: Dict[str, Any]) -> EmailUser: 

622 """Create EmailUser instance from cached dict. 

623 

624 Args: 

625 user_dict: User data dictionary from cache 

626 

627 Returns: 

628 EmailUser instance (detached from any session) 

629 """ 

630 return EmailUser( 

631 email=user_dict["email"], 

632 password_hash=user_dict.get("password_hash", ""), 

633 full_name=user_dict.get("full_name"), 

634 is_admin=user_dict.get("is_admin", False), 

635 is_active=user_dict.get("is_active", True), 

636 auth_provider=user_dict.get("auth_provider", "local"), 

637 password_change_required=user_dict.get("password_change_required", False), 

638 email_verified_at=user_dict.get("email_verified_at"), 

639 created_at=user_dict.get("created_at", datetime.now(timezone.utc)), 

640 updated_at=user_dict.get("updated_at", datetime.now(timezone.utc)), 

641 ) 

642 

643 

644async def get_current_user( 

645 credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), 

646 request: Optional[object] = None, 

647) -> EmailUser: 

648 """Get current authenticated user from JWT token with revocation checking. 

649 

650 Supports plugin-based custom authentication via HTTP_AUTH_RESOLVE_USER hook. 

651 

652 Args: 

653 credentials: HTTP authorization credentials 

654 request: Optional request object for plugin hooks 

655 

656 Returns: 

657 EmailUser: Authenticated user 

658 

659 Raises: 

660 HTTPException: If authentication fails 

661 """ 

662 logger = logging.getLogger(__name__) 

663 

664 async def _set_auth_method_from_payload(payload: dict) -> None: 

665 """Set request.state.auth_method based on JWT payload. 

666 

667 Args: 

668 payload: Decoded JWT payload 

669 """ 

670 if not request: 

671 return 

672 

673 # NOTE: Cannot use structural check (scopes dict) because email login JWTs 

674 # also have scopes dict (see email_auth.py:160) 

675 user_info = payload.get("user", {}) 

676 auth_provider = user_info.get("auth_provider") 

677 

678 if auth_provider == "api_token": 

679 request.state.auth_method = "api_token" 

680 return 

681 

682 if auth_provider: 

683 # email, oauth, saml, or any other interactive auth provider 

684 request.state.auth_method = "jwt" 

685 return 

686 

687 # Legacy API token fallback: check if JTI exists in API token table 

688 # This handles tokens created before auth_provider was added 

689 jti_for_check = payload.get("jti") 

690 if jti_for_check: 

691 is_legacy_api_token = await asyncio.to_thread(_is_api_token_jti_sync, jti_for_check) 

692 if is_legacy_api_token: 

693 request.state.auth_method = "api_token" 

694 logger.debug(f"Legacy API token detected via DB lookup (JTI: ...{jti_for_check[-8:]})") 

695 else: 

696 request.state.auth_method = "jwt" 

697 else: 

698 # No auth_provider or JTI; default to interactive 

699 request.state.auth_method = "jwt" 

700 

701 # NEW: Custom authentication hook - allows plugins to provide alternative auth 

702 # This hook is invoked BEFORE standard JWT/API token validation 

703 try: 

704 # Get plugin manager singleton 

705 plugin_manager = get_plugin_manager() 

706 

707 if plugin_manager and plugin_manager.has_hooks_for(HttpHookType.HTTP_AUTH_RESOLVE_USER): 

708 # Extract client information 

709 client_host = None 

710 client_port = None 

711 if request and hasattr(request, "client") and request.client: 

712 client_host = request.client.host 

713 client_port = request.client.port 

714 

715 # Serialize credentials for plugin 

716 credentials_dict = None 

717 if credentials: 

718 credentials_dict = { 

719 "scheme": credentials.scheme, 

720 "credentials": credentials.credentials, 

721 } 

722 

723 # Extract headers from request 

724 # Note: Middleware modifies request.scope["headers"], so request.headers 

725 # will automatically reflect any modifications made by HTTP_PRE_REQUEST hooks 

726 headers = {} 

727 if request and hasattr(request, "headers"): 

728 headers = dict(request.headers) 

729 

730 # Get request ID from correlation ID context (set by CorrelationIDMiddleware) 

731 request_id = get_correlation_id() 

732 if not request_id: 

733 # Fallback chain for safety 

734 if request and hasattr(request, "state") and hasattr(request.state, "request_id"): 

735 request_id = request.state.request_id 

736 else: 

737 request_id = uuid.uuid4().hex 

738 logger.debug(f"Generated fallback request ID in get_current_user: {request_id}") 

739 

740 # Get plugin contexts from request state if available 

741 global_context = getattr(request.state, "plugin_global_context", None) if request else None 

742 if not global_context: 

743 # Create global context 

744 global_context = GlobalContext( 

745 request_id=request_id, 

746 server_id=None, 

747 tenant_id=None, 

748 ) 

749 

750 context_table = getattr(request.state, "plugin_context_table", None) if request else None 

751 

752 # Invoke custom auth resolution hook 

753 # violations_as_exceptions=True so PluginViolationError is raised for explicit denials 

754 auth_result, context_table_result = await plugin_manager.invoke_hook( 

755 HttpHookType.HTTP_AUTH_RESOLVE_USER, 

756 payload=HttpAuthResolveUserPayload( 

757 credentials=credentials_dict, 

758 headers=HttpHeaderPayload(root=headers), 

759 client_host=client_host, 

760 client_port=client_port, 

761 ), 

762 global_context=global_context, 

763 local_contexts=context_table, 

764 violations_as_exceptions=True, # Raise PluginViolationError for auth denials 

765 ) 

766 

767 # If plugin successfully authenticated user, return it 

768 if auth_result.modified_payload and isinstance(auth_result.modified_payload, dict): 

769 logger.info("User authenticated via plugin hook") 

770 # Create EmailUser from dict returned by plugin 

771 user_dict = auth_result.modified_payload 

772 user = EmailUser( 

773 email=user_dict.get("email"), 

774 password_hash=user_dict.get("password_hash", ""), 

775 full_name=user_dict.get("full_name"), 

776 is_admin=user_dict.get("is_admin", False), 

777 is_active=user_dict.get("is_active", True), 

778 auth_provider=user_dict.get("auth_provider", "local"), 

779 password_change_required=user_dict.get("password_change_required", False), 

780 email_verified_at=user_dict.get("email_verified_at"), 

781 created_at=user_dict.get("created_at", datetime.now(timezone.utc)), 

782 updated_at=user_dict.get("updated_at", datetime.now(timezone.utc)), 

783 ) 

784 

785 # Store auth_method in request.state so it can be accessed by RBAC middleware 

786 if request and auth_result.metadata: 

787 auth_method = auth_result.metadata.get("auth_method") 

788 if auth_method: 

789 request.state.auth_method = auth_method 

790 logger.debug(f"Stored auth_method '{auth_method}' in request.state") 

791 

792 if request and context_table_result: 

793 request.state.plugin_context_table = context_table_result 

794 

795 if request and global_context: 

796 request.state.plugin_global_context = global_context 

797 

798 if plugin_manager and plugin_manager.config.plugin_settings.include_user_info: 

799 _inject_userinfo_instate(request, user) 

800 

801 return user 

802 # If continue_processing=True (no payload), fall through to standard auth 

803 

804 except PluginViolationError as e: 

805 # Plugin explicitly denied authentication with custom message 

806 logger.warning(f"Authentication denied by plugin: {e.message}") 

807 raise HTTPException( 

808 status_code=status.HTTP_401_UNAUTHORIZED, 

809 detail=e.message, # Use plugin's custom error message 

810 headers={"WWW-Authenticate": "Bearer"}, 

811 ) 

812 except HTTPException: 

813 # Re-raise HTTP exceptions 

814 raise 

815 except Exception as e: 

816 # Log but don't fail on plugin errors - fall back to standard auth 

817 logger.warning(f"HTTP_AUTH_RESOLVE_USER hook failed, falling back to standard auth: {e}") 

818 

819 # EXISTING: Standard authentication (JWT, API tokens) 

820 if not credentials: 

821 logger.warning("No credentials provided") 

822 raise HTTPException( 

823 status_code=status.HTTP_401_UNAUTHORIZED, 

824 detail="Authentication required", 

825 headers={"WWW-Authenticate": "Bearer"}, 

826 ) 

827 

828 logger.debug("Attempting authentication with token: %s...", credentials.credentials[:20]) 

829 email = None 

830 

831 try: 

832 # Try JWT token first using the centralized verify_jwt_token_cached function 

833 logger.debug("Attempting JWT token validation") 

834 payload = await verify_jwt_token_cached(credentials.credentials, request) 

835 

836 logger.debug("JWT token validated successfully") 

837 # Extract user identifier (support both new and legacy token formats) 

838 email = payload.get("sub") 

839 if email is None: 

840 # Try legacy format 

841 email = payload.get("email") 

842 

843 if email is None: 

844 logger.debug("No email/sub found in JWT payload") 

845 raise HTTPException( 

846 status_code=status.HTTP_401_UNAUTHORIZED, 

847 detail="Invalid token", 

848 headers={"WWW-Authenticate": "Bearer"}, 

849 ) 

850 

851 logger.debug("JWT authentication successful for email: %s", email) 

852 

853 # Extract JTI for revocation check 

854 jti = payload.get("jti") 

855 

856 # === AUTH CACHING: Check cache before DB queries === 

857 if settings.auth_cache_enabled: 

858 try: 

859 # First-Party 

860 from mcpgateway.cache.auth_cache import auth_cache, CachedAuthContext # pylint: disable=import-outside-toplevel 

861 

862 cached_ctx = await auth_cache.get_auth_context(email, jti) 

863 if cached_ctx: 

864 logger.debug(f"Auth cache hit for {email}") 

865 

866 # Check revocation from cache 

867 if cached_ctx.is_token_revoked: 

868 raise HTTPException( 

869 status_code=status.HTTP_401_UNAUTHORIZED, 

870 detail="Token has been revoked", 

871 headers={"WWW-Authenticate": "Bearer"}, 

872 ) 

873 

874 # Check user active status from cache 

875 if cached_ctx.user and not cached_ctx.user.get("is_active", True): 

876 raise HTTPException( 

877 status_code=status.HTTP_401_UNAUTHORIZED, 

878 detail="Account disabled", 

879 headers={"WWW-Authenticate": "Bearer"}, 

880 ) 

881 

882 # Resolve teams based on token_use 

883 if request: 

884 token_use = payload.get("token_use") 

885 request.state.token_use = token_use 

886 

887 if token_use == "session": # nosec B105 - Not a password; token_use is a JWT claim type 

888 # Session token: resolve teams from DB/cache 

889 user_info = cached_ctx.user or {"is_admin": False} 

890 teams = await _resolve_teams_from_db(email, user_info) 

891 else: 

892 # API token or legacy: use embedded teams 

893 teams = normalize_token_teams(payload) 

894 

895 request.state.token_teams = teams 

896 

897 # Set team_id: only for single-team API tokens 

898 if teams is None: 

899 request.state.team_id = None 

900 elif len(teams) == 1 and token_use != "session": # nosec B105 

901 request.state.team_id = teams[0] if isinstance(teams[0], str) else teams[0].get("id") 

902 else: 

903 request.state.team_id = None 

904 

905 await _set_auth_method_from_payload(payload) 

906 

907 # Return user from cache 

908 if cached_ctx.user: 

909 # When require_user_in_db is enabled, verify user still exists in DB 

910 # This prevents stale cache from bypassing strict mode 

911 if settings.require_user_in_db: 

912 db_user = await asyncio.to_thread(_get_user_by_email_sync, email) 

913 if db_user is None: 

914 logger.warning( 

915 f"Authentication rejected for {email}: cached user not found in database. " "REQUIRE_USER_IN_DB is enabled.", 

916 extra={"security_event": "user_not_in_db_rejected", "user_id": email}, 

917 ) 

918 raise HTTPException( 

919 status_code=status.HTTP_401_UNAUTHORIZED, 

920 detail="User not found in database", 

921 headers={"WWW-Authenticate": "Bearer"}, 

922 ) 

923 

924 if plugin_manager and plugin_manager.config.plugin_settings.include_user_info: 

925 _inject_userinfo_instate(request, _user_from_cached_dict(cached_ctx.user)) 

926 

927 return _user_from_cached_dict(cached_ctx.user) 

928 

929 # User not in cache but context was (shouldn't happen, but handle it) 

930 logger.debug("Auth context cached but user missing, falling through to DB") 

931 

932 except HTTPException: 

933 raise 

934 except Exception as cache_error: 

935 logger.debug(f"Auth cache check failed, falling through to DB: {cache_error}") 

936 

937 # === BATCHED QUERIES: Single DB call for user + team + revocation === 

938 if settings.auth_cache_batch_queries: 

939 try: 

940 auth_ctx = await asyncio.to_thread(_get_auth_context_batched_sync, email, jti) 

941 

942 # Check revocation 

943 if auth_ctx.get("is_token_revoked"): 

944 raise HTTPException( 

945 status_code=status.HTTP_401_UNAUTHORIZED, 

946 detail="Token has been revoked", 

947 headers={"WWW-Authenticate": "Bearer"}, 

948 ) 

949 

950 # Resolve teams based on token_use 

951 token_use = payload.get("token_use") 

952 if token_use == "session": # nosec B105 - Not a password; token_use is a JWT claim type 

953 # Session token: use team_ids from batched query 

954 user_dict = auth_ctx.get("user") 

955 is_admin = user_dict.get("is_admin", False) if user_dict else False 

956 if is_admin: 

957 teams = None # Admin bypass 

958 else: 

959 teams = auth_ctx.get("team_ids", []) 

960 else: 

961 # API token or legacy: use embedded teams 

962 teams = normalize_token_teams(payload) 

963 

964 # Set team_id: only for single-team API tokens 

965 if teams is None: 

966 team_id = None 

967 elif len(teams) == 1 and token_use != "session": # nosec B105 

968 team_id = teams[0] if isinstance(teams[0], str) else teams[0].get("id") 

969 else: 

970 team_id = None 

971 

972 if request: 

973 request.state.token_teams = teams 

974 request.state.team_id = team_id 

975 request.state.token_use = token_use 

976 await _set_auth_method_from_payload(payload) 

977 

978 # Store in cache for future requests 

979 if settings.auth_cache_enabled: 

980 try: 

981 # First-Party 

982 from mcpgateway.cache.auth_cache import auth_cache, CachedAuthContext # noqa: F811 pylint: disable=import-outside-toplevel 

983 

984 await auth_cache.set_auth_context( 

985 email, 

986 jti, 

987 CachedAuthContext( 

988 user=auth_ctx.get("user"), 

989 personal_team_id=auth_ctx.get("personal_team_id"), 

990 is_token_revoked=auth_ctx.get("is_token_revoked", False), 

991 ), 

992 ) 

993 # Also populate teams-list cache so cached-path requests 

994 # don't need an extra DB query via _resolve_teams_from_db() 

995 if token_use == "session" and teams is not None: # nosec B105 

996 await auth_cache.set_user_teams(f"{email}:True", teams) 

997 except Exception as cache_set_error: 

998 logger.debug(f"Failed to cache auth context: {cache_set_error}") 

999 

1000 # Create user from batched result 

1001 if auth_ctx.get("user"): 

1002 user_dict = auth_ctx["user"] 

1003 if not user_dict.get("is_active", True): 

1004 raise HTTPException( 

1005 status_code=status.HTTP_401_UNAUTHORIZED, 

1006 detail="Account disabled", 

1007 headers={"WWW-Authenticate": "Bearer"}, 

1008 ) 

1009 # Store user for return at end of function 

1010 # We'll check platform admin case and return below 

1011 _batched_user = _user_from_cached_dict(user_dict) 

1012 else: 

1013 _batched_user = None 

1014 

1015 # Handle user not found case 

1016 if _batched_user is None: 

1017 # Check if strict user-in-DB mode is enabled 

1018 if settings.require_user_in_db: 

1019 logger.warning( 

1020 f"Authentication rejected for {email}: user not found in database. " "REQUIRE_USER_IN_DB is enabled.", 

1021 extra={"security_event": "user_not_in_db_rejected", "user_id": email}, 

1022 ) 

1023 raise HTTPException( 

1024 status_code=status.HTTP_401_UNAUTHORIZED, 

1025 detail="User not found in database", 

1026 headers={"WWW-Authenticate": "Bearer"}, 

1027 ) 

1028 

1029 # Platform admin bootstrap (only when REQUIRE_USER_IN_DB=false) 

1030 if email == getattr(settings, "platform_admin_email", "admin@example.com"): 

1031 logger.info( 

1032 f"Platform admin bootstrap authentication for {email}. " "User authenticated via platform admin configuration.", 

1033 extra={"security_event": "platform_admin_bootstrap", "user_id": email}, 

1034 ) 

1035 _batched_user = EmailUser( 

1036 email=email, 

1037 password_hash="", # nosec B106 

1038 full_name=getattr(settings, "platform_admin_full_name", "Platform Administrator"), 

1039 is_admin=True, 

1040 is_active=True, 

1041 auth_provider="local", 

1042 password_change_required=False, 

1043 email_verified_at=datetime.now(timezone.utc), 

1044 created_at=datetime.now(timezone.utc), 

1045 updated_at=datetime.now(timezone.utc), 

1046 ) 

1047 else: 

1048 raise HTTPException( 

1049 status_code=status.HTTP_401_UNAUTHORIZED, 

1050 detail="User not found", 

1051 headers={"WWW-Authenticate": "Bearer"}, 

1052 ) 

1053 

1054 if plugin_manager and plugin_manager.config.plugin_settings.include_user_info: 

1055 _inject_userinfo_instate(request, _batched_user) 

1056 

1057 return _batched_user 

1058 

1059 except HTTPException: 

1060 raise 

1061 except Exception as batch_error: 

1062 logger.warning(f"Batched auth query failed, falling back to individual queries: {batch_error}") 

1063 

1064 # === FALLBACK: Original individual queries (if batching disabled or failed) === 

1065 if jti: 

1066 try: 

1067 is_revoked = await asyncio.to_thread(_check_token_revoked_sync, jti) 

1068 if is_revoked: 

1069 raise HTTPException( 

1070 status_code=status.HTTP_401_UNAUTHORIZED, 

1071 detail="Token has been revoked", 

1072 headers={"WWW-Authenticate": "Bearer"}, 

1073 ) 

1074 except HTTPException: 

1075 raise 

1076 except Exception as revoke_check_error: 

1077 # Log the error but don't fail authentication for admin tokens 

1078 logger.warning(f"Token revocation check failed for JTI {jti}: {revoke_check_error}") 

1079 

1080 # Resolve teams based on token_use 

1081 token_use = payload.get("token_use") 

1082 if token_use == "session": # nosec B105 - Not a password; token_use is a JWT claim type 

1083 # Session token: resolve teams from DB/cache (fallback path — separate query OK) 

1084 user_info = {"is_admin": payload.get("is_admin", False) or payload.get("user", {}).get("is_admin", False)} 

1085 normalized_teams = await _resolve_teams_from_db(email, user_info) 

1086 else: 

1087 # API token or legacy: use embedded teams 

1088 normalized_teams = normalize_token_teams(payload) 

1089 

1090 # Set team_id: only for single-team API tokens 

1091 if normalized_teams is None: 

1092 team_id = None 

1093 elif len(normalized_teams) == 1 and token_use != "session": # nosec B105 

1094 team_id = normalized_teams[0] if isinstance(normalized_teams[0], str) else normalized_teams[0].get("id") 

1095 else: 

1096 team_id = None 

1097 

1098 if request: 

1099 request.state.token_teams = normalized_teams 

1100 request.state.team_id = team_id 

1101 request.state.token_use = token_use 

1102 await _set_auth_method_from_payload(payload) 

1103 

1104 except HTTPException: 

1105 # Re-raise HTTPException from verify_jwt_token (handles expired/invalid tokens) 

1106 raise 

1107 except Exception as jwt_error: 

1108 # JWT validation failed, try database API token 

1109 # Uses fresh DB session via asyncio.to_thread to avoid blocking event loop 

1110 logger.debug("JWT validation failed with error: %s, trying database API token", jwt_error) 

1111 try: 

1112 token_hash = hashlib.sha256(credentials.credentials.encode()).hexdigest() 

1113 logger.debug("Generated token hash: %s", token_hash) 

1114 

1115 # Lookup API token using fresh session in thread pool 

1116 api_token_info = await asyncio.to_thread(_lookup_api_token_sync, token_hash) 

1117 logger.debug(f"Database lookup result: {api_token_info is not None}") 

1118 

1119 if api_token_info: 

1120 # Check for error conditions returned by helper 

1121 if api_token_info.get("expired"): 

1122 raise HTTPException( 

1123 status_code=status.HTTP_401_UNAUTHORIZED, 

1124 detail="API token expired", 

1125 headers={"WWW-Authenticate": "Bearer"}, 

1126 ) 

1127 

1128 if api_token_info.get("revoked"): 

1129 raise HTTPException( 

1130 status_code=status.HTTP_401_UNAUTHORIZED, 

1131 detail="API token has been revoked", 

1132 headers={"WWW-Authenticate": "Bearer"}, 

1133 ) 

1134 

1135 # Use the email from the API token 

1136 email = api_token_info["user_email"] 

1137 logger.debug(f"API token authentication successful for email: {email}") 

1138 

1139 # Set auth_method for database API tokens 

1140 if request: 

1141 request.state.auth_method = "api_token" 

1142 else: 

1143 logger.debug("API token not found in database") 

1144 logger.debug("No valid authentication method found") 

1145 # Neither JWT nor API token worked 

1146 raise HTTPException( 

1147 status_code=status.HTTP_401_UNAUTHORIZED, 

1148 detail="Invalid authentication credentials", 

1149 headers={"WWW-Authenticate": "Bearer"}, 

1150 ) 

1151 except HTTPException: 

1152 # Re-raise HTTP exceptions 

1153 raise 

1154 except Exception as e: 

1155 # Neither JWT nor API token validation worked 

1156 logger.debug(f"Database API token validation failed with exception: {e}") 

1157 raise HTTPException( 

1158 status_code=status.HTTP_401_UNAUTHORIZED, 

1159 detail="Invalid authentication credentials", 

1160 headers={"WWW-Authenticate": "Bearer"}, 

1161 ) 

1162 

1163 # Get user from database using fresh session in thread pool 

1164 user = await asyncio.to_thread(_get_user_by_email_sync, email) 

1165 

1166 if user is None: 

1167 # Check if strict user-in-DB mode is enabled 

1168 if settings.require_user_in_db: 

1169 logger.warning( 

1170 f"Authentication rejected for {email}: user not found in database. " "REQUIRE_USER_IN_DB is enabled.", 

1171 extra={"security_event": "user_not_in_db_rejected", "user_id": email}, 

1172 ) 

1173 raise HTTPException( 

1174 status_code=status.HTTP_401_UNAUTHORIZED, 

1175 detail="User not found in database", 

1176 headers={"WWW-Authenticate": "Bearer"}, 

1177 ) 

1178 

1179 # Platform admin bootstrap (only when REQUIRE_USER_IN_DB=false) 

1180 # If user doesn't exist but token is valid and email matches platform admin, 

1181 # create a virtual admin user object 

1182 if email == getattr(settings, "platform_admin_email", "admin@example.com"): 

1183 logger.info( 

1184 f"Platform admin bootstrap authentication for {email}. " "User authenticated via platform admin configuration.", 

1185 extra={"security_event": "platform_admin_bootstrap", "user_id": email}, 

1186 ) 

1187 # Create a virtual admin user for authentication purposes 

1188 user = EmailUser( 

1189 email=email, 

1190 password_hash="", # nosec B106 - Not used for JWT authentication 

1191 full_name=getattr(settings, "platform_admin_full_name", "Platform Administrator"), 

1192 is_admin=True, 

1193 is_active=True, 

1194 auth_provider="local", 

1195 password_change_required=False, 

1196 email_verified_at=datetime.now(timezone.utc), 

1197 created_at=datetime.now(timezone.utc), 

1198 updated_at=datetime.now(timezone.utc), 

1199 ) 

1200 else: 

1201 raise HTTPException( 

1202 status_code=status.HTTP_401_UNAUTHORIZED, 

1203 detail="User not found", 

1204 headers={"WWW-Authenticate": "Bearer"}, 

1205 ) 

1206 

1207 if not user.is_active: 

1208 raise HTTPException( 

1209 status_code=status.HTTP_401_UNAUTHORIZED, 

1210 detail="Account disabled", 

1211 headers={"WWW-Authenticate": "Bearer"}, 

1212 ) 

1213 

1214 if plugin_manager and plugin_manager.config.plugin_settings.include_user_info: 

1215 _inject_userinfo_instate(request, user) 

1216 

1217 return user 

1218 

1219 

1220def _inject_userinfo_instate(request: Optional[object] = None, user: Optional[EmailUser] = None) -> None: 

1221 """This function injects user related information into the plugin_global_context, if the config has 

1222 include_user_info key set as true. 

1223 

1224 Args: 

1225 request: Optional request object for plugin hooks 

1226 user: User related information 

1227 """ 

1228 

1229 logger = logging.getLogger(__name__) 

1230 # Get request ID from correlation ID context (set by CorrelationIDMiddleware) 

1231 request_id = get_correlation_id() 

1232 if not request_id: 

1233 # Fallback chain for safety 

1234 if request and hasattr(request, "state") and hasattr(request.state, "request_id"): 

1235 request_id = request.state.request_id 

1236 else: 

1237 request_id = uuid.uuid4().hex 

1238 logger.debug(f"Generated fallback request ID in get_current_user: {request_id}") 

1239 

1240 # Get plugin contexts from request state if available 

1241 global_context = getattr(request.state, "plugin_global_context", None) if request else None 

1242 if not global_context: 

1243 # Create global context 

1244 global_context = GlobalContext( 

1245 request_id=request_id, 

1246 server_id=None, 

1247 tenant_id=None, 

1248 ) 

1249 

1250 if user: 

1251 if not global_context.user: 

1252 global_context.user = {} 

1253 global_context.user["email"] = user.email 

1254 global_context.user["is_admin"] = user.is_admin 

1255 global_context.user["full_name"] = user.full_name 

1256 

1257 if request and global_context: 

1258 request.state.plugin_global_context = global_context