Coverage for mcpgateway / services / oauth_manager.py: 99%

651 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/oauth_manager.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7OAuth 2.0 Manager for ContextForge. 

8 

9This module handles OAuth 2.0 authentication flows including: 

10- Client Credentials (Machine-to-Machine) 

11- Authorization Code (User Delegation) 

12""" 

13 

14# Standard 

15import asyncio 

16import base64 

17from datetime import datetime, timedelta, timezone 

18import hashlib 

19import logging 

20import secrets 

21from typing import Any, Dict, Optional 

22from urllib.parse import urlparse 

23 

24# Third-Party 

25import httpx 

26import orjson 

27from requests_oauthlib import OAuth2Session 

28 

29# First-Party 

30from mcpgateway.common.validators import SecurityValidator 

31from mcpgateway.config import get_settings 

32from mcpgateway.services.encryption_service import decrypt_oauth_config_for_runtime, get_encryption_service 

33from mcpgateway.services.http_client_service import get_http_client 

34from mcpgateway.utils.redis_client import get_redis_client as _get_shared_redis_client 

35 

36logger = logging.getLogger(__name__) 

37 

38# In-memory storage for OAuth states with expiration (fallback for single-process) 

39# Format: {state_key: {"state": state, "gateway_id": gateway_id, "expires_at": datetime}} 

40_oauth_states: Dict[str, Dict[str, Any]] = {} 

41# Reverse lookup for callback handlers that only receive state. 

42# Format: {state: gateway_id} 

43_oauth_state_lookup: Dict[str, str] = {} 

44# Lock for thread-safe state operations 

45_state_lock = asyncio.Lock() 

46 

47# State TTL in seconds (5 minutes) 

48STATE_TTL_SECONDS = 300 

49 

50# Redis client for distributed state storage (uses shared factory) 

51_redis_client: Optional[Any] = None 

52_REDIS_INITIALIZED = False 

53 

54 

55async def _get_redis_client(): 

56 """Get shared Redis client for distributed state storage. 

57 

58 Uses the centralized Redis client factory for consistent configuration. 

59 

60 Returns: 

61 Redis client instance or None if unavailable 

62 """ 

63 global _redis_client, _REDIS_INITIALIZED # pylint: disable=global-statement 

64 

65 if _REDIS_INITIALIZED: 

66 return _redis_client 

67 

68 settings = get_settings() 

69 if settings.cache_type == "redis" and settings.redis_url: 

70 try: 

71 _redis_client = await _get_shared_redis_client() 

72 if _redis_client: 

73 logger.info("OAuth manager connected to shared Redis client") 

74 except Exception as e: 

75 logger.warning(f"Failed to get Redis client, falling back to in-memory storage: {e}") 

76 _redis_client = None 

77 else: 

78 _redis_client = None 

79 

80 _REDIS_INITIALIZED = True 

81 return _redis_client 

82 

83 

84class OAuthManager: 

85 """Manages OAuth 2.0 authentication flows. 

86 

87 Examples: 

88 >>> manager = OAuthManager(request_timeout=30, max_retries=3) 

89 >>> manager.request_timeout 

90 30 

91 >>> manager.max_retries 

92 3 

93 >>> manager.token_storage is None 

94 True 

95 >>> 

96 >>> # Test grant type validation 

97 >>> grant_type = "client_credentials" 

98 >>> grant_type in ["client_credentials", "authorization_code"] 

99 True 

100 >>> grant_type = "invalid_grant" 

101 >>> grant_type in ["client_credentials", "authorization_code"] 

102 False 

103 >>> 

104 >>> # Test encrypted secret detection heuristic 

105 >>> short_secret = "secret123" 

106 >>> len(short_secret) > 50 

107 False 

108 >>> encrypted_secret = "gAAAAABh" + "x" * 60 # Simulated encrypted secret 

109 >>> len(encrypted_secret) > 50 

110 True 

111 >>> 

112 >>> # Test scope list handling 

113 >>> scopes = ["read", "write"] 

114 >>> " ".join(scopes) 

115 'read write' 

116 >>> empty_scopes = [] 

117 >>> " ".join(empty_scopes) 

118 '' 

119 """ 

120 

121 # Known Microsoft Entra login hosts (global + sovereign clouds). 

122 _ENTRA_HOSTS: frozenset[str] = frozenset( 

123 { 

124 "login.microsoftonline.com", 

125 "login.microsoftonline.us", 

126 "login.microsoftonline.de", 

127 "login.partner.microsoftonline.cn", 

128 } 

129 ) 

130 

131 def __init__(self, request_timeout: int = 30, max_retries: int = 3, token_storage: Optional[Any] = None): 

132 """Initialize OAuth Manager. 

133 

134 Args: 

135 request_timeout: Timeout for OAuth requests in seconds 

136 max_retries: Maximum number of retry attempts for token requests 

137 token_storage: Optional TokenStorageService for storing tokens 

138 """ 

139 self.request_timeout = request_timeout 

140 self.max_retries = max_retries 

141 self.token_storage = token_storage 

142 self.settings = get_settings() 

143 

144 async def _get_client(self) -> httpx.AsyncClient: 

145 """Get the shared singleton HTTP client. 

146 

147 Returns: 

148 Shared httpx.AsyncClient instance with connection pooling 

149 """ 

150 return await get_http_client() 

151 

152 def _generate_pkce_params(self) -> Dict[str, str]: 

153 """Generate PKCE parameters for OAuth Authorization Code flow (RFC 7636). 

154 

155 Returns: 

156 Dict containing code_verifier, code_challenge, and code_challenge_method 

157 """ 

158 # Generate code_verifier: 43-128 character random string 

159 code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=") 

160 

161 # Generate code_challenge: base64url(SHA256(code_verifier)) 

162 code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()).decode("utf-8").rstrip("=") 

163 

164 return {"code_verifier": code_verifier, "code_challenge": code_challenge, "code_challenge_method": "S256"} 

165 

166 async def get_access_token(self, credentials: Dict[str, Any]) -> str: 

167 """Get access token based on grant type. 

168 

169 Args: 

170 credentials: OAuth configuration containing grant_type and other params 

171 

172 Returns: 

173 Access token string 

174 

175 Raises: 

176 ValueError: If grant type is unsupported 

177 OAuthError: If token acquisition fails 

178 

179 Examples: 

180 Client credentials flow: 

181 >>> import asyncio 

182 >>> class TestMgr(OAuthManager): 

183 ... async def _client_credentials_flow(self, credentials): 

184 ... return 'tok' 

185 >>> mgr = TestMgr() 

186 >>> asyncio.run(mgr.get_access_token({'grant_type': 'client_credentials'})) 

187 'tok' 

188 

189 Authorization code flow requires interactive completion: 

190 >>> def _auth_code_requires_consent(): 

191 ... try: 

192 ... asyncio.run(mgr.get_access_token({'grant_type': 'authorization_code'})) 

193 ... except OAuthError: 

194 ... return True 

