Coverage for mcpgateway / services / oauth_manager.py: 100%

583 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/oauth_manager.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7OAuth 2.0 Manager for MCP Gateway. 

8 

9This module handles OAuth 2.0 authentication flows including: 

10- Client Credentials (Machine-to-Machine) 

11- Authorization Code (User Delegation) 

12""" 

13 

14# Standard 

15import asyncio 

16import base64 

17from datetime import datetime, timedelta, timezone 

18import hashlib 

19import hmac 

20import logging 

21import secrets 

22from typing import Any, Dict, Optional 

23 

24# Third-Party 

25import httpx 

26import orjson 

27from requests_oauthlib import OAuth2Session 

28 

29# First-Party 

30from mcpgateway.config import get_settings 

31from mcpgateway.services.encryption_service import get_encryption_service 

32from mcpgateway.services.http_client_service import get_http_client 

33from mcpgateway.utils.redis_client import get_redis_client as _get_shared_redis_client 

34 

35logger = logging.getLogger(__name__) 

36 

37# In-memory storage for OAuth states with expiration (fallback for single-process) 

38# Format: {state_key: {"state": state, "gateway_id": gateway_id, "expires_at": datetime}} 

39_oauth_states: Dict[str, Dict[str, Any]] = {} 

40# Lock for thread-safe state operations 

41_state_lock = asyncio.Lock() 

42 

43# State TTL in seconds (5 minutes) 

44STATE_TTL_SECONDS = 300 

45 

46# Redis client for distributed state storage (uses shared factory) 

47_redis_client: Optional[Any] = None 

48_REDIS_INITIALIZED = False 

49 

50 

51async def _get_redis_client(): 

52 """Get shared Redis client for distributed state storage. 

53 

54 Uses the centralized Redis client factory for consistent configuration. 

55 

56 Returns: 

57 Redis client instance or None if unavailable 

58 """ 

59 global _redis_client, _REDIS_INITIALIZED # pylint: disable=global-statement 

60 

61 if _REDIS_INITIALIZED: 

62 return _redis_client 

63 

64 settings = get_settings() 

65 if settings.cache_type == "redis" and settings.redis_url: 

66 try: 

67 _redis_client = await _get_shared_redis_client() 

68 if _redis_client: 

69 logger.info("OAuth manager connected to shared Redis client") 

70 except Exception as e: 

71 logger.warning(f"Failed to get Redis client, falling back to in-memory storage: {e}") 

72 _redis_client = None 

73 else: 

74 _redis_client = None 

75 

76 _REDIS_INITIALIZED = True 

77 return _redis_client 

78 

79 

80class OAuthManager: 

81 """Manages OAuth 2.0 authentication flows. 

82 

83 Examples: 

84 >>> manager = OAuthManager(request_timeout=30, max_retries=3) 

85 >>> manager.request_timeout 

86 30 

87 >>> manager.max_retries 

88 3 

89 >>> manager.token_storage is None 

90 True 

91 >>> 

92 >>> # Test grant type validation 

93 >>> grant_type = "client_credentials" 

94 >>> grant_type in ["client_credentials", "authorization_code"] 

95 True 

96 >>> grant_type = "invalid_grant" 

97 >>> grant_type in ["client_credentials", "authorization_code"] 

98 False 

99 >>> 

100 >>> # Test encrypted secret detection heuristic 

101 >>> short_secret = "secret123" 

102 >>> len(short_secret) > 50 

103 False 

104 >>> encrypted_secret = "gAAAAABh" + "x" * 60 # Simulated encrypted secret 

105 >>> len(encrypted_secret) > 50 

106 True 

107 >>> 

108 >>> # Test scope list handling 

109 >>> scopes = ["read", "write"] 

110 >>> " ".join(scopes) 

111 'read write' 

112 >>> empty_scopes = [] 

113 >>> " ".join(empty_scopes) 

114 '' 

115 """ 

116 

117 def __init__(self, request_timeout: int = 30, max_retries: int = 3, token_storage: Optional[Any] = None): 

118 """Initialize OAuth Manager. 

119 

120 Args: 

121 request_timeout: Timeout for OAuth requests in seconds 

122 max_retries: Maximum number of retry attempts for token requests 

123 token_storage: Optional TokenStorageService for storing tokens 

124 """ 

125 self.request_timeout = request_timeout 

126 self.max_retries = max_retries 

127 self.token_storage = token_storage 

128 self.settings = get_settings() 

129 

130 async def _get_client(self) -> httpx.AsyncClient: 

131 """Get the shared singleton HTTP client. 

132 

133 Returns: 

134 Shared httpx.AsyncClient instance with connection pooling 

135 """ 

136 return await get_http_client() 

137 

138 def _generate_pkce_params(self) -> Dict[str, str]: 

139 """Generate PKCE parameters for OAuth Authorization Code flow (RFC 7636). 

140 

141 Returns: 

142 Dict containing code_verifier, code_challenge, and code_challenge_method 

143 """ 

144 # Generate code_verifier: 43-128 character random string 

145 code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=") 

146 

147 # Generate code_challenge: base64url(SHA256(code_verifier)) 

148 code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()).decode("utf-8").rstrip("=") 

149 

150 return {"code_verifier": code_verifier, "code_challenge": code_challenge, "code_challenge_method": "S256"} 

151 

152 async def get_access_token(self, credentials: Dict[str, Any]) -> str: 

153 """Get access token based on grant type. 

154 

155 Args: 

156 credentials: OAuth configuration containing grant_type and other params 

157 

158 Returns: 

159 Access token string 

160 

161 Raises: 

162 ValueError: If grant type is unsupported 

163 OAuthError: If token acquisition fails 

164 

165 Examples: 

166 Client credentials flow: 

167 >>> import asyncio 

168 >>> class TestMgr(OAuthManager): 

169 ... async def _client_credentials_flow(self, credentials): 

170 ... return 'tok' 

