Coverage for mcpgateway / services / oauth_manager.py: 99%

647 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/oauth_manager.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7OAuth 2.0 Manager for ContextForge. 

8 

9This module handles OAuth 2.0 authentication flows including: 

10- Client Credentials (Machine-to-Machine) 

11- Authorization Code (User Delegation) 

12""" 

13 

14# Standard 

15import asyncio 

16import base64 

17from datetime import datetime, timedelta, timezone 

18import hashlib 

19import logging 

20import secrets 

21from typing import Any, Dict, Optional 

22from urllib.parse import urlparse 

23 

24# Third-Party 

25import httpx 

26import orjson 

27from requests_oauthlib import OAuth2Session 

28 

29# First-Party 

30from mcpgateway.config import get_settings 

31from mcpgateway.services.encryption_service import decrypt_oauth_config_for_runtime, get_encryption_service 

32from mcpgateway.services.http_client_service import get_http_client 

33from mcpgateway.utils.redis_client import get_redis_client as _get_shared_redis_client 

34 

35logger = logging.getLogger(__name__) 

36 

37# In-memory storage for OAuth states with expiration (fallback for single-process) 

38# Format: {state_key: {"state": state, "gateway_id": gateway_id, "expires_at": datetime}} 

39_oauth_states: Dict[str, Dict[str, Any]] = {} 

40# Reverse lookup for callback handlers that only receive state. 

41# Format: {state: gateway_id} 

42_oauth_state_lookup: Dict[str, str] = {} 

43# Lock for thread-safe state operations 

44_state_lock = asyncio.Lock() 

45 

46# State TTL in seconds (5 minutes) 

47STATE_TTL_SECONDS = 300 

48 

49# Redis client for distributed state storage (uses shared factory) 

50_redis_client: Optional[Any] = None 

51_REDIS_INITIALIZED = False 

52 

53 

54async def _get_redis_client(): 

55 """Get shared Redis client for distributed state storage. 

56 

57 Uses the centralized Redis client factory for consistent configuration. 

58 

59 Returns: 

60 Redis client instance or None if unavailable 

61 """ 

62 global _redis_client, _REDIS_INITIALIZED # pylint: disable=global-statement 

63 

64 if _REDIS_INITIALIZED: 

65 return _redis_client 

66 

67 settings = get_settings() 

68 if settings.cache_type == "redis" and settings.redis_url: 

69 try: 

70 _redis_client = await _get_shared_redis_client() 

71 if _redis_client: 

72 logger.info("OAuth manager connected to shared Redis client") 

73 except Exception as e: 

74 logger.warning(f"Failed to get Redis client, falling back to in-memory storage: {e}") 

75 _redis_client = None 

76 else: 

77 _redis_client = None 

78 

79 _REDIS_INITIALIZED = True 

80 return _redis_client 

81 

82 

83class OAuthManager: 

84 """Manages OAuth 2.0 authentication flows. 

85 

86 Examples: 

87 >>> manager = OAuthManager(request_timeout=30, max_retries=3) 

88 >>> manager.request_timeout 

89 30 

90 >>> manager.max_retries 

91 3 

92 >>> manager.token_storage is None 

93 True 

94 >>> 

95 >>> # Test grant type validation 

96 >>> grant_type = "client_credentials" 

97 >>> grant_type in ["client_credentials", "authorization_code"] 

98 True 

99 >>> grant_type = "invalid_grant" 

100 >>> grant_type in ["client_credentials", "authorization_code"] 

101 False 

102 >>> 

103 >>> # Test encrypted secret detection heuristic 

104 >>> short_secret = "secret123" 

105 >>> len(short_secret) > 50 

106 False 

107 >>> encrypted_secret = "gAAAAABh" + "x" * 60 # Simulated encrypted secret 

108 >>> len(encrypted_secret) > 50 

109 True 

110 >>> 

111 >>> # Test scope list handling 

112 >>> scopes = ["read", "write"] 

113 >>> " ".join(scopes) 

114 'read write' 

115 >>> empty_scopes = [] 

116 >>> " ".join(empty_scopes) 

117 '' 

118 """ 

119 

120 # Known Microsoft Entra login hosts (global + sovereign clouds). 

121 _ENTRA_HOSTS: frozenset[str] = frozenset( 

122 { 

123 "login.microsoftonline.com", 

124 "login.microsoftonline.us", 

125 "login.microsoftonline.de", 

126 "login.partner.microsoftonline.cn", 

127 } 

128 ) 

129 

130 def __init__(self, request_timeout: int = 30, max_retries: int = 3, token_storage: Optional[Any] = None): 

131 """Initialize OAuth Manager. 

132 

133 Args: 

134 request_timeout: Timeout for OAuth requests in seconds 

135 max_retries: Maximum number of retry attempts for token requests 

136 token_storage: Optional TokenStorageService for storing tokens 

137 """ 

138 self.request_timeout = request_timeout 

139 self.max_retries = max_retries 

140 self.token_storage = token_storage 

141 self.settings = get_settings() 

142 

143 async def _get_client(self) -> httpx.AsyncClient: 

144 """Get the shared singleton HTTP client. 

145 

146 Returns: 

147 Shared httpx.AsyncClient instance with connection pooling 

148 """ 

149 return await get_http_client() 

150 

151 def _generate_pkce_params(self) -> Dict[str, str]: 

152 """Generate PKCE parameters for OAuth Authorization Code flow (RFC 7636). 

153 

154 Returns: 

155 Dict containing code_verifier, code_challenge, and code_challenge_method 

156 """ 

157 # Generate code_verifier: 43-128 character random string 

158 code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=") 

159 

160 # Generate code_challenge: base64url(SHA256(code_verifier)) 

161 code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()).decode("utf-8").rstrip("=") 

162 

163 return {"code_verifier": code_verifier, "code_challenge": code_challenge, "code_challenge_method": "S256"} 

164 

165 async def get_access_token(self, credentials: Dict[str, Any]) -> str: 

166 """Get access token based on grant type. 

167 

168 Args: 

169 credentials: OAuth configuration containing grant_type and other params 

170 

171 Returns: 

172 Access token string 

173 

174 Raises: 

175 ValueError: If grant type is unsupported 

176 OAuthError: If token acquisition fails 

177 

178 Examples: 

179 Client credentials flow: 

180 >>> import asyncio 

181 >>> class TestMgr(OAuthManager): 

182 ... async def _client_credentials_flow(self, credentials): 

183 ... return 'tok' 

184 >>> mgr = TestMgr() 

185 >>> asyncio.run(mgr.get_access_token({'grant_type': 'client_credentials'})) 

186 'tok' 

187 

188 Authorization code flow requires interactive completion: 

189 >>> def _auth_code_requires_consent(): 

190 ... try: 