195 >>> _auth_code_requires_consent() 

196 True 

197 

198 Unsupported grant type raises ValueError: 

199 >>> def _unsupported(): 

200 ... try: 

201 ... asyncio.run(mgr.get_access_token({'grant_type': 'bad'})) 

202 ... except ValueError: 

203 ... return True 

204 >>> _unsupported() 

205 True 

206 """ 

207 grant_type = credentials.get("grant_type") 

208 logger.debug(f"Getting access token for grant type: {grant_type}") 

209 

210 if grant_type == "client_credentials": 

211 return await self._client_credentials_flow(credentials) 

212 if grant_type == "password": 

213 return await self._password_flow(credentials) 

214 if grant_type == "authorization_code": 

215 raise OAuthError("Authorization code flow requires user consent via /oauth/authorize and does not support client_credentials fallback") 

216 raise ValueError(f"Unsupported grant type: {grant_type}") 

217 

218 @staticmethod 

219 async def _prepare_runtime_credentials(credentials: Dict[str, Any], flow_name: str) -> Dict[str, Any]: 

220 """Return runtime-ready oauth credentials with sensitive fields decrypted. 

221 

222 Args: 

223 credentials: Stored oauth_config payload. 

224 flow_name: Flow label for diagnostic logging. 

225 

226 Returns: 

227 Dict[str, Any]: Runtime-ready credentials. 

228 """ 

229 try: 

230 settings = get_settings() 

231 encryption = get_encryption_service(settings.auth_encryption_secret) 

232 runtime_credentials = await decrypt_oauth_config_for_runtime(credentials, encryption=encryption) 

233 if isinstance(runtime_credentials, dict): 

234 return runtime_credentials 

235 except Exception as exc: 

236 logger.warning("Failed to prepare runtime OAuth credentials for %s flow: %s", flow_name, exc) 

237 return credentials 

238 

239 async def _client_credentials_flow(self, credentials: Dict[str, Any]) -> str: 

240 """Machine-to-machine authentication using client credentials. 

241 

242 Args: 

243 credentials: OAuth configuration with client_id, client_secret, token_url 

244 

245 Returns: 

246 Access token string 

247 

248 Raises: 

249 OAuthError: If token acquisition fails after all retries 

250 """ 

251 runtime_credentials = await self._prepare_runtime_credentials(credentials, "client_credentials") 

252 client_id = runtime_credentials["client_id"] 

253 client_secret = runtime_credentials["client_secret"] 

254 token_url = runtime_credentials["token_url"] 

255 scopes = runtime_credentials.get("scopes", []) 

256 

257 # Prepare token request data 

258 token_data = { 

259 "grant_type": "client_credentials", 

260 "client_id": client_id, 

261 "client_secret": client_secret, 

262 } 

263 

264 if scopes: 

265 token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes 

266 

267 # Fetch token with retries 

268 for attempt in range(self.max_retries): 

269 try: 

270 client = await self._get_client() 

271 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

272 response.raise_for_status() 

273 

274 # GitHub returns form-encoded responses, not JSON 

275 content_type = response.headers.get("content-type", "") 

276 if "application/x-www-form-urlencoded" in content_type: 

277 # Parse form-encoded response 

278 text_response = response.text 

279 token_response = {} 

280 for pair in text_response.split("&"): 

281 if "=" in pair: 

282 key, value = pair.split("=", 1) 

283 token_response[key] = value 

284 else: 

285 # Try JSON response 

286 try: 

287 token_response = response.json() 

288 except Exception as e: 

289 logger.warning(f"Failed to parse JSON response: {e}") 

290 # Fallback to text parsing 

291 text_response = response.text 

292 token_response = {"raw_response": text_response} 

293 

294 if "access_token" not in token_response: 

295 raise OAuthError(f"No access_token in response: {token_response}") 

296 

297 logger.info("""Successfully obtained access token via client credentials""") 

298 return token_response["access_token"] 

299 

300 except httpx.HTTPError as e: 

301 logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}") 

302 if attempt == self.max_retries - 1: 

303 raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}") 

304 await asyncio.sleep(2**attempt) # Exponential backoff 

305 

306 # This should never be reached due to the exception above, but needed for type safety 

307 raise OAuthError("Failed to obtain access token after all retry attempts") 

308 

309 async def _password_flow(self, credentials: Dict[str, Any]) -> str: 

310 """Resource Owner Password Credentials flow (RFC 6749 Section 4.3). 

311 

312 This flow is used when the application can directly handle the user's credentials, 

313 such as with trusted first-party applications or legacy integrations like Keycloak. 

314 

315 Args: 

316 credentials: OAuth configuration with client_id, optional client_secret, token_url, username, password 

317 

318 Returns: 

319 Access token string 

320 

321 Raises: 

322 OAuthError: If token acquisition fails after all retries 

323 """ 

324 runtime_credentials = await self._prepare_runtime_credentials(credentials, "password") 

325 client_id = runtime_credentials.get("client_id") 

326 client_secret = runtime_credentials.get("client_secret") 

327 token_url = runtime_credentials["token_url"] 

328 username = runtime_credentials.get("username") 

329 password = runtime_credentials.get("password") 

330 scopes = runtime_credentials.get("scopes", []) 

331 

332 if not username or not password: 

333 raise OAuthError("Username and password are required for password grant type") 

334 

335 # Prepare token request data 

336 token_data = { 

337 "grant_type": "password", 

338 "username": username, 

339 "password": password, 

340 } 

341 

342 # Add client_id (required by most providers including Keycloak) 

343 if client_id: 

344 token_data["client_id"] = client_id 

345 

346 # Add client_secret if present (some providers require it, others don't) 

347 if client_secret: 

348 token_data["client_secret"] = client_secret 

349 

350 if scopes: 

351 token_data["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes 

352 

353 # Fetch token with retries 

354 for attempt in range(self.max_retries): 

355 try: 

356 client = await self._get_client() 

357 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

358 response.raise_for_status() 

359 

360 # Handle both JSON and form-encoded responses 

361 content_type = response.headers.get("content-type", "") 

362 if "application/x-www-form-urlencoded" in content_type: 

363 # Parse form-encoded response 

364 text_response = response.text 

365 token_response = {} 

366 for pair in text_response.split("&"): 

367 if "=" in pair: 

368 key, value = pair.split("=", 1) 

369 token_response[key] = value 

370 else: 

371 # Try JSON response 

372 try: 

373 token_response = response.json() 

374 except Exception as e: 

375 logger.warning(f"Failed to parse JSON response: {e}") 

376 # Fallback to text parsing 

377 text_response = response.text 

378 token_response = {"raw_response": text_response} 

379 

380 if "access_token" not in token_response: 

381 raise OAuthError(f"No access_token in response: {token_response}") 

382 

383 logger.info("Successfully obtained access token via password grant") 

384 return token_response["access_token"] 

385 

386 except httpx.HTTPError as e: 

387 logger.warning(f"Token request attempt {attempt + 1} failed: {str(e)}") 

388 if attempt == self.max_retries - 1: 

389 raise OAuthError(f"Failed to obtain access token after {self.max_retries} attempts: {str(e)}") 

390 await asyncio.sleep(2**attempt) # Exponential backoff 

391 

392 # This should never be reached due to the exception above, but needed for type safety 

393 raise OAuthError("Failed to obtain access token after all retry attempts") 

394 

395 async def get_authorization_url(self, credentials: Dict[str, Any]) -> Dict[str, str]: 

396 """Get authorization URL for user delegation flow. 