171 >>> mgr = TestMgr() 

172 >>> asyncio.run(mgr.get_access_token({'grant_type': 'client_credentials'})) 

173 'tok' 

174 

175 Authorization code fallback to client credentials: 

176 >>> asyncio.run(mgr.get_access_token({'grant_type': 'authorization_code'})) 

177 'tok' 

178 

179 Unsupported grant type raises ValueError: 

180 >>> def _unsupported(): 

181 ... try: 

182 ... asyncio.run(mgr.get_access_token({'grant_type': 'bad'})) 

183 ... except ValueError: 

184 ... return True 

185 >>> _unsupported() 

186 True 

187 """ 

188 grant_type = credentials.get("grant_type") 

189 logger.debug(f"Getting access token for grant type: {grant_type}") 

190 

191 if grant_type == "client_credentials": 

192 return await self._client_credentials_flow(credentials) 

193 if grant_type == "password": 

194 return await self._password_flow(credentials) 

195 if grant_type == "authorization_code": 

196 # For authorization code flow in gateway initialization, we need to handle this differently 

197 # Since this is called during gateway setup, we'll try to use client credentials as fallback 

198 # or provide a more helpful error message 

199 logger.warning("Authorization code flow requires user interaction. " + "For gateway initialization, consider using 'client_credentials' grant type instead.") 

200 # Try to use client credentials flow if possible (some OAuth providers support this) 

201 try: 

202 return await self._client_credentials_flow(credentials) 

203 except Exception as e: 

204 raise OAuthError( 

205 f"Authorization code flow cannot be used for automatic gateway initialization. " 

206 f"Please use 'client_credentials' grant type or complete the OAuth flow manually first. " 

207 f"Error: {str(e)}" 

208 ) 

209 else: 

210 raise ValueError(f"Unsupported grant type: {grant_type}") 

211 

212 async def _client_credentials_flow(self, credentials: Dict[str, Any]) -> str: 

213 """Machine-to-machine authentication using client credentials. 

214 

215 Args: 

216 credentials: OAuth configuration with client_id, client_secret, token_url 

217 

218 Returns: 

219 Access token string 

220 

221 Raises: 

222 OAuthError: If token acquisition fails after all retries 

223 """ 

224 client_id = credentials["client_id"] 

225 client_secret = credentials["client_secret"] 

226 token_url = credentials["token_url"] 

227 scopes = credentials.get("scopes", []) 

228 

229 # Decrypt client secret if it's encrypted 

230 if len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer 

231 try: 

232 settings = get_settings() 

233 encryption = get_encryption_service(settings.auth_encryption_secret) 

234 decrypted_secret = await encryption.decrypt_secret_async(client_secret) 

235 if decrypted_secret: 

236 client_secret = decrypted_secret 

237 logger.debug("Successfully decrypted client secret") 

238 else: 

239 logger.warning("Failed to decrypt client secret, using encrypted version") 

240 except Exception as e: 

241 logger.warning(f"Failed to decrypt client secret: {e}, using encrypted version") 

242 

243 # Prepare token request data 

244 token_data = { 

245 "grant_type": "client_credentials", 

246 "client_id": client_id, 

247 "client_secret": client_secret, 

248 } 

249 

250 if scopes: 

251 token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes 

252 

253 # Fetch token with retries 

254 for attempt in range(self.max_retries): 

255 try: 

256 client = await self._get_client() 

257 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

258 response.raise_for_status() 

259 

260 # GitHub returns form-encoded responses, not JSON 

261 content_type = response.headers.get("content-type", "") 

262 if "application/x-www-form-urlencoded" in content_type: 

263 # Parse form-encoded response 

264 text_response = response.text 

265 token_response = {} 

266 for pair in text_response.split("&"): 

267 if "=" in pair: 

268 key, value = pair.split("=", 1) 

269 token_response[key] = value 

270 else: 

271 # Try JSON response 

272 try: 

273 token_response = response.json() 

274 except Exception as e: 

275 logger.warning(f"Failed to parse JSON response: {e}") 

276 # Fallback to text parsing 

277 text_response = response.text 

278 token_response = {"raw_response": text_response} 

279 

280 if "access_token" not in token_response: 

281 raise OAuthError(f"No access_token in response: {token_response}") 

282 

283 logger.info("""Successfully obtained access token via client credentials""") 

284 return token_response["access_token"] 

285 

286 except httpx.HTTPError as e: 

287 logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}") 

288 if attempt == self.max_retries - 1: 

289 raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}") 

290 await asyncio.sleep(2**attempt) # Exponential backoff 

291 

292 # This should never be reached due to the exception above, but needed for type safety 

293 raise OAuthError("Failed to obtain access token after all retry attempts") 

294 

295 async def _password_flow(self, credentials: Dict[str, Any]) -> str: 

296 """Resource Owner Password Credentials flow (RFC 6749 Section 4.3). 

297 

298 This flow is used when the application can directly handle the user's credentials, 

299 such as with trusted first-party applications or legacy integrations like Keycloak. 

300 

301 Args: 

302 credentials: OAuth configuration with client_id, optional client_secret, token_url, username, password 

303 

304 Returns: 

305 Access token string 

306 

307 Raises: 

308 OAuthError: If token acquisition fails after all retries 

