Coverage for mcpgateway / services / oauth_manager.py: 99%
651 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/oauth_manager.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7OAuth 2.0 Manager for ContextForge.
9This module handles OAuth 2.0 authentication flows including:
10- Client Credentials (Machine-to-Machine)
11- Authorization Code (User Delegation)
12"""
14# Standard
15import asyncio
16import base64
17from datetime import datetime, timedelta, timezone
18import hashlib
19import logging
20import secrets
21from typing import Any, Dict, Optional
22from urllib.parse import urlparse
24# Third-Party
25import httpx
26import orjson
27from requests_oauthlib import OAuth2Session
29# First-Party
30from mcpgateway.common.validators import SecurityValidator
31from mcpgateway.config import get_settings
32from mcpgateway.services.encryption_service import decrypt_oauth_config_for_runtime, get_encryption_service
33from mcpgateway.services.http_client_service import get_http_client
34from mcpgateway.utils.redis_client import get_redis_client as _get_shared_redis_client
36logger = logging.getLogger(__name__)
38# In-memory storage for OAuth states with expiration (fallback for single-process)
39# Format: {state_key: {"state": state, "gateway_id": gateway_id, "expires_at": datetime}}
40_oauth_states: Dict[str, Dict[str, Any]] = {}
41# Reverse lookup for callback handlers that only receive state.
42# Format: {state: gateway_id}
43_oauth_state_lookup: Dict[str, str] = {}
44# Lock for thread-safe state operations
45_state_lock = asyncio.Lock()
47# State TTL in seconds (5 minutes)
48STATE_TTL_SECONDS = 300
50# Redis client for distributed state storage (uses shared factory)
51_redis_client: Optional[Any] = None
52_REDIS_INITIALIZED = False
55async def _get_redis_client():
56 """Get shared Redis client for distributed state storage.
58 Uses the centralized Redis client factory for consistent configuration.
60 Returns:
61 Redis client instance or None if unavailable
62 """
63 global _redis_client, _REDIS_INITIALIZED # pylint: disable=global-statement
65 if _REDIS_INITIALIZED:
66 return _redis_client
68 settings = get_settings()
69 if settings.cache_type == "redis" and settings.redis_url:
70 try:
71 _redis_client = await _get_shared_redis_client()
72 if _redis_client:
73 logger.info("OAuth manager connected to shared Redis client")
74 except Exception as e:
75 logger.warning(f"Failed to get Redis client, falling back to in-memory storage: {e}")
76 _redis_client = None
77 else:
78 _redis_client = None
80 _REDIS_INITIALIZED = True
81 return _redis_client
84class OAuthManager:
85 """Manages OAuth 2.0 authentication flows.
87 Examples:
88 >>> manager = OAuthManager(request_timeout=30, max_retries=3)
89 >>> manager.request_timeout
90 30
91 >>> manager.max_retries
92 3
93 >>> manager.token_storage is None
94 True
95 >>>
96 >>> # Test grant type validation
97 >>> grant_type = "client_credentials"
98 >>> grant_type in ["client_credentials", "authorization_code"]
99 True
100 >>> grant_type = "invalid_grant"
101 >>> grant_type in ["client_credentials", "authorization_code"]
102 False
103 >>>
104 >>> # Test encrypted secret detection heuristic
105 >>> short_secret = "secret123"
106 >>> len(short_secret) > 50
107 False
108 >>> encrypted_secret = "gAAAAABh" + "x" * 60 # Simulated encrypted secret
109 >>> len(encrypted_secret) > 50
110 True
111 >>>
112 >>> # Test scope list handling
113 >>> scopes = ["read", "write"]
114 >>> " ".join(scopes)
115 'read write'
116 >>> empty_scopes = []
117 >>> " ".join(empty_scopes)
118 ''
119 """
121 # Known Microsoft Entra login hosts (global + sovereign clouds).
122 _ENTRA_HOSTS: frozenset[str] = frozenset(
123 {
124 "login.microsoftonline.com",
125 "login.microsoftonline.us",
126 "login.microsoftonline.de",
127 "login.partner.microsoftonline.cn",
128 }
129 )
131 def __init__(self, request_timeout: int = 30, max_retries: int = 3, token_storage: Optional[Any] = None):
132 """Initialize OAuth Manager.
134 Args:
135 request_timeout: Timeout for OAuth requests in seconds
136 max_retries: Maximum number of retry attempts for token requests
137 token_storage: Optional TokenStorageService for storing tokens
138 """
139 self.request_timeout = request_timeout
140 self.max_retries = max_retries
141 self.token_storage = token_storage
142 self.settings = get_settings()
144 async def _get_client(self) -> httpx.AsyncClient:
145 """Get the shared singleton HTTP client.
147 Returns:
148 Shared httpx.AsyncClient instance with connection pooling
149 """
150 return await get_http_client()
152 def _generate_pkce_params(self) -> Dict[str, str]:
153 """Generate PKCE parameters for OAuth Authorization Code flow (RFC 7636).
155 Returns:
156 Dict containing code_verifier, code_challenge, and code_challenge_method
157 """
158 # Generate code_verifier: 43-128 character random string
159 code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
161 # Generate code_challenge: base64url(SHA256(code_verifier))
162 code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()).decode("utf-8").rstrip("=")
164 return {"code_verifier": code_verifier, "code_challenge": code_challenge, "code_challenge_method": "S256"}
166 async def get_access_token(self, credentials: Dict[str, Any]) -> str:
167 """Get access token based on grant type.
169 Args:
170 credentials: OAuth configuration containing grant_type and other params
172 Returns:
173 Access token string
175 Raises:
176 ValueError: If grant type is unsupported
177 OAuthError: If token acquisition fails
179 Examples:
180 Client credentials flow:
181 >>> import asyncio
182 >>> class TestMgr(OAuthManager):
183 ... async def _client_credentials_flow(self, credentials):
184 ... return 'tok'
185 >>> mgr = TestMgr()
186 >>> asyncio.run(mgr.get_access_token({'grant_type': 'client_credentials'}))
187 'tok'
189 Authorization code flow requires interactive completion:
190 >>> def _auth_code_requires_consent():
191 ... try:
192 ... asyncio.run(mgr.get_access_token({'grant_type': 'authorization_code'}))
193 ... except OAuthError:
194 ... return True
195 >>> _auth_code_requires_consent()
196 True
198 Unsupported grant type raises ValueError:
199 >>> def _unsupported():
200 ... try:
201 ... asyncio.run(mgr.get_access_token({'grant_type': 'bad'}))
202 ... except ValueError:
203 ... return True
204 >>> _unsupported()
205 True
206 """
207 grant_type = credentials.get("grant_type")
208 logger.debug(f"Getting access token for grant type: {grant_type}")
210 if grant_type == "client_credentials":
211 return await self._client_credentials_flow(credentials)
212 if grant_type == "password":
213 return await self._password_flow(credentials)
214 if grant_type == "authorization_code":
215 raise OAuthError("Authorization code flow requires user consent via /oauth/authorize and does not support client_credentials fallback")
216 raise ValueError(f"Unsupported grant type: {grant_type}")
218 @staticmethod
219 async def _prepare_runtime_credentials(credentials: Dict[str, Any], flow_name: str) -> Dict[str, Any]:
220 """Return runtime-ready oauth credentials with sensitive fields decrypted.
222 Args:
223 credentials: Stored oauth_config payload.
224 flow_name: Flow label for diagnostic logging.
226 Returns:
227 Dict[str, Any]: Runtime-ready credentials.
228 """
229 try:
230 settings = get_settings()
231 encryption = get_encryption_service(settings.auth_encryption_secret)
232 runtime_credentials = await decrypt_oauth_config_for_runtime(credentials, encryption=encryption)
233 if isinstance(runtime_credentials, dict):
234 return runtime_credentials
235 except Exception as exc:
236 logger.warning("Failed to prepare runtime OAuth credentials for %s flow: %s", flow_name, exc)
237 return credentials
239 async def _client_credentials_flow(self, credentials: Dict[str, Any]) -> str:
240 """Machine-to-machine authentication using client credentials.
242 Args:
243 credentials: OAuth configuration with client_id, client_secret, token_url
245 Returns:
246 Access token string
248 Raises:
249 OAuthError: If token acquisition fails after all retries
250 """
251 runtime_credentials = await self._prepare_runtime_credentials(credentials, "client_credentials")
252 client_id = runtime_credentials["client_id"]
253 client_secret = runtime_credentials["client_secret"]
254 token_url = runtime_credentials["token_url"]
255 scopes = runtime_credentials.get("scopes", [])
257 # Prepare token request data
258 token_data = {
259 "grant_type": "client_credentials",
260 "client_id": client_id,
261 "client_secret": client_secret,
262 }
264 if scopes:
265 token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes
267 # Fetch token with retries
268 for attempt in range(self.max_retries):
269 try:
270 client = await self._get_client()
271 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
272 response.raise_for_status()
274 # GitHub returns form-encoded responses, not JSON
275 content_type = response.headers.get("content-type", "")
276 if "application/x-www-form-urlencoded" in content_type:
277 # Parse form-encoded response
278 text_response = response.text
279 token_response = {}
280 for pair in text_response.split("&"):
281 if "=" in pair:
282 key, value = pair.split("=", 1)
283 token_response[key] = value
284 else:
285 # Try JSON response
286 try:
287 token_response = response.json()
288 except Exception as e:
289 logger.warning(f"Failed to parse JSON response: {e}")
290 # Fallback to text parsing
291 text_response = response.text
292 token_response = {"raw_response": text_response}
294 if "access_token" not in token_response:
295 raise OAuthError(f"No access_token in response: {token_response}")
297 logger.info("""Successfully obtained access token via client credentials""")
298 return token_response["access_token"]
300 except httpx.HTTPError as e:
301 logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}")
302 if attempt == self.max_retries - 1:
303 raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}")
304 await asyncio.sleep(2**attempt) # Exponential backoff
306 # This should never be reached due to the exception above, but needed for type safety
307 raise OAuthError("Failed to obtain access token after all retry attempts")
309 async def _password_flow(self, credentials: Dict[str, Any]) -> str:
310 """Resource Owner Password Credentials flow (RFC 6749 Section 4.3).
312 This flow is used when the application can directly handle the user's credentials,
313 such as with trusted first-party applications or legacy integrations like Keycloak.
315 Args:
316 credentials: OAuth configuration with client_id, optional client_secret, token_url, username, password
318 Returns:
319 Access token string
321 Raises:
322 OAuthError: If token acquisition fails after all retries
323 """
324 runtime_credentials = await self._prepare_runtime_credentials(credentials, "password")
325 client_id = runtime_credentials.get("client_id")
326 client_secret = runtime_credentials.get("client_secret")
327 token_url = runtime_credentials["token_url"]
328 username = runtime_credentials.get("username")
329 password = runtime_credentials.get("password")
330 scopes = runtime_credentials.get("scopes", [])
332 if not username or not password:
333 raise OAuthError("Username and password are required for password grant type")
335 # Prepare token request data
336 token_data = {
337 "grant_type": "password",
338 "username": username,
339 "password": password,
340 }
342 # Add client_id (required by most providers including Keycloak)
343 if client_id:
344 token_data["client_id"] = client_id
346 # Add client_secret if present (some providers require it, others don't)
347 if client_secret:
348 token_data["client_secret"] = client_secret
350 if scopes:
351 token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes
353 # Fetch token with retries
354 for attempt in range(self.max_retries):
355 try:
356 client = await self._get_client()
357 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
358 response.raise_for_status()
360 # Handle both JSON and form-encoded responses
361 content_type = response.headers.get("content-type", "")
362 if "application/x-www-form-urlencoded" in content_type:
363 # Parse form-encoded response
364 text_response = response.text
365 token_response = {}
366 for pair in text_response.split("&"):
367 if "=" in pair:
368 key, value = pair.split("=", 1)
369 token_response[key] = value
370 else:
371 # Try JSON response
372 try:
373 token_response = response.json()
374 except Exception as e:
375 logger.warning(f"Failed to parse JSON response: {e}")
376 # Fallback to text parsing
377 text_response = response.text
378 token_response = {"raw_response": text_response}
380 if "access_token" not in token_response:
381 raise OAuthError(f"No access_token in response: {token_response}")
383 logger.info("Successfully obtained access token via password grant")
384 return token_response["access_token"]
386 except httpx.HTTPError as e:
387 logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}")
388 if attempt == self.max_retries - 1:
389 raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}")
390 await asyncio.sleep(2**attempt) # Exponential backoff
392 # This should never be reached due to the exception above, but needed for type safety
393 raise OAuthError("Failed to obtain access token after all retry attempts")
395 async def get_authorization_url(self, credentials: Dict[str, Any]) -> Dict[str, str]:
396 """Get authorization URL for user delegation flow.
398 Args:
399 credentials: OAuth configuration with client_id, authorization_url, etc.
401 Returns:
402 Dict containing authorization_url and state
403 """
404 client_id = credentials["client_id"]
405 redirect_uri = credentials["redirect_uri"]
406 authorization_url = credentials["authorization_url"]
407 scopes = credentials.get("scopes", [])
409 # Create OAuth2 session
410 oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes)
412 # Generate authorization URL with state for CSRF protection
413 auth_url, state = oauth.authorization_url(authorization_url)
415 logger.info(f"Generated authorization URL for client {client_id}")
417 return {"authorization_url": auth_url, "state": state}
419 async def exchange_code_for_token(self, credentials: Dict[str, Any], code: str, state: str) -> str: # pylint: disable=unused-argument
420 """Exchange authorization code for access token.
422 Args:
423 credentials: OAuth configuration
424 code: Authorization code from callback
425 state: State parameter for CSRF validation
427 Returns:
428 Access token string
430 Raises:
431 OAuthError: If token exchange fails
432 """
433 runtime_credentials = await self._prepare_runtime_credentials(credentials, "authorization_code_exchange")
434 client_id = runtime_credentials["client_id"]
435 client_secret = runtime_credentials.get("client_secret") # Optional for public clients (PKCE-only)
436 token_url = runtime_credentials["token_url"]
437 redirect_uri = runtime_credentials["redirect_uri"]
439 # Prepare token exchange data
440 token_data = {
441 "grant_type": "authorization_code",
442 "code": code,
443 "redirect_uri": redirect_uri,
444 "client_id": client_id,
445 }
447 # Only include client_secret if present (public clients don't have secrets)
448 if client_secret:
449 token_data["client_secret"] = client_secret
451 # Exchange code for token with retries
452 for attempt in range(self.max_retries):
453 try:
454 client = await self._get_client()
455 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
456 response.raise_for_status()
458 # GitHub returns form-encoded responses, not JSON
459 content_type = response.headers.get("content-type", "")
460 if "application/x-www-form-urlencoded" in content_type:
461 # Parse form-encoded response
462 text_response = response.text
463 token_response = {}
464 for pair in text_response.split("&"):
465 if "=" in pair:
466 key, value = pair.split("=", 1)
467 token_response[key] = value
468 else:
469 # Try JSON response
470 try:
471 token_response = response.json()
472 except Exception as e:
473 logger.warning(f"Failed to parse JSON response: {e}")
474 # Fallback to text parsing
475 text_response = response.text
476 token_response = {"raw_response": text_response}
478 if "access_token" not in token_response:
479 raise OAuthError(f"No access_token in response: {token_response}")
481 logger.info("""Successfully exchanged authorization code for access token""")
482 return token_response["access_token"]
484 except httpx.HTTPError as e:
485 logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}")
486 if attempt == self.max_retries - 1:
487 raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}")
488 await asyncio.sleep(2**attempt) # Exponential backoff
490 # This should never be reached due to the exception above, but needed for type safety
491 raise OAuthError("Failed to exchange code for token after all retry attempts")
493 async def initiate_authorization_code_flow(self, gateway_id: str, credentials: Dict[str, Any], app_user_email: str = None) -> Dict[str, str]:
494 """Initiate Authorization Code flow with PKCE and return authorization URL.
496 Args:
497 gateway_id: ID of the gateway being configured
498 credentials: OAuth configuration with client_id, authorization_url, etc.
499 app_user_email: ContextForge user email to associate with tokens
501 Returns:
502 Dict containing authorization_url and state
503 """
505 # Generate PKCE parameters (RFC 7636)
506 pkce_params = self._generate_pkce_params()
508 # Generate state parameter with user context for CSRF protection
509 state = self._generate_state(gateway_id, app_user_email)
511 # Store state with code_verifier in session/cache for validation
512 if self.token_storage:
513 await self._store_authorization_state(
514 gateway_id,
515 state,
516 code_verifier=pkce_params["code_verifier"],
517 app_user_email=app_user_email,
518 )
520 # Generate authorization URL with PKCE
521 auth_url = self._create_authorization_url_with_pkce(credentials, state, pkce_params["code_challenge"], pkce_params["code_challenge_method"])
523 logger.info(f"Generated authorization URL with PKCE for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
525 return {"authorization_url": auth_url, "state": state, "gateway_id": gateway_id}
527 async def complete_authorization_code_flow(self, gateway_id: str, code: str, state: str, credentials: Dict[str, Any]) -> Dict[str, Any]:
528 """Complete Authorization Code flow with PKCE and store tokens.
530 Args:
531 gateway_id: ID of the gateway
532 code: Authorization code from callback
533 state: State parameter for CSRF validation
534 credentials: OAuth configuration
536 Returns:
537 Dict containing success status, user_id, and expiration info
539 Raises:
540 OAuthError: If state validation fails or token exchange fails
541 """
542 # Validate state and retrieve code_verifier
543 state_data = await self._validate_and_retrieve_state(gateway_id, state)
544 if not state_data:
545 raise OAuthError("Invalid or expired state parameter - possible replay attack")
547 code_verifier = state_data.get("code_verifier")
548 app_user_email = state_data.get("app_user_email")
550 # Defence-in-depth: if app_user_email is absent from server-side
551 # state (e.g. state stored by an older code path), attempt a
552 # gateway-mismatch check via the legacy state parser but NEVER
553 # extract identity fields from unsigned payloads (CWE-345).
554 # Note: the /oauth/callback router rejects pure legacy states
555 # before reaching here (allow_legacy_fallback=False), so this
556 # block only fires for server-stored states that lack the email.
557 if not app_user_email:
558 legacy_state_payload = self._extract_legacy_state_payload(state)
559 if legacy_state_payload:
560 legacy_gateway_id = legacy_state_payload.get("gateway_id")
561 if legacy_gateway_id and legacy_gateway_id != gateway_id:
562 raise OAuthError("State parameter gateway mismatch")
563 if self.token_storage:
564 logger.error("User context (app_user_email) missing from OAuth state; refusing to bind tokens (CWE-287). gateway_id=%s", gateway_id)
565 raise OAuthError("User context required for OAuth token storage")
566 logger.warning("User context (app_user_email) missing from OAuth state; no token_storage configured — proceeding without binding. gateway_id=%s", gateway_id)
568 # Exchange code for tokens with PKCE code_verifier
569 token_response = await self._exchange_code_for_tokens(credentials, code, code_verifier=code_verifier)
571 # Extract user information from token response
572 user_id = self._extract_user_id(token_response, credentials)
574 # Store tokens if storage service is available
575 if self.token_storage:
576 token_record = await self.token_storage.store_tokens(
577 gateway_id=gateway_id,
578 user_id=user_id,
579 app_user_email=app_user_email, # User from state
580 access_token=token_response["access_token"],
581 refresh_token=token_response.get("refresh_token"),
582 expires_in=token_response.get("expires_in", self.settings.oauth_default_timeout),
583 scopes=token_response.get("scope", "").split(),
584 )
586 return {"success": True, "user_id": user_id, "expires_at": token_record.expires_at.isoformat() if token_record.expires_at else None}
587 return {"success": True, "user_id": user_id, "expires_at": None}
589 async def get_access_token_for_user(self, gateway_id: str, app_user_email: str) -> Optional[str]:
590 """Get valid access token for a specific user.
592 Args:
593 gateway_id: ID of the gateway
594 app_user_email: ContextForge user email
596 Returns:
597 Valid access token or None if not available
598 """
599 if self.token_storage:
600 return await self.token_storage.get_user_token(gateway_id, app_user_email)
601 return None
603 def _generate_state(self, _gateway_id: str, _app_user_email: str = None) -> str:
604 """Generate an opaque state token for CSRF protection.
606 Args:
607 _gateway_id: Gateway identifier (reserved for compatibility with
608 prior embedded-state call sites).
609 _app_user_email: ContextForge user email (reserved for
610 compatibility with prior embedded-state call sites).
612 Returns:
613 Opaque random state token
614 """
615 return secrets.token_urlsafe(48)
617 @staticmethod
618 def _extract_legacy_state_payload(state: str) -> Optional[Dict[str, Any]]:
619 """Best-effort decode of legacy state payloads used before opaque states.
621 Legacy formats supported:
622 - base64url(payload || signature) where payload is JSON
623 - gateway_id_random suffix format
625 Security: Legacy payloads lack signature verification, so only
626 ``gateway_id`` is returned — never identity-sensitive fields like
627 ``app_user_email`` which could be forged (CWE-345).
629 Args:
630 state: Callback state token to decode.
632 Returns:
633 Dict containing only ``gateway_id`` when format is recognized;
634 otherwise ``None``.
635 """
636 safe_legacy_fields = {"gateway_id"}
638 try:
639 state_raw = base64.urlsafe_b64decode(state.encode())
640 if len(state_raw) <= 32:
641 return None
643 payload_bytes = state_raw[:-32]
644 payload = orjson.loads(payload_bytes)
645 if isinstance(payload, dict):
646 # Only return gateway_id — unsigned payloads must not
647 # carry identity claims.
648 safe = {k: v for k, v in payload.items() if k in safe_legacy_fields}
649 return safe if safe else None
650 except Exception:
651 # Fall back to legacy gateway_id_random format
652 if "_" in state:
653 gateway_id = state.split("_", 1)[0]
654 if gateway_id:
655 return {"gateway_id": gateway_id}
656 return None
658 async def resolve_gateway_id_from_state(self, state: str, allow_legacy_fallback: bool = True) -> Optional[str]:
659 """Resolve gateway ID for a callback state token without consuming it.
661 Args:
662 state: OAuth callback state parameter
663 allow_legacy_fallback: Whether to decode legacy callback state formats.
665 Returns:
666 Gateway ID when resolvable, otherwise ``None``.
667 """
668 settings = get_settings()
670 if settings.cache_type == "redis":
671 redis = await _get_redis_client()
672 if redis:
673 try:
674 lookup_key = f"oauth:state_lookup:{state}"
675 gateway_id = await redis.get(lookup_key)
676 if gateway_id:
677 if isinstance(gateway_id, bytes):
678 gateway_id = gateway_id.decode("utf-8")
679 return gateway_id
680 except Exception as e:
681 logger.warning(f"Failed to resolve state gateway in Redis: {e}")
683 if settings.cache_type == "database":
684 try:
685 # First-Party
686 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel
688 db_gen = get_db()
689 db = next(db_gen)
690 try:
691 oauth_state = db.query(OAuthState).filter(OAuthState.state == state).first()
692 if oauth_state:
693 return oauth_state.gateway_id
694 finally:
695 db_gen.close()
696 except Exception as e:
697 logger.warning(f"Failed to resolve state gateway in database: {e}")
699 async with _state_lock:
700 now = datetime.now(timezone.utc)
701 expired_keys = [key for key, data in _oauth_states.items() if datetime.fromisoformat(data["expires_at"]) < now]
702 for key in expired_keys:
703 expired_state = _oauth_states[key].get("state")
704 del _oauth_states[key]
705 if expired_state:
706 _oauth_state_lookup.pop(expired_state, None)
707 gateway_id = _oauth_state_lookup.get(state)
708 if gateway_id:
709 return gateway_id
711 if allow_legacy_fallback:
712 legacy_payload = self._extract_legacy_state_payload(state)
713 if legacy_payload:
714 return legacy_payload.get("gateway_id")
715 return None
717 async def _store_authorization_state(
718 self,
719 gateway_id: str,
720 state: str,
721 code_verifier: str = None,
722 app_user_email: str = None,
723 ) -> None:
724 """Store authorization state for validation with TTL.
726 Args:
727 gateway_id: ID of the gateway
728 state: State parameter to store
729 code_verifier: Optional PKCE code verifier (RFC 7636)
730 app_user_email: Requesting user email for token association
731 """
732 expires_at = datetime.now(timezone.utc) + timedelta(seconds=STATE_TTL_SECONDS)
733 settings = get_settings()
735 # Try Redis first for distributed storage
736 if settings.cache_type == "redis":
737 redis = await _get_redis_client()
738 if redis:
739 try:
740 state_key = f"oauth:state:{gateway_id}:{state}"
741 lookup_key = f"oauth:state_lookup:{state}"
742 state_data = {
743 "state": state,
744 "gateway_id": gateway_id,
745 "code_verifier": code_verifier,
746 "app_user_email": app_user_email,
747 "expires_at": expires_at.isoformat(),
748 "used": False,
749 }
750 # Store in Redis with TTL
751 await redis.setex(state_key, STATE_TTL_SECONDS, orjson.dumps(state_data))
752 await redis.setex(lookup_key, STATE_TTL_SECONDS, gateway_id)
753 logger.debug(f"Stored OAuth state in Redis for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
754 return
755 except Exception as e:
756 logger.warning(f"Failed to store state in Redis: {e}, falling back")
758 # Try database storage for multi-worker deployments
759 if settings.cache_type == "database":
760 try:
761 # First-Party
762 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel
764 db_gen = get_db()
765 db = next(db_gen)
766 try:
767 # Clean up expired states first
768 db.query(OAuthState).filter(OAuthState.expires_at < datetime.now(timezone.utc)).delete()
770 # Store new state with code_verifier
771 oauth_state_kwargs = {
772 "gateway_id": gateway_id,
773 "state": state,
774 "code_verifier": code_verifier,
775 "expires_at": expires_at,
776 "used": False,
777 }
778 if hasattr(OAuthState, "app_user_email"):
779 oauth_state_kwargs["app_user_email"] = app_user_email
781 oauth_state = OAuthState(**oauth_state_kwargs)
782 db.add(oauth_state)
783 db.commit()
784 logger.debug(f"Stored OAuth state in database for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
785 return
786 finally:
787 db_gen.close()
788 except Exception as e:
789 logger.warning(f"Failed to store state in database: {e}, falling back to memory")
791 # Fallback to in-memory storage for development
792 async with _state_lock:
793 # Clean up expired states first
794 now = datetime.now(timezone.utc)
795 state_key = f"oauth:state:{gateway_id}:{state}"
796 state_data = {
797 "state": state,
798 "gateway_id": gateway_id,
799 "code_verifier": code_verifier,
800 "app_user_email": app_user_email,
801 "expires_at": expires_at.isoformat(),
802 "used": False,
803 }
804 expired_states = [key for key, data in _oauth_states.items() if datetime.fromisoformat(data["expires_at"]) < now]
805 for key in expired_states:
806 expired_state_value = _oauth_states[key].get("state")
807 del _oauth_states[key]
808 if expired_state_value:
809 _oauth_state_lookup.pop(expired_state_value, None)
810 logger.debug(f"Cleaned up expired state: {key[:20]}...")
812 # Store the new state with expiration
813 _oauth_states[state_key] = state_data
814 _oauth_state_lookup[state] = gateway_id
815 logger.debug(f"Stored OAuth state in memory for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
817 async def _validate_authorization_state(self, gateway_id: str, state: str) -> bool:
818 """Validate authorization state parameter and mark as used.
820 Args:
821 gateway_id: ID of the gateway
822 state: State parameter to validate
824 Returns:
825 True if state is valid and not yet used, False otherwise
826 """
827 settings = get_settings()
829 # Try Redis first for distributed storage
830 if settings.cache_type == "redis":
831 redis = await _get_redis_client()
832 if redis:
833 try:
834 state_key = f"oauth:state:{gateway_id}:{state}"
835 lookup_key = f"oauth:state_lookup:{state}"
836 # Get and delete state atomically (single-use)
837 state_json = await redis.getdel(state_key)
838 await redis.delete(lookup_key)
839 if not state_json:
840 logger.warning(f"State not found in Redis for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
841 return False
843 state_data = orjson.loads(state_json)
845 # Parse expires_at as timezone-aware datetime. If the stored value
846 # is naive, assume UTC for compatibility.
847 try:
848 expires_at = datetime.fromisoformat(state_data["expires_at"])
849 except Exception:
850 # Fallback: try parsing without microseconds/offsets
851 expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S")
853 if expires_at.tzinfo is None:
854 # Assume UTC for naive timestamps
855 expires_at = expires_at.replace(tzinfo=timezone.utc)
857 # Check if state has expired
858 if expires_at < datetime.now(timezone.utc):
859 logger.warning(f"State has expired for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
860 return False
862 # Check if state was already used (should not happen with getdel)
863 if state_data.get("used", False):
864 logger.warning(f"State was already used for gateway {SecurityValidator.sanitize_log_message(gateway_id)} - possible replay attack")
865 return False
867 logger.debug(f"Successfully validated OAuth state from Redis for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
868 return True
869 except Exception as e:
870 logger.warning(f"Failed to validate state in Redis: {e}, falling back")
872 # Try database storage for multi-worker deployments
873 if settings.cache_type == "database":
874 try:
875 # First-Party
876 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel
878 db_gen = get_db()
879 db = next(db_gen)
880 try:
881 # Find the state
882 oauth_state = db.query(OAuthState).filter(OAuthState.gateway_id == gateway_id, OAuthState.state == state).first()
884 if not oauth_state:
885 logger.warning(f"State not found in database for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
886 return False
888 # Check if state has expired
889 # Ensure oauth_state.expires_at is timezone-aware. If naive, assume UTC.
890 expires_at = oauth_state.expires_at
891 if expires_at.tzinfo is None:
892 expires_at = expires_at.replace(tzinfo=timezone.utc)
894 if expires_at < datetime.now(timezone.utc):
895 logger.warning(f"State has expired for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
896 db.delete(oauth_state)
897 db.commit()
898 return False
900 # Check if state was already used
901 if oauth_state.used:
902 logger.warning(f"State has already been used for gateway {SecurityValidator.sanitize_log_message(gateway_id)} - possible replay attack")
903 return False
905 # Mark as used and delete (single-use)
906 db.delete(oauth_state)
907 db.commit()
908 logger.debug(f"Successfully validated OAuth state from database for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
909 return True
910 finally:
911 db_gen.close()
912 except Exception as e:
913 logger.warning(f"Failed to validate state in database: {e}, falling back to memory")
915 # Fallback to in-memory storage for development
916 state_key = f"oauth:state:{gateway_id}:{state}"
917 async with _state_lock:
918 state_data = _oauth_states.get(state_key)
920 # Check if state exists
921 if not state_data:
922 logger.warning(f"State not found in memory for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
923 return False
925 # Parse and normalize expires_at to timezone-aware datetime
926 expires_at = datetime.fromisoformat(state_data["expires_at"])
927 if expires_at.tzinfo is None:
928 expires_at = expires_at.replace(tzinfo=timezone.utc)
930 if expires_at < datetime.now(timezone.utc):
931 logger.warning(f"State has expired for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
932 del _oauth_states[state_key] # Clean up expired state
933 _oauth_state_lookup.pop(state, None)
934 return False
936 # Check if state has already been used (prevent replay)
937 if state_data.get("used", False):
938 logger.warning(f"State has already been used for gateway {SecurityValidator.sanitize_log_message(gateway_id)} - possible replay attack")
939 return False
941 # Mark state as used and remove it (single-use)
942 del _oauth_states[state_key]
943 _oauth_state_lookup.pop(state, None)
944 logger.debug(f"Successfully validated OAuth state from memory for gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
945 return True
947 async def _validate_and_retrieve_state(self, gateway_id: str, state: str) -> Optional[Dict[str, Any]]:
948 """Validate state and return full state data including code_verifier.
950 Args:
951 gateway_id: ID of the gateway
952 state: State parameter to validate
954 Returns:
955 Dict with state data including code_verifier, or None if invalid/expired
956 """
957 settings = get_settings()
959 # Try Redis first
960 if settings.cache_type == "redis":
961 redis = await _get_redis_client()
962 if redis:
963 try:
964 state_key = f"oauth:state:{gateway_id}:{state}"
965 lookup_key = f"oauth:state_lookup:{state}"
966 state_json = await redis.getdel(state_key) # Atomic get+delete
967 await redis.delete(lookup_key)
968 if not state_json:
969 return None
971 state_data = orjson.loads(state_json)
973 # Check expiration
974 try:
975 expires_at = datetime.fromisoformat(state_data["expires_at"])
976 except Exception:
977 expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S")
979 if expires_at.tzinfo is None:
980 expires_at = expires_at.replace(tzinfo=timezone.utc)
982 if expires_at < datetime.now(timezone.utc):
983 return None
985 return state_data
986 except Exception as e:
987 logger.warning(f"Failed to validate state in Redis: {e}, falling back")
989 # Try database
990 if settings.cache_type == "database":
991 try:
992 # First-Party
993 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel
995 db_gen = get_db()
996 db = next(db_gen)
997 try:
998 oauth_state = db.query(OAuthState).filter(OAuthState.gateway_id == gateway_id, OAuthState.state == state).first()
1000 if not oauth_state:
1001 return None
1003 # Check expiration
1004 expires_at = oauth_state.expires_at
1005 if expires_at.tzinfo is None:
1006 expires_at = expires_at.replace(tzinfo=timezone.utc)
1008 if expires_at < datetime.now(timezone.utc):
1009 db.delete(oauth_state)
1010 db.commit()
1011 return None
1013 # Check if already used
1014 if oauth_state.used:
1015 return None
1017 # Build state data
1018 state_data = {
1019 "state": oauth_state.state,
1020 "gateway_id": oauth_state.gateway_id,
1021 "code_verifier": oauth_state.code_verifier,
1022 "expires_at": oauth_state.expires_at.isoformat(),
1023 }
1024 if hasattr(oauth_state, "app_user_email"):
1025 state_data["app_user_email"] = getattr(oauth_state, "app_user_email", None)
1027 # Mark as used and delete
1028 db.delete(oauth_state)
1029 db.commit()
1031 return state_data
1032 finally:
1033 db_gen.close()
1034 except Exception as e:
1035 logger.warning(f"Failed to validate state in database: {e}")
1037 # Fallback to in-memory
1038 state_key = f"oauth:state:{gateway_id}:{state}"
1039 async with _state_lock:
1040 state_data = _oauth_states.get(state_key)
1041 if not state_data:
1042 return None
1044 # Check expiration
1045 expires_at = datetime.fromisoformat(state_data["expires_at"])
1046 if expires_at.tzinfo is None:
1047 expires_at = expires_at.replace(tzinfo=timezone.utc)
1049 if expires_at < datetime.now(timezone.utc):
1050 del _oauth_states[state_key]
1051 _oauth_state_lookup.pop(state, None)
1052 return None
1054 # Remove from memory (single-use)
1055 del _oauth_states[state_key]
1056 _oauth_state_lookup.pop(state, None)
1057 return state_data
1059 def _create_authorization_url(self, credentials: Dict[str, Any], state: str) -> tuple[str, str]:
1060 """Create authorization URL with state parameter.
1062 Args:
1063 credentials: OAuth configuration
1064 state: State parameter for CSRF protection
1066 Returns:
1067 Tuple of (authorization_url, state)
1068 """
1069 client_id = credentials["client_id"]
1070 redirect_uri = credentials["redirect_uri"]
1071 authorization_url = credentials["authorization_url"]
1072 scopes = credentials.get("scopes", [])
1074 # Create OAuth2 session
1075 oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes)
1077 # Generate authorization URL with state for CSRF protection
1078 auth_url, state = oauth.authorization_url(authorization_url, state=state)
1080 return auth_url, state
1082 @staticmethod
1083 def _is_microsoft_entra_v2_endpoint(endpoint_url: Any) -> bool:
1084 """Return True when endpoint matches Microsoft Entra v2 login endpoints.
1086 Args:
1087 endpoint_url: OAuth endpoint URL to check
1089 Returns:
1090 True if the endpoint is a Microsoft Entra v2 OAuth endpoint
1091 """
1092 if not isinstance(endpoint_url, str) or not endpoint_url:
1093 return False
1095 parsed = urlparse(endpoint_url)
1096 host = (parsed.hostname or "").lower()
1097 path = parsed.path.lower()
1099 return host in OAuthManager._ENTRA_HOSTS and "/oauth2/v2.0/" in path
1101 @staticmethod
1102 def _is_enabled_flag(value: Any) -> bool:
1103 """Parse boolean-like config values from oauth_config.
1105 Args:
1106 value: Config value to interpret as boolean
1108 Returns:
1109 True if value represents an enabled/truthy setting
1110 """
1111 if isinstance(value, bool):
1112 return value
1113 if isinstance(value, str):
1114 return value.strip().lower() in {"1", "true", "yes", "on"}
1115 return False
1117 def _should_include_resource_parameter(self, credentials: Dict[str, Any], scopes: Any) -> bool:
1118 """Determine whether RFC 8707 resource should be sent for this request.
1120 Args:
1121 credentials: OAuth configuration containing resource and endpoint URLs
1122 scopes: OAuth scopes for the request
1124 Returns:
1125 True if the resource parameter should be included in the request
1126 """
1127 if not credentials.get("resource"):
1128 return False
1130 if self._is_enabled_flag(credentials.get("omit_resource")):
1131 return False
1133 # Microsoft Entra v2 does not accept legacy resource with v2 scope-based requests.
1134 if scopes and (self._is_microsoft_entra_v2_endpoint(credentials.get("authorization_url")) or self._is_microsoft_entra_v2_endpoint(credentials.get("token_url"))):
1135 logger.info("Omitting OAuth resource parameter for Microsoft Entra v2 scope-based flow")
1136 return False
1138 return True
1140 def _create_authorization_url_with_pkce(self, credentials: Dict[str, Any], state: str, code_challenge: str, code_challenge_method: str) -> str:
1141 """Create authorization URL with PKCE parameters (RFC 7636).
1143 Args:
1144 credentials: OAuth configuration
1145 state: State parameter for CSRF protection
1146 code_challenge: PKCE code challenge
1147 code_challenge_method: PKCE method (S256)
1149 Returns:
1150 Authorization URL string with PKCE parameters
1151 """
1152 # Standard
1153 from urllib.parse import urlencode # pylint: disable=import-outside-toplevel
1155 client_id = credentials["client_id"]
1156 redirect_uri = credentials["redirect_uri"]
1157 authorization_url = credentials["authorization_url"]
1158 scopes = credentials.get("scopes", [])
1160 # Build authorization parameters
1161 params = {"response_type": "code", "client_id": client_id, "redirect_uri": redirect_uri, "state": state, "code_challenge": code_challenge, "code_challenge_method": code_challenge_method}
1163 # Add scopes if present
1164 if scopes:
1165 params["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes
1167 # Add resource parameter for JWT access token (RFC 8707)
1168 # The resource is the MCP server URL, set by oauth_router.py
1169 resource = credentials.get("resource")
1170 if self._should_include_resource_parameter(credentials, scopes):
1171 params["resource"] = resource # urlencode with doseq=True handles lists
1173 # Build full URL (doseq=True handles list values like multiple resource params)
1174 query_string = urlencode(params, doseq=True)
1175 return f"{authorization_url}?{query_string}"
1177 async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str, code_verifier: str = None) -> Dict[str, Any]:
1178 """Exchange authorization code for tokens with PKCE support.
1180 Args:
1181 credentials: OAuth configuration
1182 code: Authorization code from callback
1183 code_verifier: Optional PKCE code verifier (RFC 7636)
1185 Returns:
1186 Token response dictionary
1188 Raises:
1189 OAuthError: If token exchange fails
1190 """
1191 runtime_credentials = await self._prepare_runtime_credentials(credentials, "authorization_code_exchange_with_pkce")
1192 client_id = runtime_credentials["client_id"]
1193 client_secret = runtime_credentials.get("client_secret") # Optional for public clients (PKCE-only)
1194 token_url = runtime_credentials["token_url"]
1195 redirect_uri = runtime_credentials["redirect_uri"]
1197 # Prepare token exchange data
1198 token_data = {
1199 "grant_type": "authorization_code",
1200 "code": code,
1201 "redirect_uri": redirect_uri,
1202 "client_id": client_id,
1203 }
1205 # Only include client_secret if present (public clients don't have secrets)
1206 if client_secret:
1207 token_data["client_secret"] = client_secret
1209 # Add PKCE code_verifier if present (RFC 7636)
1210 if code_verifier:
1211 token_data["code_verifier"] = code_verifier
1213 # Add resource parameter to request JWT access token (RFC 8707)
1214 # The resource identifies the MCP server (resource server), not the OAuth server
1215 resource = runtime_credentials.get("resource")
1216 scopes = runtime_credentials.get("scopes", [])
1217 if self._should_include_resource_parameter(credentials, scopes):
1218 if isinstance(resource, list):
1219 # RFC 8707 allows multiple resource parameters - use list of tuples
1220 form_data: list[tuple[str, str]] = list(token_data.items())
1221 for r in resource:
1222 if r:
1223 form_data.append(("resource", r))
1224 token_data = form_data # type: ignore[assignment]
1225 else:
1226 token_data["resource"] = resource
1228 # Exchange code for token with retries
1229 for attempt in range(self.max_retries):
1230 try:
1231 client = await self._get_client()
1232 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
1233 response.raise_for_status()
1235 # GitHub returns form-encoded responses, not JSON
1236 content_type = response.headers.get("content-type", "")
1237 if "application/x-www-form-urlencoded" in content_type:
1238 # Parse form-encoded response
1239 text_response = response.text
1240 token_response = {}
1241 for pair in text_response.split("&"):
1242 if "=" in pair:
1243 key, value = pair.split("=", 1)
1244 token_response[key] = value
1245 else:
1246 # Try JSON response
1247 try:
1248 token_response = response.json()
1249 except Exception as e:
1250 logger.warning(f"Failed to parse JSON response: {e}")
1251 # Fallback to text parsing
1252 text_response = response.text
1253 token_response = {"raw_response": text_response}
1255 if "access_token" not in token_response:
1256 raise OAuthError(f"No access_token in response: {token_response}")
1258 logger.info("""Successfully exchanged authorization code for tokens""")
1259 return token_response
1261 except httpx.HTTPError as e:
1262 logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}")
1263 if attempt == self.max_retries - 1:
1264 raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}")
1265 await asyncio.sleep(2**attempt) # Exponential backoff
1267 # This should never be reached due to the exception above, but needed for type safety
1268 raise OAuthError("Failed to exchange code for token after all retry attempts")
1270 async def refresh_token(self, refresh_token: str, credentials: Dict[str, Any]) -> Dict[str, Any]:
1271 """Refresh an expired access token using a refresh token.
1273 Args:
1274 refresh_token: The refresh token to use
1275 credentials: OAuth configuration including client_id, client_secret, token_url
1277 Returns:
1278 Dict containing new access_token, optional refresh_token, and expires_in
1280 Raises:
1281 OAuthError: If token refresh fails
1282 """
1283 if not refresh_token:
1284 raise OAuthError("No refresh token available")
1286 runtime_credentials = await self._prepare_runtime_credentials(credentials, "refresh_token")
1287 token_url = runtime_credentials.get("token_url")
1288 if not token_url:
1289 raise OAuthError("No token URL configured for OAuth provider")
1291 client_id = runtime_credentials.get("client_id")
1292 client_secret = runtime_credentials.get("client_secret")
1294 if not client_id:
1295 raise OAuthError("No client_id configured for OAuth provider")
1297 # Prepare token refresh request
1298 token_data = {
1299 "grant_type": "refresh_token",
1300 "refresh_token": refresh_token,
1301 "client_id": client_id,
1302 }
1304 # Add client_secret if available (some providers require it)
1305 if client_secret:
1306 token_data["client_secret"] = client_secret
1308 # Add resource parameter for JWT access token (RFC 8707)
1309 # Must be included in refresh requests to maintain JWT token type
1310 resource = runtime_credentials.get("resource")
1311 scopes = runtime_credentials.get("scopes", [])
1312 if self._should_include_resource_parameter(credentials, scopes):
1313 if isinstance(resource, list):
1314 # RFC 8707 allows multiple resource parameters - use list of tuples
1315 form_data: list[tuple[str, str]] = list(token_data.items())
1316 for r in resource:
1317 if r:
1318 form_data.append(("resource", r))
1319 token_data = form_data # type: ignore[assignment]
1320 else:
1321 token_data["resource"] = resource
1323 # Attempt token refresh with retries
1324 for attempt in range(self.max_retries):
1325 try:
1326 client = await self._get_client()
1327 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
1328 if response.status_code == 200:
1329 token_response = response.json()
1331 # Validate required fields
1332 if "access_token" not in token_response:
1333 raise OAuthError("No access_token in refresh response")
1335 logger.info("Successfully refreshed OAuth token")
1336 return token_response
1338 error_text = response.text
1339 # If we get a 400/401, the refresh token is likely invalid
1340 if response.status_code in [400, 401]:
1341 raise OAuthError(f"Refresh token invalid or expired: {error_text}")
1342 logger.warning(f"Token refresh failed with status {response.status_code}: {error_text}")
1344 except httpx.HTTPError as e:
1345 logger.warning(f"Token refresh attempt {attempt + 1} failed: {str(e)}")
1346 if attempt == self.max_retries - 1:
1347 raise OAuthError(f"Failed to refresh token after {self.max_retries} attempts: {str(e)}")
1348 await asyncio.sleep(2**attempt) # Exponential backoff
1350 raise OAuthError("Failed to refresh token after all retry attempts")
1352 def _extract_user_id(self, token_response: Dict[str, Any], credentials: Dict[str, Any]) -> str:
1353 """Extract user ID from token response.
1355 Args:
1356 token_response: Response from token exchange
1357 credentials: OAuth configuration
1359 Returns:
1360 User ID string
1361 """
1362 # Try to extract user ID from various common fields in token response
1363 # Different OAuth providers use different field names
1365 # Check for 'sub' (subject) - JWT standard
1366 if "sub" in token_response:
1367 return token_response["sub"]
1369 # Check for 'user_id' - common in some OAuth responses
1370 if "user_id" in token_response:
1371 return token_response["user_id"]
1373 # Check for 'id' - also common
1374 if "id" in token_response:
1375 return token_response["id"]
1377 # Fallback to client_id if no user info is available
1378 if credentials.get("client_id"):
1379 return credentials["client_id"]
1381 # Final fallback
1382 return "unknown_user"
1385class OAuthError(Exception):
1386 """OAuth-related errors.
1388 Examples:
1389 >>> try:
1390 ... raise OAuthError("Token acquisition failed")
1391 ... except OAuthError as e:
1392 ... str(e)
1393 'Token acquisition failed'
1394 >>> try:
1395 ... raise OAuthError("Invalid grant type")
1396 ... except Exception as e:
1397 ... isinstance(e, OAuthError)
1398 True
1399 """
1402class OAuthRequiredError(OAuthError):
1403 """Raised when a server requires OAuth but the caller is unauthenticated.
1405 Carries ``server_id`` so the middleware can identify which server
1406 triggered the rejection when constructing the ``WWW-Authenticate``
1407 header.
1409 Examples:
1410 >>> err = OAuthRequiredError("auth required", server_id="s1")
1411 >>> err.server_id
1412 's1'
1413 >>> isinstance(err, OAuthError)
1414 True
1415 """
1417 def __init__(self, message: str, *, server_id: str = "") -> None:
1418 """Initialize with message and optional server_id.
1420 Args:
1421 message: Human-readable error description.
1422 server_id: Virtual-server identifier that triggered the rejection.
1423 """
1424 super().__init__(message)
1425 self.server_id = server_id
1428class OAuthEnforcementUnavailableError(OAuthError):
1429 """Raised when OAuth enforcement cannot be performed due to infrastructure failure.
1431 Used when the database or other backing services needed to check a
1432 server's ``oauth_enabled`` flag are unavailable. The middleware
1433 translates this into an HTTP 503 to avoid silently allowing
1434 unauthenticated access (fail-closed).
1436 Examples:
1437 >>> err = OAuthEnforcementUnavailableError("DB down", server_id="s1")
1438 >>> err.server_id
1439 's1'
1440 >>> isinstance(err, OAuthError)
1441 True
1442 """
1444 def __init__(self, message: str, *, server_id: str = "") -> None:
1445 """Initialize with message and optional server_id.
1447 Args:
1448 message: Human-readable error description.
1449 server_id: Virtual-server identifier that triggered the rejection.
1450 """
1451 super().__init__(message)
1452 self.server_id = server_id