191 ... asyncio.run(mgr.get_access_token({'grant_type': 'authorization_code'})) 

192 ... except OAuthError: 

193 ... return True 

194 >>> _auth_code_requires_consent() 

195 True 

196 

197 Unsupported grant type raises ValueError: 

198 >>> def _unsupported(): 

199 ... try: 

200 ... asyncio.run(mgr.get_access_token({'grant_type': 'bad'})) 

201 ... except ValueError: 

202 ... return True 

203 >>> _unsupported() 

204 True 

205 """ 

206 grant_type = credentials.get("grant_type") 

207 logger.debug(f"Getting access token for grant type: {grant_type}") 

208 

209 if grant_type == "client_credentials": 

210 return await self._client_credentials_flow(credentials) 

211 if grant_type == "password": 

212 return await self._password_flow(credentials) 

213 if grant_type == "authorization_code": 

214 raise OAuthError("Authorization code flow requires user consent via /oauth/authorize and does not support client_credentials fallback") 

215 raise ValueError(f"Unsupported grant type: {grant_type}") 

216 

217 @staticmethod 

218 async def _prepare_runtime_credentials(credentials: Dict[str, Any], flow_name: str) -> Dict[str, Any]: 

219 """Return runtime-ready oauth credentials with sensitive fields decrypted. 

220 

221 Args: 

222 credentials: Stored oauth_config payload. 

223 flow_name: Flow label for diagnostic logging. 

224 

225 Returns: 

226 Dict[str, Any]: Runtime-ready credentials. 

227 """ 

228 try: 

229 settings = get_settings() 

230 encryption = get_encryption_service(settings.auth_encryption_secret) 

231 runtime_credentials = await decrypt_oauth_config_for_runtime(credentials, encryption=encryption) 

232 if isinstance(runtime_credentials, dict): 

233 return runtime_credentials 

234 except Exception as exc: 

235 logger.warning("Failed to prepare runtime OAuth credentials for %s flow: %s", flow_name, exc) 

236 return credentials 

237 

238 async def _client_credentials_flow(self, credentials: Dict[str, Any]) -> str: 

239 """Machine-to-machine authentication using client credentials. 

240 

241 Args: 

242 credentials: OAuth configuration with client_id, client_secret, token_url 

243 

244 Returns: 

245 Access token string 

246 

247 Raises: 

248 OAuthError: If token acquisition fails after all retries 

249 """ 

250 runtime_credentials = await self._prepare_runtime_credentials(credentials, "client_credentials") 

251 client_id = runtime_credentials["client_id"] 

252 client_secret = runtime_credentials["client_secret"] 

253 token_url = runtime_credentials["token_url"] 

254 scopes = runtime_credentials.get("scopes", []) 

255 

256 # Prepare token request data 

257 token_data = { 

258 "grant_type": "client_credentials", 

259 "client_id": client_id, 

260 "client_secret": client_secret, 

261 } 

262 

263 if scopes: 

264 token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes 

265 

266 # Fetch token with retries 

267 for attempt in range(self.max_retries): 

268 try: 

269 client = await self._get_client() 

270 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

271 response.raise_for_status() 

272 

273 # GitHub returns form-encoded responses, not JSON 

274 content_type = response.headers.get("content-type", "") 

275 if "application/x-www-form-urlencoded" in content_type: 

276 # Parse form-encoded response 

277 text_response = response.text 

278 token_response = {} 

279 for pair in text_response.split("&"): 

280 if "=" in pair: 

281 key, value = pair.split("=", 1) 

282 token_response[key] = value 

283 else: 

284 # Try JSON response 

285 try: 

286 token_response = response.json() 

287 except Exception as e: 

288 logger.warning(f"Failed to parse JSON response: {e}") 

289 # Fallback to text parsing 

290 text_response = response.text 

291 token_response = {"raw_response": text_response} 

292 

293 if "access_token" not in token_response: 

294 raise OAuthError(f"No access_token in response: {token_response}") 

295 

296 logger.info("""Successfully obtained access token via client credentials""") 

297 return token_response["access_token"] 

298 

299 except httpx.HTTPError as e: 

300 logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}") 

301 if attempt == self.max_retries - 1: 

302 raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}") 

303 await asyncio.sleep(2**attempt) # Exponential backoff 

304 

305 # This should never be reached due to the exception above, but needed for type safety 

306 raise OAuthError("Failed to obtain access token after all retry attempts") 

307 

308 async def _password_flow(self, credentials: Dict[str, Any]) -> str: 

309 """Resource Owner Password Credentials flow (RFC 6749 Section 4.3). 

310 

311 This flow is used when the application can directly handle the user's credentials, 

312 such as with trusted first-party applications or legacy integrations like Keycloak. 

313 

314 Args: 

315 credentials: OAuth configuration with client_id, optional client_secret, token_url, username, password 

316 

317 Returns: 

318 Access token string 

319 

320 Raises: 

321 OAuthError: If token acquisition fails after all retries 

322 """ 

323 runtime_credentials = await self._prepare_runtime_credentials(credentials, "password") 

324 client_id = runtime_credentials.get("client_id") 

325 client_secret = runtime_credentials.get("client_secret") 

326 token_url = runtime_credentials["token_url"] 

327 username = runtime_credentials.get("username") 

328 password = runtime_credentials.get("password") 

329 scopes = runtime_credentials.get("scopes", []) 

330 

331 if not username or not password: 

332 raise OAuthError("Username and password are required for password grant type") 

333 

334 # Prepare token request data 

335 token_data = { 

336 "grant_type": "password", 

337 "username": username, 

338 "password": password, 

339 } 

340 

341 # Add client_id (required by most providers including Keycloak) 

342 if client_id: 

343 token_data["client_id"] = client_id 

344 

345 # Add client_secret if present (some providers require it, others don't) 

346 if client_secret: 

347 token_data["client_secret"] = client_secret 

348 

349 if scopes: 

350 token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes 

351 

352 # Fetch token with retries 

353 for attempt in range(self.max_retries): 

354 try: 

355 client = await self._get_client() 

356 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

357 response.raise_for_status() 

358 

359 # Handle both JSON and form-encoded responses 

360 content_type = response.headers.get("content-type", "") 

361 if "application/x-www-form-urlencoded" in content_type: 

362 # Parse form-encoded response 

363 text_response = response.text 

364 token_response = {} 

365 for pair in text_response.split("&"): 

366 if "=" in pair: 

367 key, value = pair.split("=", 1) 

368 token_response[key] = value 

369 else: 

370 # Try JSON response 

371 try: 

372 token_response = response.json() 

373 except Exception as e: 

374 logger.warning(f"Failed to parse JSON response: {e}") 

375 # Fallback to text parsing 

376 text_response = response.text 

377 token_response = {"raw_response": text_response} 

378 

379 if "access_token" not in token_response: 

380 raise OAuthError(f"No access_token in response: {token_response}") 

381 

382 logger.info("Successfully obtained access token via password grant") 

383 return token_response["access_token"] 

384 

385 except httpx.HTTPError as e: 

386 logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}") 

387 if attempt == self.max_retries - 1: 

388 raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}") 

389 await asyncio.sleep(2**attempt) # Exponential backoff 

390 

391 # This should never be reached due to the exception above, but needed for type safety 

392 raise OAuthError("Failed to obtain access token after all retry attempts") 

393 

394 async def get_authorization_url(self, credentials: Dict[str, Any]) -> Dict[str, str]: 

395 """Get authorization URL for user delegation flow. 