309 """ 

310 client_id = credentials.get("client_id") 

311 client_secret = credentials.get("client_secret") 

312 token_url = credentials["token_url"] 

313 username = credentials.get("username") 

314 password = credentials.get("password") 

315 scopes = credentials.get("scopes", []) 

316 

317 if not username or not password: 

318 raise OAuthError("Username and password are required for password grant type") 

319 

320 # Decrypt client secret if it's encrypted and present 

321 if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer 

322 try: 

323 settings = get_settings() 

324 encryption = get_encryption_service(settings.auth_encryption_secret) 

325 decrypted_secret = await encryption.decrypt_secret_async(client_secret) 

326 if decrypted_secret: 

327 client_secret = decrypted_secret 

328 logger.debug("Successfully decrypted client secret") 

329 else: 

330 logger.warning("Failed to decrypt client secret, using encrypted version") 

331 except Exception as e: 

332 logger.warning(f"Failed to decrypt client secret: {e}, using encrypted version") 

333 

334 # Prepare token request data 

335 token_data = { 

336 "grant_type": "password", 

337 "username": username, 

338 "password": password, 

339 } 

340 

341 # Add client_id (required by most providers including Keycloak) 

342 if client_id: 

343 token_data["client_id"] = client_id 

344 

345 # Add client_secret if present (some providers require it, others don't) 

346 if client_secret: 

347 token_data["client_secret"] = client_secret 

348 

349 if scopes: 

350 token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes 

351 

352 # Fetch token with retries 

353 for attempt in range(self.max_retries): 

354 try: 

355 client = await self._get_client() 

356 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

357 response.raise_for_status() 

358 

359 # Handle both JSON and form-encoded responses 

360 content_type = response.headers.get("content-type", "") 

361 if "application/x-www-form-urlencoded" in content_type: 

362 # Parse form-encoded response 

363 text_response = response.text 

364 token_response = {} 

365 for pair in text_response.split("&"): 

366 if "=" in pair: 

367 key, value = pair.split("=", 1) 

368 token_response[key] = value 

369 else: 

370 # Try JSON response 

371 try: 

372 token_response = response.json() 

373 except Exception as e: 

374 logger.warning(f"Failed to parse JSON response: {e}") 

375 # Fallback to text parsing 

376 text_response = response.text 

377 token_response = {"raw_response": text_response} 

378 

379 if "access_token" not in token_response: 

380 raise OAuthError(f"No access_token in response: {token_response}") 

381 

382 logger.info("Successfully obtained access token via password grant") 

383 return token_response["access_token"] 

384 

385 except httpx.HTTPError as e: 

386 logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}") 

387 if attempt == self.max_retries - 1: 

388 raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}") 

389 await asyncio.sleep(2**attempt) # Exponential backoff 

390 

391 # This should never be reached due to the exception above, but needed for type safety 

392 raise OAuthError("Failed to obtain access token after all retry attempts") 

393 

394 async def get_authorization_url(self, credentials: Dict[str, Any]) -> Dict[str, str]: 

395 """Get authorization URL for user delegation flow. 

396 

397 Args: 

398 credentials: OAuth configuration with client_id, authorization_url, etc. 

399 

400 Returns: 

401 Dict containing authorization_url and state 

402 """ 

403 client_id = credentials["client_id"] 

404 redirect_uri = credentials["redirect_uri"] 

405 authorization_url = credentials["authorization_url"] 

406 scopes = credentials.get("scopes", []) 

407 

408 # Create OAuth2 session 

409 oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes) 

410 

411 # Generate authorization URL with state for CSRF protection 

412 auth_url, state = oauth.authorization_url(authorization_url) 

413 

414 logger.info(f"Generated authorization URL for client {client_id}") 

415 

416 return {"authorization_url": auth_url, "state": state} 

417 

418 async def exchange_code_for_token(self, credentials: Dict[str, Any], code: str, state: str) -> str: # pylint: disable=unused-argument 

419 """Exchange authorization code for access token. 

420 

421 Args: 

422 credentials: OAuth configuration 

423 code: Authorization code from callback 

424 state: State parameter for CSRF validation 

425 

426 Returns: 

427 Access token string 

428 

429 Raises: 

430 OAuthError: If token exchange fails 

431 """ 

432 client_id = credentials["client_id"] 

433 client_secret = credentials.get("client_secret") # Optional for public clients (PKCE-only) 

434 token_url = credentials["token_url"] 

435 redirect_uri = credentials["redirect_uri"] 

436 

437 # Decrypt client secret if it's encrypted and present 

438 if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer 

439 try: 

440 settings = get_settings() 

441 encryption = get_encryption_service(settings.auth_encryption_secret) 

442 decrypted_secret = await encryption.decrypt_secret_async(client_secret) 

443 if decrypted_secret: 

444 client_secret = decrypted_secret 

445 logger.debug("Successfully decrypted client secret") 

446 else: 

447 logger.warning("Failed to decrypt client secret, using encrypted version") 

448 except Exception as e: 

449 logger.warning(f"Failed to decrypt client secret: {e}, using encrypted version") 

450 

451 # Prepare token exchange data 

452 token_data = { 

453 "grant_type": "authorization_code", 

454 "code": code, 

455 "redirect_uri": redirect_uri, 

456 "client_id": client_id, 

457 } 

458 

459 # Only include client_secret if present (public clients don't have secrets) 

460 if client_secret: 

461 token_data["client_secret"] = client_secret 

462 

463 # Exchange code for token with retries 

464 for attempt in range(self.max_retries): 

465 try: 

466 client = await self._get_client() 

467 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

468 response.raise_for_status() 

469 

470 # GitHub returns form-encoded responses, not JSON 

471 content_type = response.headers.get("content-type", "") 

472 if "application/x-www-form-urlencoded" in content_type: 

473 # Parse form-encoded response 

474 text_response = response.text 

475 token_response = {} 

476 for pair in text_response.split("&"): 

477 if "=" in pair: 

478 key, value = pair.split("=", 1) 

479 token_response[key] = value 

480 else: 

481 # Try JSON response 

482 try: 

483 token_response = response.json() 

484 except Exception as e: 

485 logger.warning(f"Failed to parse JSON response: {e}") 

486 # Fallback to text parsing 

487 text_response = response.text 

488 token_response = {"raw_response": text_response} 

489 

490 if "access_token" not in token_response: 

491 raise OAuthError(f"No access_token in response: {token_response}") 

492 

493 logger.info("""Successfully exchanged authorization code for access token""") 

494 return token_response["access_token"] 

495 

496 except httpx.HTTPError as e: 

497 logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}") 

498 if attempt == self.max_retries - 1: 

499 raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}") 

500 await asyncio.sleep(2**attempt) # Exponential backoff 

501 

502 # This should never be reached due to the exception above, but needed for type safety 

503 raise OAuthError("Failed to exchange code for token after all retry attempts") 

504 

505 async def initiate_authorization_code_flow(self, gateway_id: str, credentials: Dict[str, Any], app_user_email: str = None) -> Dict[str, str]: 

506 """Initiate Authorization Code flow with PKCE and return authorization URL. 