397 

398 Args: 

399 credentials: OAuth configuration with client_id, authorization_url, etc. 

400 

401 Returns: 

402 Dict containing authorization_url and state 

403 """ 

404 client_id = credentials["client_id"] 

405 redirect_uri = credentials["redirect_uri"] 

406 authorization_url = credentials["authorization_url"] 

407 scopes = credentials.get("scopes", []) 

408 

409 # Create OAuth2 session 

410 oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes) 

411 

412 # Generate authorization URL with state for CSRF protection 

413 auth_url, state = oauth.authorization_url(authorization_url) 

414 

415 logger.info(f"Generated authorization URL for client {client_id}") 

416 

417 return {"authorization_url": auth_url, "state": state} 

418 

419 async def exchange_code_for_token(self, credentials: Dict[str, Any], code: str, state: str) -> str: # pylint: disable=unused-argument 

420 """Exchange authorization code for access token. 

421 

422 Args: 

423 credentials: OAuth configuration 

424 code: Authorization code from callback 

425 state: State parameter for CSRF validation 

426 

427 Returns: 

428 Access token string 

429 

430 Raises: 

431 OAuthError: If token exchange fails 

432 """ 

433 runtime_credentials = await self._prepare_runtime_credentials(credentials, "authorization_code_exchange") 

434 client_id = runtime_credentials["client_id"] 

435 client_secret = runtime_credentials.get("client_secret") # Optional for public clients (PKCE-only) 

436 token_url = runtime_credentials["token_url"] 

437 redirect_uri = runtime_credentials["redirect_uri"] 

438 

439 # Prepare token exchange data 

440 token_data = { 

441 "grant_type": "authorization_code", 

442 "code": code, 

443 "redirect_uri": redirect_uri, 

444 "client_id": client_id, 

445 } 

446 

447 # Only include client_secret if present (public clients don't have secrets) 

448 if client_secret: 

449 token_data["client_secret"] = client_secret 

450 

451 # Exchange code for token with retries 

452 for attempt in range(self.max_retries): 

453 try: 

454 client = await self._get_client() 

455 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

456 response.raise_for_status() 

457 

458 # GitHub returns form-encoded responses, not JSON 

459 content_type = response.headers.get("content-type", "") 

460 if "application/x-www-form-urlencoded" in content_type: 

461 # Parse form-encoded response 

462 text_response = response.text 

463 token_response = {} 

464 for pair in text_response.split("&"): 

465 if "=" in pair: 

466 key, value = pair.split("=", 1) 

467 token_response[key] = value 

468 else: 

469 # Try JSON response 

470 try: 

471 token_response = response.json() 

472 except Exception as e: 

473 logger.warning(f"Failed to parse JSON response: {e}") 

474 # Fallback to text parsing 

475 text_response = response.text 

476 token_response = {"raw_response": text_response} 

477 

478 if "access_token" not in token_response: 

479 raise OAuthError(f"No access_token in response: {token_response}") 

480 

481 logger.info("""Successfully exchanged authorization code for access token""") 

482 return token_response["access_token"] 

483 

484 except httpx.HTTPError as e: 

485 logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}") 

486 if attempt == self.max_retries - 1: 

487 raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}") 

488 await asyncio.sleep(2**attempt) # Exponential backoff 

489 

490 # This should never be reached due to the exception above, but needed for type safety 

491 raise OAuthError("Failed to exchange code for token after all retry attempts") 

492 

493 async def initiate_authorization_code_flow(self, gateway_id: str, credentials: Dict[str, Any], app_user_email: str = None) -> Dict[str, str]: 

494 """Initiate Authorization Code flow with PKCE and return authorization URL. 

495 

496 Args: 

497 gateway_id: ID of the gateway being configured 

498 credentials: OAuth configuration with client_id, authorization_url, etc. 

499 app_user_email: ContextForge user email to associate with tokens 

500 

501 Returns: 

502 Dict containing authorization_url and state 

503 """ 

504 

505 # Generate PKCE parameters (RFC 7636) 

506 pkce_params = self._generate_pkce_params() 

507 

508 # Generate state parameter with user context for CSRF protection 

509 state = self._generate_state(gateway_id, app_user_email) 

510 

511 # Store state with code_verifier in session/cache for validation 

512 if self.token_storage: 

513 await self._store_authorization_state( 

514 gateway_id, 

515 state, 

516 code_verifier=pkce_params["code_verifier"], 

517 app_user_email=app_user_email, 

518 ) 

519 

520 # Generate authorization URL with PKCE 

521 auth_url = self._create_authorization_url_with_pkce(credentials, state, pkce_params["code_challenge"], pkce_params["code_challenge_method"]) 

522 

523 logger.info(f"Generated authorization URL with PKCE for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

524 

525 return {"authorization_url": auth_url, "state": state, "gateway_id": gateway_id} 

526 

527 async def complete_authorization_code_flow(self, gateway_id: str, code: str, state: str, credentials: Dict[str, Any]) -> Dict[str, Any]: 

528 """Complete Authorization Code flow with PKCE and store tokens. 

529 

530 Args: 

531 gateway_id: ID of the gateway 

532 code: Authorization code from callback 

533 state: State parameter for CSRF validation 

534 credentials: OAuth configuration 

535 

536 Returns: 

537 Dict containing success status, user_id, and expiration info 

538 

539 Raises: 

540 OAuthError: If state validation fails or token exchange fails 