396 

397 Args: 

398 credentials: OAuth configuration with client_id, authorization_url, etc. 

399 

400 Returns: 

401 Dict containing authorization_url and state 

402 """ 

403 client_id = credentials["client_id"] 

404 redirect_uri = credentials["redirect_uri"] 

405 authorization_url = credentials["authorization_url"] 

406 scopes = credentials.get("scopes", []) 

407 

408 # Create OAuth2 session 

409 oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes) 

410 

411 # Generate authorization URL with state for CSRF protection 

412 auth_url, state = oauth.authorization_url(authorization_url) 

413 

414 logger.info(f"Generated authorization URL for client {client_id}") 

415 

416 return {"authorization_url": auth_url, "state": state} 

417 

418 async def exchange_code_for_token(self, credentials: Dict[str, Any], code: str, state: str) -> str: # pylint: disable=unused-argument 

419 """Exchange authorization code for access token. 

420 

421 Args: 

422 credentials: OAuth configuration 

423 code: Authorization code from callback 

424 state: State parameter for CSRF validation 

425 

426 Returns: 

427 Access token string 

428 

429 Raises: 

430 OAuthError: If token exchange fails 

431 """ 

432 runtime_credentials = await self._prepare_runtime_credentials(credentials, "authorization_code_exchange") 

433 client_id = runtime_credentials["client_id"] 

434 client_secret = runtime_credentials.get("client_secret") # Optional for public clients (PKCE-only) 

435 token_url = runtime_credentials["token_url"] 

436 redirect_uri = runtime_credentials["redirect_uri"] 

437 

438 # Prepare token exchange data 

439 token_data = { 

440 "grant_type": "authorization_code", 

441 "code": code, 

442 "redirect_uri": redirect_uri, 

443 "client_id": client_id, 

444 } 

445 

446 # Only include client_secret if present (public clients don't have secrets) 

447 if client_secret: 

448 token_data["client_secret"] = client_secret 

449 

450 # Exchange code for token with retries 

451 for attempt in range(self.max_retries): 

452 try: 

453 client = await self._get_client() 

454 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

455 response.raise_for_status() 

456 

457 # GitHub returns form-encoded responses, not JSON 

458 content_type = response.headers.get("content-type", "") 

459 if "application/x-www-form-urlencoded" in content_type: 

460 # Parse form-encoded response 

461 text_response = response.text 

462 token_response = {} 

463 for pair in text_response.split("&"): 

464 if "=" in pair: 

465 key, value = pair.split("=", 1) 

466 token_response[key] = value 

467 else: 

468 # Try JSON response 

469 try: 

470 token_response = response.json() 

471 except Exception as e: 

472 logger.warning(f"Failed to parse JSON response: {e}") 

473 # Fallback to text parsing 

474 text_response = response.text 

475 token_response = {"raw_response": text_response} 

476 

477 if "access_token" not in token_response: 

478 raise OAuthError(f"No access_token in response: {token_response}") 

479 

480 logger.info("""Successfully exchanged authorization code for access token""") 

481 return token_response["access_token"] 

482 

483 except httpx.HTTPError as e: 

484 logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}") 

485 if attempt == self.max_retries - 1: 

486 raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}") 

487 await asyncio.sleep(2**attempt) # Exponential backoff 

488 

489 # This should never be reached due to the exception above, but needed for type safety 

490 raise OAuthError("Failed to exchange code for token after all retry attempts") 

491 

492 async def initiate_authorization_code_flow(self, gateway_id: str, credentials: Dict[str, Any], app_user_email: str = None) -> Dict[str, str]: 

493 """Initiate Authorization Code flow with PKCE and return authorization URL. 

494 

495 Args: 

496 gateway_id: ID of the gateway being configured 

497 credentials: OAuth configuration with client_id, authorization_url, etc. 

498 app_user_email: ContextForge user email to associate with tokens 

499 

500 Returns: 

501 Dict containing authorization_url and state 

502 """ 

503 

504 # Generate PKCE parameters (RFC 7636) 

505 pkce_params = self._generate_pkce_params() 

506 

507 # Generate state parameter with user context for CSRF protection 

508 state = self._generate_state(gateway_id, app_user_email) 

509 

510 # Store state with code_verifier in session/cache for validation 

511 if self.token_storage: 

512 await self._store_authorization_state( 

513 gateway_id, 

514 state, 

515 code_verifier=pkce_params["code_verifier"], 

516 app_user_email=app_user_email, 

517 ) 

518 

519 # Generate authorization URL with PKCE 

520 auth_url = self._create_authorization_url_with_pkce(credentials, state, pkce_params["code_challenge"], pkce_params["code_challenge_method"]) 

521 

522 logger.info(f"Generated authorization URL with PKCE for gateway {gateway_id}") 

523 

524 return {"authorization_url": auth_url, "state": state, "gateway_id": gateway_id} 

525 

526 async def complete_authorization_code_flow(self, gateway_id: str, code: str, state: str, credentials: Dict[str, Any]) -> Dict[str, Any]: 

527 """Complete Authorization Code flow with PKCE and store tokens. 

528 

529 Args: 

530 gateway_id: ID of the gateway 

531 code: Authorization code from callback 

532 state: State parameter for CSRF validation 

533 credentials: OAuth configuration 

534 

535 Returns: 

536 Dict containing success status, user_id, and expiration info 

537 

538 Raises: 

539 OAuthError: If state validation fails or token exchange fails 