507 

508 Args: 

509 gateway_id: ID of the gateway being configured 

510 credentials: OAuth configuration with client_id, authorization_url, etc. 

511 app_user_email: MCP Gateway user email to associate with tokens 

512 

513 Returns: 

514 Dict containing authorization_url and state 

515 """ 

516 

517 # Generate PKCE parameters (RFC 7636) 

518 pkce_params = self._generate_pkce_params() 

519 

520 # Generate state parameter with user context for CSRF protection 

521 state = self._generate_state(gateway_id, app_user_email) 

522 

523 # Store state with code_verifier in session/cache for validation 

524 if self.token_storage: 

525 await self._store_authorization_state(gateway_id, state, code_verifier=pkce_params["code_verifier"]) 

526 

527 # Generate authorization URL with PKCE 

528 auth_url = self._create_authorization_url_with_pkce(credentials, state, pkce_params["code_challenge"], pkce_params["code_challenge_method"]) 

529 

530 logger.info(f"Generated authorization URL with PKCE for gateway {gateway_id}") 

531 

532 return {"authorization_url": auth_url, "state": state, "gateway_id": gateway_id} 

533 

534 async def complete_authorization_code_flow(self, gateway_id: str, code: str, state: str, credentials: Dict[str, Any]) -> Dict[str, Any]: 

535 """Complete Authorization Code flow with PKCE and store tokens. 

536 

537 Args: 

538 gateway_id: ID of the gateway 

539 code: Authorization code from callback 

540 state: State parameter for CSRF validation 

541 credentials: OAuth configuration 

542 

543 Returns: 

544 Dict containing success status, user_id, and expiration info 

545 

546 Raises: 

547 OAuthError: If state validation fails or token exchange fails 

548 """ 

549 # Validate state and retrieve code_verifier 

550 state_data = await self._validate_and_retrieve_state(gateway_id, state) 

551 if not state_data: 

552 raise OAuthError("Invalid or expired state parameter - possible replay attack") 

553 

554 code_verifier = state_data.get("code_verifier") 

555 

556 # Decode state to extract user context and verify HMAC 

557 try: 

558 # Decode base64 

559 state_with_sig = base64.urlsafe_b64decode(state.encode()) 

560 

561 # Split state and signature (HMAC-SHA256 is 32 bytes) 

562 state_bytes = state_with_sig[:-32] 

563 received_signature = state_with_sig[-32:] 

564 

565 # Verify HMAC signature 

566 secret_key = self.settings.auth_encryption_secret.get_secret_value().encode() if self.settings.auth_encryption_secret else b"default-secret-key" 

567 expected_signature = hmac.new(secret_key, state_bytes, hashlib.sha256).digest() 

568 

569 if not hmac.compare_digest(received_signature, expected_signature): 

570 raise OAuthError("Invalid state signature - possible CSRF attack") 

571 

572 # Parse state data 

573 state_json = state_bytes.decode() 

574 state_payload = orjson.loads(state_json) 

575 app_user_email = state_payload.get("app_user_email") 

576 state_gateway_id = state_payload.get("gateway_id") 

577 

578 # Validate gateway ID matches 

579 if state_gateway_id != gateway_id: 

580 raise OAuthError("State parameter gateway mismatch") 

581 except Exception as e: 

582 # Fallback for legacy state format (gateway_id_random) 

583 logger.warning(f"Failed to decode state JSON, trying legacy format: {e}") 

584 app_user_email = None 

585 

586 # Exchange code for tokens with PKCE code_verifier 

587 token_response = await self._exchange_code_for_tokens(credentials, code, code_verifier=code_verifier) 

588 

589 # Extract user information from token response 

590 user_id = self._extract_user_id(token_response, credentials) 

591 

592 # Store tokens if storage service is available 

593 if self.token_storage: 

594 if not app_user_email: 

595 raise OAuthError("User context required for OAuth token storage") 

596 

597 token_record = await self.token_storage.store_tokens( 

598 gateway_id=gateway_id, 

599 user_id=user_id, 

600 app_user_email=app_user_email, # User from state 

601 access_token=token_response["access_token"], 

602 refresh_token=token_response.get("refresh_token"), 

603 expires_in=token_response.get("expires_in", self.settings.oauth_default_timeout), 

604 scopes=token_response.get("scope", "").split(), 

605 ) 

606 

607 return {"success": True, "user_id": user_id, "expires_at": token_record.expires_at.isoformat() if token_record.expires_at else None} 

608 return {"success": True, "user_id": user_id, "expires_at": None} 

609 

610 async def get_access_token_for_user(self, gateway_id: str, app_user_email: str) -> Optional[str]: 

611 """Get valid access token for a specific user. 

612 

613 Args: 

614 gateway_id: ID of the gateway 

615 app_user_email: MCP Gateway user email 

616 

617 Returns: 

618 Valid access token or None if not available 

619 """ 

620 if self.token_storage: 

621 return await self.token_storage.get_user_token(gateway_id, app_user_email) 

622 return None 

623 

624 def _generate_state(self, gateway_id: str, app_user_email: str = None) -> str: 

625 """Generate a unique state parameter with user context for CSRF protection. 

626 

627 Args: 

628 gateway_id: ID of the gateway 