541 """ 

542 # Validate state and retrieve code_verifier 

543 state_data = await self._validate_and_retrieve_state(gateway_id, state) 

544 if not state_data: 

545 raise OAuthError("Invalid or expired state parameter - possible replay attack") 

546 

547 code_verifier = state_data.get("code_verifier") 

548 app_user_email = state_data.get("app_user_email") 

549 

550 # Defence-in-depth: if app_user_email is absent from server-side 

551 # state (e.g. state stored by an older code path), attempt a 

552 # gateway-mismatch check via the legacy state parser but NEVER 

553 # extract identity fields from unsigned payloads (CWE-345). 

554 # Note: the /oauth/callback router rejects pure legacy states 

555 # before reaching here (allow_legacy_fallback=False), so this 

556 # block only fires for server-stored states that lack the email. 

557 if not app_user_email: 

558 legacy_state_payload = self._extract_legacy_state_payload(state) 

559 if legacy_state_payload: 

560 legacy_gateway_id = legacy_state_payload.get("gateway_id") 

561 if legacy_gateway_id and legacy_gateway_id != gateway_id: 

562 raise OAuthError("State parameter gateway mismatch") 

563 if self.token_storage: 

564 logger.error("User context (app_user_email) missing from OAuth state; refusing to bind tokens (CWE-287). gateway_id=%s", gateway_id) 

565 raise OAuthError("User context required for OAuth token storage") 

566 logger.warning("User context (app_user_email) missing from OAuth state; no token_storage configured — proceeding without binding. gateway_id=%s", gateway_id) 

567 

568 # Exchange code for tokens with PKCE code_verifier 

569 token_response = await self._exchange_code_for_tokens(credentials, code, code_verifier=code_verifier) 

570 

571 # Extract user information from token response 

572 user_id = self._extract_user_id(token_response, credentials) 

573 

574 # Store tokens if storage service is available 

575 if self.token_storage: 

576 token_record = await self.token_storage.store_tokens( 

577 gateway_id=gateway_id, 

578 user_id=user_id, 

579 app_user_email=app_user_email, # User from state 

580 access_token=token_response["access_token"], 

581 refresh_token=token_response.get("refresh_token"), 

582 expires_in=token_response.get("expires_in", self.settings.oauth_default_timeout), 

583 scopes=token_response.get("scope", "").split(), 

584 ) 

585 

586 return {"success": True, "user_id": user_id, "expires_at": token_record.expires_at.isoformat() if token_record.expires_at else None} 

587 return {"success": True, "user_id": user_id, "expires_at": None} 

588 

589 async def get_access_token_for_user(self, gateway_id: str, app_user_email: str) -> Optional[str]: 

590 """Get valid access token for a specific user. 

591 

592 Args: 

593 gateway_id: ID of the gateway 

594 app_user_email: ContextForge user email 

595 

596 Returns: 

597 Valid access token or None if not available 

598 """ 

599 if self.token_storage: 

600 return await self.token_storage.get_user_token(gateway_id, app_user_email) 

601 return None 

602 

603 def _generate_state(self, _gateway_id: str, _app_user_email: str = None) -> str: 

604 """Generate an opaque state token for CSRF protection. 

605 

606 Args: 

607 _gateway_id: Gateway identifier (reserved for compatibility with 

608 prior embedded-state call sites). 

609 _app_user_email: ContextForge user email (reserved for 

610 compatibility with prior embedded-state call sites). 

611 

612 Returns: 

613 Opaque random state token 

614 """ 

615 return secrets.token_urlsafe(48) 

616 

617 @staticmethod 

618 def _extract_legacy_state_payload(state: str) -> Optional[Dict[str, Any]]: 

619 """Best-effort decode of legacy state payloads used before opaque states. 

620 

621 Legacy formats supported: 

622 - base64url(payload || signature) where payload is JSON 

623 - gateway_id_random suffix format 

624 

625 Security: Legacy payloads lack signature verification, so only 

626 ``gateway_id`` is returned — never identity-sensitive fields like 

627 ``app_user_email`` which could be forged (CWE-345). 

628 

629 Args: 

630 state: Callback state token to decode. 

631 

632 Returns: 

633 Dict containing only ``gateway_id`` when format is recognized; 

634 otherwise ``None``. 

635 """ 

636 safe_legacy_fields = {"gateway_id"} 

637 

638 try: 

639 state_raw = base64.urlsafe_b64decode(state.encode()) 

640 if len(state_raw) <= 32: 

641 return None 

642 

643 payload_bytes = state_raw[:-32] 

644 payload = orjson.loads(payload_bytes) 

645 if isinstance(payload, dict): 

646 # Only return gateway_id — unsigned payloads must not 

647 # carry identity claims. 

648 safe = {k: v for k, v in payload.items() if k in safe_legacy_fields} 

649 return safe if safe else None 

650 except Exception: 

651 # Fall back to legacy gateway_id_random format 

652 if "_" in state: 

653 gateway_id = state.split("_", 1)[0] 

654 if gateway_id: 

655 return {"gateway_id": gateway_id} 

656 return None 

657 

658 async def resolve_gateway_id_from_state(self, state: str, allow_legacy_fallback: bool = True) -> Optional[str]: 

659 """Resolve gateway ID for a callback state token without consuming it. 

660 

661 Args: 

662 state: OAuth callback state parameter 

663 allow_legacy_fallback: Whether to decode legacy callback state formats. 

664 

665 Returns: 

666 Gateway ID when resolvable, otherwise ``None``. 

667 """ 

668 settings = get_settings() 

669 

670 if settings.cache_type == "redis": 

671 redis = await _get_redis_client() 

672 if redis: 

673 try: 

674 lookup_key = f"oauth:state_lookup:{state}" 

675 gateway_id = await redis.get(lookup_key) 

676 if gateway_id: 

677 if isinstance(gateway_id, bytes): 

678 gateway_id = gateway_id.decode("utf-8") 

679 return gateway_id 

680 except Exception as e: 

681 logger.warning(f"Failed to resolve state gateway in Redis: {e}") 

682 

683 if settings.cache_type == "database": 

684 try: 

685 # First-Party 

686 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel 

687 

688 db_gen = get_db() 

689 db = next(db_gen) 

690 try: 

691 oauth_state = db.query(OAuthState).filter(OAuthState.state == state).first() 

692 if oauth_state: 

693 return oauth_state.gateway_id 

694 finally: 

695 db_gen.close() 

696 except Exception as e: 

697 logger.warning(f"Failed to resolve state gateway in database: {e}") 

698 

699 async with _state_lock: 

700 now = datetime.now(timezone.utc) 

701 expired_keys = [key for key, data in _oauth_states.items() if datetime.fromisoformat(data["expires_at"]) < now] 

702 for key in expired_keys: 

703 expired_state = _oauth_states[key].get("state") 

704 del _oauth_states[key] 

705 if expired_state: 

706 _oauth_state_lookup.pop(expired_state, None) 

707 gateway_id = _oauth_state_lookup.get(state) 

708 if gateway_id: 

709 return gateway_id 

710 

711 if allow_legacy_fallback: 

712 legacy_payload = self._extract_legacy_state_payload(state) 

713 if legacy_payload: 

714 return legacy_payload.get("gateway_id") 

715 return None 

716 

717 async def _store_authorization_state( 

718 self, 

719 gateway_id: str, 

720 state: str, 

721 code_verifier: str = None, 

722 app_user_email: str = None, 

723 ) -> None: 

724 """Store authorization state for validation with TTL. 

725 

726 Args: 

727 gateway_id: ID of the gateway 

728 state: State parameter to store 

729 code_verifier: Optional PKCE code verifier (RFC 7636) 

730 app_user_email: Requesting user email for token association 