540 """ 

541 # Validate state and retrieve code_verifier 

542 state_data = await self._validate_and_retrieve_state(gateway_id, state) 

543 if not state_data: 

544 raise OAuthError("Invalid or expired state parameter - possible replay attack") 

545 

546 code_verifier = state_data.get("code_verifier") 

547 app_user_email = state_data.get("app_user_email") 

548 

549 # Backward compatibility for in-flight legacy states that embedded user context. 

550 if not app_user_email: 

551 legacy_state_payload = self._extract_legacy_state_payload(state) 

552 if legacy_state_payload: 

553 legacy_gateway_id = legacy_state_payload.get("gateway_id") 

554 if legacy_gateway_id and legacy_gateway_id != gateway_id: 

555 raise OAuthError("State parameter gateway mismatch") 

556 app_user_email = legacy_state_payload.get("app_user_email") 

557 

558 # Exchange code for tokens with PKCE code_verifier 

559 token_response = await self._exchange_code_for_tokens(credentials, code, code_verifier=code_verifier) 

560 

561 # Extract user information from token response 

562 user_id = self._extract_user_id(token_response, credentials) 

563 

564 # Store tokens if storage service is available 

565 if self.token_storage: 

566 if not app_user_email: 

567 raise OAuthError("User context required for OAuth token storage") 

568 

569 token_record = await self.token_storage.store_tokens( 

570 gateway_id=gateway_id, 

571 user_id=user_id, 

572 app_user_email=app_user_email, # User from state 

573 access_token=token_response["access_token"], 

574 refresh_token=token_response.get("refresh_token"), 

575 expires_in=token_response.get("expires_in", self.settings.oauth_default_timeout), 

576 scopes=token_response.get("scope", "").split(), 

577 ) 

578 

579 return {"success": True, "user_id": user_id, "expires_at": token_record.expires_at.isoformat() if token_record.expires_at else None} 

580 return {"success": True, "user_id": user_id, "expires_at": None} 

581 

582 async def get_access_token_for_user(self, gateway_id: str, app_user_email: str) -> Optional[str]: 

583 """Get valid access token for a specific user. 

584 

585 Args: 

586 gateway_id: ID of the gateway 

587 app_user_email: ContextForge user email 

588 

589 Returns: 

590 Valid access token or None if not available 

591 """ 

592 if self.token_storage: 

593 return await self.token_storage.get_user_token(gateway_id, app_user_email) 

594 return None 

595 

596 def _generate_state(self, _gateway_id: str, _app_user_email: str = None) -> str: 

597 """Generate an opaque state token for CSRF protection. 

598 

599 Args: 

600 _gateway_id: Gateway identifier (reserved for compatibility with 

601 prior embedded-state call sites). 

602 _app_user_email: ContextForge user email (reserved for 

603 compatibility with prior embedded-state call sites). 

604 

605 Returns: 

606 Opaque random state token 

607 """ 

608 return secrets.token_urlsafe(48) 

609 

610 @staticmethod 

611 def _extract_legacy_state_payload(state: str) -> Optional[Dict[str, Any]]: 

612 """Best-effort decode of legacy state payloads used before opaque states. 

613 

614 Legacy formats supported: 

615 - base64url(payload || signature) where payload is JSON 

616 - gateway_id_random suffix format 

617 

618 Args: 

619 state: Callback state token to decode. 

620 

621 Returns: 

622 Decoded legacy payload when format is recognized; otherwise ``None``. 

623 """ 

624 try: 

625 state_raw = base64.urlsafe_b64decode(state.encode()) 

626 if len(state_raw) <= 32: 

627 return None 

628 

629 payload_bytes = state_raw[:-32] 

630 payload = orjson.loads(payload_bytes) 

631 if isinstance(payload, dict): 

632 return payload 

633 except Exception: 

634 # Fall back to legacy gateway_id_random format 

635 if "_" in state: 

636 gateway_id = state.split("_", 1)[0] 

637 if gateway_id: 

638 return {"gateway_id": gateway_id} 

639 return None 

640 

641 async def resolve_gateway_id_from_state(self, state: str, allow_legacy_fallback: bool = True) -> Optional[str]: 

642 """Resolve gateway ID for a callback state token without consuming it. 

643 

644 Args: 

645 state: OAuth callback state parameter 

646 allow_legacy_fallback: Whether to decode legacy callback state formats. 

647 

648 Returns: 

649 Gateway ID when resolvable, otherwise ``None``. 

650 """ 

651 settings = get_settings() 

652 

653 if settings.cache_type == "redis": 

654 redis = await _get_redis_client() 

655 if redis: 

656 try: 

657 lookup_key = f"oauth:state_lookup:{state}" 

658 gateway_id = await redis.get(lookup_key) 

659 if gateway_id: 

660 if isinstance(gateway_id, bytes): 

661 gateway_id = gateway_id.decode("utf-8") 

662 return gateway_id 

663 except Exception as e: 

664 logger.warning(f"Failed to resolve state gateway in Redis: {e}") 

665 

666 if settings.cache_type == "database": 

667 try: 

668 # First-Party 

669 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel 

670 

671 db_gen = get_db() 

672 db = next(db_gen) 

673 try: 

674 oauth_state = db.query(OAuthState).filter(OAuthState.state == state).first() 

675 if oauth_state: 

676 return oauth_state.gateway_id 

677 finally: 

678 db_gen.close() 

679 except Exception as e: 

680 logger.warning(f"Failed to resolve state gateway in database: {e}") 

681 

682 async with _state_lock: 

683 now = datetime.now(timezone.utc) 

684 expired_keys = [key for key, data in _oauth_states.items() if datetime.fromisoformat(data["expires_at"]) < now] 

685 for key in expired_keys: 

686 expired_state = _oauth_states[key].get("state") 

687 del _oauth_states[key] 

688 if expired_state: 

689 _oauth_state_lookup.pop(expired_state, None) 

690 gateway_id = _oauth_state_lookup.get(state) 

691 if gateway_id: 

692 return gateway_id 

693 

694 if allow_legacy_fallback: 

695 legacy_payload = self._extract_legacy_state_payload(state) 

696 if legacy_payload: 

697 return legacy_payload.get("gateway_id") 

698 return None 

699 

700 async def _store_authorization_state( 

701 self, 

702 gateway_id: str, 

703 state: str, 

704 code_verifier: str = None, 

705 app_user_email: str = None, 

706 ) -> None: 

707 """Store authorization state for validation with TTL. 

708 

709 Args: 

710 gateway_id: ID of the gateway 

711 state: State parameter to store 

712 code_verifier: Optional PKCE code verifier (RFC 7636) 

713 app_user_email: Requesting user email for token association 