629 app_user_email: MCP Gateway user email (optional but recommended) 

630 

631 Returns: 

632 Unique state string with embedded user context and HMAC signature 

633 """ 

634 # Include user email in state for secure user association 

635 state_data = {"gateway_id": gateway_id, "app_user_email": app_user_email, "nonce": secrets.token_urlsafe(16), "timestamp": datetime.now(timezone.utc).isoformat()} 

636 

637 # Encode state as JSON (orjson produces compact output by default) 

638 state_bytes = orjson.dumps(state_data) 

639 

640 # Create HMAC signature 

641 secret_key = self.settings.auth_encryption_secret.get_secret_value().encode() if self.settings.auth_encryption_secret else b"default-secret-key" 

642 signature = hmac.new(secret_key, state_bytes, hashlib.sha256).digest() 

643 

644 # Combine state and signature, then base64 encode 

645 state_with_sig = state_bytes + signature 

646 state_encoded = base64.urlsafe_b64encode(state_with_sig).decode() 

647 

648 return state_encoded 

649 

650 async def _store_authorization_state(self, gateway_id: str, state: str, code_verifier: str = None) -> None: 

651 """Store authorization state for validation with TTL. 

652 

653 Args: 

654 gateway_id: ID of the gateway 

655 state: State parameter to store 

656 code_verifier: Optional PKCE code verifier (RFC 7636) 

657 """ 

658 expires_at = datetime.now(timezone.utc) + timedelta(seconds=STATE_TTL_SECONDS) 

659 settings = get_settings() 

660 

661 # Try Redis first for distributed storage 

662 if settings.cache_type == "redis": 

663 redis = await _get_redis_client() 

664 if redis: 

665 try: 

666 state_key = f"oauth:state:{gateway_id}:{state}" 

667 state_data = {"state": state, "gateway_id": gateway_id, "code_verifier": code_verifier, "expires_at": expires_at.isoformat(), "used": False} 

668 # Store in Redis with TTL 

669 await redis.setex(state_key, STATE_TTL_SECONDS, orjson.dumps(state_data)) 

670 logger.debug(f"Stored OAuth state in Redis for gateway {gateway_id}") 

671 return 

672 except Exception as e: 

673 logger.warning(f"Failed to store state in Redis: {e}, falling back") 

674 

675 # Try database storage for multi-worker deployments 

676 if settings.cache_type == "database": 

677 try: 

678 # First-Party 

679 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel 

680 

681 db_gen = get_db() 

682 db = next(db_gen) 

683 try: 

684 # Clean up expired states first 

685 db.query(OAuthState).filter(OAuthState.expires_at < datetime.now(timezone.utc)).delete() 

686 

687 # Store new state with code_verifier 

688 oauth_state = OAuthState(gateway_id=gateway_id, state=state, code_verifier=code_verifier, expires_at=expires_at, used=False) 

689 db.add(oauth_state) 

690 db.commit() 

691 logger.debug(f"Stored OAuth state in database for gateway {gateway_id}") 

692 return 

693 finally: 

694 db_gen.close() 

695 except Exception as e: 

696 logger.warning(f"Failed to store state in database: {e}, falling back to memory") 

697 

698 # Fallback to in-memory storage for development 

699 async with _state_lock: 

700 # Clean up expired states first 

701 now = datetime.now(timezone.utc) 

702 state_key = f"oauth:state:{gateway_id}:{state}" 

703 state_data = {"state": state, "gateway_id": gateway_id, "code_verifier": code_verifier, "expires_at": expires_at.isoformat(), "used": False} 

704 expired_states = [key for key, data in _oauth_states.items() if datetime.fromisoformat(data["expires_at"]) < now] 

705 for key in expired_states: 

706 del _oauth_states[key] 

707 logger.debug(f"Cleaned up expired state: {key[:20]}...") 

708 

709 # Store the new state with expiration 

710 _oauth_states[state_key] = state_data 

711 logger.debug(f"Stored OAuth state in memory for gateway {gateway_id}") 

712 

713 async def _validate_authorization_state(self, gateway_id: str, state: str) -> bool: 

714 """Validate authorization state parameter and mark as used. 

715 

716 Args: 

717 gateway_id: ID of the gateway 

718 state: State parameter to validate 

719 

720 Returns: 

721 True if state is valid and not yet used, False otherwise 