731 """ 

732 expires_at = datetime.now(timezone.utc) + timedelta(seconds=STATE_TTL_SECONDS) 

733 settings = get_settings() 

734 

735 # Try Redis first for distributed storage 

736 if settings.cache_type == "redis": 

737 redis = await _get_redis_client() 

738 if redis: 

739 try: 

740 state_key = f"oauth:state:{gateway_id}:{state}" 

741 lookup_key = f"oauth:state_lookup:{state}" 

742 state_data = { 

743 "state": state, 

744 "gateway_id": gateway_id, 

745 "code_verifier": code_verifier, 

746 "app_user_email": app_user_email, 

747 "expires_at": expires_at.isoformat(), 

748 "used": False, 

749 } 

750 # Store in Redis with TTL 

751 await redis.setex(state_key, STATE_TTL_SECONDS, orjson.dumps(state_data)) 

752 await redis.setex(lookup_key, STATE_TTL_SECONDS, gateway_id) 

753 logger.debug(f"Stored OAuth state in Redis for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

754 return 

755 except Exception as e: 

756 logger.warning(f"Failed to store state in Redis: {e}, falling back") 

757 

758 # Try database storage for multi-worker deployments 

759 if settings.cache_type == "database": 

760 try: 

761 # First-Party 

762 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel 

763 

764 db_gen = get_db() 

765 db = next(db_gen) 

766 try: 

767 # Clean up expired states first 

768 db.query(OAuthState).filter(OAuthState.expires_at < datetime.now(timezone.utc)).delete() 

769 

770 # Store new state with code_verifier 

771 oauth_state_kwargs = { 

772 "gateway_id": gateway_id, 

773 "state": state, 

774 "code_verifier": code_verifier, 

775 "expires_at": expires_at, 

776 "used": False, 

777 } 

778 if hasattr(OAuthState, "app_user_email"): 

779 oauth_state_kwargs["app_user_email"] = app_user_email 

780 

781 oauth_state = OAuthState(**oauth_state_kwargs) 

782 db.add(oauth_state) 

783 db.commit() 

784 logger.debug(f"Stored OAuth state in database for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

785 return 

786 finally: 

787 db_gen.close() 

788 except Exception as e: 

789 logger.warning(f"Failed to store state in database: {e}, falling back to memory") 

790 

791 # Fallback to in-memory storage for development 

792 async with _state_lock: 

793 # Clean up expired states first 

794 now = datetime.now(timezone.utc) 

795 state_key = f"oauth:state:{gateway_id}:{state}" 

796 state_data = { 

797 "state": state, 

798 "gateway_id": gateway_id, 

799 "code_verifier": code_verifier, 

800 "app_user_email": app_user_email, 

801 "expires_at": expires_at.isoformat(), 

802 "used": False, 

803 } 

804 expired_states = [key for key, data in _oauth_states.items() if datetime.fromisoformat(data["expires_at"]) < now] 

805 for key in expired_states: 

806 expired_state_value = _oauth_states[key].get("state") 

807 del _oauth_states[key] 

808 if expired_state_value: 

809 _oauth_state_lookup.pop(expired_state_value, None) 

810 logger.debug(f"Cleaned up expired state: {key[:20]}...") 

811 

812 # Store the new state with expiration 

813 _oauth_states[state_key] = state_data 

814 _oauth_state_lookup[state] = gateway_id 

815 logger.debug(f"Stored OAuth state in memory for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

816 

817 async def _validate_authorization_state(self, gateway_id: str, state: str) -> bool: 

818 """Validate authorization state parameter and mark as used. 

819 

820 Args: 

821 gateway_id: ID of the gateway 

822 state: State parameter to validate 

823 

824 Returns: 

825 True if state is valid and not yet used, False otherwise 

826 """ 

827 settings = get_settings() 

828 

829 # Try Redis first for distributed storage 

830 if settings.cache_type == "redis": 

831 redis = await _get_redis_client() 

832 if redis: 

833 try: 

834 state_key = f"oauth:state:{gateway_id}:{state}" 

835 lookup_key = f"oauth:state_lookup:{state}" 

836 # Get and delete state atomically (single-use) 

837 state_json = await redis.getdel(state_key) 

838 await redis.delete(lookup_key) 

839 if not state_json: 

840 logger.warning(f"State not found in Redis for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

841 return False 

842 

843 state_data = orjson.loads(state_json) 

844 

845 # Parse expires_at as timezone-aware datetime. If the stored value 

846 # is naive, assume UTC for compatibility. 

847 try: 

848 expires_at = datetime.fromisoformat(state_data["expires_at"]) 

849 except Exception: 

850 # Fallback: try parsing without microseconds/offsets 

851 expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S") 

852 

853 if expires_at.tzinfo is None: 

854 # Assume UTC for naive timestamps 

855 expires_at = expires_at.replace(tzinfo=timezone.utc) 

856 

857 # Check if state has expired 

858 if expires_at < datetime.now(timezone.utc): 

859 logger.warning(f"State has expired for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

860 return False 

861 

862 # Check if state was already used (should not happen with getdel) 

863 if state_data.get("used", False): 

864 logger.warning(f"State was already used for gateway {SecurityValidator.sanitize_log_message(gateway_id)} - possible replay attack") 

865 return False 

866 

867 logger.debug(f"Successfully validated OAuth state from Redis for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

868 return True 

869 except Exception as e: 

870 logger.warning(f"Failed to validate state in Redis: {e}, falling back") 

871 

872 # Try database storage for multi-worker deployments 

873 if settings.cache_type == "database": 

874 try: 

875 # First-Party 

876 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel 

877 

878 db_gen = get_db() 

879 db = next(db_gen) 

880 try: 

881 # Find the state 

882 oauth_state = db.query(OAuthState).filter(OAuthState.gateway_id == gateway_id, OAuthState.state == state).first() 

883 

884 if not oauth_state: 

885 logger.warning(f"State not found in database for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

886 return False 

887 

888 # Check if state has expired 

889 # Ensure oauth_state.expires_at is timezone-aware. If naive, assume UTC. 

890 expires_at = oauth_state.expires_at 

891 if expires_at.tzinfo is None: 

892 expires_at = expires_at.replace(tzinfo=timezone.utc) 

893 

894 if expires_at < datetime.now(timezone.utc): 

895 logger.warning(f"State has expired for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

896 db.delete(oauth_state) 

897 db.commit() 

898 return False 

899 

900 # Check if state was already used 

901 if oauth_state.used: 

902 logger.warning(f"State has already been used for gateway {SecurityValidator.sanitize_log_message(gateway_id)} - possible replay attack") 

903 return False 

904 

905 # Mark as used and delete (single-use) 

906 db.delete(oauth_state) 

907 db.commit() 

908 logger.debug(f"Successfully validated OAuth state from database for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

909 return True 

910 finally: 

911 db_gen.close() 

912 except Exception as e: 

913 logger.warning(f"Failed to validate state in database: {e}, falling back to memory") 

914 

915 # Fallback to in-memory storage for development 

916 state_key = f"oauth:state:{gateway_id}:{state}" 

917 async with _state_lock: 

918 state_data = _oauth_states.get(state_key) 

919 

920 # Check if state exists 

921 if not state_data: 

922 logger.warning(f"State not found in memory for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

923 return False 

924 

925 # Parse and normalize expires_at to timezone-aware datetime 

926 expires_at = datetime.fromisoformat(state_data["expires_at"]) 

927 if expires_at.tzinfo is None: 

928 expires_at = expires_at.replace(tzinfo=timezone.utc) 

929 

930 if expires_at < datetime.now(timezone.utc): 

931 logger.warning(f"State has expired for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

932 del _oauth_states[state_key] # Clean up expired state 

933 _oauth_state_lookup.pop(state, None) 

934 return False 

935 

936 # Check if state has already been used (prevent replay) 

937 if state_data.get("used", False): 

938 logger.warning(f"State has already been used for gateway {SecurityValidator.sanitize_log_message(gateway_id)} - possible replay attack") 

939 return False 

940 

941 # Mark state as used and remove it (single-use) 

942 del _oauth_states[state_key] 

943 _oauth_state_lookup.pop(state, None) 

944 logger.debug(f"Successfully validated OAuth state from memory for gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

945 return True 

946 

947 async def _validate_and_retrieve_state(self, gateway_id: str, state: str) -> Optional[Dict[str, Any]]: 

948 """Validate state and return full state data including code_verifier. 