714 """ 

715 expires_at = datetime.now(timezone.utc) + timedelta(seconds=STATE_TTL_SECONDS) 

716 settings = get_settings() 

717 

718 # Try Redis first for distributed storage 

719 if settings.cache_type == "redis": 

720 redis = await _get_redis_client() 

721 if redis: 

722 try: 

723 state_key = f"oauth:state:{gateway_id}:{state}" 

724 lookup_key = f"oauth:state_lookup:{state}" 

725 state_data = { 

726 "state": state, 

727 "gateway_id": gateway_id, 

728 "code_verifier": code_verifier, 

729 "app_user_email": app_user_email, 

730 "expires_at": expires_at.isoformat(), 

731 "used": False, 

732 } 

733 # Store in Redis with TTL 

734 await redis.setex(state_key, STATE_TTL_SECONDS, orjson.dumps(state_data)) 

735 await redis.setex(lookup_key, STATE_TTL_SECONDS, gateway_id) 

736 logger.debug(f"Stored OAuth state in Redis for gateway {gateway_id}") 

737 return 

738 except Exception as e: 

739 logger.warning(f"Failed to store state in Redis: {e}, falling back") 

740 

741 # Try database storage for multi-worker deployments 

742 if settings.cache_type == "database": 

743 try: 

744 # First-Party 

745 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel 

746 

747 db_gen = get_db() 

748 db = next(db_gen) 

749 try: 

750 # Clean up expired states first 

751 db.query(OAuthState).filter(OAuthState.expires_at < datetime.now(timezone.utc)).delete() 

752 

753 # Store new state with code_verifier 

754 oauth_state_kwargs = { 

755 "gateway_id": gateway_id, 

756 "state": state, 

757 "code_verifier": code_verifier, 

758 "expires_at": expires_at, 

759 "used": False, 

760 } 

761 if hasattr(OAuthState, "app_user_email"): 

762 oauth_state_kwargs["app_user_email"] = app_user_email 

763 

764 oauth_state = OAuthState(**oauth_state_kwargs) 

765 db.add(oauth_state) 

766 db.commit() 

767 logger.debug(f"Stored OAuth state in database for gateway {gateway_id}") 

768 return 

769 finally: 

770 db_gen.close() 

771 except Exception as e: 

772 logger.warning(f"Failed to store state in database: {e}, falling back to memory") 

773 

774 # Fallback to in-memory storage for development 

775 async with _state_lock: 

776 # Clean up expired states first 

777 now = datetime.now(timezone.utc) 

778 state_key = f"oauth:state:{gateway_id}:{state}" 

779 state_data = { 

780 "state": state, 

781 "gateway_id": gateway_id, 

782 "code_verifier": code_verifier, 

783 "app_user_email": app_user_email, 

784 "expires_at": expires_at.isoformat(), 

785 "used": False, 

786 } 

787 expired_states = [key for key, data in _oauth_states.items() if datetime.fromisoformat(data["expires_at"]) < now] 

788 for key in expired_states: 

789 expired_state_value = _oauth_states[key].get("state") 

790 del _oauth_states[key] 

791 if expired_state_value: 

792 _oauth_state_lookup.pop(expired_state_value, None) 

793 logger.debug(f"Cleaned up expired state: {key[:20]}...") 

794 

795 # Store the new state with expiration 

796 _oauth_states[state_key] = state_data 

797 _oauth_state_lookup[state] = gateway_id 

798 logger.debug(f"Stored OAuth state in memory for gateway {gateway_id}") 

799 

800 async def _validate_authorization_state(self, gateway_id: str, state: str) -> bool: 

801 """Validate authorization state parameter and mark as used. 

802 

803 Args: 

804 gateway_id: ID of the gateway 

805 state: State parameter to validate 

806 

807 Returns: 

808 True if state is valid and not yet used, False otherwise 

809 """ 

810 settings = get_settings() 

811 

812 # Try Redis first for distributed storage 

813 if settings.cache_type == "redis": 

814 redis = await _get_redis_client() 

815 if redis: 

816 try: 

817 state_key = f"oauth:state:{gateway_id}:{state}" 

818 lookup_key = f"oauth:state_lookup:{state}" 

819 # Get and delete state atomically (single-use) 

820 state_json = await redis.getdel(state_key) 

821 await redis.delete(lookup_key) 

822 if not state_json: 

823 logger.warning(f"State not found in Redis for gateway {gateway_id}") 

824 return False 

825 

826 state_data = orjson.loads(state_json) 

827 

828 # Parse expires_at as timezone-aware datetime. If the stored value 

829 # is naive, assume UTC for compatibility. 

830 try: 

831 expires_at = datetime.fromisoformat(state_data["expires_at"]) 

832 except Exception: 

833 # Fallback: try parsing without microseconds/offsets 

834 expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S") 

835 

836 if expires_at.tzinfo is None: 

837 # Assume UTC for naive timestamps 

838 expires_at = expires_at.replace(tzinfo=timezone.utc) 

839 

840 # Check if state has expired 

841 if expires_at < datetime.now(timezone.utc): 

842 logger.warning(f"State has expired for gateway {gateway_id}") 

843 return False 

844 

845 # Check if state was already used (should not happen with getdel) 

846 if state_data.get("used", False): 

847 logger.warning(f"State was already used for gateway {gateway_id} - possible replay attack") 

848 return False 

849 

850 logger.debug(f"Successfully validated OAuth state from Redis for gateway {gateway_id}") 

851 return True 

852 except Exception as e: 

853 logger.warning(f"Failed to validate state in Redis: {e}, falling back") 

854 

855 # Try database storage for multi-worker deployments 

856 if settings.cache_type == "database": 

857 try: 

858 # First-Party 

859 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel 

860 

861 db_gen = get_db() 

862 db = next(db_gen) 

863 try: 

864 # Find the state 

865 oauth_state = db.query(OAuthState).filter(OAuthState.gateway_id == gateway_id, OAuthState.state == state).first() 

866 

867 if not oauth_state: 

868 logger.warning(f"State not found in database for gateway {gateway_id}") 

869 return False 

870 

871 # Check if state has expired 

872 # Ensure oauth_state.expires_at is timezone-aware. If naive, assume UTC. 

873 expires_at = oauth_state.expires_at 

874 if expires_at.tzinfo is None: 

875 expires_at = expires_at.replace(tzinfo=timezone.utc) 

876 

877 if expires_at < datetime.now(timezone.utc): 

878 logger.warning(f"State has expired for gateway {gateway_id}") 

879 db.delete(oauth_state) 

880 db.commit() 

881 return False 

882 

883 # Check if state was already used 

884 if oauth_state.used: 

885 logger.warning(f"State has already been used for gateway {gateway_id} - possible replay attack") 

886 return False 

887 

888 # Mark as used and delete (single-use) 

889 db.delete(oauth_state) 

890 db.commit() 

891 logger.debug(f"Successfully validated OAuth state from database for gateway {gateway_id}") 

892 return True 

893 finally: 

894 db_gen.close() 

895 except Exception as e: 

896 logger.warning(f"Failed to validate state in database: {e}, falling back to memory") 

897 

898 # Fallback to in-memory storage for development 

899 state_key = f"oauth:state:{gateway_id}:{state}" 

900 async with _state_lock: 

901 state_data = _oauth_states.get(state_key) 

902 

903 # Check if state exists 

904 if not state_data: 

905 logger.warning(f"State not found in memory for gateway {gateway_id}") 

906 return False 

907 

908 # Parse and normalize expires_at to timezone-aware datetime 

909 expires_at = datetime.fromisoformat(state_data["expires_at"]) 

910 if expires_at.tzinfo is None: 

911 expires_at = expires_at.replace(tzinfo=timezone.utc) 

912 

913 if expires_at < datetime.now(timezone.utc): 

914 logger.warning(f"State has expired for gateway {gateway_id}") 

915 del _oauth_states[state_key] # Clean up expired state 

916 _oauth_state_lookup.pop(state, None) 

917 return False 

918 

919 # Check if state has already been used (prevent replay) 

920 if state_data.get("used", False): 

921 logger.warning(f"State has already been used for gateway {gateway_id} - possible replay attack") 

922 return False 

923 

924 # Mark state as used and remove it (single-use) 

925 del _oauth_states[state_key] 

926 _oauth_state_lookup.pop(state, None) 

927 logger.debug(f"Successfully validated OAuth state from memory for gateway {gateway_id}") 

928 return True 

929 

930 async def _validate_and_retrieve_state(self, gateway_id: str, state: str) -> Optional[Dict[str, Any]]: 

931 """Validate state and return full state data including code_verifier. 