722 """ 

723 settings = get_settings() 

724 

725 # Try Redis first for distributed storage 

726 if settings.cache_type == "redis": 

727 redis = await _get_redis_client() 

728 if redis: 

729 try: 

730 state_key = f"oauth:state:{gateway_id}:{state}" 

731 # Get and delete state atomically (single-use) 

732 state_json = await redis.getdel(state_key) 

733 if not state_json: 

734 logger.warning(f"State not found in Redis for gateway {gateway_id}") 

735 return False 

736 

737 state_data = orjson.loads(state_json) 

738 

739 # Parse expires_at as timezone-aware datetime. If the stored value 

740 # is naive, assume UTC for compatibility. 

741 try: 

742 expires_at = datetime.fromisoformat(state_data["expires_at"]) 

743 except Exception: 

744 # Fallback: try parsing without microseconds/offsets 

745 expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S") 

746 

747 if expires_at.tzinfo is None: 

748 # Assume UTC for naive timestamps 

749 expires_at = expires_at.replace(tzinfo=timezone.utc) 

750 

751 # Check if state has expired 

752 if expires_at < datetime.now(timezone.utc): 

753 logger.warning(f"State has expired for gateway {gateway_id}") 

754 return False 

755 

756 # Check if state was already used (should not happen with getdel) 

757 if state_data.get("used", False): 

758 logger.warning(f"State was already used for gateway {gateway_id} - possible replay attack") 

759 return False 

760 

761 logger.debug(f"Successfully validated OAuth state from Redis for gateway {gateway_id}") 

762 return True 

763 except Exception as e: 

764 logger.warning(f"Failed to validate state in Redis: {e}, falling back") 

765 

766 # Try database storage for multi-worker deployments 

767 if settings.cache_type == "database": 

768 try: 

769 # First-Party 

770 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel 

771 

772 db_gen = get_db() 

773 db = next(db_gen) 

774 try: 

775 # Find the state 

776 oauth_state = db.query(OAuthState).filter(OAuthState.gateway_id == gateway_id, OAuthState.state == state).first() 

777 

778 if not oauth_state: 

779 logger.warning(f"State not found in database for gateway {gateway_id}") 

780 return False 

781 

782 # Check if state has expired 

783 # Ensure oauth_state.expires_at is timezone-aware. If naive, assume UTC. 

784 expires_at = oauth_state.expires_at 

785 if expires_at.tzinfo is None: 

786 expires_at = expires_at.replace(tzinfo=timezone.utc) 

787 

788 if expires_at < datetime.now(timezone.utc): 

789 logger.warning(f"State has expired for gateway {gateway_id}") 

790 db.delete(oauth_state) 

791 db.commit() 

792 return False 

793 

794 # Check if state was already used 

795 if oauth_state.used: 

796 logger.warning(f"State has already been used for gateway {gateway_id} - possible replay attack") 

797 return False 

798 

799 # Mark as used and delete (single-use) 

800 db.delete(oauth_state) 

801 db.commit() 

802 logger.debug(f"Successfully validated OAuth state from database for gateway {gateway_id}") 

803 return True 

804 finally: 

805 db_gen.close() 

806 except Exception as e: 

807 logger.warning(f"Failed to validate state in database: {e}, falling back to memory") 

808 

809 # Fallback to in-memory storage for development 

810 state_key = f"oauth:state:{gateway_id}:{state}" 

811 async with _state_lock: 

812 state_data = _oauth_states.get(state_key) 

813 

814 # Check if state exists 

815 if not state_data: 

816 logger.warning(f"State not found in memory for gateway {gateway_id}") 

817 return False 

818 

819 # Parse and normalize expires_at to timezone-aware datetime 

820 expires_at = datetime.fromisoformat(state_data["expires_at"]) 

821 if expires_at.tzinfo is None: 

822 expires_at = expires_at.replace(tzinfo=timezone.utc) 

823 

824 if expires_at < datetime.now(timezone.utc): 

825 logger.warning(f"State has expired for gateway {gateway_id}") 

826 del _oauth_states[state_key] # Clean up expired state 

827 return False 

828 

829 # Check if state has already been used (prevent replay) 

830 if state_data.get("used", False): 

831 logger.warning(f"State has already been used for gateway {gateway_id} - possible replay attack") 

832 return False 

833 

834 # Mark state as used and remove it (single-use) 

835 del _oauth_states[state_key] 

836 logger.debug(f"Successfully validated OAuth state from memory for gateway {gateway_id}") 

837 return True 

838 

839 async def _validate_and_retrieve_state(self, gateway_id: str, state: str) -> Optional[Dict[str, Any]]: 

840 """Validate state and return full state data including code_verifier. 

841 

842 Args: 

843 gateway_id: ID of the gateway 

844 state: State parameter to validate 

845 

846 Returns: 

847 Dict with state data including code_verifier, or None if invalid/expired 

848 """ 

849 settings = get_settings() 

850 

851 # Try Redis first 

852 if settings.cache_type == "redis": 

853 redis = await _get_redis_client() 

854 if redis: 

855 try: 

856 state_key = f"oauth:state:{gateway_id}:{state}" 

857 state_json = await redis.getdel(state_key) # Atomic get+delete 

858 if not state_json: 

859 return None 

860 

861 state_data = orjson.loads(state_json) 

862 

863 # Check expiration 

864 try: 

865 expires_at = datetime.fromisoformat(state_data["expires_at"]) 

866 except Exception: 

867 expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S") 

868 

869 if expires_at.tzinfo is None: 

870 expires_at = expires_at.replace(tzinfo=timezone.utc) 

871 

872 if expires_at < datetime.now(timezone.utc): 

873 return None 

874 

875 return state_data 

876 except Exception as e: 

877 logger.warning(f"Failed to validate state in Redis: {e}, falling back") 

878 

879 # Try database 

880 if settings.cache_type == "database": 

881 try: 

882 # First-Party 

883 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel 

884 

885 db_gen = get_db() 

886 db = next(db_gen) 

887 try: 

888 oauth_state = db.query(OAuthState).filter(OAuthState.gateway_id == gateway_id, OAuthState.state == state).first() 

889 

890 if not oauth_state: 

891 return None 

892 

893 # Check expiration 

894 expires_at = oauth_state.expires_at 

895 if expires_at.tzinfo is None: 

896 expires_at = expires_at.replace(tzinfo=timezone.utc) 

897 

898 if expires_at < datetime.now(timezone.utc): 

899 db.delete(oauth_state) 

900 db.commit() 

901 return None 

902 

903 # Check if already used 

904 if oauth_state.used: 

905 return None 

906 

907 # Build state data 

908 state_data = {"state": oauth_state.state, "gateway_id": oauth_state.gateway_id, "code_verifier": oauth_state.code_verifier, "expires_at": oauth_state.expires_at.isoformat()} 

909 

910 # Mark as used and delete 

911 db.delete(oauth_state) 

912 db.commit() 

913 

914 return state_data 

915 finally: 

916 db_gen.close() 

917 except Exception as e: 

918 logger.warning(f"Failed to validate state in database: {e}") 

919 

920 # Fallback to in-memory 

921 state_key = f"oauth:state:{gateway_id}:{state}" 

922 async with _state_lock: 

923 state_data = _oauth_states.get(state_key) 

924 if not state_data: 

925 return None 

926 

927 # Check expiration 

928 expires_at = datetime.fromisoformat(state_data["expires_at"]) 

929 if expires_at.tzinfo is None: 

930 expires_at = expires_at.replace(tzinfo=timezone.utc) 

931 

932 if expires_at < datetime.now(timezone.utc): 

933 del _oauth_states[state_key] 

934 return None 

935 

936 # Remove from memory (single-use) 

937 del _oauth_states[state_key] 

938 return state_data 

939 

940 def _create_authorization_url(self, credentials: Dict[str, Any], state: str) -> tuple[str, str]: 

941 """Create authorization URL with state parameter. 