949 

950 Args: 

951 gateway_id: ID of the gateway 

952 state: State parameter to validate 

953 

954 Returns: 

955 Dict with state data including code_verifier, or None if invalid/expired 

956 """ 

957 settings = get_settings() 

958 

959 # Try Redis first 

960 if settings.cache_type == "redis": 

961 redis = await _get_redis_client() 

962 if redis: 

963 try: 

964 state_key = f"oauth:state:{gateway_id}:{state}" 

965 lookup_key = f"oauth:state_lookup:{state}" 

966 state_json = await redis.getdel(state_key) # Atomic get+delete 

967 await redis.delete(lookup_key) 

968 if not state_json: 

969 return None 

970 

971 state_data = orjson.loads(state_json) 

972 

973 # Check expiration 

974 try: 

975 expires_at = datetime.fromisoformat(state_data["expires_at"]) 

976 except Exception: 

977 expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S") 

978 

979 if expires_at.tzinfo is None: 

980 expires_at = expires_at.replace(tzinfo=timezone.utc) 

981 

982 if expires_at < datetime.now(timezone.utc): 

983 return None 

984 

985 return state_data 

986 except Exception as e: 

987 logger.warning(f"Failed to validate state in Redis: {e}, falling back") 

988 

989 # Try database 

990 if settings.cache_type == "database": 

991 try: 

992 # First-Party 

993 from mcpgateway.db import get_db, OAuthState # pylint: disable=import-outside-toplevel 

994 

995 db_gen = get_db() 

996 db = next(db_gen) 

997 try: 

998 oauth_state = db.query(OAuthState).filter(OAuthState.gateway_id == gateway_id, OAuthState.state == state).first() 

999 

1000 if not oauth_state: 

1001 return None 

1002 

1003 # Check expiration 

1004 expires_at = oauth_state.expires_at 

1005 if expires_at.tzinfo is None: 

1006 expires_at = expires_at.replace(tzinfo=timezone.utc) 

1007 

1008 if expires_at < datetime.now(timezone.utc): 

1009 db.delete(oauth_state) 

1010 db.commit() 

1011 return None 

1012 

1013 # Check if already used 

1014 if oauth_state.used: 

1015 return None 

1016 

1017 # Build state data 

1018 state_data = { 

1019 "state": oauth_state.state, 

1020 "gateway_id": oauth_state.gateway_id, 

1021 "code_verifier": oauth_state.code_verifier, 

1022 "expires_at": oauth_state.expires_at.isoformat(), 

1023 } 

1024 if hasattr(oauth_state, "app_user_email"): 

1025 state_data["app_user_email"] = getattr(oauth_state, "app_user_email", None) 

1026 

1027 # Mark as used and delete 

1028 db.delete(oauth_state) 

1029 db.commit() 

1030 

1031 return state_data 

1032 finally: 

1033 db_gen.close() 

1034 except Exception as e: 

1035 logger.warning(f"Failed to validate state in database: {e}") 

1036 

1037 # Fallback to in-memory 

1038 state_key = f"oauth:state:{gateway_id}:{state}" 

1039 async with _state_lock: 

1040 state_data = _oauth_states.get(state_key) 

1041 if not state_data: 

1042 return None 

1043 

1044 # Check expiration 

1045 expires_at = datetime.fromisoformat(state_data["expires_at"]) 

1046 if expires_at.tzinfo is None: 

1047 expires_at = expires_at.replace(tzinfo=timezone.utc) 

1048 

1049 if expires_at < datetime.now(timezone.utc): 

1050 del _oauth_states[state_key] 

1051 _oauth_state_lookup.pop(state, None) 

1052 return None 

1053 

1054 # Remove from memory (single-use) 

1055 del _oauth_states[state_key] 

1056 _oauth_state_lookup.pop(state, None) 

1057 return state_data 

1058 

1059 def _create_authorization_url(self, credentials: Dict[str, Any], state: str) -> tuple[str, str]: 

1060 """Create authorization URL with state parameter. 

1061 

1062 Args: 

1063 credentials: OAuth configuration 

1064 state: State parameter for CSRF protection 

1065 

1066 Returns: 

1067 Tuple of (authorization_url, state) 

1068 """ 

1069 client_id = credentials["client_id"] 

1070 redirect_uri = credentials["redirect_uri"] 

1071 authorization_url = credentials["authorization_url"] 

1072 scopes = credentials.get("scopes", []) 

1073 

1074 # Create OAuth2 session 

1075 oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes) 

1076 

1077 # Generate authorization URL with state for CSRF protection 

1078 auth_url, state = oauth.authorization_url(authorization_url, state=state) 

1079 

1080 return auth_url, state 

1081 

1082 @staticmethod 

1083 def _is_microsoft_entra_v2_endpoint(endpoint_url: Any) -> bool: 

1084 """Return True when endpoint matches Microsoft Entra v2 login endpoints. 

1085 

1086 Args: 

1087 endpoint_url: OAuth endpoint URL to check 

1088 

1089 Returns: 

1090 True if the endpoint is a Microsoft Entra v2 OAuth endpoint 

1091 """ 

1092 if not isinstance(endpoint_url, str) or not endpoint_url: 

1093 return False 

1094 

1095 parsed = urlparse(endpoint_url) 

1096 host = (parsed.hostname or "").lower() 

1097 path = parsed.path.lower() 

1098 

1099 return host in OAuthManager._ENTRA_HOSTS and "/oauth2/v2.0/" in path 

1100 

1101 @staticmethod 

1102 def _is_enabled_flag(value: Any) -> bool: 

1103 """Parse boolean-like config values from oauth_config. 

1104 

1105 Args: 

1106 value: Config value to interpret as boolean 

1107 

1108 Returns: 

1109 True if value represents an enabled/truthy setting 

