Coverage for mcpgateway / services / oauth_manager.py: 100%
583 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/oauth_manager.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7OAuth 2.0 Manager for MCP Gateway.
9This module handles OAuth 2.0 authentication flows including:
10- Client Credentials (Machine-to-Machine)
11- Authorization Code (User Delegation)
12"""
14# Standard
15import asyncio
16import base64
17from datetime import datetime, timedelta, timezone
18import hashlib
19import hmac
20import logging
21import secrets
22from typing import Any, Dict, Optional
24# Third-Party
25import httpx
26import orjson
27from requests_oauthlib import OAuth2Session
29# First-Party
30from mcpgateway.config import get_settings
31from mcpgateway.services.encryption_service import get_encryption_service
32from mcpgateway.services.http_client_service import get_http_client
33from mcpgateway.utils.redis_client import get_redis_client as _get_shared_redis_client
35logger = logging.getLogger(__name__)
37# In-memory storage for OAuth states with expiration (fallback for single-process)
38# Format: {state_key: {"state": state, "gateway_id": gateway_id, "expires_at": datetime}}
39_oauth_states: Dict[str, Dict[str, Any]] = {}
40# Lock for thread-safe state operations
41_state_lock = asyncio.Lock()
43# State TTL in seconds (5 minutes)
44STATE_TTL_SECONDS = 300
46# Redis client for distributed state storage (uses shared factory)
47_redis_client: Optional[Any] = None
48_REDIS_INITIALIZED = False
51async def _get_redis_client():
52 """Get shared Redis client for distributed state storage.
54 Uses the centralized Redis client factory for consistent configuration.
56 Returns:
57 Redis client instance or None if unavailable
58 """
59 global _redis_client, _REDIS_INITIALIZED # pylint: disable=global-statement
61 if _REDIS_INITIALIZED:
62 return _redis_client
64 settings = get_settings()
65 if settings.cache_type == "redis" and settings.redis_url:
66 try:
67 _redis_client = await _get_shared_redis_client()
68 if _redis_client:
69 logger.info("OAuth manager connected to shared Redis client")
70 except Exception as e:
71 logger.warning(f"Failed to get Redis client, falling back to in-memory storage: {e}")
72 _redis_client = None
73 else:
74 _redis_client = None
76 _REDIS_INITIALIZED = True
77 return _redis_client
80class OAuthManager:
81 """Manages OAuth 2.0 authentication flows.
83 Examples:
84 >>> manager = OAuthManager(request_timeout=30, max_retries=3)
85 >>> manager.request_timeout
86 30
87 >>> manager.max_retries
88 3
89 >>> manager.token_storage is None
90 True
91 >>>
92 >>> # Test grant type validation
93 >>> grant_type = "client_credentials"
94 >>> grant_type in ["client_credentials", "authorization_code"]
95 True
96 >>> grant_type = "invalid_grant"
97 >>> grant_type in ["client_credentials", "authorization_code"]
98 False
99 >>>
100 >>> # Test encrypted secret detection heuristic
101 >>> short_secret = "secret123"
102 >>> len(short_secret) > 50
103 False
104 >>> encrypted_secret = "gAAAAABh" + "x" * 60 # Simulated encrypted secret
105 >>> len(encrypted_secret) > 50
106 True
107 >>>
108 >>> # Test scope list handling
109 >>> scopes = ["read", "write"]
110 >>> " ".join(scopes)
111 'read write'
112 >>> empty_scopes = []
113 >>> " ".join(empty_scopes)
114 ''
115 """
117 def __init__(self, request_timeout: int = 30, max_retries: int = 3, token_storage: Optional[Any] = None):
118 """Initialize OAuth Manager.
120 Args:
121 request_timeout: Timeout for OAuth requests in seconds
122 max_retries: Maximum number of retry attempts for token requests
123 token_storage: Optional TokenStorageService for storing tokens
124 """
125 self.request_timeout = request_timeout
126 self.max_retries = max_retries
127 self.token_storage = token_storage
128 self.settings = get_settings()
130 async def _get_client(self) -> httpx.AsyncClient:
131 """Get the shared singleton HTTP client.
133 Returns:
134 Shared httpx.AsyncClient instance with connection pooling
135 """
136 return await get_http_client()
138 def _generate_pkce_params(self) -> Dict[str, str]:
139 """Generate PKCE parameters for OAuth Authorization Code flow (RFC 7636).
141 Returns:
142 Dict containing code_verifier, code_challenge, and code_challenge_method
143 """
144 # Generate code_verifier: 43-128 character random string
145 code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
147 # Generate code_challenge: base64url(SHA256(code_verifier))
148 code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()).decode("utf-8").rstrip("=")
150 return {"code_verifier": code_verifier, "code_challenge": code_challenge, "code_challenge_method": "S256"}
152 async def get_access_token(self, credentials: Dict[str, Any]) -> str:
153 """Get access token based on grant type.
155 Args:
156 credentials: OAuth configuration containing grant_type and other params
158 Returns:
159 Access token string
161 Raises:
162 ValueError: If grant type is unsupported
163 OAuthError: If token acquisition fails
165 Examples:
166 Client credentials flow:
167 >>> import asyncio
168 >>> class TestMgr(OAuthManager):
169 ... async def _client_credentials_flow(self, credentials):
170 ... return 'tok'
171 >>> mgr = TestMgr()
172 >>> asyncio.run(mgr.get_access_token({'grant_type': 'client_credentials'}))
173 'tok'
175 Authorization code fallback to client credentials:
176 >>> asyncio.run(mgr.get_access_token({'grant_type': 'authorization_code'}))
177 'tok'
179 Unsupported grant type raises ValueError:
180 >>> def _unsupported():
181 ... try:
182 ... asyncio.run(mgr.get_access_token({'grant_type': 'bad'}))
183 ... except ValueError:
184 ... return True
185 >>> _unsupported()
186 True
187 """
188 grant_type = credentials.get("grant_type")
189 logger.debug(f"Getting access token for grant type: {grant_type}")
191 if grant_type == "client_credentials":
192 return await self._client_credentials_flow(credentials)
193 if grant_type == "password":
194 return await self._password_flow(credentials)
195 if grant_type == "authorization_code":
196 # For authorization code flow in gateway initialization, we need to handle this differently
197 # Since this is called during gateway setup, we'll try to use client credentials as fallback
198 # or provide a more helpful error message
199 logger.warning("Authorization code flow requires user interaction. " + "For gateway initialization, consider using 'client_credentials' grant type instead.")
200 # Try to use client credentials flow if possible (some OAuth providers support this)
201 try:
202 return await self._client_credentials_flow(credentials)
203 except Exception as e:
204 raise OAuthError(
205 f"Authorization code flow cannot be used for automatic gateway initialization. "
206 f"Please use 'client_credentials' grant type or complete the OAuth flow manually first. "
207 f"Error: {str(e)}"
208 )
209 else:
210 raise ValueError(f"Unsupported grant type: {grant_type}")
212 async def _client_credentials_flow(self, credentials: Dict[str, Any]) -> str:
213 """Machine-to-machine authentication using client credentials.
215 Args:
216 credentials: OAuth configuration with client_id, client_secret, token_url
218 Returns:
219 Access token string
221 Raises:
222 OAuthError: If token acquisition fails after all retries
223 """
224 client_id = credentials["client_id"]
225 client_secret = credentials["client_secret"]
226 token_url = credentials["token_url"]
227 scopes = credentials.get("scopes", [])
229 # Decrypt client secret if it's encrypted
230 if len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer
231 try:
232 settings = get_settings()
233 encryption = get_encryption_service(settings.auth_encryption_secret)
234 decrypted_secret = await encryption.decrypt_secret_async(client_secret)
235 if decrypted_secret:
236 client_secret = decrypted_secret
237 logger.debug("Successfully decrypted client secret")
238 else:
239 logger.warning("Failed to decrypt client secret, using encrypted version")
240 except Exception as e:
241 logger.warning(f"Failed to decrypt client secret: {e}, using encrypted version")
243 # Prepare token request data
244 token_data = {
245 "grant_type": "client_credentials",
246 "client_id": client_id,
247 "client_secret": client_secret,
248 }
250 if scopes:
251 token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes
253 # Fetch token with retries
254 for attempt in range(self.max_retries):
255 try:
256 client = await self._get_client()
257 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
258 response.raise_for_status()
260 # GitHub returns form-encoded responses, not JSON
261 content_type = response.headers.get("content-type", "")
262 if "application/x-www-form-urlencoded" in content_type:
263 # Parse form-encoded response
264 text_response = response.text
265 token_response = {}
266 for pair in text_response.split("&"):
267 if "=" in pair:
268 key, value = pair.split("=", 1)
269 token_response[key] = value
270 else:
271 # Try JSON response
272 try:
273 token_response = response.json()
274 except Exception as e:
275 logger.warning(f"Failed to parse JSON response: {e}")
276 # Fallback to text parsing
277 text_response = response.text
278 token_response = {"raw_response": text_response}
280 if "access_token" not in token_response:
281 raise OAuthError(f"No access_token in response: {token_response}")
283 logger.info("""Successfully obtained access token via client credentials""")
284 return token_response["access_token"]
286 except httpx.HTTPError as e:
287 logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}")
288 if attempt == self.max_retries - 1:
289 raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}")
290 await asyncio.sleep(2**attempt) # Exponential backoff
292 # This should never be reached due to the exception above, but needed for type safety
293 raise OAuthError("Failed to obtain access token after all retry attempts")
295 async def _password_flow(self, credentials: Dict[str, Any]) -> str:
296 """Resource Owner Password Credentials flow (RFC 6749 Section 4.3).
298 This flow is used when the application can directly handle the user's credentials,
299 such as with trusted first-party applications or legacy integrations like Keycloak.
301 Args:
302 credentials: OAuth configuration with client_id, optional client_secret, token_url, username, password
304 Returns:
305 Access token string
307 Raises:
308 OAuthError: If token acquisition fails after all retries
309 """
310 client_id = credentials.get("client_id")
311 client_secret = credentials.get("client_secret")
312 token_url = credentials["token_url"]
313 username = credentials.get("username")
314 password = credentials.get("password")
315 scopes = credentials.get("scopes", [])
317 if not username or not password:
318 raise OAuthError("Username and password are required for password grant type")
320 # Decrypt client secret if it's encrypted and present
321 if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer
322 try:
323 settings = get_settings()
324 encryption = get_encryption_service(settings.auth_encryption_secret)
325 decrypted_secret = await encryption.decrypt_secret_async(client_secret)
326 if decrypted_secret:
327 client_secret = decrypted_secret
328 logger.debug("Successfully decrypted client secret")
329 else:
330 logger.warning("Failed to decrypt client secret, using encrypted version")
331 except Exception as e:
332 logger.warning(f"Failed to decrypt client secret: {e}, using encrypted version")
334 # Prepare token request data
335 token_data = {
336 "grant_type": "password",
337 "username": username,
338 "password": password,
339 }
341 # Add client_id (required by most providers including Keycloak)
342 if client_id:
343 token_data["client_id"] = client_id
345 # Add client_secret if present (some providers require it, others don't)
346 if client_secret:
347 token_data["client_secret"] = client_secret
349 if scopes:
350 token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes
352 # Fetch token with retries
353 for attempt in range(self.max_retries):
354 try:
355 client = await self._get_client()
356 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
357 response.raise_for_status()
359 # Handle both JSON and form-encoded responses
360 content_type = response.headers.get("content-type", "")
361 if "application/x-www-form-urlencoded" in content_type:
362 # Parse form-encoded response
363 text_response = response.text
364 token_response = {}
365 for pair in text_response.split("&"):
366 if "=" in pair:
367 key, value = pair.split("=", 1)
368 token_response[key] = value
369 else:
370 # Try JSON response
371 try:
372 token_response = response.json()
373 except Exception as e:
374 logger.warning(f"Failed to parse JSON response: {e}")
375 # Fallback to text parsing
376 text_response = response.text
377 token_response = {"raw_response": text_response}
379 if "access_token" not in token_response:
380 raise OAuthError(f"No access_token in response: {token_response}")
382 logger.info("Successfully obtained access token via password grant")
383 return token_response["access_token"]
385 except httpx.HTTPError as e:
386 logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}")
387 if attempt == self.max_retries - 1:
388 raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}")
389 await asyncio.sleep(2**attempt) # Exponential backoff
391 # This should never be reached due to the exception above, but needed for type safety
392 raise OAuthError("Failed to obtain access token after all retry attempts")
394 async def get_authorization_url(self, credentials: Dict[str, Any]) -> Dict[str, str]:
395 """Get authorization URL for user delegation flow.
397 Args:
398 credentials: OAuth configuration with client_id, authorization_url, etc.
400 Returns:
401 Dict containing authorization_url and state
402 """
403 client_id = credentials["client_id"]
404 redirect_uri = credentials["redirect_uri"]
405 authorization_url = credentials["authorization_url"]
406 scopes = credentials.get("scopes", [])
408 # Create OAuth2 session
409 oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes)
411 # Generate authorization URL with state for CSRF protection
412 auth_url, state = oauth.authorization_url(authorization_url)
414 logger.info(f"Generated authorization URL for client {client_id}")
416 return {"authorization_url": auth_url, "state": state}
418 async def exchange_code_for_token(self, credentials: Dict[str, Any], code: str, state: str) -> str: # pylint: disable=unused-argument
419 """Exchange authorization code for access token.
421 Args:
422 credentials: OAuth configuration
423 code: Authorization code from callback
424 state: State parameter for CSRF validation
426 Returns:
427 Access token string
429 Raises:
430 OAuthError: If token exchange fails
431 """
432 client_id = credentials["client_id"]
433 client_secret = credentials.get("client_secret") # Optional for public clients (PKCE-only)
434 token_url = credentials["token_url"]
435 redirect_uri = credentials["redirect_uri"]
437 # Decrypt client secret if it's encrypted and present
438 if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer
439 try:
440 settings = get_settings()
441 encryption = get_encryption_service(settings.auth_encryption_secret)
442 decrypted_secret = await encryption.decrypt_secret_async(client_secret)
443 if decrypted_secret:
444 client_secret = decrypted_secret
445 logger.debug("Successfully decrypted client secret")
446 else:
447 logger.warning("Failed to decrypt client secret, using encrypted version")
448 except Exception as e:
449 logger.warning(f"Failed to decrypt client secret: {e}, using encrypted version")
451 # Prepare token exchange data
452 token_data = {
453 "grant_type": "authorization_code",
454 "code": code,
455 "redirect_uri": redirect_uri,
456 "client_id": client_id,
457 }
459 # Only include client_secret if present (public clients don't have secrets)
460 if client_secret:
461 token_data["client_secret"] = client_secret
463 # Exchange code for token with retries
464 for attempt in range(self.max_retries):
465 try:
466 client = await self._get_client()
467 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
468 response.raise_for_status()
470 # GitHub returns form-encoded responses, not JSON
471 content_type = response.headers.get("content-type", "")
472 if "application/x-www-form-urlencoded" in content_type:
473 # Parse form-encoded response
474 text_response = response.text
475 token_response = {}
476 for pair in text_response.split("&"):
477 if "=" in pair:
478 key, value = pair.split("=", 1)
479 token_response[key] = value
480 else:
481 # Try JSON response
482 try:
483 token_response = response.json()
484 except Exception as e:
485 logger.warning(f"Failed to parse JSON response: {e}")
486 # Fallback to text parsing
487 text_response = response.text
488 token_response = {"raw_response": text_response}
490 if "access_token" not in token_response:
491 raise OAuthError(f"No access_token in response: {token_response}")
493 logger.info("""Successfully exchanged authorization code for access token""")
494 return token_response["access_token"]
496 except httpx.HTTPError as e:
497 logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}")
498 if attempt == self.max_retries - 1:
499 raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}")
500 await asyncio.sleep(2**attempt) # Exponential backoff
502 # This should never be reached due to the exception above, but needed for type safety
503 raise OAuthError("Failed to exchange code for token after all retry attempts")
505 async def initiate_authorization_code_flow(self, gateway_id: str, credentials: Dict[str, Any], app_user_email: str = None) -> Dict[str, str]:
506 """Initiate Authorization Code flow with PKCE and return authorization URL.
508 Args:
509 gateway_id: ID of the gateway being configured
510 credentials: OAuth configuration with client_id, authorization_url, etc.
511 app_user_email: MCP Gateway user email to associate with tokens
513 Returns:
514 Dict containing authorization_url and state
515 """
517 # Generate PKCE parameters (RFC 7636)
518 pkce_params = self._generate_pkce_params()
520 # Generate state parameter with user context for CSRF protection
521 state = self._generate_state(gateway_id, app_user_email)
523 # Store state with code_verifier in session/cache for validation
524 if self.token_storage:
525 await self._store_authorization_state(gateway_id, state, code_verifier=pkce_params["code_verifier"])
527 # Generate authorization URL with PKCE
528 auth_url = self._create_authorization_url_with_pkce(credentials, state, pkce_params["code_challenge"], pkce_params["code_challenge_method"])
530 logger.info(f"Generated authorization URL with PKCE for gateway {gateway_id}")
532 return {"authorization_url": auth_url, "state": state, "gateway_id": gateway_id}
534 async def complete_authorization_code_flow(self, gateway_id: str, code: str, state: str, credentials: Dict[str, Any]) -> Dict[str, Any]:
535 """Complete Authorization Code flow with PKCE and store tokens.
537 Args:
538 gateway_id: ID of the gateway
539 code: Authorization code from callback
540 state: State parameter for CSRF validation
541 credentials: OAuth configuration
543 Returns:
544 Dict containing success status, user_id, and expiration info
546 Raises:
547 OAuthError: If state validation fails or token exchange fails
548 """
549 # Validate state and retrieve code_verifier
550 state_data = await self._validate_and_retrieve_state(gateway_id, state)
551 if not state_data:
552 raise OAuthError("Invalid or expired state parameter - possible replay attack")
554 code_verifier = state_data.get("code_verifier")
556 # Decode state to extract user context and verify HMAC
557 try:
558 # Decode base64
559 state_with_sig = base64.urlsafe_b64decode(state.encode())
561 # Split state and signature (HMAC-SHA256 is 32 bytes)
562 state_bytes = state_with_sig[:-32]
563 received_signature = state_with_sig[-32:]
565 # Verify HMAC signature
566 secret_key = self.settings.auth_encryption_secret.get_secret_value().encode() if self.settings.auth_encryption_secret else b"default-secret-key"
567 expected_signature = hmac.new(secret_key, state_bytes, hashlib.sha256).digest()
569 if not hmac.compare_digest(received_signature, expected_signature):
570 raise OAuthError("Invalid state signature - possible CSRF attack")
572 # Parse state data
573 state_json = state_bytes.decode()
574 state_payload = orjson.loads(state_json)
575 app_user_email = state_payload.get("app_user_email")
576 state_gateway_id = state_payload.get("gateway_id")
578 # Validate gateway ID matches
579 if state_gateway_id != gateway_id:
580 raise OAuthError("State parameter gateway mismatch")
581 except Exception as e:
582 # Fallback for legacy state format (gateway_id_random)
583 logger.warning(f"Failed to decode state JSON, trying legacy format: {e}")
584 app_user_email = None
586 # Exchange code for tokens with PKCE code_verifier
587 token_response = await self._exchange_code_for_tokens(credentials, code, code_verifier=code_verifier)
589 # Extract user information from token response
590 user_id = self._extract_user_id(token_response, credentials)
592 # Store tokens if storage service is available
593 if self.token_storage:
594 if not app_user_email:
595 raise OAuthError("User context required for OAuth token storage")
597 token_record = await self.token_storage.store_tokens(
598 gateway_id=gateway_id,
599 user_id=user_id,
600 app_user_email=app_user_email, # User from state
601 access_token=token_response["access_token"],
602 refresh_token=token_response.get("refresh_token"),
603 expires_in=token_response.get("expires_in", self.settings.oauth_default_timeout),
604 scopes=token_response.get("scope", "").split(),
605 )
607 return {"success": True, "user_id": user_id, "expires_at": token_record.expires_at.isoformat() if token_record.expires_at else None}
608 return {"success": True, "user_id": user_id, "expires_at": None}
610 async def get_access_token_for_user(self, gateway_id: str, app_user_email: str) -> Optional[str]:
611 """Get valid access token for a specific user.
613 Args:
614 gateway_id: ID of the gateway
615 app_user_email: MCP Gateway user email
617 Returns:
618 Valid access token or None if not available
619 """
620 if self.token_storage:
621 return await self.token_storage.get_user_token(gateway_id, app_user_email)
622 return None
624 def _generate_state(self, gateway_id: str, app_user_email: str = None) -> str:
625 """Generate a unique state parameter with user context for CSRF protection.
627 Args:
628 gateway_id: ID of the gateway
629 app_user_email: MCP Gateway user email (optional but recommended)
631 Returns:
632 Unique state string with embedded user context and HMAC signature
633 """
634 # Include user email in state for secure user association
635 state_data = {"gateway_id": gateway_id, "app_user_email": app_user_email, "nonce": secrets.token_urlsafe(16), "timestamp": datetime.now(timezone.utc).isoformat()}
637 # Encode state as JSON (orjson produces compact output by default)
638 state_bytes = orjson.dumps(state_data)
640 # Create HMAC signature
641 secret_key = self.settings.auth_encryption_secret.get_secret_value().encode() if self.settings.auth_encryption_secret else b"default-secret-key"
642 signature = hmac.new(secret_key, state_bytes, hashlib.sha256).digest()
644 # Combine state and signature, then base64 encode
645 state_with_sig = state_bytes + signature
646 state_encoded = base64.urlsafe_b64encode(state_with_sig).decode()
648 return state_encoded
650 async def _store_authorization_state(self, gateway_id: str, state: str, code_verifier: str = None) -> None:
651 """Store authorization state for validation with TTL.
653 Args:
654 gateway_id: ID of the gateway
655 state: State parameter to store
656 code_verifier: Optional PKCE code verifier (RFC 7636)
657 """
658 expires_at = datetime.now(timezone.utc) + timedelta(seconds=STATE_TTL_SECONDS)
659 settings = get_settings()
661 # Try Redis first for distributed storage
662 if settings.cache_type == "redis":
663 redis = await _get_redis_client()
664 if redis:
665 try:
666 state_key = f"oauth:state:{gateway_id}:{state}"
667 state_data = {"state": state, "gateway_id": gateway_id, "code_verifier": code_verifier, "expires_at": expires_at.isoformat(), "used": False}
668 # Store in Redis with TTL
669 await redis.setex(state_key, STATE_TTL_SECONDS, orjson.dumps(state_data))
670 logger.debug(f"Stored OAuth state in Redis for gateway {gateway_id}")
671 return
672 except Exception as e:
673 logger.warning(f"Failed to store state in Redis: {e}, falling back")
675 # Try database storage for multi-worker deployments
676 if settings.cache_type == "database":
677 try:
678 # First-Party
679 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel
681 db_gen = get_db()
682 db = next(db_gen)
683 try:
684 # Clean up expired states first
685 db.query(OAuthState).filter(OAuthState.expires_at < datetime.now(timezone.utc)).delete()
687 # Store new state with code_verifier
688 oauth_state = OAuthState(gateway_id=gateway_id, state=state, code_verifier=code_verifier, expires_at=expires_at, used=False)
689 db.add(oauth_state)
690 db.commit()
691 logger.debug(f"Stored OAuth state in database for gateway {gateway_id}")
692 return
693 finally:
694 db_gen.close()
695 except Exception as e:
696 logger.warning(f"Failed to store state in database: {e}, falling back to memory")
698 # Fallback to in-memory storage for development
699 async with _state_lock:
700 # Clean up expired states first
701 now = datetime.now(timezone.utc)
702 state_key = f"oauth:state:{gateway_id}:{state}"
703 state_data = {"state": state, "gateway_id": gateway_id, "code_verifier": code_verifier, "expires_at": expires_at.isoformat(), "used": False}
704 expired_states = [key for key, data in _oauth_states.items() if datetime.fromisoformat(data["expires_at"]) < now]
705 for key in expired_states:
706 del _oauth_states[key]
707 logger.debug(f"Cleaned up expired state: {key[:20]}...")
709 # Store the new state with expiration
710 _oauth_states[state_key] = state_data
711 logger.debug(f"Stored OAuth state in memory for gateway {gateway_id}")
713 async def _validate_authorization_state(self, gateway_id: str, state: str) -> bool:
714 """Validate authorization state parameter and mark as used.
716 Args:
717 gateway_id: ID of the gateway
718 state: State parameter to validate
720 Returns:
721 True if state is valid and not yet used, False otherwise
722 """
723 settings = get_settings()
725 # Try Redis first for distributed storage
726 if settings.cache_type == "redis":
727 redis = await _get_redis_client()
728 if redis:
729 try:
730 state_key = f"oauth:state:{gateway_id}:{state}"
731 # Get and delete state atomically (single-use)
732 state_json = await redis.getdel(state_key)
733 if not state_json:
734 logger.warning(f"State not found in Redis for gateway {gateway_id}")
735 return False
737 state_data = orjson.loads(state_json)
739 # Parse expires_at as timezone-aware datetime. If the stored value
740 # is naive, assume UTC for compatibility.
741 try:
742 expires_at = datetime.fromisoformat(state_data["expires_at"])
743 except Exception:
744 # Fallback: try parsing without microseconds/offsets
745 expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S")
747 if expires_at.tzinfo is None:
748 # Assume UTC for naive timestamps
749 expires_at = expires_at.replace(tzinfo=timezone.utc)
751 # Check if state has expired
752 if expires_at < datetime.now(timezone.utc):
753 logger.warning(f"State has expired for gateway {gateway_id}")
754 return False
756 # Check if state was already used (should not happen with getdel)
757 if state_data.get("used", False):
758 logger.warning(f"State was already used for gateway {gateway_id} - possible replay attack")
759 return False
761 logger.debug(f"Successfully validated OAuth state from Redis for gateway {gateway_id}")
762 return True
763 except Exception as e:
764 logger.warning(f"Failed to validate state in Redis: {e}, falling back")
766 # Try database storage for multi-worker deployments
767 if settings.cache_type == "database":
768 try:
769 # First-Party
770 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel
772 db_gen = get_db()
773 db = next(db_gen)
774 try:
775 # Find the state
776 oauth_state = db.query(OAuthState).filter(OAuthState.gateway_id == gateway_id, OAuthState.state == state).first()
778 if not oauth_state:
779 logger.warning(f"State not found in database for gateway {gateway_id}")
780 return False
782 # Check if state has expired
783 # Ensure oauth_state.expires_at is timezone-aware. If naive, assume UTC.
784 expires_at = oauth_state.expires_at
785 if expires_at.tzinfo is None:
786 expires_at = expires_at.replace(tzinfo=timezone.utc)
788 if expires_at < datetime.now(timezone.utc):
789 logger.warning(f"State has expired for gateway {gateway_id}")
790 db.delete(oauth_state)
791 db.commit()
792 return False
794 # Check if state was already used
795 if oauth_state.used:
796 logger.warning(f"State has already been used for gateway {gateway_id} - possible replay attack")
797 return False
799 # Mark as used and delete (single-use)
800 db.delete(oauth_state)
801 db.commit()
802 logger.debug(f"Successfully validated OAuth state from database for gateway {gateway_id}")
803 return True
804 finally:
805 db_gen.close()
806 except Exception as e:
807 logger.warning(f"Failed to validate state in database: {e}, falling back to memory")
809 # Fallback to in-memory storage for development
810 state_key = f"oauth:state:{gateway_id}:{state}"
811 async with _state_lock:
812 state_data = _oauth_states.get(state_key)
814 # Check if state exists
815 if not state_data:
816 logger.warning(f"State not found in memory for gateway {gateway_id}")
817 return False
819 # Parse and normalize expires_at to timezone-aware datetime
820 expires_at = datetime.fromisoformat(state_data["expires_at"])
821 if expires_at.tzinfo is None:
822 expires_at = expires_at.replace(tzinfo=timezone.utc)
824 if expires_at < datetime.now(timezone.utc):
825 logger.warning(f"State has expired for gateway {gateway_id}")
826 del _oauth_states[state_key] # Clean up expired state
827 return False
829 # Check if state has already been used (prevent replay)
830 if state_data.get("used", False):
831 logger.warning(f"State has already been used for gateway {gateway_id} - possible replay attack")
832 return False
834 # Mark state as used and remove it (single-use)
835 del _oauth_states[state_key]
836 logger.debug(f"Successfully validated OAuth state from memory for gateway {gateway_id}")
837 return True
839 async def _validate_and_retrieve_state(self, gateway_id: str, state: str) -> Optional[Dict[str, Any]]:
840 """Validate state and return full state data including code_verifier.
842 Args:
843 gateway_id: ID of the gateway
844 state: State parameter to validate
846 Returns:
847 Dict with state data including code_verifier, or None if invalid/expired
848 """
849 settings = get_settings()
851 # Try Redis first
852 if settings.cache_type == "redis":
853 redis = await _get_redis_client()
854 if redis:
855 try:
856 state_key = f"oauth:state:{gateway_id}:{state}"
857 state_json = await redis.getdel(state_key) # Atomic get+delete
858 if not state_json:
859 return None
861 state_data = orjson.loads(state_json)
863 # Check expiration
864 try:
865 expires_at = datetime.fromisoformat(state_data["expires_at"])
866 except Exception:
867 expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S")
869 if expires_at.tzinfo is None:
870 expires_at = expires_at.replace(tzinfo=timezone.utc)
872 if expires_at < datetime.now(timezone.utc):
873 return None
875 return state_data
876 except Exception as e:
877 logger.warning(f"Failed to validate state in Redis: {e}, falling back")
879 # Try database
880 if settings.cache_type == "database":
881 try:
882 # First-Party
883 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel
885 db_gen = get_db()
886 db = next(db_gen)
887 try:
888 oauth_state = db.query(OAuthState).filter(OAuthState.gateway_id == gateway_id, OAuthState.state == state).first()
890 if not oauth_state:
891 return None
893 # Check expiration
894 expires_at = oauth_state.expires_at
895 if expires_at.tzinfo is None:
896 expires_at = expires_at.replace(tzinfo=timezone.utc)
898 if expires_at < datetime.now(timezone.utc):
899 db.delete(oauth_state)
900 db.commit()
901 return None
903 # Check if already used
904 if oauth_state.used:
905 return None
907 # Build state data
908 state_data = {"state": oauth_state.state, "gateway_id": oauth_state.gateway_id, "code_verifier": oauth_state.code_verifier, "expires_at": oauth_state.expires_at.isoformat()}
910 # Mark as used and delete
911 db.delete(oauth_state)
912 db.commit()
914 return state_data
915 finally:
916 db_gen.close()
917 except Exception as e:
918 logger.warning(f"Failed to validate state in database: {e}")
920 # Fallback to in-memory
921 state_key = f"oauth:state:{gateway_id}:{state}"
922 async with _state_lock:
923 state_data = _oauth_states.get(state_key)
924 if not state_data:
925 return None
927 # Check expiration
928 expires_at = datetime.fromisoformat(state_data["expires_at"])
929 if expires_at.tzinfo is None:
930 expires_at = expires_at.replace(tzinfo=timezone.utc)
932 if expires_at < datetime.now(timezone.utc):
933 del _oauth_states[state_key]
934 return None
936 # Remove from memory (single-use)
937 del _oauth_states[state_key]
938 return state_data
940 def _create_authorization_url(self, credentials: Dict[str, Any], state: str) -> tuple[str, str]:
941 """Create authorization URL with state parameter.
943 Args:
944 credentials: OAuth configuration
945 state: State parameter for CSRF protection
947 Returns:
948 Tuple of (authorization_url, state)
949 """
950 client_id = credentials["client_id"]
951 redirect_uri = credentials["redirect_uri"]
952 authorization_url = credentials["authorization_url"]
953 scopes = credentials.get("scopes", [])
955 # Create OAuth2 session
956 oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes)
958 # Generate authorization URL with state for CSRF protection
959 auth_url, state = oauth.authorization_url(authorization_url, state=state)
961 return auth_url, state
963 def _create_authorization_url_with_pkce(self, credentials: Dict[str, Any], state: str, code_challenge: str, code_challenge_method: str) -> str:
964 """Create authorization URL with PKCE parameters (RFC 7636).
966 Args:
967 credentials: OAuth configuration
968 state: State parameter for CSRF protection
969 code_challenge: PKCE code challenge
970 code_challenge_method: PKCE method (S256)
972 Returns:
973 Authorization URL string with PKCE parameters
974 """
975 # Standard
976 from urllib.parse import urlencode # pylint: disable=import-outside-toplevel
978 client_id = credentials["client_id"]
979 redirect_uri = credentials["redirect_uri"]
980 authorization_url = credentials["authorization_url"]
981 scopes = credentials.get("scopes", [])
983 # Build authorization parameters
984 params = {"response_type": "code", "client_id": client_id, "redirect_uri": redirect_uri, "state": state, "code_challenge": code_challenge, "code_challenge_method": code_challenge_method}
986 # Add scopes if present
987 if scopes:
988 params["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes
990 # Add resource parameter for JWT access token (RFC 8707)
991 # The resource is the MCP server URL, set by oauth_router.py
992 resource = credentials.get("resource")
993 if resource:
994 # RFC 8707 allows multiple resource parameters
995 if isinstance(resource, list):
996 params["resource"] = resource # urlencode with doseq=True handles lists
997 else:
998 params["resource"] = resource
1000 # Build full URL (doseq=True handles list values like multiple resource params)
1001 query_string = urlencode(params, doseq=True)
1002 return f"{authorization_url}?{query_string}"
1004 async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str, code_verifier: str = None) -> Dict[str, Any]:
1005 """Exchange authorization code for tokens with PKCE support.
1007 Args:
1008 credentials: OAuth configuration
1009 code: Authorization code from callback
1010 code_verifier: Optional PKCE code verifier (RFC 7636)
1012 Returns:
1013 Token response dictionary
1015 Raises:
1016 OAuthError: If token exchange fails
1017 """
1018 client_id = credentials["client_id"]
1019 client_secret = credentials.get("client_secret") # Optional for public clients (PKCE-only)
1020 token_url = credentials["token_url"]
1021 redirect_uri = credentials["redirect_uri"]
1023 # Decrypt client secret if it's encrypted and present
1024 if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer
1025 try:
1026 settings = get_settings()
1027 encryption = get_encryption_service(settings.auth_encryption_secret)
1028 decrypted_secret = await encryption.decrypt_secret_async(client_secret)
1029 if decrypted_secret:
1030 client_secret = decrypted_secret
1031 logger.debug("Successfully decrypted client secret")
1032 else:
1033 logger.warning("Failed to decrypt client secret, using encrypted version")
1034 except Exception as e:
1035 logger.warning(f"Failed to decrypt client secret: {e}, using encrypted version")
1037 # Prepare token exchange data
1038 token_data = {
1039 "grant_type": "authorization_code",
1040 "code": code,
1041 "redirect_uri": redirect_uri,
1042 "client_id": client_id,
1043 }
1045 # Only include client_secret if present (public clients don't have secrets)
1046 if client_secret:
1047 token_data["client_secret"] = client_secret
1049 # Add PKCE code_verifier if present (RFC 7636)
1050 if code_verifier:
1051 token_data["code_verifier"] = code_verifier
1053 # Add resource parameter to request JWT access token (RFC 8707)
1054 # The resource identifies the MCP server (resource server), not the OAuth server
1055 resource = credentials.get("resource")
1056 if resource:
1057 if isinstance(resource, list):
1058 # RFC 8707 allows multiple resource parameters - use list of tuples
1059 form_data: list[tuple[str, str]] = list(token_data.items())
1060 for r in resource:
1061 if r:
1062 form_data.append(("resource", r))
1063 token_data = form_data # type: ignore[assignment]
1064 else:
1065 token_data["resource"] = resource
1067 # Exchange code for token with retries
1068 for attempt in range(self.max_retries):
1069 try:
1070 client = await self._get_client()
1071 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
1072 response.raise_for_status()
1074 # GitHub returns form-encoded responses, not JSON
1075 content_type = response.headers.get("content-type", "")
1076 if "application/x-www-form-urlencoded" in content_type:
1077 # Parse form-encoded response
1078 text_response = response.text
1079 token_response = {}
1080 for pair in text_response.split("&"):
1081 if "=" in pair:
1082 key, value = pair.split("=", 1)
1083 token_response[key] = value
1084 else:
1085 # Try JSON response
1086 try:
1087 token_response = response.json()
1088 except Exception as e:
1089 logger.warning(f"Failed to parse JSON response: {e}")
1090 # Fallback to text parsing
1091 text_response = response.text
1092 token_response = {"raw_response": text_response}
1094 if "access_token" not in token_response:
1095 raise OAuthError(f"No access_token in response: {token_response}")
1097 logger.info("""Successfully exchanged authorization code for tokens""")
1098 return token_response
1100 except httpx.HTTPError as e:
1101 logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}")
1102 if attempt == self.max_retries - 1:
1103 raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}")
1104 await asyncio.sleep(2**attempt) # Exponential backoff
1106 # This should never be reached due to the exception above, but needed for type safety
1107 raise OAuthError("Failed to exchange code for token after all retry attempts")
1109 async def refresh_token(self, refresh_token: str, credentials: Dict[str, Any]) -> Dict[str, Any]:
1110 """Refresh an expired access token using a refresh token.
1112 Args:
1113 refresh_token: The refresh token to use
1114 credentials: OAuth configuration including client_id, client_secret, token_url
1116 Returns:
1117 Dict containing new access_token, optional refresh_token, and expires_in
1119 Raises:
1120 OAuthError: If token refresh fails
1121 """
1122 if not refresh_token:
1123 raise OAuthError("No refresh token available")
1125 token_url = credentials.get("token_url")
1126 if not token_url:
1127 raise OAuthError("No token URL configured for OAuth provider")
1129 client_id = credentials.get("client_id")
1130 client_secret = credentials.get("client_secret")
1132 if not client_id:
1133 raise OAuthError("No client_id configured for OAuth provider")
1135 # Prepare token refresh request
1136 token_data = {
1137 "grant_type": "refresh_token",
1138 "refresh_token": refresh_token,
1139 "client_id": client_id,
1140 }
1142 # Add client_secret if available (some providers require it)
1143 if client_secret:
1144 token_data["client_secret"] = client_secret
1146 # Add resource parameter for JWT access token (RFC 8707)
1147 # Must be included in refresh requests to maintain JWT token type
1148 resource = credentials.get("resource")
1149 if resource:
1150 if isinstance(resource, list):
1151 # RFC 8707 allows multiple resource parameters - use list of tuples
1152 form_data: list[tuple[str, str]] = list(token_data.items())
1153 for r in resource:
1154 if r:
1155 form_data.append(("resource", r))
1156 token_data = form_data # type: ignore[assignment]
1157 else:
1158 token_data["resource"] = resource
1160 # Attempt token refresh with retries
1161 for attempt in range(self.max_retries):
1162 try:
1163 client = await self._get_client()
1164 response = await client.post(token_url, data=token_data, timeout=self.request_timeout)
1165 if response.status_code == 200:
1166 token_response = response.json()
1168 # Validate required fields
1169 if "access_token" not in token_response:
1170 raise OAuthError("No access_token in refresh response")
1172 logger.info("Successfully refreshed OAuth token")
1173 return token_response
1175 error_text = response.text
1176 # If we get a 400/401, the refresh token is likely invalid
1177 if response.status_code in [400, 401]:
1178 raise OAuthError(f"Refresh token invalid or expired: {error_text}")
1179 logger.warning(f"Token refresh failed with status {response.status_code}: {error_text}")
1181 except httpx.HTTPError as e:
1182 logger.warning(f"Token refresh attempt {attempt + 1} failed: {str(e)}")
1183 if attempt == self.max_retries - 1:
1184 raise OAuthError(f"Failed to refresh token after {self.max_retries} attempts: {str(e)}")
1185 await asyncio.sleep(2**attempt) # Exponential backoff
1187 raise OAuthError("Failed to refresh token after all retry attempts")
1189 def _extract_user_id(self, token_response: Dict[str, Any], credentials: Dict[str, Any]) -> str:
1190 """Extract user ID from token response.
1192 Args:
1193 token_response: Response from token exchange
1194 credentials: OAuth configuration
1196 Returns:
1197 User ID string
1198 """
1199 # Try to extract user ID from various common fields in token response
1200 # Different OAuth providers use different field names
1202 # Check for 'sub' (subject) - JWT standard
1203 if "sub" in token_response:
1204 return token_response["sub"]
1206 # Check for 'user_id' - common in some OAuth responses
1207 if "user_id" in token_response:
1208 return token_response["user_id"]
1210 # Check for 'id' - also common
1211 if "id" in token_response:
1212 return token_response["id"]
1214 # Fallback to client_id if no user info is available
1215 if credentials.get("client_id"):
1216 return credentials["client_id"]
1218 # Final fallback
1219 return "unknown_user"
1222class OAuthError(Exception):
1223 """OAuth-related errors.
1225 Examples:
1226 >>> try:
1227 ... raise OAuthError("Token acquisition failed")
1228 ... except OAuthError as e:
1229 ... str(e)
1230 'Token acquisition failed'
1231 >>> try:
1232 ... raise OAuthError("Invalid grant type")
1233 ... except Exception as e:
1234 ... isinstance(e, OAuthError)
1235 True
1236 """