942 

943 Args: 

944 credentials: OAuth configuration 

945 state: State parameter for CSRF protection 

946 

947 Returns: 

948 Tuple of (authorization_url, state) 

949 """ 

950 client_id = credentials["client_id"] 

951 redirect_uri = credentials["redirect_uri"] 

952 authorization_url = credentials["authorization_url"] 

953 scopes = credentials.get("scopes", []) 

954 

955 # Create OAuth2 session 

956 oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes) 

957 

958 # Generate authorization URL with state for CSRF protection 

959 auth_url, state = oauth.authorization_url(authorization_url, state=state) 

960 

961 return auth_url, state 

962 

963 def _create_authorization_url_with_pkce(self, credentials: Dict[str, Any], state: str, code_challenge: str, code_challenge_method: str) -> str: 

964 """Create authorization URL with PKCE parameters (RFC 7636). 

965 

966 Args: 

967 credentials: OAuth configuration 

968 state: State parameter for CSRF protection 

969 code_challenge: PKCE code challenge 

970 code_challenge_method: PKCE method (S256) 

971 

972 Returns: 

973 Authorization URL string with PKCE parameters 

974 """ 

975 # Standard 

976 from urllib.parse import urlencode # pylint: disable=import-outside-toplevel 

977 

978 client_id = credentials["client_id"] 

979 redirect_uri = credentials["redirect_uri"] 

980 authorization_url = credentials["authorization_url"] 

981 scopes = credentials.get("scopes", []) 

982 

983 # Build authorization parameters 

984 params = {"response_type": "code", "client_id": client_id, "redirect_uri": redirect_uri, "state": state, "code_challenge": code_challenge, "code_challenge_method": code_challenge_method} 

985 

986 # Add scopes if present 

987 if scopes: 

988 params["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes 

989 

990 # Add resource parameter for JWT access token (RFC 8707) 

991 # The resource is the MCP server URL, set by oauth_router.py 

992 resource = credentials.get("resource") 

993 if resource: 

994 # RFC 8707 allows multiple resource parameters 

995 if isinstance(resource, list): 

996 params["resource"] = resource # urlencode with doseq=True handles lists 

997 else: 

998 params["resource"] = resource 

999 

1000 # Build full URL (doseq=True handles list values like multiple resource params) 

1001 query_string = urlencode(params, doseq=True) 

1002 return f"{authorization_url}?{query_string}" 

1003 

1004 async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str, code_verifier: str = None) -> Dict[str, Any]: 

1005 """Exchange authorization code for tokens with PKCE support. 

1006 

1007 Args: 

1008 credentials: OAuth configuration 

1009 code: Authorization code from callback 

1010 code_verifier: Optional PKCE code verifier (RFC 7636) 

1011 

1012 Returns: 

1013 Token response dictionary 

1014 

1015 Raises: 

1016 OAuthError: If token exchange fails 

1017 """ 

1018 client_id = credentials["client_id"] 

1019 client_secret = credentials.get("client_secret") # Optional for public clients (PKCE-only) 

1020 token_url = credentials["token_url"] 

1021 redirect_uri = credentials["redirect_uri"] 

1022 

1023 # Decrypt client secret if it's encrypted and present 

1024 if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer 

1025 try: 

1026 settings = get_settings() 

1027 encryption = get_encryption_service(settings.auth_encryption_secret) 

1028 decrypted_secret = await encryption.decrypt_secret_async(client_secret) 

1029 if decrypted_secret: 

1030 client_secret = decrypted_secret 

1031 logger.debug("Successfully decrypted client secret") 

1032 else: 

1033 logger.warning("Failed to decrypt client secret, using encrypted version") 

1034 except Exception as e: 

1035 logger.warning(f"Failed to decrypt client secret: {e}, using encrypted version") 

1036 

1037 # Prepare token exchange data 

1038 token_data = { 

1039 "grant_type": "authorization_code", 

1040 "code": code, 

1041 "redirect_uri": redirect_uri, 

1042 "client_id": client_id, 

1043 } 

1044 

1045 # Only include client_secret if present (public clients don't have secrets) 

1046 if client_secret: 

1047 token_data["client_secret"] = client_secret 

1048 

1049 # Add PKCE code_verifier if present (RFC 7636) 

1050 if code_verifier: 

1051 token_data["code_verifier"] = code_verifier 

1052 

1053 # Add resource parameter to request JWT access token (RFC 8707) 

1054 # The resource identifies the MCP server (resource server), not the OAuth server 

1055 resource = credentials.get("resource") 

1056 if resource: 

1057 if isinstance(resource, list): 

1058 # RFC 8707 allows multiple resource parameters - use list of tuples 

1059 form_data: list[tuple[str, str]] = list(token_data.items()) 

1060 for r in resource: 

1061 if r: 

1062 form_data.append(("resource", r)) 

1063 token_data = form_data # type: ignore[assignment] 

1064 else: 

1065 token_data["resource"] = resource 

1066 

1067 # Exchange code for token with retries 

1068 for attempt in range(self.max_retries): 

1069 try: 

1070 client = await self._get_client() 

1071 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

1072 response.raise_for_status() 

1073 

1074 # GitHub returns form-encoded responses, not JSON 

1075 content_type = response.headers.get("content-type", "") 

1076 if "application/x-www-form-urlencoded" in content_type: 

1077 # Parse form-encoded response 

1078 text_response = response.text 

1079 token_response = {} 

1080 for pair in text_response.split("&"): 

1081 if "=" in pair: 

1082 key, value = pair.split("=", 1) 

1083 token_response[key] = value 

1084 else: 

1085 # Try JSON response 

1086 try: 

1087 token_response = response.json() 

1088 except Exception as e: 

1089 logger.warning(f"Failed to parse JSON response: {e}") 

1090 # Fallback to text parsing 

1091 text_response = response.text 

1092 token_response = {"raw_response": text_response} 

1093 

1094 if "access_token" not in token_response: 

1095 raise OAuthError(f"No access_token in response: {token_response}") 

1096 

1097 logger.info("""Successfully exchanged authorization code for tokens""") 

1098 return token_response 

1099 

1100 except httpx.HTTPError as e: 

1101 logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}") 

1102 if attempt == self.max_retries - 1: 

1103 raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}") 

1104 await asyncio.sleep(2**attempt) # Exponential backoff 

1105 

1106 # This should never be reached due to the exception above, but needed for type safety 

1107 raise OAuthError("Failed to exchange code for token after all retry attempts") 

1108 

1109 async def refresh_token(self, refresh_token: str, credentials: Dict[str, Any]) -> Dict[str, Any]: 

1110 """Refresh an expired access token using a refresh token. 