1110 """ 

1111 if isinstance(value, bool): 

1112 return value 

1113 if isinstance(value, str): 

1114 return value.strip().lower() in {"1", "true", "yes", "on"} 

1115 return False 

1116 

1117 def _should_include_resource_parameter(self, credentials: Dict[str, Any], scopes: Any) -> bool: 

1118 """Determine whether RFC 8707 resource should be sent for this request. 

1119 

1120 Args: 

1121 credentials: OAuth configuration containing resource and endpoint URLs 

1122 scopes: OAuth scopes for the request 

1123 

1124 Returns: 

1125 True if the resource parameter should be included in the request 

1126 """ 

1127 if not credentials.get("resource"): 

1128 return False 

1129 

1130 if self._is_enabled_flag(credentials.get("omit_resource")): 

1131 return False 

1132 

1133 # Microsoft Entra v2 does not accept legacy resource with v2 scope-based requests. 

1134 if scopes and (self._is_microsoft_entra_v2_endpoint(credentials.get("authorization_url")) or self._is_microsoft_entra_v2_endpoint(credentials.get("token_url"))): 

1135 logger.info("Omitting OAuth resource parameter for Microsoft Entra v2 scope-based flow") 

1136 return False 

1137 

1138 return True 

1139 

1140 def _create_authorization_url_with_pkce(self, credentials: Dict[str, Any], state: str, code_challenge: str, code_challenge_method: str) -> str: 

1141 """Create authorization URL with PKCE parameters (RFC 7636). 

1142 

1143 Args: 

1144 credentials: OAuth configuration 

1145 state: State parameter for CSRF protection 

1146 code_challenge: PKCE code challenge 

1147 code_challenge_method: PKCE method (S256) 

1148 

1149 Returns: 

1150 Authorization URL string with PKCE parameters 

1151 """ 

1152 # Standard 

1153 from urllib.parse import urlencode # pylint: disable=import-outside-toplevel 

1154 

1155 client_id = credentials["client_id"] 

1156 redirect_uri = credentials["redirect_uri"] 

1157 authorization_url = credentials["authorization_url"] 

1158 scopes = credentials.get("scopes", []) 

1159 

1160 # Build authorization parameters 

1161 params = {"response_type": "code", "client_id": client_id, "redirect_uri": redirect_uri, "state": state, "code_challenge": code_challenge, "code_challenge_method": code_challenge_method} 

1162 

1163 # Add scopes if present 

1164 if scopes: 

1165 params["scope"] = " ".join(scopes) if isinstance(scopes, list) else scopes 

1166 

1167 # Add resource parameter for JWT access token (RFC 8707) 

1168 # The resource is the MCP server URL, set by oauth_router.py 

1169 resource = credentials.get("resource") 

1170 if self._should_include_resource_parameter(credentials, scopes): 

1171 params["resource"] = resource # urlencode with doseq=True handles lists 

1172 

1173 # Build full URL (doseq=True handles list values like multiple resource params) 

1174 query_string = urlencode(params, doseq=True) 

1175 return f"{authorization_url}?{query_string}" 

1176 

1177 async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str, code_verifier: str = None) -> Dict[str, Any]: 

1178 """Exchange authorization code for tokens with PKCE support. 

1179 

1180 Args: 

1181 credentials: OAuth configuration 

1182 code: Authorization code from callback 

1183 code_verifier: Optional PKCE code verifier (RFC 7636) 

1184 

1185 Returns: 

1186 Token response dictionary 

1187 

1188 Raises: 

1189 OAuthError: If token exchange fails 

1190 """ 

1191 runtime_credentials = await self._prepare_runtime_credentials(credentials, "authorization_code_exchange_with_pkce") 

1192 client_id = runtime_credentials["client_id"] 

1193 client_secret = runtime_credentials.get("client_secret") # Optional for public clients (PKCE-only) 

1194 token_url = runtime_credentials["token_url"] 

1195 redirect_uri = runtime_credentials["redirect_uri"] 

1196 

1197 # Prepare token exchange data 

1198 token_data = { 

1199 "grant_type": "authorization_code", 

1200 "code": code, 

1201 "redirect_uri": redirect_uri, 

1202 "client_id": client_id, 

1203 } 

1204 

1205 # Only include client_secret if present (public clients don't have secrets) 

1206 if client_secret: 

1207 token_data["client_secret"] = client_secret 

1208 

1209 # Add PKCE code_verifier if present (RFC 7636) 

1210 if code_verifier: 

1211 token_data["code_verifier"] = code_verifier 

1212 

1213 # Add resource parameter to request JWT access token (RFC 8707) 

1214 # The resource identifies the MCP server (resource server), not the OAuth server 

1215 resource = runtime_credentials.get("resource") 

1216 scopes = runtime_credentials.get("scopes", []) 

1217 if self._should_include_resource_parameter(credentials, scopes): 

1218 if isinstance(resource, list): 

1219 # RFC 8707 allows multiple resource parameters - use list of tuples 

1220 form_data: list[tuple[str, str]] = list(token_data.items()) 

1221 for r in resource: 

1222 if r: 

1223 form_data.append(("resource", r)) 

1224 token_data = form_data # type: ignore[assignment] 

1225 else: 

1226 token_data["resource"] = resource 

1227 

1228 # Exchange code for token with retries 

1229 for attempt in range(self.max_retries): 

1230 try: 

1231 client = await self._get_client() 

1232 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

1233 response.raise_for_status() 

1234 

1235 # GitHub returns form-encoded responses, not JSON 

1236 content_type = response.headers.get("content-type", "") 

1237 if "application/x-www-form-urlencoded" in content_type: 

1238 # Parse form-encoded response 

1239 text_response = response.text 

1240 token_response = {} 

1241 for pair in text_response.split("&"): 

1242 if "=" in pair: 

1243 key, value = pair.split("=", 1) 

1244 token_response[key] = value 

1245 else: 

1246 # Try JSON response 

1247 try: 

1248 token_response = response.json() 

1249 except Exception as e: 

1250 logger.warning(f"Failed to parse JSON response: {e}") 

1251 # Fallback to text parsing 

1252 text_response = response.text 

1253 token_response = {"raw_response": text_response} 

1254 

1255 if "access_token" not in token_response: 

1256 raise OAuthError(f"No access_token in response: {token_response}") 

1257 

1258 logger.info("""Successfully exchanged authorization code for tokens""") 

1259 return token_response 

1260 

1261 except httpx.HTTPError as e: 

1262 logger.warning(f"Token exchange attempt {attempt + 1} failed: {str(e)}") 

1263 if attempt == self.max_retries - 1: 

1264 raise OAuthError(f"Failed to exchange code for token after {self.max_retries} attempts: {str(e)}") 

1265 await asyncio.sleep(2**attempt) # Exponential backoff 

1266 

1267 # This should never be reached due to the exception above, but needed for type safety 

1268 raise OAuthError("Failed to exchange code for token after all retry attempts") 

1269 

1270 async def refresh_token(self, refresh_token: str, credentials: Dict[str, Any]) -> Dict[str, Any]: 

1271 """Refresh an expired access token using a refresh token. 