932 

933 Args: 

934 gateway_id: ID of the gateway 

935 state: State parameter to validate 

936 

937 Returns: 

938 Dict with state data including code_verifier, or None if invalid/expired 

939 """ 

940 settings = get_settings() 

941 

942 # Try Redis first 

943 if settings.cache_type == "redis": 

944 redis = await _get_redis_client() 

945 if redis: 

946 try: 

947 state_key = f"oauth:state:{gateway_id}:{state}" 

948 lookup_key = f"oauth:state_lookup:{state}" 

949 state_json = await redis.getdel(state_key) # Atomic get+delete 

950 await redis.delete(lookup_key) 

951 if not state_json: 

952 return None 

953 

954 state_data = orjson.loads(state_json) 

955 

956 # Check expiration 

957 try: 

958 expires_at = datetime.fromisoformat(state_data["expires_at"]) 

959 except Exception: 

960 expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S") 

961 

962 if expires_at.tzinfo is None: 

963 expires_at = expires_at.replace(tzinfo=timezone.utc) 

964 

965 if expires_at < datetime.now(timezone.utc): 

966 return None 

967 

968 return state_data 

969 except Exception as e: 

970 logger.warning(f"Failed to validate state in Redis: {e}, falling back") 

971 

972 # Try database 

973 if settings.cache_type == "database": 

974 try: 

975 # First-Party 

976 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel 

977 

978 db_gen = get_db() 

979 db = next(db_gen) 

980 try: 

981 oauth_state = db.query(OAuthState).filter(OAuthState.gateway_id == gateway_id, OAuthState.state == state).first() 

982 

983 if not oauth_state: 

984 return None 

985 

986 # Check expiration 

987 expires_at = oauth_state.expires_at 

988 if expires_at.tzinfo is None: 

989 expires_at = expires_at.replace(tzinfo=timezone.utc) 

990 

991 if expires_at < datetime.now(timezone.utc): 

992 db.delete(oauth_state) 

993 db.commit() 

994 return None 

995 

996 # Check if already used 

997 if oauth_state.used: 

998 return None 

999 

1000 # Build state data 

1001 state_data = { 

1002 "state": oauth_state.state, 

1003 "gateway_id": oauth_state.gateway_id, 

1004 "code_verifier": oauth_state.code_verifier, 

1005 "expires_at": oauth_state.expires_at.isoformat(), 

1006 } 

1007 if hasattr(oauth_state, "app_user_email"): 

1008 state_data["app_user_email"] = getattr(oauth_state, "app_user_email", None) 

1009 

1010 # Mark as used and delete 

1011 db.delete(oauth_state) 

1012 db.commit() 

1013 

1014 return state_data 

1015 finally: 

1016 db_gen.close() 

1017 except Exception as e: 

1018 logger.warning(f"Failed to validate state in database: {e}") 

1019 

1020 # Fallback to in-memory 

1021 state_key = f"oauth:state:{gateway_id}:{state}" 

1022 async with _state_lock: 

1023 state_data = _oauth_states.get(state_key) 

1024 if not state_data: 

1025 return None 

1026 

1027 # Check expiration 

1028 expires_at = datetime.fromisoformat(state_data["expires_at"]) 

1029 if expires_at.tzinfo is None: 

1030 expires_at = expires_at.replace(tzinfo=timezone.utc) 

1031 

1032 if expires_at < datetime.now(timezone.utc): 

1033 del _oauth_states[state_key] 

1034 _oauth_state_lookup.pop(state, None) 

1035 return None 

1036 

1037 # Remove from memory (single-use) 

1038 del _oauth_states[state_key] 

1039 _oauth_state_lookup.pop(state, None) 

1040 return state_data 

1041 

1042 def _create_authorization_url(self, credentials: Dict[str, Any], state: str) -> tuple[str, str]: 

1043 """Create authorization URL with state parameter. 

1044 

1045 Args: 

1046 credentials: OAuth configuration 

1047 state: State parameter for CSRF protection 

1048 

1049 Returns: 

1050 Tuple of (authorization_url, state) 

1051 """ 

1052 client_id = credentials["client_id"] 

1053 redirect_uri = credentials["redirect_uri"] 

1054 authorization_url = credentials["authorization_url"] 

1055 scopes = credentials.get("scopes", []) 

1056 

1057 # Create OAuth2 session 

1058 oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes) 

1059 

1060 # Generate authorization URL with state for CSRF protection 

1061 auth_url, state = oauth.authorization_url(authorization_url, state=state) 

1062 

1063 return auth_url, state 

1064 

1065 @staticmethod 

1066 def _is_microsoft_entra_v2_endpoint(endpoint_url: Any) -> bool: 

1067 """Return True when endpoint matches Microsoft Entra v2 login endpoints. 

1068 

1069 Args: 

1070 endpoint_url: OAuth endpoint URL to check 

1071 

1072 Returns: 

1073 True if the endpoint is a Microsoft Entra v2 OAuth endpoint 

1074 """ 

1075 if not isinstance(endpoint_url, str) or not endpoint_url: 

1076 return False 

1077 

1078 parsed = urlparse(endpoint_url) 

1079 host = (parsed.hostname or "").lower() 

1080 path = parsed.path.lower() 

1081 

1082 return host in OAuthManager._ENTRA_HOSTS and "/oauth2/v2.0/" in path 

1083 

1084 @staticmethod 

1085 def _is_enabled_flag(value: Any) -> bool: 

1086 """Parse boolean-like config values from oauth_config. 

1087 

1088 Args: 

1089 value: Config value to interpret as boolean 

1090 

1091 Returns: 

1092 True if value represents an enabled/truthy setting 