1111 

1112 Args: 

1113 refresh_token: The refresh token to use 

1114 credentials: OAuth configuration including client_id, client_secret, token_url 

1115 

1116 Returns: 

1117 Dict containing new access_token, optional refresh_token, and expires_in 

1118 

1119 Raises: 

1120 OAuthError: If token refresh fails 

1121 """ 

1122 if not refresh_token: 

1123 raise OAuthError("No refresh token available") 

1124 

1125 token_url = credentials.get("token_url") 

1126 if not token_url: 

1127 raise OAuthError("No token URL configured for OAuth provider") 

1128 

1129 client_id = credentials.get("client_id") 

1130 client_secret = credentials.get("client_secret") 

1131 

1132 if not client_id: 

1133 raise OAuthError("No client_id configured for OAuth provider") 

1134 

1135 # Prepare token refresh request 

1136 token_data = { 

1137 "grant_type": "refresh_token", 

1138 "refresh_token": refresh_token, 

1139 "client_id": client_id, 

1140 } 

1141 

1142 # Add client_secret if available (some providers require it) 

1143 if client_secret: 

1144 token_data["client_secret"] = client_secret 

1145 

1146 # Add resource parameter for JWT access token (RFC 8707) 

1147 # Must be included in refresh requests to maintain JWT token type 

1148 resource = credentials.get("resource") 

1149 if resource: 

1150 if isinstance(resource, list): 

1151 # RFC 8707 allows multiple resource parameters - use list of tuples 

1152 form_data: list[tuple[str, str]] = list(token_data.items()) 

1153 for r in resource: 

1154 if r: 

1155 form_data.append(("resource", r)) 

1156 token_data = form_data # type: ignore[assignment] 

1157 else: 

1158 token_data["resource"] = resource 

1159 

1160 # Attempt token refresh with retries 

1161 for attempt in range(self.max_retries): 

1162 try: 

1163 client = await self._get_client() 

1164 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

1165 if response.status_code == 200: 

1166 token_response = response.json() 

1167 

1168 # Validate required fields 

1169 if "access_token" not in token_response: 

1170 raise OAuthError("No access_token in refresh response") 

1171 

1172 logger.info("Successfully refreshed OAuth token") 

1173 return token_response 

1174 

1175 error_text = response.text 

1176 # If we get a 400/401, the refresh token is likely invalid 

1177 if response.status_code in [400, 401]: 

1178 raise OAuthError(f"Refresh token invalid or expired: {error_text}") 

1179 logger.warning(f"Token refresh failed with status {response.status_code}: {error_text}") 

1180 

1181 except httpx.HTTPError as e: 

1182 logger.warning(f"Token refresh attempt {attempt + 1} failed: {str(e)}") 

1183 if attempt == self.max_retries - 1: 

1184 raise OAuthError(f"Failed to refresh token after {self.max_retries} attempts: {str(e)}") 

1185 await asyncio.sleep(2**attempt) # Exponential backoff 

1186 

1187 raise OAuthError("Failed to refresh token after all retry attempts") 

1188 

1189 def _extract_user_id(self, token_response: Dict[str, Any], credentials: Dict[str, Any]) -> str: 

1190 """Extract user ID from token response. 

1191 

1192 Args: 

1193 token_response: Response from token exchange 

1194 credentials: OAuth configuration 

1195 

1196 Returns: 

1197 User ID string 

1198 """ 

1199 # Try to extract user ID from various common fields in token response 

1200 # Different OAuth providers use different field names 

1201 

1202 # Check for 'sub' (subject) - JWT standard 

1203 if "sub" in token_response: 

1204 return token_response["sub"] 

1205 

1206 # Check for 'user_id' - common in some OAuth responses 

1207 if "user_id" in token_response: 

1208 return token_response["user_id"] 

1209 

1210 # Check for 'id' - also common 

1211 if "id" in token_response: 

1212 return token_response["id"] 

1213 

1214 # Fallback to client_id if no user info is available 

1215 if credentials.get("client_id"): 

1216 return credentials["client_id"] 

1217 

1218 # Final fallback 

1219 return "unknown_user" 

1220 

1221 

1222class OAuthError(Exception): 

1223 """OAuth-related errors. 

1224 

1225 Examples: 

1226 >>> try: 

1227 ... raise OAuthError("Token acquisition failed") 

1228 ... except OAuthError as e: 

1229 ... str(e) 

1230 'Token acquisition failed' 

1231 >>> try: 

1232 ... raise OAuthError("Invalid grant type") 

1233 ... except Exception as e: 

1234 ... isinstance(e, OAuthError) 

1235 True 

1236 """