1272 

1273 Args: 

1274 refresh_token: The refresh token to use 

1275 credentials: OAuth configuration including client_id, client_secret, token_url 

1276 

1277 Returns: 

1278 Dict containing new access_token, optional refresh_token, and expires_in 

1279 

1280 Raises: 

1281 OAuthError: If token refresh fails 

1282 """ 

1283 if not refresh_token: 

1284 raise OAuthError("No refresh token available") 

1285 

1286 runtime_credentials = await self._prepare_runtime_credentials(credentials, "refresh_token") 

1287 token_url = runtime_credentials.get("token_url") 

1288 if not token_url: 

1289 raise OAuthError("No token URL configured for OAuth provider") 

1290 

1291 client_id = runtime_credentials.get("client_id") 

1292 client_secret = runtime_credentials.get("client_secret") 

1293 

1294 if not client_id: 

1295 raise OAuthError("No client_id configured for OAuth provider") 

1296 

1297 # Prepare token refresh request 

1298 token_data = { 

1299 "grant_type": "refresh_token", 

1300 "refresh_token": refresh_token, 

1301 "client_id": client_id, 

1302 } 

1303 

1304 # Add client_secret if available (some providers require it) 

1305 if client_secret: 

1306 token_data["client_secret"] = client_secret 

1307 

1308 # Add resource parameter for JWT access token (RFC 8707) 

1309 # Must be included in refresh requests to maintain JWT token type 

1310 resource = runtime_credentials.get("resource") 

1311 scopes = runtime_credentials.get("scopes", []) 

1312 if self._should_include_resource_parameter(credentials, scopes): 

1313 if isinstance(resource, list): 

1314 # RFC 8707 allows multiple resource parameters - use list of tuples 

1315 form_data: list[tuple[str, str]] = list(token_data.items()) 

1316 for r in resource: 

1317 if r: 

1318 form_data.append(("resource", r)) 

1319 token_data = form_data # type: ignore[assignment] 

1320 else: 

1321 token_data["resource"] = resource 

1322 

1323 # Attempt token refresh with retries 

1324 for attempt in range(self.max_retries): 

1325 try: 

1326 client = await self._get_client() 

1327 response = await client.post(token_url, data=token_data, timeout=self.request_timeout) 

1328 if response.status_code == 200: 

1329 token_response = response.json() 

1330 

1331 # Validate required fields 

1332 if "access_token" not in token_response: 

1333 raise OAuthError("No access_token in refresh response") 

1334 

1335 logger.info("Successfully refreshed OAuth token") 

1336 return token_response 

1337 

1338 error_text = response.text 

1339 # If we get a 400/401, the refresh token is likely invalid 

1340 if response.status_code in [400, 401]: 

1341 raise OAuthError(f"Refresh token invalid or expired: {error_text}") 

1342 logger.warning(f"Token refresh failed with status {response.status_code}: {error_text}") 

1343 

1344 except httpx.HTTPError as e: 

1345 logger.warning(f"Token refresh attempt {attempt + 1} failed: {str(e)}") 

1346 if attempt == self.max_retries - 1: 

1347 raise OAuthError(f"Failed to refresh token after {self.max_retries} attempts: {str(e)}") 

1348 await asyncio.sleep(2**attempt) # Exponential backoff 

1349 

1350 raise OAuthError("Failed to refresh token after all retry attempts") 

1351 

1352 def _extract_user_id(self, token_response: Dict[str, Any], credentials: Dict[str, Any]) -> str: 

1353 """Extract user ID from token response. 

1354 

1355 Args: 

1356 token_response: Response from token exchange 

1357 credentials: OAuth configuration 

1358 

1359 Returns: 

1360 User ID string 

1361 """ 

1362 # Try to extract user ID from various common fields in token response 

1363 # Different OAuth providers use different field names 

1364 

1365 # Check for 'sub' (subject) - JWT standard 

1366 if "sub" in token_response: 

1367 return token_response["sub"] 

1368 

1369 # Check for 'user_id' - common in some OAuth responses 

1370 if "user_id" in token_response: 

1371 return token_response["user_id"] 

1372 

1373 # Check for 'id' - also common 

1374 if "id" in token_response: 

1375 return token_response["id"] 

1376 

1377 # Fallback to client_id if no user info is available 

1378 if credentials.get("client_id"): 

1379 return credentials["client_id"] 

1380 

1381 # Final fallback 

1382 return "unknown_user" 

1383 

1384 

1385class OAuthError(Exception): 

1386 """OAuth-related errors. 

1387 

1388 Examples: 

1389 >>> try: 

1390 ... raise OAuthError("Token acquisition failed") 

1391 ... except OAuthError as e: 

1392 ... str(e) 

1393 'Token acquisition failed' 

1394 >>> try: 

1395 ... raise OAuthError("Invalid grant type") 

1396 ... except Exception as e: 

1397 ... isinstance(e, OAuthError) 

1398 True 

1399 """ 

1400 

1401 

1402class OAuthRequiredError(OAuthError): 

1403 """Raised when a server requires OAuth but the caller is unauthenticated. 

1404 

1405 Carries ``server_id`` so the middleware can identify which server 

1406 triggered the rejection when constructing the ``WWW-Authenticate`` 

1407 header. 

1408 

1409 Examples: 

1410 >>> err = OAuthRequiredError("auth required", server_id="s1") 

1411 >>> err.server_id 

1412 's1' 

1413 >>> isinstance(err, OAuthError) 

1414 True 

1415 """ 

1416 

1417 def __init__(self, message: str, *, server_id: str = "") -> None: 

1418 """Initialize with message and optional server_id. 

1419 

1420 Args: 

1421 message: Human-readable error description. 

1422 server_id: Virtual-server identifier that triggered the rejection. 

1423 """ 

1424 super().__init__(message) 

1425 self.server_id = server_id 

1426 

1427 

1428class OAuthEnforcementUnavailableError(OAuthError): 

1429 """Raised when OAuth enforcement cannot be performed due to infrastructure failure. 

1430 

1431 Used when the database or other backing services needed to check a 

1432 server's ``oauth_enabled`` flag are unavailable. The middleware 

1433 translates this into an HTTP 503 to avoid silently allowing 

1434 unauthenticated access (fail-closed). 

1435 

1436 Examples: 

1437 >>> err = OAuthEnforcementUnavailableError("DB down", server_id="s1") 

1438 >>> err.server_id 

1439 's1' 

1440 >>> isinstance(err, OAuthError) 

1441 True 

1442 """ 

1443 

1444 def __init__(self, message: str, *, server_id: str = "") -> None: 

1445 """Initialize with message and optional server_id. 

1446 

1447 Args: 

1448 message: Human-readable error description. 

1449 server_id: Virtual-server identifier that triggered the rejection. 

1450 """ 

1451 super().__init__(message) 

1452 self.server_id = server_id