1093 """ 

1094 if isinstance(value, bool): 

1095 return value 

1096 if isinstance(value, str): 

1097 return value.strip().lower() in {"1", "true", "yes", "on"} 

1098 return False 

1099 

1100 def _should_include_resource_parameter(self, credentials: Dict[str, Any], scopes: Any) -> bool: 

1101 """Determine whether RFC 8707 resource should be sent for this request. 

1102 

1103 Args: 

1104 credentials: OAuth configuration containing resource and endpoint URLs 

1105 scopes: OAuth scopes for the request 

1106 

1107 Returns: 

1108 True if the resource parameter should be included in the request 

1109 """ 

1110 if not credentials.get("resource"): 

1111 return False 

1112 

1113 if self._is_enabled_flag(credentials.get("omit_resource")): 

1114 return False 

1115 

1116 # Microsoft Entra v2 does not accept legacy resource with v2 scope-based requests. 

1117 if scopes and (self._is_microsoft_entra_v2_endpoint(credentials.get("authorization_url")) or self._is_microsoft_entra_v2_endpoint(credentials.get("token_url"))): 

1118 logger.info("Omitting OAuth resource parameter for Microsoft Entra v2 scope-based flow") 

1119 return False 

1120 

1121 return True 

1122 

1123 def _create_authorization_url_with_pkce(self, credentials: Dict[str, Any], state: str, code_challenge: str, code_challenge_method: str) -> str: 

1124 """Create authorization URL with PKCE parameters (RFC 7636). 

1125 

1126 Args: 

1127 credentials: OAuth configuration 

1128 state: State parameter for CSRF protection 

1129 code_challenge: PKCE code challenge 

1130 code_challenge_method: PKCE method (S256) 

1131 

1132 Returns: 

1133 Authorization URL string with PKCE parameters 

1134 """ 

1135 # Standard 

1136 from urllib.parse import urlencode # pylint: disable=import-outside-toplevel 

1137 

1138 client_id = credentials["client_id"] 

1139 redirect_uri = credentials["redirect_uri"] 

1140 authorization_url = credentials["authorization_url"] 

1141 scopes = credentials.get("scopes", []) 

1142 

1143 # Build authorization parameters 

1144 params = {"response_type": "code", "client_id": client_id, "redirect_uri": redirect_uri, "state": state, "code_challenge": code_challenge, "code_challenge_method": code_challenge_method} 

1145 

1146 # Add scopes if present 

1147 if scopes: 

1148 params["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes 

1149 

1150 # Add resource parameter for JWT access token (RFC 8707) 

1151 # The resource is the MCP server URL, set by oauth_router.py 

1152 resource = credentials.get("resource") 

1153 if self._should_include_resource_parameter(credentials, scopes): 

1154 params["resource"] = resource # urlencode with doseq=True handles lists 

1155 

1156 # Build full URL (doseq=True handles list values like multiple resource params) 

1157 query_string = urlencode(params, doseq=True) 

1158 return f"{authorization_url}?{query_string}" 

1159 

1160 async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str, code_verifier: str = None) -> Dict[str, Any]: 

1161 """Exchange authorization code for tokens with PKCE support. 

1162 

1163 Args: 

1164 credentials: OAuth configuration 

1165 code: Authorization code from callback 

1166 code_verifier: Optional PKCE code verifier (RFC 7636) 

1167 

1168 Returns: 

1169 Token response dictionary 

1170 

1171 Raises: 

1172 OAuthError: If token exchange fails 

1173 """ 

1174 runtime_credentials = await self._prepare_runtime_credentials(credentials, "authorization_code_exchange_with_pkce") 

1175 client_id = runtime_credentials["client_id"] 

1176 client_secret = runtime_credentials.get("client_secret") # Optional for public clients (PKCE-only) 

1177 token_url = runtime_credentials["token_url"] 

1178 redirect_uri = runtime_credentials["redirect_uri"] 

1179 

1180 # Prepare token exchange data 

1181 token_data = { 

1182 "grant_type": "authorization_code", 

1183 "code": code, 

1184 "redirect_uri": redirect_uri, 

1185 "client_id": client_id, 

1186 } 

1187 

1188 # Only include client_secret if present (public clients don't have secrets) 

1189 if client_secret: 

1190 token_data["client_secret"] = client_secret 

1191 

1192 # Add PKCE code_verifier if present (RFC 7636) 

1193 if code_verifier: 

1194 token_data["code_verifier"] = code_verifier 

1195 

1196 # Add resource parameter to request JWT access token (RFC 8707) 

1197 # The resource identifies the MCP server (resource server), not the OAuth server 

1198 resource = runtime_credentials.get("resource") 

1199 scopes = runtime_credentials.get("scopes", []) 

1200 if self._should_include_resource_parameter(credentials, scopes): 

1201 if isinstance(resource, list): 

1202 # RFC 8707 allows multiple resource parameters - use list of tuples 

1203 form_data: list[tuple[str, str]] = list(token_data.items()) 

1204 for r in resource: 

1205 if r: 

1206 form_data.append(("resource", r)) 

1207 token_data = form_data # type: ignore[assignment] 

1208 else: 

1209 token_data["resource"] = resource 

1210 

1211 # Exchange code for token with retries 

1212 for attempt in range(self.max_retries): 

1213 try: 

1214 client = await self._get_client() 

1215 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

1216 response.raise_for_status() 

1217 

1218 # GitHub returns form-encoded responses, not JSON 

1219 content_type = response.headers.get("content-type", "") 

1220 if "application/x-www-form-urlencoded" in content_type: 

1221 # Parse form-encoded response 

1222 text_response = response.text 

1223 token_response = {} 

1224 for pair in text_response.split("&"): 

1225 if "=" in pair: 

1226 key, value = pair.split("=", 1) 

1227 token_response[key] = value 

1228 else: 

1229 # Try JSON response 

1230 try: 

1231 token_response = response.json() 

1232 except Exception as e: 

1233 logger.warning(f"Failed to parse JSON response: {e}") 

1234 # Fallback to text parsing 

1235 text_response = response.text 

1236 token_response = {"raw_response": text_response} 

1237 

1238 if "access_token" not in token_response: 

1239 raise OAuthError(f"No access_token in response: {token_response}") 

1240 

1241 logger.info("""Successfully exchanged authorization code for tokens""") 

1242 return token_response 

1243 

1244 except httpx.HTTPError as e: 

1245 logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}") 

1246 if attempt == self.max_retries - 1: 

1247 raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}") 

1248 await asyncio.sleep(2**attempt) # Exponential backoff 

1249 

1250 # This should never be reached due to the exception above, but needed for type safety 

1251 raise OAuthError("Failed to exchange code for token after all retry attempts") 

1252 

1253 async def refresh_token(self, refresh_token: str, credentials: Dict[str, Any]) -> Dict[str, Any]: 

1254 """Refresh an expired access token using a refresh token. 

