Coverage for mcpgateway / services / oauth_manager.py: 99%
647 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/oauth_manager.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7OAuth 2.0 Manager for ContextForge.
9This module handles OAuth 2.0 authentication flows including:
10- Client Credentials (Machine-to-Machine)
11- Authorization Code (User Delegation)
12"""
14# Standard
15import asyncio
16import base64
17from datetime import datetime, timedelta, timezone
18import hashlib
19import logging
20import secrets
21from typing import Any, Dict, Optional
22from urllib.parse import urlparse
24# Third-Party
25import httpx
26import orjson
27from requests_oauthlib import OAuth2Session
29# First-Party
30from mcpgateway.config import get_settings
31from mcpgateway.services.encryption_service import decrypt_oauth_config_for_runtime, get_encryption_service
32from mcpgateway.services.http_client_service import get_http_client
33from mcpgateway.utils.redis_client import get_redis_client as _get_shared_redis_client
35logger = logging.getLogger(__name__)
37# In-memory storage for OAuth states with expiration (fallback for single-process)
38# Format: {state_key: {"state": state, "gateway_id": gateway_id, "expires_at": datetime}}
39_oauth_states: Dict[str, Dict[str, Any]] = {}
40# Reverse lookup for callback handlers that only receive state.
41# Format: {state: gateway_id}
42_oauth_state_lookup: Dict[str, str] = {}
43# Lock for thread-safe state operations
44_state_lock = asyncio.Lock()
46# State TTL in seconds (5 minutes)
47STATE_TTL_SECONDS = 300
49# Redis client for distributed state storage (uses shared factory)
50_redis_client: Optional[Any] = None
51_REDIS_INITIALIZED = False
54async def _get_redis_client():
55 """Get shared Redis client for distributed state storage.
57 Uses the centralized Redis client factory for consistent configuration.
59 Returns:
60 Redis client instance or None if unavailable
61 """
62 global _redis_client, _REDIS_INITIALIZED # pylint: disable=global-statement
64 if _REDIS_INITIALIZED:
65 return _redis_client
67 settings = get_settings()
68 if settings.cache_type == "redis" and settings.redis_url:
69 try:
70 _redis_client = await _get_shared_redis_client()
71 if _redis_client:
72 logger.info("OAuth manager connected to shared Redis client")
73 except Exception as e:
74 logger.warning(f"Failed to get Redis client, falling back to in-memory storage: {e}")
75 _redis_client = None
76 else:
77 _redis_client = None
79 _REDIS_INITIALIZED = True
80 return _redis_client
83class OAuthManager:
84 """Manages OAuth 2.0 authentication flows.
86 Examples:
87 >>> manager = OAuthManager(request_timeout=30, max_retries=3)
88 >>> manager.request_timeout
89 30
90 >>> manager.max_retries
91 3
92 >>> manager.token_storage is None
93 True
94 >>>
95 >>> # Test grant type validation
96 >>> grant_type = "client_credentials"
97 >>> grant_type in ["client_credentials", "authorization_code"]
98 True
99 >>> grant_type = "invalid_grant"
100 >>> grant_type in ["client_credentials", "authorization_code"]
101 False
102 >>>
103 >>> # Test encrypted secret detection heuristic
104 >>> short_secret = "secret123"
105 >>> len(short_secret) > 50
106 False
107 >>> encrypted_secret = "gAAAAABh" + "x" * 60 # Simulated encrypted secret
108 >>> len(encrypted_secret) > 50
109 True
110 >>>
111 >>> # Test scope list handling
112 >>> scopes = ["read", "write"]
113 >>> " ".join(scopes)
114 'read write'
115 >>> empty_scopes = []
116 >>> " ".join(empty_scopes)
117 ''
118 """
120 # Known Microsoft Entra login hosts (global + sovereign clouds).
121 _ENTRA_HOSTS: frozenset[str] = frozenset(
122 {
123 "login.microsoftonline.com",
124 "login.microsoftonline.us",
125 "login.microsoftonline.de",
126 "login.partner.microsoftonline.cn",
127 }
128 )
130 def __init__(self, request_timeout: int = 30, max_retries: int = 3, token_storage: Optional[Any] = None):
131 """Initialize OAuth Manager.
133 Args:
134 request_timeout: Timeout for OAuth requests in seconds
135 max_retries: Maximum number of retry attempts for token requests
136 token_storage: Optional TokenStorageService for storing tokens
137 """
138 self.request_timeout = request_timeout
139 self.max_retries = max_retries
140 self.token_storage = token_storage
141 self.settings = get_settings()
143 async def _get_client(self) -> httpx.AsyncClient:
144 """Get the shared singleton HTTP client.
146 Returns:
147 Shared httpx.AsyncClient instance with connection pooling
148 """
149 return await get_http_client()
151 def _generate_pkce_params(self) -> Dict[str, str]:
152 """Generate PKCE parameters for OAuth Authorization Code flow (RFC 7636).
154 Returns:
155 Dict containing code_verifier, code_challenge, and code_challenge_method
156 """
157 # Generate code_verifier: 43-128 character random string
158 code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
160 # Generate code_challenge: base64url(SHA256(code_verifier))
161 code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()).decode("utf-8").rstrip("=")
163 return {"code_verifier": code_verifier, "code_challenge": code_challenge, "code_challenge_method": "S256"}
165 async def get_access_token(self, credentials: Dict[str, Any]) -> str:
166 """Get access token based on grant type.
168 Args:
169 credentials: OAuth configuration containing grant_type and other params
171 Returns:
172 Access token string
174 Raises:
175 ValueError: If grant type is unsupported
176 OAuthError: If token acquisition fails
178 Examples:
179 Client credentials flow:
180 >>> import asyncio
181 >>> class TestMgr(OAuthManager):
182 ... async def _client_credentials_flow(self, credentials):
183 ... return 'tok'
184 >>> mgr = TestMgr()
185 >>> asyncio.run(mgr.get_access_token({'grant_type': 'client_credentials'}))
186 'tok'
188 Authorization code flow requires interactive completion:
189 >>> def _auth_code_requires_consent():
190 ... try:
191 ... asyncio.run(mgr.get_access_token({'grant_type': 'authorization_code'}))
192 ... except OAuthError:
193 ... return True
194 >>> _auth_code_requires_consent()
195 True
197 Unsupported grant type raises ValueError:
198 >>> def _unsupported():
199 ... try:
200 ... asyncio.run(mgr.get_access_token({'grant_type': 'bad'}))
201 ... except ValueError:
202 ... return True
203 >>> _unsupported()
204 True
205 """
206 grant_type = credentials.get("grant_type")
207 logger.debug(f"Getting access token for grant type: {grant_type}")
209 if grant_type == "client_credentials":
210 return await self._client_credentials_flow(credentials)
211 if grant_type == "password":
212 return await self._password_flow(credentials)
213 if grant_type == "authorization_code":
214 raise OAuthError("Authorization code flow requires user consent via /oauth/authorize and does not support client_credentials fallback")
215 raise ValueError(f"Unsupported grant type: {grant_type}")
217 @staticmethod
218 async def _prepare_runtime_credentials(credentials: Dict[str, Any], flow_name: str) -> Dict[str, Any]:
219 """Return runtime-ready oauth credentials with sensitive fields decrypted.
221 Args:
222 credentials: Stored oauth_config payload.
223 flow_name: Flow label for diagnostic logging.
225 Returns:
226 Dict[str, Any]: Runtime-ready credentials.
227 """
228 try:
229 settings = get_settings()
230 encryption = get_encryption_service(settings.auth_encryption_secret)
231 runtime_credentials = await decrypt_oauth_config_for_runtime(credentials, encryption=encryption)
232 if isinstance(runtime_credentials, dict):
233 return runtime_credentials
234 except Exception as exc:
235 logger.warning("Failed to prepare runtime OAuth credentials for %s flow: %s", flow_name, exc)
236 return credentials
238 async def _client_credentials_flow(self, credentials: Dict[str, Any]) -> str:
239 """Machine-to-machine authentication using client credentials.
241 Args:
242 credentials: OAuth configuration with client_id, client_secret, token_url
244 Returns:
245 Access token string
247 Raises:
248 OAuthError: If token acquisition fails after all retries
249 """
250 runtime_credentials = await self._prepare_runtime_credentials(credentials, "client_credentials")
251 client_id = runtime_credentials["client_id"]
252 client_secret = runtime_credentials["client_secret"]
253 token_url = runtime_credentials["token_url"]
254 scopes = runtime_credentials.get("scopes", [])
256 # Prepare token request data
257 token_data = {
258 "grant_type": "client_credentials",
259 "client_id": client_id,
260 "client_secret": client_secret,
261 }
263 if scopes:
264 token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes
266 # Fetch token with retries
267 for attempt in range(self.max_retries):
268 try:
269 client = await self._get_client()
270 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
271 response.raise_for_status()
273 # GitHub returns form-encoded responses, not JSON
274 content_type = response.headers.get("content-type", "")
275 if "application/x-www-form-urlencoded" in content_type:
276 # Parse form-encoded response
277 text_response = response.text
278 token_response = {}
279 for pair in text_response.split("&"):
280 if "=" in pair:
281 key, value = pair.split("=", 1)
282 token_response[key] = value
283 else:
284 # Try JSON response
285 try:
286 token_response = response.json()
287 except Exception as e:
288 logger.warning(f"Failed to parse JSON response: {e}")
289 # Fallback to text parsing
290 text_response = response.text
291 token_response = {"raw_response": text_response}
293 if "access_token" not in token_response:
294 raise OAuthError(f"No access_token in response: {token_response}")
296 logger.info("""Successfully obtained access token via client credentials""")
297 return token_response["access_token"]
299 except httpx.HTTPError as e:
300 logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}")
301 if attempt == self.max_retries - 1:
302 raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}")
303 await asyncio.sleep(2**attempt) # Exponential backoff
305 # This should never be reached due to the exception above, but needed for type safety
306 raise OAuthError("Failed to obtain access token after all retry attempts")
308 async def _password_flow(self, credentials: Dict[str, Any]) -> str:
309 """Resource Owner Password Credentials flow (RFC 6749 Section 4.3).
311 This flow is used when the application can directly handle the user's credentials,
312 such as with trusted first-party applications or legacy integrations like Keycloak.
314 Args:
315 credentials: OAuth configuration with client_id, optional client_secret, token_url, username, password
317 Returns:
318 Access token string
320 Raises:
321 OAuthError: If token acquisition fails after all retries
322 """
323 runtime_credentials = await self._prepare_runtime_credentials(credentials, "password")
324 client_id = runtime_credentials.get("client_id")
325 client_secret = runtime_credentials.get("client_secret")
326 token_url = runtime_credentials["token_url"]
327 username = runtime_credentials.get("username")
328 password = runtime_credentials.get("password")
329 scopes = runtime_credentials.get("scopes", [])
331 if not username or not password:
332 raise OAuthError("Username and password are required for password grant type")
334 # Prepare token request data
335 token_data = {
336 "grant_type": "password",
337 "username": username,
338 "password": password,
339 }
341 # Add client_id (required by most providers including Keycloak)
342 if client_id:
343 token_data["client_id"] = client_id
345 # Add client_secret if present (some providers require it, others don't)
346 if client_secret:
347 token_data["client_secret"] = client_secret
349 if scopes:
350 token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes
352 # Fetch token with retries
353 for attempt in range(self.max_retries):
354 try:
355 client = await self._get_client()
356 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
357 response.raise_for_status()
359 # Handle both JSON and form-encoded responses
360 content_type = response.headers.get("content-type", "")
361 if "application/x-www-form-urlencoded" in content_type:
362 # Parse form-encoded response
363 text_response = response.text
364 token_response = {}
365 for pair in text_response.split("&"):
366 if "=" in pair:
367 key, value = pair.split("=", 1)
368 token_response[key] = value
369 else:
370 # Try JSON response
371 try:
372 token_response = response.json()
373 except Exception as e:
374 logger.warning(f"Failed to parse JSON response: {e}")
375 # Fallback to text parsing
376 text_response = response.text
377 token_response = {"raw_response": text_response}
379 if "access_token" not in token_response:
380 raise OAuthError(f"No access_token in response: {token_response}")
382 logger.info("Successfully obtained access token via password grant")
383 return token_response["access_token"]
385 except httpx.HTTPError as e:
386 logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}")
387 if attempt == self.max_retries - 1:
388 raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}")
389 await asyncio.sleep(2**attempt) # Exponential backoff
391 # This should never be reached due to the exception above, but needed for type safety
392 raise OAuthError("Failed to obtain access token after all retry attempts")
394 async def get_authorization_url(self, credentials: Dict[str, Any]) -> Dict[str, str]:
395 """Get authorization URL for user delegation flow.
397 Args:
398 credentials: OAuth configuration with client_id, authorization_url, etc.
400 Returns:
401 Dict containing authorization_url and state
402 """
403 client_id = credentials["client_id"]
404 redirect_uri = credentials["redirect_uri"]
405 authorization_url = credentials["authorization_url"]
406 scopes = credentials.get("scopes", [])
408 # Create OAuth2 session
409 oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes)
411 # Generate authorization URL with state for CSRF protection
412 auth_url, state = oauth.authorization_url(authorization_url)
414 logger.info(f"Generated authorization URL for client {client_id}")
416 return {"authorization_url": auth_url, "state": state}
418 async def exchange_code_for_token(self, credentials: Dict[str, Any], code: str, state: str) -> str: # pylint: disable=unused-argument
419 """Exchange authorization code for access token.
421 Args:
422 credentials: OAuth configuration
423 code: Authorization code from callback
424 state: State parameter for CSRF validation
426 Returns:
427 Access token string
429 Raises:
430 OAuthError: If token exchange fails
431 """
432 runtime_credentials = await self._prepare_runtime_credentials(credentials, "authorization_code_exchange")
433 client_id = runtime_credentials["client_id"]
434 client_secret = runtime_credentials.get("client_secret") # Optional for public clients (PKCE-only)
435 token_url = runtime_credentials["token_url"]
436 redirect_uri = runtime_credentials["redirect_uri"]
438 # Prepare token exchange data
439 token_data = {
440 "grant_type": "authorization_code",
441 "code": code,
442 "redirect_uri": redirect_uri,
443 "client_id": client_id,
444 }
446 # Only include client_secret if present (public clients don't have secrets)
447 if client_secret:
448 token_data["client_secret"] = client_secret
450 # Exchange code for token with retries
451 for attempt in range(self.max_retries):
452 try:
453 client = await self._get_client()
454 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
455 response.raise_for_status()
457 # GitHub returns form-encoded responses, not JSON
458 content_type = response.headers.get("content-type", "")
459 if "application/x-www-form-urlencoded" in content_type:
460 # Parse form-encoded response
461 text_response = response.text
462 token_response = {}
463 for pair in text_response.split("&"):
464 if "=" in pair:
465 key, value = pair.split("=", 1)
466 token_response[key] = value
467 else:
468 # Try JSON response
469 try:
470 token_response = response.json()
471 except Exception as e:
472 logger.warning(f"Failed to parse JSON response: {e}")
473 # Fallback to text parsing
474 text_response = response.text
475 token_response = {"raw_response": text_response}
477 if "access_token" not in token_response:
478 raise OAuthError(f"No access_token in response: {token_response}")
480 logger.info("""Successfully exchanged authorization code for access token""")
481 return token_response["access_token"]
483 except httpx.HTTPError as e:
484 logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}")
485 if attempt == self.max_retries - 1:
486 raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}")
487 await asyncio.sleep(2**attempt) # Exponential backoff
489 # This should never be reached due to the exception above, but needed for type safety
490 raise OAuthError("Failed to exchange code for token after all retry attempts")
492 async def initiate_authorization_code_flow(self, gateway_id: str, credentials: Dict[str, Any], app_user_email: str = None) -> Dict[str, str]:
493 """Initiate Authorization Code flow with PKCE and return authorization URL.
495 Args:
496 gateway_id: ID of the gateway being configured
497 credentials: OAuth configuration with client_id, authorization_url, etc.
498 app_user_email: ContextForge user email to associate with tokens
500 Returns:
501 Dict containing authorization_url and state
502 """
504 # Generate PKCE parameters (RFC 7636)
505 pkce_params = self._generate_pkce_params()
507 # Generate state parameter with user context for CSRF protection
508 state = self._generate_state(gateway_id, app_user_email)
510 # Store state with code_verifier in session/cache for validation
511 if self.token_storage:
512 await self._store_authorization_state(
513 gateway_id,
514 state,
515 code_verifier=pkce_params["code_verifier"],
516 app_user_email=app_user_email,
517 )
519 # Generate authorization URL with PKCE
520 auth_url = self._create_authorization_url_with_pkce(credentials, state, pkce_params["code_challenge"], pkce_params["code_challenge_method"])
522 logger.info(f"Generated authorization URL with PKCE for gateway {gateway_id}")
524 return {"authorization_url": auth_url, "state": state, "gateway_id": gateway_id}
526 async def complete_authorization_code_flow(self, gateway_id: str, code: str, state: str, credentials: Dict[str, Any]) -> Dict[str, Any]:
527 """Complete Authorization Code flow with PKCE and store tokens.
529 Args:
530 gateway_id: ID of the gateway
531 code: Authorization code from callback
532 state: State parameter for CSRF validation
533 credentials: OAuth configuration
535 Returns:
536 Dict containing success status, user_id, and expiration info
538 Raises:
539 OAuthError: If state validation fails or token exchange fails
540 """
541 # Validate state and retrieve code_verifier
542 state_data = await self._validate_and_retrieve_state(gateway_id, state)
543 if not state_data:
544 raise OAuthError("Invalid or expired state parameter - possible replay attack")
546 code_verifier = state_data.get("code_verifier")
547 app_user_email = state_data.get("app_user_email")
549 # Backward compatibility for in-flight legacy states that embedded user context.
550 if not app_user_email:
551 legacy_state_payload = self._extract_legacy_state_payload(state)
552 if legacy_state_payload:
553 legacy_gateway_id = legacy_state_payload.get("gateway_id")
554 if legacy_gateway_id and legacy_gateway_id != gateway_id:
555 raise OAuthError("State parameter gateway mismatch")
556 app_user_email = legacy_state_payload.get("app_user_email")
558 # Exchange code for tokens with PKCE code_verifier
559 token_response = await self._exchange_code_for_tokens(credentials, code, code_verifier=code_verifier)
561 # Extract user information from token response
562 user_id = self._extract_user_id(token_response, credentials)
564 # Store tokens if storage service is available
565 if self.token_storage:
566 if not app_user_email:
567 raise OAuthError("User context required for OAuth token storage")
569 token_record = await self.token_storage.store_tokens(
570 gateway_id=gateway_id,
571 user_id=user_id,
572 app_user_email=app_user_email, # User from state
573 access_token=token_response["access_token"],
574 refresh_token=token_response.get("refresh_token"),
575 expires_in=token_response.get("expires_in", self.settings.oauth_default_timeout),
576 scopes=token_response.get("scope", "").split(),
577 )
579 return {"success": True, "user_id": user_id, "expires_at": token_record.expires_at.isoformat() if token_record.expires_at else None}
580 return {"success": True, "user_id": user_id, "expires_at": None}
582 async def get_access_token_for_user(self, gateway_id: str, app_user_email: str) -> Optional[str]:
583 """Get valid access token for a specific user.
585 Args:
586 gateway_id: ID of the gateway
587 app_user_email: ContextForge user email
589 Returns:
590 Valid access token or None if not available
591 """
592 if self.token_storage:
593 return await self.token_storage.get_user_token(gateway_id, app_user_email)
594 return None
596 def _generate_state(self, _gateway_id: str, _app_user_email: str = None) -> str:
597 """Generate an opaque state token for CSRF protection.
599 Args:
600 _gateway_id: Gateway identifier (reserved for compatibility with
601 prior embedded-state call sites).
602 _app_user_email: ContextForge user email (reserved for
603 compatibility with prior embedded-state call sites).
605 Returns:
606 Opaque random state token
607 """
608 return secrets.token_urlsafe(48)
610 @staticmethod
611 def _extract_legacy_state_payload(state: str) -> Optional[Dict[str, Any]]:
612 """Best-effort decode of legacy state payloads used before opaque states.
614 Legacy formats supported:
615 - base64url(payload || signature) where payload is JSON
616 - gateway_id_random suffix format
618 Args:
619 state: Callback state token to decode.
621 Returns:
622 Decoded legacy payload when format is recognized; otherwise ``None``.
623 """
624 try:
625 state_raw = base64.urlsafe_b64decode(state.encode())
626 if len(state_raw) <= 32:
627 return None
629 payload_bytes = state_raw[:-32]
630 payload = orjson.loads(payload_bytes)
631 if isinstance(payload, dict):
632 return payload
633 except Exception:
634 # Fall back to legacy gateway_id_random format
635 if "_" in state:
636 gateway_id = state.split("_", 1)[0]
637 if gateway_id:
638 return {"gateway_id": gateway_id}
639 return None
641 async def resolve_gateway_id_from_state(self, state: str, allow_legacy_fallback: bool = True) -> Optional[str]:
642 """Resolve gateway ID for a callback state token without consuming it.
644 Args:
645 state: OAuth callback state parameter
646 allow_legacy_fallback: Whether to decode legacy callback state formats.
648 Returns:
649 Gateway ID when resolvable, otherwise ``None``.
650 """
651 settings = get_settings()
653 if settings.cache_type == "redis":
654 redis = await _get_redis_client()
655 if redis:
656 try:
657 lookup_key = f"oauth:state_lookup:{state}"
658 gateway_id = await redis.get(lookup_key)
659 if gateway_id:
660 if isinstance(gateway_id, bytes):
661 gateway_id = gateway_id.decode("utf-8")
662 return gateway_id
663 except Exception as e:
664 logger.warning(f"Failed to resolve state gateway in Redis: {e}")
666 if settings.cache_type == "database":
667 try:
668 # First-Party
669 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel
671 db_gen = get_db()
672 db = next(db_gen)
673 try:
674 oauth_state = db.query(OAuthState).filter(OAuthState.state == state).first()
675 if oauth_state:
676 return oauth_state.gateway_id
677 finally:
678 db_gen.close()
679 except Exception as e:
680 logger.warning(f"Failed to resolve state gateway in database: {e}")
682 async with _state_lock:
683 now = datetime.now(timezone.utc)
684 expired_keys = [key for key, data in _oauth_states.items() if datetime.fromisoformat(data["expires_at"]) < now]
685 for key in expired_keys:
686 expired_state = _oauth_states[key].get("state")
687 del _oauth_states[key]
688 if expired_state:
689 _oauth_state_lookup.pop(expired_state, None)
690 gateway_id = _oauth_state_lookup.get(state)
691 if gateway_id:
692 return gateway_id
694 if allow_legacy_fallback:
695 legacy_payload = self._extract_legacy_state_payload(state)
696 if legacy_payload:
697 return legacy_payload.get("gateway_id")
698 return None
700 async def _store_authorization_state(
701 self,
702 gateway_id: str,
703 state: str,
704 code_verifier: str = None,
705 app_user_email: str = None,
706 ) -> None:
707 """Store authorization state for validation with TTL.
709 Args:
710 gateway_id: ID of the gateway
711 state: State parameter to store
712 code_verifier: Optional PKCE code verifier (RFC 7636)
713 app_user_email: Requesting user email for token association
714 """
715 expires_at = datetime.now(timezone.utc) + timedelta(seconds=STATE_TTL_SECONDS)
716 settings = get_settings()
718 # Try Redis first for distributed storage
719 if settings.cache_type == "redis":
720 redis = await _get_redis_client()
721 if redis:
722 try:
723 state_key = f"oauth:state:{gateway_id}:{state}"
724 lookup_key = f"oauth:state_lookup:{state}"
725 state_data = {
726 "state": state,
727 "gateway_id": gateway_id,
728 "code_verifier": code_verifier,
729 "app_user_email": app_user_email,
730 "expires_at": expires_at.isoformat(),
731 "used": False,
732 }
733 # Store in Redis with TTL
734 await redis.setex(state_key, STATE_TTL_SECONDS, orjson.dumps(state_data))
735 await redis.setex(lookup_key, STATE_TTL_SECONDS, gateway_id)
736 logger.debug(f"Stored OAuth state in Redis for gateway {gateway_id}")
737 return
738 except Exception as e:
739 logger.warning(f"Failed to store state in Redis: {e}, falling back")
741 # Try database storage for multi-worker deployments
742 if settings.cache_type == "database":
743 try:
744 # First-Party
745 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel
747 db_gen = get_db()
748 db = next(db_gen)
749 try:
750 # Clean up expired states first
751 db.query(OAuthState).filter(OAuthState.expires_at < datetime.now(timezone.utc)).delete()
753 # Store new state with code_verifier
754 oauth_state_kwargs = {
755 "gateway_id": gateway_id,
756 "state": state,
757 "code_verifier": code_verifier,
758 "expires_at": expires_at,
759 "used": False,
760 }
761 if hasattr(OAuthState, "app_user_email"):
762 oauth_state_kwargs["app_user_email"] = app_user_email
764 oauth_state = OAuthState(**oauth_state_kwargs)
765 db.add(oauth_state)
766 db.commit()
767 logger.debug(f"Stored OAuth state in database for gateway {gateway_id}")
768 return
769 finally:
770 db_gen.close()
771 except Exception as e:
772 logger.warning(f"Failed to store state in database: {e}, falling back to memory")
774 # Fallback to in-memory storage for development
775 async with _state_lock:
776 # Clean up expired states first
777 now = datetime.now(timezone.utc)
778 state_key = f"oauth:state:{gateway_id}:{state}"
779 state_data = {
780 "state": state,
781 "gateway_id": gateway_id,
782 "code_verifier": code_verifier,
783 "app_user_email": app_user_email,
784 "expires_at": expires_at.isoformat(),
785 "used": False,
786 }
787 expired_states = [key for key, data in _oauth_states.items() if datetime.fromisoformat(data["expires_at"]) < now]
788 for key in expired_states:
789 expired_state_value = _oauth_states[key].get("state")
790 del _oauth_states[key]
791 if expired_state_value:
792 _oauth_state_lookup.pop(expired_state_value, None)
793 logger.debug(f"Cleaned up expired state: {key[:20]}...")
795 # Store the new state with expiration
796 _oauth_states[state_key] = state_data
797 _oauth_state_lookup[state] = gateway_id
798 logger.debug(f"Stored OAuth state in memory for gateway {gateway_id}")
800 async def _validate_authorization_state(self, gateway_id: str, state: str) -> bool:
801 """Validate authorization state parameter and mark as used.
803 Args:
804 gateway_id: ID of the gateway
805 state: State parameter to validate
807 Returns:
808 True if state is valid and not yet used, False otherwise
809 """
810 settings = get_settings()
812 # Try Redis first for distributed storage
813 if settings.cache_type == "redis":
814 redis = await _get_redis_client()
815 if redis:
816 try:
817 state_key = f"oauth:state:{gateway_id}:{state}"
818 lookup_key = f"oauth:state_lookup:{state}"
819 # Get and delete state atomically (single-use)
820 state_json = await redis.getdel(state_key)
821 await redis.delete(lookup_key)
822 if not state_json:
823 logger.warning(f"State not found in Redis for gateway {gateway_id}")
824 return False
826 state_data = orjson.loads(state_json)
828 # Parse expires_at as timezone-aware datetime. If the stored value
829 # is naive, assume UTC for compatibility.
830 try:
831 expires_at = datetime.fromisoformat(state_data["expires_at"])
832 except Exception:
833 # Fallback: try parsing without microseconds/offsets
834 expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S")
836 if expires_at.tzinfo is None:
837 # Assume UTC for naive timestamps
838 expires_at = expires_at.replace(tzinfo=timezone.utc)
840 # Check if state has expired
841 if expires_at < datetime.now(timezone.utc):
842 logger.warning(f"State has expired for gateway {gateway_id}")
843 return False
845 # Check if state was already used (should not happen with getdel)
846 if state_data.get("used", False):
847 logger.warning(f"State was already used for gateway {gateway_id} - possible replay attack")
848 return False
850 logger.debug(f"Successfully validated OAuth state from Redis for gateway {gateway_id}")
851 return True
852 except Exception as e:
853 logger.warning(f"Failed to validate state in Redis: {e}, falling back")
855 # Try database storage for multi-worker deployments
856 if settings.cache_type == "database":
857 try:
858 # First-Party
859 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel
861 db_gen = get_db()
862 db = next(db_gen)
863 try:
864 # Find the state
865 oauth_state = db.query(OAuthState).filter(OAuthState.gateway_id == gateway_id, OAuthState.state == state).first()
867 if not oauth_state:
868 logger.warning(f"State not found in database for gateway {gateway_id}")
869 return False
871 # Check if state has expired
872 # Ensure oauth_state.expires_at is timezone-aware. If naive, assume UTC.
873 expires_at = oauth_state.expires_at
874 if expires_at.tzinfo is None:
875 expires_at = expires_at.replace(tzinfo=timezone.utc)
877 if expires_at < datetime.now(timezone.utc):
878 logger.warning(f"State has expired for gateway {gateway_id}")
879 db.delete(oauth_state)
880 db.commit()
881 return False
883 # Check if state was already used
884 if oauth_state.used:
885 logger.warning(f"State has already been used for gateway {gateway_id} - possible replay attack")
886 return False
888 # Mark as used and delete (single-use)
889 db.delete(oauth_state)
890 db.commit()
891 logger.debug(f"Successfully validated OAuth state from database for gateway {gateway_id}")
892 return True
893 finally:
894 db_gen.close()
895 except Exception as e:
896 logger.warning(f"Failed to validate state in database: {e}, falling back to memory")
898 # Fallback to in-memory storage for development
899 state_key = f"oauth:state:{gateway_id}:{state}"
900 async with _state_lock:
901 state_data = _oauth_states.get(state_key)
903 # Check if state exists
904 if not state_data:
905 logger.warning(f"State not found in memory for gateway {gateway_id}")
906 return False
908 # Parse and normalize expires_at to timezone-aware datetime
909 expires_at = datetime.fromisoformat(state_data["expires_at"])
910 if expires_at.tzinfo is None:
911 expires_at = expires_at.replace(tzinfo=timezone.utc)
913 if expires_at < datetime.now(timezone.utc):
914 logger.warning(f"State has expired for gateway {gateway_id}")
915 del _oauth_states[state_key] # Clean up expired state
916 _oauth_state_lookup.pop(state, None)
917 return False
919 # Check if state has already been used (prevent replay)
920 if state_data.get("used", False):
921 logger.warning(f"State has already been used for gateway {gateway_id} - possible replay attack")
922 return False
924 # Mark state as used and remove it (single-use)
925 del _oauth_states[state_key]
926 _oauth_state_lookup.pop(state, None)
927 logger.debug(f"Successfully validated OAuth state from memory for gateway {gateway_id}")
928 return True
930 async def _validate_and_retrieve_state(self, gateway_id: str, state: str) -> Optional[Dict[str, Any]]:
931 """Validate state and return full state data including code_verifier.
933 Args:
934 gateway_id: ID of the gateway
935 state: State parameter to validate
937 Returns:
938 Dict with state data including code_verifier, or None if invalid/expired
939 """
940 settings = get_settings()
942 # Try Redis first
943 if settings.cache_type == "redis":
944 redis = await _get_redis_client()
945 if redis:
946 try:
947 state_key = f"oauth:state:{gateway_id}:{state}"
948 lookup_key = f"oauth:state_lookup:{state}"
949 state_json = await redis.getdel(state_key) # Atomic get+delete
950 await redis.delete(lookup_key)
951 if not state_json:
952 return None
954 state_data = orjson.loads(state_json)
956 # Check expiration
957 try:
958 expires_at = datetime.fromisoformat(state_data["expires_at"])
959 except Exception:
960 expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S")
962 if expires_at.tzinfo is None:
963 expires_at = expires_at.replace(tzinfo=timezone.utc)
965 if expires_at < datetime.now(timezone.utc):
966 return None
968 return state_data
969 except Exception as e:
970 logger.warning(f"Failed to validate state in Redis: {e}, falling back")
972 # Try database
973 if settings.cache_type == "database":
974 try:
975 # First-Party
976 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel
978 db_gen = get_db()
979 db = next(db_gen)
980 try:
981 oauth_state = db.query(OAuthState).filter(OAuthState.gateway_id == gateway_id, OAuthState.state == state).first()
983 if not oauth_state:
984 return None
986 # Check expiration
987 expires_at = oauth_state.expires_at
988 if expires_at.tzinfo is None:
989 expires_at = expires_at.replace(tzinfo=timezone.utc)
991 if expires_at < datetime.now(timezone.utc):
992 db.delete(oauth_state)
993 db.commit()
994 return None
996 # Check if already used
997 if oauth_state.used:
998 return None
1000 # Build state data
1001 state_data = {
1002 "state": oauth_state.state,
1003 "gateway_id": oauth_state.gateway_id,
1004 "code_verifier": oauth_state.code_verifier,
1005 "expires_at": oauth_state.expires_at.isoformat(),
1006 }
1007 if hasattr(oauth_state, "app_user_email"):
1008 state_data["app_user_email"] = getattr(oauth_state, "app_user_email", None)
1010 # Mark as used and delete
1011 db.delete(oauth_state)
1012 db.commit()
1014 return state_data
1015 finally:
1016 db_gen.close()
1017 except Exception as e:
1018 logger.warning(f"Failed to validate state in database: {e}")
1020 # Fallback to in-memory
1021 state_key = f"oauth:state:{gateway_id}:{state}"
1022 async with _state_lock:
1023 state_data = _oauth_states.get(state_key)
1024 if not state_data:
1025 return None
1027 # Check expiration
1028 expires_at = datetime.fromisoformat(state_data["expires_at"])
1029 if expires_at.tzinfo is None:
1030 expires_at = expires_at.replace(tzinfo=timezone.utc)
1032 if expires_at < datetime.now(timezone.utc):
1033 del _oauth_states[state_key]
1034 _oauth_state_lookup.pop(state, None)
1035 return None
1037 # Remove from memory (single-use)
1038 del _oauth_states[state_key]
1039 _oauth_state_lookup.pop(state, None)
1040 return state_data
1042 def _create_authorization_url(self, credentials: Dict[str, Any], state: str) -> tuple[str, str]:
1043 """Create authorization URL with state parameter.
1045 Args:
1046 credentials: OAuth configuration
1047 state: State parameter for CSRF protection
1049 Returns:
1050 Tuple of (authorization_url, state)
1051 """
1052 client_id = credentials["client_id"]
1053 redirect_uri = credentials["redirect_uri"]
1054 authorization_url = credentials["authorization_url"]
1055 scopes = credentials.get("scopes", [])
1057 # Create OAuth2 session
1058 oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes)
1060 # Generate authorization URL with state for CSRF protection
1061 auth_url, state = oauth.authorization_url(authorization_url, state=state)
1063 return auth_url, state
1065 @staticmethod
1066 def _is_microsoft_entra_v2_endpoint(endpoint_url: Any) -> bool:
1067 """Return True when endpoint matches Microsoft Entra v2 login endpoints.
1069 Args:
1070 endpoint_url: OAuth endpoint URL to check
1072 Returns:
1073 True if the endpoint is a Microsoft Entra v2 OAuth endpoint
1074 """
1075 if not isinstance(endpoint_url, str) or not endpoint_url:
1076 return False
1078 parsed = urlparse(endpoint_url)
1079 host = (parsed.hostname or "").lower()
1080 path = parsed.path.lower()
1082 return host in OAuthManager._ENTRA_HOSTS and "/oauth2/v2.0/" in path
1084 @staticmethod
1085 def _is_enabled_flag(value: Any) -> bool:
1086 """Parse boolean-like config values from oauth_config.
1088 Args:
1089 value: Config value to interpret as boolean
1091 Returns:
1092 True if value represents an enabled/truthy setting
1093 """
1094 if isinstance(value, bool):
1095 return value
1096 if isinstance(value, str):
1097 return value.strip().lower() in {"1", "true", "yes", "on"}
1098 return False
1100 def _should_include_resource_parameter(self, credentials: Dict[str, Any], scopes: Any) -> bool:
1101 """Determine whether RFC 8707 resource should be sent for this request.
1103 Args:
1104 credentials: OAuth configuration containing resource and endpoint URLs
1105 scopes: OAuth scopes for the request
1107 Returns:
1108 True if the resource parameter should be included in the request
1109 """
1110 if not credentials.get("resource"):
1111 return False
1113 if self._is_enabled_flag(credentials.get("omit_resource")):
1114 return False
1116 # Microsoft Entra v2 does not accept legacy resource with v2 scope-based requests.
1117 if scopes and (self._is_microsoft_entra_v2_endpoint(credentials.get("authorization_url")) or self._is_microsoft_entra_v2_endpoint(credentials.get("token_url"))):
1118 logger.info("Omitting OAuth resource parameter for Microsoft Entra v2 scope-based flow")
1119 return False
1121 return True
1123 def _create_authorization_url_with_pkce(self, credentials: Dict[str, Any], state: str, code_challenge: str, code_challenge_method: str) -> str:
1124 """Create authorization URL with PKCE parameters (RFC 7636).
1126 Args:
1127 credentials: OAuth configuration
1128 state: State parameter for CSRF protection
1129 code_challenge: PKCE code challenge
1130 code_challenge_method: PKCE method (S256)
1132 Returns:
1133 Authorization URL string with PKCE parameters
1134 """
1135 # Standard
1136 from urllib.parse import urlencode # pylint: disable=import-outside-toplevel
1138 client_id = credentials["client_id"]
1139 redirect_uri = credentials["redirect_uri"]
1140 authorization_url = credentials["authorization_url"]
1141 scopes = credentials.get("scopes", [])
1143 # Build authorization parameters
1144 params = {"response_type": "code", "client_id": client_id, "redirect_uri": redirect_uri, "state": state, "code_challenge": code_challenge, "code_challenge_method": code_challenge_method}
1146 # Add scopes if present
1147 if scopes:
1148 params["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes
1150 # Add resource parameter for JWT access token (RFC 8707)
1151 # The resource is the MCP server URL, set by oauth_router.py
1152 resource = credentials.get("resource")
1153 if self._should_include_resource_parameter(credentials, scopes):
1154 params["resource"] = resource # urlencode with doseq=True handles lists
1156 # Build full URL (doseq=True handles list values like multiple resource params)
1157 query_string = urlencode(params, doseq=True)
1158 return f"{authorization_url}?{query_string}"
1160 async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str, code_verifier: str = None) -> Dict[str, Any]:
1161 """Exchange authorization code for tokens with PKCE support.
1163 Args:
1164 credentials: OAuth configuration
1165 code: Authorization code from callback
1166 code_verifier: Optional PKCE code verifier (RFC 7636)
1168 Returns:
1169 Token response dictionary
1171 Raises:
1172 OAuthError: If token exchange fails
1173 """
1174 runtime_credentials = await self._prepare_runtime_credentials(credentials, "authorization_code_exchange_with_pkce")
1175 client_id = runtime_credentials["client_id"]
1176 client_secret = runtime_credentials.get("client_secret") # Optional for public clients (PKCE-only)
1177 token_url = runtime_credentials["token_url"]
1178 redirect_uri = runtime_credentials["redirect_uri"]
1180 # Prepare token exchange data
1181 token_data = {
1182 "grant_type": "authorization_code",
1183 "code": code,
1184 "redirect_uri": redirect_uri,
1185 "client_id": client_id,
1186 }
1188 # Only include client_secret if present (public clients don't have secrets)
1189 if client_secret:
1190 token_data["client_secret"] = client_secret
1192 # Add PKCE code_verifier if present (RFC 7636)
1193 if code_verifier:
1194 token_data["code_verifier"] = code_verifier
1196 # Add resource parameter to request JWT access token (RFC 8707)
1197 # The resource identifies the MCP server (resource server), not the OAuth server
1198 resource = runtime_credentials.get("resource")
1199 scopes = runtime_credentials.get("scopes", [])
1200 if self._should_include_resource_parameter(credentials, scopes):
1201 if isinstance(resource, list):
1202 # RFC 8707 allows multiple resource parameters - use list of tuples
1203 form_data: list[tuple[str, str]] = list(token_data.items())
1204 for r in resource:
1205 if r:
1206 form_data.append(("resource", r))
1207 token_data = form_data # type: ignore[assignment]
1208 else:
1209 token_data["resource"] = resource
1211 # Exchange code for token with retries
1212 for attempt in range(self.max_retries):
1213 try:
1214 client = await self._get_client()
1215 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
1216 response.raise_for_status()
1218 # GitHub returns form-encoded responses, not JSON
1219 content_type = response.headers.get("content-type", "")
1220 if "application/x-www-form-urlencoded" in content_type:
1221 # Parse form-encoded response
1222 text_response = response.text
1223 token_response = {}
1224 for pair in text_response.split("&"):
1225 if "=" in pair:
1226 key, value = pair.split("=", 1)
1227 token_response[key] = value
1228 else:
1229 # Try JSON response
1230 try:
1231 token_response = response.json()
1232 except Exception as e:
1233 logger.warning(f"Failed to parse JSON response: {e}")
1234 # Fallback to text parsing
1235 text_response = response.text
1236 token_response = {"raw_response": text_response}
1238 if "access_token" not in token_response:
1239 raise OAuthError(f"No access_token in response: {token_response}")
1241 logger.info("""Successfully exchanged authorization code for tokens""")
1242 return token_response
1244 except httpx.HTTPError as e:
1245 logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}")
1246 if attempt == self.max_retries - 1:
1247 raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}")
1248 await asyncio.sleep(2**attempt) # Exponential backoff
1250 # This should never be reached due to the exception above, but needed for type safety
1251 raise OAuthError("Failed to exchange code for token after all retry attempts")
1253 async def refresh_token(self, refresh_token: str, credentials: Dict[str, Any]) -> Dict[str, Any]:
1254 """Refresh an expired access token using a refresh token.
1256 Args:
1257 refresh_token: The refresh token to use
1258 credentials: OAuth configuration including client_id, client_secret, token_url
1260 Returns:
1261 Dict containing new access_token, optional refresh_token, and expires_in
1263 Raises:
1264 OAuthError: If token refresh fails
1265 """
1266 if not refresh_token:
1267 raise OAuthError("No refresh token available")
1269 runtime_credentials = await self._prepare_runtime_credentials(credentials, "refresh_token")
1270 token_url = runtime_credentials.get("token_url")
1271 if not token_url:
1272 raise OAuthError("No token URL configured for OAuth provider")
1274 client_id = runtime_credentials.get("client_id")
1275 client_secret = runtime_credentials.get("client_secret")
1277 if not client_id:
1278 raise OAuthError("No client_id configured for OAuth provider")
1280 # Prepare token refresh request
1281 token_data = {
1282 "grant_type": "refresh_token",
1283 "refresh_token": refresh_token,
1284 "client_id": client_id,
1285 }
1287 # Add client_secret if available (some providers require it)
1288 if client_secret:
1289 token_data["client_secret"] = client_secret
1291 # Add resource parameter for JWT access token (RFC 8707)
1292 # Must be included in refresh requests to maintain JWT token type
1293 resource = runtime_credentials.get("resource")
1294 scopes = runtime_credentials.get("scopes", [])
1295 if self._should_include_resource_parameter(credentials, scopes):
1296 if isinstance(resource, list):
1297 # RFC 8707 allows multiple resource parameters - use list of tuples
1298 form_data: list[tuple[str, str]] = list(token_data.items())
1299 for r in resource:
1300 if r:
1301 form_data.append(("resource", r))
1302 token_data = form_data # type: ignore[assignment]
1303 else:
1304 token_data["resource"] = resource
1306 # Attempt token refresh with retries
1307 for attempt in range(self.max_retries):
1308 try:
1309 client = await self._get_client()
1310 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
1311 if response.status_code == 200:
1312 token_response = response.json()
1314 # Validate required fields
1315 if "access_token" not in token_response:
1316 raise OAuthError("No access_token in refresh response")
1318 logger.info("Successfully refreshed OAuth token")
1319 return token_response
1321 error_text = response.text
1322 # If we get a 400/401, the refresh token is likely invalid
1323 if response.status_code in [400, 401]:
1324 raise OAuthError(f"Refresh token invalid or expired: {error_text}")
1325 logger.warning(f"Token refresh failed with status {response.status_code}: {error_text}")
1327 except httpx.HTTPError as e:
1328 logger.warning(f"Token refresh attempt {attempt + 1} failed: {str(e)}")
1329 if attempt == self.max_retries - 1:
1330 raise OAuthError(f"Failed to refresh token after {self.max_retries} attempts: {str(e)}")
1331 await asyncio.sleep(2**attempt) # Exponential backoff
1333 raise OAuthError("Failed to refresh token after all retry attempts")
1335 def _extract_user_id(self, token_response: Dict[str, Any], credentials: Dict[str, Any]) -> str:
1336 """Extract user ID from token response.
1338 Args:
1339 token_response: Response from token exchange
1340 credentials: OAuth configuration
1342 Returns:
1343 User ID string
1344 """
1345 # Try to extract user ID from various common fields in token response
1346 # Different OAuth providers use different field names
1348 # Check for 'sub' (subject) - JWT standard
1349 if "sub" in token_response:
1350 return token_response["sub"]
1352 # Check for 'user_id' - common in some OAuth responses
1353 if "user_id" in token_response:
1354 return token_response["user_id"]
1356 # Check for 'id' - also common
1357 if "id" in token_response:
1358 return token_response["id"]
1360 # Fallback to client_id if no user info is available
1361 if credentials.get("client_id"):
1362 return credentials["client_id"]
1364 # Final fallback
1365 return "unknown_user"
1368class OAuthError(Exception):
1369 """OAuth-related errors.
1371 Examples:
1372 >>> try:
1373 ... raise OAuthError("Token acquisition failed")
1374 ... except OAuthError as e:
1375 ... str(e)
1376 'Token acquisition failed'
1377 >>> try:
1378 ... raise OAuthError("Invalid grant type")
1379 ... except Exception as e:
1380 ... isinstance(e, OAuthError)
1381 True
1382 """
1385class OAuthRequiredError(OAuthError):
1386 """Raised when a server requires OAuth but the caller is unauthenticated.
1388 Carries ``server_id`` so the middleware can identify which server
1389 triggered the rejection when constructing the ``WWW-Authenticate``
1390 header.
1392 Examples:
1393 >>> err = OAuthRequiredError("auth required", server_id="s1")
1394 >>> err.server_id
1395 's1'
1396 >>> isinstance(err, OAuthError)
1397 True
1398 """
1400 def __init__(self, message: str, *, server_id: str = "") -> None:
1401 """Initialize with message and optional server_id.
1403 Args:
1404 message: Human-readable error description.
1405 server_id: Virtual-server identifier that triggered the rejection.
1406 """
1407 super().__init__(message)
1408 self.server_id = server_id
1411class OAuthEnforcementUnavailableError(OAuthError):
1412 """Raised when OAuth enforcement cannot be performed due to infrastructure failure.
1414 Used when the database or other backing services needed to check a
1415 server's ``oauth_enabled`` flag are unavailable. The middleware
1416 translates this into an HTTP 503 to avoid silently allowing
1417 unauthenticated access (fail-closed).
1419 Examples:
1420 >>> err = OAuthEnforcementUnavailableError("DB down", server_id="s1")
1421 >>> err.server_id
1422 's1'
1423 >>> isinstance(err, OAuthError)
1424 True
1425 """
1427 def __init__(self, message: str, *, server_id: str = "") -> None:
1428 """Initialize with message and optional server_id.
1430 Args:
1431 message: Human-readable error description.
1432 server_id: Virtual-server identifier that triggered the rejection.
1433 """
1434 super().__init__(message)
1435 self.server_id = server_id