1255 

1256 Args: 

1257 refresh_token: The refresh token to use 

1258 credentials: OAuth configuration including client_id, client_secret, token_url 

1259 

1260 Returns: 

1261 Dict containing new access_token, optional refresh_token, and expires_in 

1262 

1263 Raises: 

1264 OAuthError: If token refresh fails 

1265 """ 

1266 if not refresh_token: 

1267 raise OAuthError("No refresh token available") 

1268 

1269 runtime_credentials = await self._prepare_runtime_credentials(credentials, "refresh_token") 

1270 token_url = runtime_credentials.get("token_url") 

1271 if not token_url: 

1272 raise OAuthError("No token URL configured for OAuth provider") 

1273 

1274 client_id = runtime_credentials.get("client_id") 

1275 client_secret = runtime_credentials.get("client_secret") 

1276 

1277 if not client_id: 

1278 raise OAuthError("No client_id configured for OAuth provider") 

1279 

1280 # Prepare token refresh request 

1281 token_data = { 

1282 "grant_type": "refresh_token", 

1283 "refresh_token": refresh_token, 

1284 "client_id": client_id, 

1285 } 

1286 

1287 # Add client_secret if available (some providers require it) 

1288 if client_secret: 

1289 token_data["client_secret"] = client_secret 

1290 

1291 # Add resource parameter for JWT access token (RFC 8707) 

1292 # Must be included in refresh requests to maintain JWT token type 

1293 resource = runtime_credentials.get("resource") 

1294 scopes = runtime_credentials.get("scopes", []) 

1295 if self._should_include_resource_parameter(credentials, scopes): 

1296 if isinstance(resource, list): 

1297 # RFC 8707 allows multiple resource parameters - use list of tuples 

1298 form_data: list[tuple[str, str]] = list(token_data.items()) 

1299 for r in resource: 

1300 if r: 

1301 form_data.append(("resource", r)) 

1302 token_data = form_data # type: ignore[assignment] 

1303 else: 

1304 token_data["resource"] = resource 

1305 

1306 # Attempt token refresh with retries 

1307 for attempt in range(self.max_retries): 

1308 try: 

1309 client = await self._get_client() 

1310 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

1311 if response.status_code == 200: 

1312 token_response = response.json() 

1313 

1314 # Validate required fields 

1315 if "access_token" not in token_response: 

1316 raise OAuthError("No access_token in refresh response") 

1317 

1318 logger.info("Successfully refreshed OAuth token") 

1319 return token_response 

1320 

1321 error_text = response.text 

1322 # If we get a 400/401, the refresh token is likely invalid 

1323 if response.status_code in [400, 401]: 

1324 raise OAuthError(f"Refresh token invalid or expired: {error_text}") 

1325 logger.warning(f"Token refresh failed with status {response.status_code}: {error_text}") 

1326 

1327 except httpx.HTTPError as e: 

1328 logger.warning(f"Token refresh attempt {attempt + 1} failed: {str(e)}") 

1329 if attempt == self.max_retries - 1: 

1330 raise OAuthError(f"Failed to refresh token after {self.max_retries} attempts: {str(e)}") 

1331 await asyncio.sleep(2**attempt) # Exponential backoff 

1332 

1333 raise OAuthError("Failed to refresh token after all retry attempts") 

1334 

1335 def _extract_user_id(self, token_response: Dict[str, Any], credentials: Dict[str, Any]) -> str: 

1336 """Extract user ID from token response. 

1337 

1338 Args: 

1339 token_response: Response from token exchange 

1340 credentials: OAuth configuration 

1341 

1342 Returns: 

1343 User ID string 

1344 """ 

1345 # Try to extract user ID from various common fields in token response 

1346 # Different OAuth providers use different field names 

1347 

1348 # Check for 'sub' (subject) - JWT standard 

1349 if "sub" in token_response: 

1350 return token_response["sub"] 

1351 

1352 # Check for 'user_id' - common in some OAuth responses 

1353 if "user_id" in token_response: 

1354 return token_response["user_id"] 

1355 

1356 # Check for 'id' - also common 

1357 if "id" in token_response: 

1358 return token_response["id"] 

1359 

1360 # Fallback to client_id if no user info is available 

1361 if credentials.get("client_id"): 

1362 return credentials["client_id"] 

1363 

1364 # Final fallback 

1365 return "unknown_user" 

1366 

1367 

1368class OAuthError(Exception): 

1369 """OAuth-related errors. 

1370 

1371 Examples: 

1372 >>> try: 

1373 ... raise OAuthError("Token acquisition failed") 

1374 ... except OAuthError as e: 

1375 ... str(e) 

1376 'Token acquisition failed' 

1377 >>> try: 

1378 ... raise OAuthError("Invalid grant type") 

1379 ... except Exception as e: 

1380 ... isinstance(e, OAuthError) 

1381 True 

1382 """ 

1383 

1384 

1385class OAuthRequiredError(OAuthError): 

1386 """Raised when a server requires OAuth but the caller is unauthenticated. 

1387 

1388 Carries ``server_id`` so the middleware can identify which server 

1389 triggered the rejection when constructing the ``WWW-Authenticate`` 

1390 header. 

1391 

1392 Examples: 

1393 >>> err = OAuthRequiredError("auth required", server_id="s1") 

1394 >>> err.server_id 

1395 's1' 

1396 >>> isinstance(err, OAuthError) 

1397 True 

1398 """ 

1399 

1400 def __init__(self, message: str, *, server_id: str = "") -> None: 

1401 """Initialize with message and optional server_id. 

1402 

1403 Args: 

1404 message: Human-readable error description. 

1405 server_id: Virtual-server identifier that triggered the rejection. 

1406 """ 

1407 super().__init__(message) 

1408 self.server_id = server_id 

1409 

1410 

1411class OAuthEnforcementUnavailableError(OAuthError): 

1412 """Raised when OAuth enforcement cannot be performed due to infrastructure failure. 

1413 

1414 Used when the database or other backing services needed to check a 

1415 server's ``oauth_enabled`` flag are unavailable. The middleware 

1416 translates this into an HTTP 503 to avoid silently allowing 

1417 unauthenticated access (fail-closed). 

1418 

1419 Examples: 

1420 >>> err = OAuthEnforcementUnavailableError("DB down", server_id="s1") 

1421 >>> err.server_id 

1422 's1' 

1423 >>> isinstance(err, OAuthError) 

1424 True 

1425 """ 

1426 

1427 def __init__(self, message: str, *, server_id: str = "") -> None: 

1428 """Initialize with message and optional server_id. 

1429 

1430 Args: 

1431 message: Human-readable error description. 

1432 server_id: Virtual-server identifier that triggered the rejection. 

1433 """ 

1434 super().__init__(message) 

1435 self.server_id = server_id