Coverage for mcpgateway / services / sso_service.py: 97%

403 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/sso_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Single Sign-On (SSO) authentication service for OAuth2 and OIDC providers. 

8Handles provider management, OAuth flows, and user authentication. 

9""" 

10 

11# Future 

12from __future__ import annotations 

13 

14# Standard 

15import base64 

16from datetime import timedelta 

17import hashlib 

18import logging 

19import secrets 

20import string 

21from typing import Any, Dict, List, Optional, Tuple 

22import urllib.parse 

23 

24# Third-Party 

25import orjson 

26from sqlalchemy import and_, select 

27from sqlalchemy.orm import Session 

28 

29# First-Party 

30from mcpgateway.config import settings 

31from mcpgateway.db import PendingUserApproval, SSOAuthSession, SSOProvider, utc_now 

32from mcpgateway.services.email_auth_service import EmailAuthService 

33from mcpgateway.services.encryption_service import get_encryption_service 

34from mcpgateway.utils.create_jwt_token import create_jwt_token 

35 

36# Logger 

37logger = logging.getLogger(__name__) 

38 

39 

40class SSOService: 

41 """Service for managing SSO authentication flows and providers. 

42 

43 Handles OAuth2/OIDC authentication flows, provider configuration, 

44 and integration with the local user system. 

45 

46 Examples: 

47 Basic construction and helper checks: 

48 >>> from unittest.mock import Mock 

49 >>> service = SSOService(Mock()) 

50 >>> isinstance(service, SSOService) 

51 True 

52 >>> callable(service.list_enabled_providers) 

53 True 

54 """ 

55 

56 def __init__(self, db: Session): 

57 """Initialize SSO service with database session. 

58 

59 Args: 

60 db: SQLAlchemy database session 

61 """ 

62 self.db = db 

63 self.auth_service = EmailAuthService(db) 

64 self._encryption = get_encryption_service(settings.auth_encryption_secret) 

65 

66 async def _encrypt_secret(self, secret: str) -> str: 

67 """Encrypt a client secret for secure storage. 

68 

69 Args: 

70 secret: Plain text client secret 

71 

72 Returns: 

73 Encrypted secret string 

74 """ 

75 return await self._encryption.encrypt_secret_async(secret) 

76 

77 async def _decrypt_secret(self, encrypted_secret: str) -> Optional[str]: 

78 """Decrypt a client secret for use. 

79 

80 Args: 

81 encrypted_secret: Encrypted secret string 

82 

83 Returns: 

84 Plain text client secret 

85 """ 

86 decrypted: str | None = await self._encryption.decrypt_secret_async(encrypted_secret) 

87 if decrypted: 

88 return decrypted 

89 

90 return None 

91 

92 def _decode_jwt_claims(self, token: str) -> Optional[Dict[str, Any]]: 

93 """Decode JWT token payload without verification. 

94 

95 This is used to extract claims from ID tokens where we've already 

96 validated the OAuth flow. The token signature is not verified here 

97 because the token was received directly from the trusted token endpoint. 

98 

99 Args: 

100 token: JWT token string 

101 

102 Returns: 

103 Decoded payload dict or None if decoding fails 

104 

105 Examples: 

106 >>> from unittest.mock import Mock 

107 >>> service = SSOService(Mock()) 

108 >>> # Valid JWT structure (header.payload.signature) 

109 >>> import base64 

110 >>> payload = base64.urlsafe_b64encode(b'{"sub":"123","groups":["admin"]}').decode().rstrip('=') 

111 >>> token = f"eyJhbGciOiJSUzI1NiJ9.{payload}.signature" 

112 >>> claims = service._decode_jwt_claims(token) 

113 >>> claims is not None 

114 True 

115 """ 

116 try: 

117 # JWT format: header.payload.signature 

118 parts = token.split(".") 

119 if len(parts) != 3: 

120 logger.warning("Invalid JWT format: expected 3 parts") 

121 return None 

122 

123 # Decode payload (middle part) - add padding if needed 

124 payload_b64 = parts[1] 

125 # Add padding for base64 decoding 

126 padding = 4 - len(payload_b64) % 4 

127 if padding != 4: 

128 payload_b64 += "=" * padding 

129 

130 payload_bytes = base64.urlsafe_b64decode(payload_b64) 

131 return orjson.loads(payload_bytes) 

132 

133 except (ValueError, orjson.JSONDecodeError, UnicodeDecodeError) as e: 

134 logger.warning(f"Failed to decode JWT claims: {e}") 

135 return None 

136 

137 def list_enabled_providers(self) -> List[SSOProvider]: 

138 """Get list of enabled SSO providers. 

139 

140 Returns: 

141 List of enabled SSO providers 

142 

143 Examples: 

144 Returns empty list when DB has no providers: 

145 >>> from unittest.mock import MagicMock 

146 >>> service = SSOService(MagicMock()) 

147 >>> service.db.execute.return_value.scalars.return_value.all.return_value = [] 

148 >>> service.list_enabled_providers() 

149 [] 

150 """ 

151 stmt = select(SSOProvider).where(SSOProvider.is_enabled.is_(True)) 

152 result = self.db.execute(stmt) 

153 return list(result.scalars().all()) 

154 

155 def get_provider(self, provider_id: str) -> Optional[SSOProvider]: 

156 """Get SSO provider by ID. 

157 

158 Args: 

159 provider_id: Provider identifier (e.g., 'github', 'google') 

160 

161 Returns: 

162 SSO provider or None if not found 

163 

164 Examples: 

165 >>> from unittest.mock import MagicMock 

166 >>> service = SSOService(MagicMock()) 

167 >>> service.db.execute.return_value.scalar_one_or_none.return_value = None 

168 >>> service.get_provider('x') is None 

169 True 

170 """ 

171 stmt = select(SSOProvider).where(SSOProvider.id == provider_id) 

172 result = self.db.execute(stmt) 

173 return result.scalar_one_or_none() 

174 

175 def get_provider_by_name(self, provider_name: str) -> Optional[SSOProvider]: 

176 """Get SSO provider by name. 

177 

178 Args: 

179 provider_name: Provider name (e.g., 'github', 'google') 

180 

181 Returns: 

182 SSO provider or None if not found 

183 

184 Examples: 

185 >>> from unittest.mock import MagicMock 

186 >>> service = SSOService(MagicMock()) 

187 >>> service.db.execute.return_value.scalar_one_or_none.return_value = None 

188 >>> service.get_provider_by_name('github') is None 

189 True 

190 """ 

191 stmt = select(SSOProvider).where(SSOProvider.name == provider_name) 

192 result = self.db.execute(stmt) 

193 return result.scalar_one_or_none() 

194 

195 async def create_provider(self, provider_data: Dict[str, Any]) -> SSOProvider: 

196 """Create new SSO provider configuration. 

197 

198 Args: 

199 provider_data: Provider configuration data 

200 

201 Returns: 

202 Created SSO provider 

203 

204 Examples: 

205 >>> import asyncio 

206 >>> from unittest.mock import MagicMock, AsyncMock 

207 >>> service = SSOService(MagicMock()) 

208 >>> service._encrypt_secret = AsyncMock(side_effect=lambda s: 'ENC(' + s + ')') 

209 >>> data = { 

210 ... 'id': 'github', 'name': 'github', 'display_name': 'GitHub', 'provider_type': 'oauth2', 

211 ... 'client_id': 'cid', 'client_secret': 'sec', 

212 ... 'authorization_url': 'https://example/auth', 'token_url': 'https://example/token', 

213 ... 'userinfo_url': 'https://example/user', 'scope': 'user:email' 

214 ... } 

215 >>> provider = asyncio.run(service.create_provider(data)) 

216 >>> hasattr(provider, 'id') and provider.id == 'github' 

217 True 

218 >>> provider.client_secret_encrypted.startswith('ENC(') 

219 True 

220 """ 

221 # Encrypt client secret 

222 client_secret = provider_data.pop("client_secret") 

223 provider_data["client_secret_encrypted"] = await self._encrypt_secret(client_secret) 

224 

225 provider = SSOProvider(**provider_data) 

226 self.db.add(provider) 

227 self.db.commit() 

228 self.db.refresh(provider) 

229 return provider 

230 

231 async def update_provider(self, provider_id: str, provider_data: Dict[str, Any]) -> Optional[SSOProvider]: 

232 """Update existing SSO provider configuration. 

233 

234 Args: 

235 provider_id: Provider identifier 

236 provider_data: Updated provider data 

237 

238 Returns: 

239 Updated SSO provider or None if not found 

240 

241 Examples: 

242 >>> import asyncio 

243 >>> from types import SimpleNamespace 

244 >>> from unittest.mock import MagicMock, AsyncMock 

245 >>> svc = SSOService(MagicMock()) 

246 >>> # Existing provider object 

247 >>> existing = SimpleNamespace(id='github', name='github', client_id='old', client_secret_encrypted='X', is_enabled=True) 

248 >>> svc.get_provider = lambda _id: existing 

249 >>> svc._encrypt_secret = AsyncMock(side_effect=lambda s: 'ENC-' + s) 

250 >>> svc.db.commit = lambda: None 

251 >>> svc.db.refresh = lambda obj: None 

252 >>> updated = asyncio.run(svc.update_provider('github', {'client_id': 'new', 'client_secret': 'sec'})) 

253 >>> updated.client_id 

254 'new' 

255 >>> updated.client_secret_encrypted 

256 'ENC-sec' 

257 """ 

258 provider = self.get_provider(provider_id) 

259 if not provider: 

260 return None 

261 

262 # Handle client secret encryption if provided 

263 if "client_secret" in provider_data: 

264 client_secret = provider_data.pop("client_secret") 

265 provider_data["client_secret_encrypted"] = await self._encrypt_secret(client_secret) 

266 

267 for key, value in provider_data.items(): 

268 if hasattr(provider, key): 268 ↛ 267line 268 didn't jump to line 267 because the condition on line 268 was always true

269 setattr(provider, key, value) 

270 

271 provider.updated_at = utc_now() 

272 self.db.commit() 

273 self.db.refresh(provider) 

274 return provider 

275 

276 def delete_provider(self, provider_id: str) -> bool: 

277 """Delete SSO provider configuration. 

278 

279 Args: 

280 provider_id: Provider identifier 

281 

282 Returns: 

283 True if deleted, False if not found 

284 

285 Examples: 

286 >>> from types import SimpleNamespace 

287 >>> from unittest.mock import MagicMock 

288 >>> svc = SSOService(MagicMock()) 

289 >>> svc.db.delete = lambda obj: None 

290 >>> svc.db.commit = lambda: None 

291 >>> svc.get_provider = lambda _id: SimpleNamespace(id='github') 

292 >>> svc.delete_provider('github') 

293 True 

294 >>> svc.get_provider = lambda _id: None 

295 >>> svc.delete_provider('missing') 

296 False 

297 """ 

298 provider = self.get_provider(provider_id) 

299 if not provider: 

300 return False 

301 

302 self.db.delete(provider) 

303 self.db.commit() 

304 return True 

305 

306 def generate_pkce_challenge(self) -> Tuple[str, str]: 

307 """Generate PKCE code verifier and challenge for OAuth 2.1. 

308 

309 Returns: 

310 Tuple of (code_verifier, code_challenge) 

311 

312 Examples: 

313 Generate verifier and challenge: 

314 >>> from unittest.mock import Mock 

315 >>> service = SSOService(Mock()) 

316 >>> verifier, challenge = service.generate_pkce_challenge() 

317 >>> isinstance(verifier, str) and isinstance(challenge, str) 

318 True 

319 >>> len(verifier) >= 43 

320 True 

321 >>> len(challenge) >= 43 

322 True 

323 """ 

324 # Generate cryptographically random code verifier 

325 code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=") 

326 

327 # Generate code challenge using SHA256 

328 code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()).decode("utf-8").rstrip("=") 

329 

330 return code_verifier, code_challenge 

331 

332 def get_authorization_url(self, provider_id: str, redirect_uri: str, scopes: Optional[List[str]] = None) -> Optional[str]: 

333 """Generate OAuth authorization URL for provider. 

334 

335 Args: 

336 provider_id: Provider identifier 

337 redirect_uri: Callback URI after authorization 

338 scopes: Optional custom scopes (uses provider default if None) 

339 

340 Returns: 

341 Authorization URL or None if provider not found 

342 

343 Examples: 

344 >>> from types import SimpleNamespace 

345 >>> from unittest.mock import MagicMock 

346 >>> service = SSOService(MagicMock()) 

347 >>> provider = SimpleNamespace(id='github', is_enabled=True, provider_type='oauth2', client_id='cid', authorization_url='https://example/auth', scope='user:email') 

348 >>> service.get_provider = lambda _pid: provider 

349 >>> service.db.add = lambda x: None 

350 >>> service.db.commit = lambda: None 

351 >>> url = service.get_authorization_url('github', 'https://app/callback', ['email']) 

352 >>> isinstance(url, str) and 'client_id=cid' in url and 'state=' in url 

353 True 

354 

355 Missing provider returns None: 

356 >>> service.get_provider = lambda _pid: None 

357 >>> service.get_authorization_url('missing', 'https://app/callback') is None 

358 True 

359 """ 

360 provider = self.get_provider(provider_id) 

361 if not provider or not provider.is_enabled: 

362 return None 

363 

364 # Generate PKCE parameters 

365 code_verifier, code_challenge = self.generate_pkce_challenge() 

366 

367 # Generate CSRF state 

368 state = secrets.token_urlsafe(32) 

369 

370 # Generate OIDC nonce if applicable 

371 nonce = secrets.token_urlsafe(16) if provider.provider_type == "oidc" else None 

372 

373 # Create auth session 

374 auth_session = SSOAuthSession(provider_id=provider_id, state=state, code_verifier=code_verifier, nonce=nonce, redirect_uri=redirect_uri) 

375 self.db.add(auth_session) 

376 self.db.commit() 

377 

378 # Build authorization URL 

379 params = { 

380 "client_id": provider.client_id, 

381 "response_type": "code", 

382 "redirect_uri": redirect_uri, 

383 "state": state, 

384 "scope": " ".join(scopes) if scopes else provider.scope, 

385 "code_challenge": code_challenge, 

386 "code_challenge_method": "S256", 

387 } 

388 

389 if nonce: 

390 params["nonce"] = nonce 

391 

392 return f"{provider.authorization_url}?{urllib.parse.urlencode(params)}" 

393 

394 async def handle_oauth_callback(self, provider_id: str, code: str, state: str) -> Optional[Dict[str, Any]]: 

395 """Handle OAuth callback and exchange code for tokens. 

396 

397 Args: 

398 provider_id: Provider identifier 

399 code: Authorization code from callback 

400 state: CSRF state parameter 

401 

402 Returns: 

403 User info dict or None if authentication failed 

404 

405 Examples: 

406 Happy-path with patched exchanges and user info: 

407 >>> import asyncio 

408 >>> from types import SimpleNamespace 

409 >>> from unittest.mock import MagicMock 

410 >>> svc = SSOService(MagicMock()) 

411 >>> # Mock DB auth session lookup 

412 >>> provider = SimpleNamespace(id='github', is_enabled=True, provider_type='oauth2') 

413 >>> auth_session = SimpleNamespace(provider_id='github', state='st', provider=provider, is_expired=False) 

414 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = auth_session 

415 >>> # Patch token exchange and user info retrieval 

416 >>> async def _ex(p, sess, c): 

417 ... return {'access_token': 'tok', 'id_token': 'id_tok'} 

418 >>> async def _ui(p, access, token_data=None): 

419 ... return {'email': 'user@example.com'} 

420 >>> svc._exchange_code_for_tokens = _ex 

421 >>> svc._get_user_info = _ui 

422 >>> svc.db.delete = lambda obj: None 

423 >>> svc.db.commit = lambda: None 

424 >>> out = asyncio.run(svc.handle_oauth_callback('github', 'code', 'st')) 

425 >>> out['email'] 

426 'user@example.com' 

427 

428 Early return cases: 

429 >>> # No session 

430 >>> svc2 = SSOService(MagicMock()) 

431 >>> svc2.db.execute.return_value.scalar_one_or_none.return_value = None 

432 >>> asyncio.run(svc2.handle_oauth_callback('github', 'c', 's')) is None 

433 True 

434 >>> # Expired session 

435 >>> expired = SimpleNamespace(provider_id='github', state='st', provider=SimpleNamespace(is_enabled=True), is_expired=True) 

436 >>> svc3 = SSOService(MagicMock()) 

437 >>> svc3.db.execute.return_value.scalar_one_or_none.return_value = expired 

438 >>> asyncio.run(svc3.handle_oauth_callback('github', 'c', 'st')) is None 

439 True 

440 >>> # Disabled provider 

441 >>> disabled = SimpleNamespace(provider_id='github', state='st', provider=SimpleNamespace(is_enabled=False), is_expired=False) 

442 >>> svc4 = SSOService(MagicMock()) 

443 >>> svc4.db.execute.return_value.scalar_one_or_none.return_value = disabled 

444 >>> asyncio.run(svc4.handle_oauth_callback('github', 'c', 'st')) is None 

445 True 

446 """ 

447 # Validate auth session 

448 stmt = select(SSOAuthSession).where(SSOAuthSession.state == state, SSOAuthSession.provider_id == provider_id) 

449 auth_session = self.db.execute(stmt).scalar_one_or_none() 

450 

451 if not auth_session or auth_session.is_expired: 

452 return None 

453 

454 provider = auth_session.provider 

455 if not provider or not provider.is_enabled: 

456 return None 

457 

458 try: 

459 # Exchange authorization code for tokens 

460 logger.info(f"Starting token exchange for provider {provider_id}") 

461 token_data = await self._exchange_code_for_tokens(provider, auth_session, code) 

462 if not token_data: 

463 logger.error(f"Failed to exchange code for tokens for provider {provider_id}") 

464 return None 

465 logger.info(f"Token exchange successful for provider {provider_id}") 

466 

467 # Get user info from provider (pass full token_data for id_token parsing) 

468 user_info = await self._get_user_info(provider, token_data["access_token"], token_data) 

469 if not user_info: 

470 logger.error(f"Failed to get user info for provider {provider_id}") 

471 return None 

472 

473 # Clean up auth session 

474 self.db.delete(auth_session) 

475 self.db.commit() 

476 

477 return user_info 

478 

479 except Exception as e: 

480 # Clean up auth session on error 

481 logger.error(f"OAuth callback failed for provider {provider_id}: {type(e).__name__}: {str(e)}") 

482 logger.exception("Full traceback for OAuth callback failure:") 

483 self.db.delete(auth_session) 

484 self.db.commit() 

485 return None 

486 

487 async def _exchange_code_for_tokens(self, provider: SSOProvider, auth_session: SSOAuthSession, code: str) -> Optional[Dict[str, Any]]: 

488 """Exchange authorization code for access tokens. 

489 

490 Args: 

491 provider: SSO provider configuration 

492 auth_session: Auth session with PKCE parameters 

493 code: Authorization code 

494 

495 Returns: 

496 Token response dict or None if failed 

497 """ 

498 token_params = { 

499 "client_id": provider.client_id, 

500 "client_secret": await self._decrypt_secret(provider.client_secret_encrypted), 

501 "code": code, 

502 "grant_type": "authorization_code", 

503 "redirect_uri": auth_session.redirect_uri, 

504 "code_verifier": auth_session.code_verifier, 

505 } 

506 

507 # First-Party 

508 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

509 

510 client = await get_http_client() 

511 response = await client.post(provider.token_url, data=token_params, headers={"Accept": "application/json"}) 

512 

513 if response.status_code == 200: 

514 return response.json() 

515 logger.error(f"Token exchange failed for {provider.name}: HTTP {response.status_code} - {response.text}") 

516 

517 return None 

518 

519 async def _get_user_info(self, provider: SSOProvider, access_token: str, token_data: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: 

520 """Get user information from provider using access token. 

521 

522 Args: 

523 provider: SSO provider configuration 

524 access_token: OAuth access token 

525 token_data: Optional full token response containing id_token for OIDC providers 

526 

527 Returns: 

528 User info dict or None if failed 

529 """ 

530 # First-Party 

531 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

532 

533 client = await get_http_client() 

534 response = await client.get(provider.userinfo_url, headers={"Authorization": f"Bearer {access_token}"}) 

535 

536 if response.status_code == 200: 

537 user_data = response.json() 

538 

539 # For GitHub, also fetch organizations if admin assignment is configured 

540 if provider.id == "github" and settings.sso_github_admin_orgs: 

541 try: 

542 orgs_response = await client.get("https://api.github.com/user/orgs", headers={"Authorization": f"Bearer {access_token}"}) 

543 if orgs_response.status_code == 200: 

544 orgs_data = orgs_response.json() 

545 user_data["organizations"] = [org["login"] for org in orgs_data] 

546 else: 

547 logger.warning(f"Failed to fetch GitHub organizations: HTTP {orgs_response.status_code}") 

548 user_data["organizations"] = [] 

549 except Exception as e: 

550 logger.warning(f"Error fetching GitHub organizations: {e}") 

551 user_data["organizations"] = [] 

552 

553 # For Entra ID, extract groups/roles from id_token since userinfo doesn't include them 

554 # Microsoft's /oidc/userinfo endpoint only returns basic claims (sub, name, email, picture) 

555 # Groups and roles are included in the id_token when configured in Azure Portal 

556 if provider.id == "entra" and token_data and "id_token" in token_data: 

557 id_token_claims = self._decode_jwt_claims(token_data["id_token"]) 

558 if id_token_claims: 558 ↛ 587line 558 didn't jump to line 587 because the condition on line 558 was always true

559 # Detect group overage - when user has too many groups (>200), EntraID returns 

560 # _claim_names/_claim_sources instead of the actual groups array. 

561 # See: https://learn.microsoft.com/en-us/entra/identity-platform/id-token-claims-reference 

562 claim_names = id_token_claims.get("_claim_names", {}) 

563 if isinstance(claim_names, dict) and "groups" in claim_names: 

564 user_email = user_data.get("email") or user_data.get("preferred_username") or "unknown" 

565 logger.warning( 

566 f"Group overage detected for user {user_email} - token contains too many groups (>200). " 

567 f"Role mapping may be incomplete. Consider using App Roles or Azure group filtering. " 

568 f"See docs/docs/manage/sso-entra-role-mapping.md#token-size-considerations" 

569 ) 

570 

571 # Extract groups from id_token (Security Groups as Object IDs) 

572 if "groups" in id_token_claims: 

573 user_data["groups"] = id_token_claims["groups"] 

574 logger.debug(f"Extracted {len(id_token_claims['groups'])} groups from Entra ID token") 

575 

576 # Extract roles from id_token (App Roles) 

577 if "roles" in id_token_claims: 

578 user_data["roles"] = id_token_claims["roles"] 

579 logger.debug(f"Extracted {len(id_token_claims['roles'])} roles from Entra ID token") 

580 

581 # Also extract any missing basic claims from id_token 

582 for claim in ["email", "name", "preferred_username", "oid", "sub"]: 

583 if claim not in user_data and claim in id_token_claims: 

584 user_data[claim] = id_token_claims[claim] 

585 

586 # For Keycloak, also extract groups/roles from id_token if available 

587 if provider.id == "keycloak" and token_data and "id_token" in token_data: 

588 id_token_claims = self._decode_jwt_claims(token_data["id_token"]) 

589 if id_token_claims: 589 ↛ 596line 589 didn't jump to line 596 because the condition on line 589 was always true

590 # Keycloak includes realm_access, resource_access, and groups in id_token 

591 for claim in ["realm_access", "resource_access", "groups"]: 

592 if claim in id_token_claims and claim not in user_data: 592 ↛ 591line 592 didn't jump to line 591 because the condition on line 592 was always true

593 user_data[claim] = id_token_claims[claim] 

594 

595 # Normalize user info across providers 

596 return self._normalize_user_info(provider, user_data) 

597 logger.error(f"User info request failed for {provider.name}: HTTP {response.status_code} - {response.text}") 

598 

599 return None 

600 

601 def _normalize_user_info(self, provider: SSOProvider, user_data: Dict[str, Any]) -> Dict[str, Any]: 

602 """Normalize user info from different providers to common format. 

603 

604 Args: 

605 provider: SSO provider configuration 

606 user_data: Raw user data from provider 

607 

608 Returns: 

609 Normalized user info dict 

610 """ 

611 # Handle GitHub provider 

612 if provider.id == "github": 

613 return { 

614 "email": user_data.get("email"), 

615 "full_name": user_data.get("name") or user_data.get("login"), 

616 "avatar_url": user_data.get("avatar_url"), 

617 "provider_id": user_data.get("id"), 

618 "username": user_data.get("login"), 

619 "provider": "github", 

620 "organizations": user_data.get("organizations", []), 

621 } 

622 

623 # Handle Google provider 

624 if provider.id == "google": 

625 return { 

626 "email": user_data.get("email"), 

627 "full_name": user_data.get("name"), 

628 "avatar_url": user_data.get("picture"), 

629 "provider_id": user_data.get("sub"), 

630 "username": user_data.get("email", "").split("@")[0], 

631 "provider": "google", 

632 } 

633 

634 # Handle IBM Verify provider 

635 if provider.id == "ibm_verify": 

636 return { 

637 "email": user_data.get("email"), 

638 "full_name": user_data.get("name"), 

639 "avatar_url": user_data.get("picture"), 

640 "provider_id": user_data.get("sub"), 

641 "username": user_data.get("preferred_username") or user_data.get("email", "").split("@")[0], 

642 "provider": "ibm_verify", 

643 } 

644 

645 # Handle Okta provider 

646 if provider.id == "okta": 

647 return { 

648 "email": user_data.get("email"), 

649 "full_name": user_data.get("name"), 

650 "avatar_url": user_data.get("picture"), 

651 "provider_id": user_data.get("sub"), 

652 "username": user_data.get("preferred_username") or user_data.get("email", "").split("@")[0], 

653 "provider": "okta", 

654 } 

655 

656 # Handle Keycloak provider with role mapping 

657 if provider.id == "keycloak": 

658 metadata = provider.provider_metadata or {} 

659 username_claim = metadata.get("username_claim", "preferred_username") 

660 email_claim = metadata.get("email_claim", "email") 

661 groups_claim = metadata.get("groups_claim", "groups") 

662 

663 groups = [] 

664 

665 # Extract realm roles 

666 if metadata.get("map_realm_roles"): 

667 realm_access = user_data.get("realm_access", {}) 

668 realm_roles = realm_access.get("roles", []) 

669 groups.extend(realm_roles) 

670 

671 # Extract client roles 

672 if metadata.get("map_client_roles"): 

673 resource_access = user_data.get("resource_access", {}) 

674 for client, access in resource_access.items(): 

675 client_roles = access.get("roles", []) 

676 # Prefix with client name to avoid conflicts 

677 groups.extend([f"{client}:{role}" for role in client_roles]) 

678 

679 # Extract groups from custom claim 

680 if groups_claim in user_data: 

681 custom_groups = user_data.get(groups_claim, []) 

682 if isinstance(custom_groups, list): 682 ↛ 685line 682 didn't jump to line 685 because the condition on line 682 was always true

683 groups.extend(custom_groups) 

684 

685 return { 

686 "email": user_data.get(email_claim), 

687 "full_name": user_data.get("name"), 

688 "avatar_url": user_data.get("picture"), 

689 "provider_id": user_data.get("sub"), 

690 "username": user_data.get(username_claim) or user_data.get(email_claim, "").split("@")[0], 

691 "provider": "keycloak", 

692 "groups": list(set(groups)), # Deduplicate 

693 } 

694 

695 # Handle Microsoft Entra ID provider with role mapping 

696 if provider.id == "entra": 

697 metadata = provider.provider_metadata or {} 

698 groups_claim = metadata.get("groups_claim", "groups") 

699 

700 # Microsoft's userinfo endpoint often omits the email claim 

701 # Fallback: preferred_username (UPN) or upn claim 

702 email = user_data.get("email") or user_data.get("preferred_username") or user_data.get("upn") 

703 

704 # Extract username from email/UPN 

705 username = None 

706 if user_data.get("preferred_username"): 

707 username = user_data.get("preferred_username") 

708 elif email: 

709 username = email.split("@")[0] 

710 

711 # Extract groups from token 

712 groups = [] 

713 

714 # Check configured groups claim (default: 'groups') 

715 if groups_claim in user_data: 

716 groups_value = user_data.get(groups_claim, []) 

717 if isinstance(groups_value, list): 717 ↛ 721line 717 didn't jump to line 721 because the condition on line 717 was always true

718 groups.extend(groups_value) 

719 

720 # Also check 'roles' claim for App Role assignments 

721 if "roles" in user_data: 

722 roles_value = user_data.get("roles", []) 

723 if isinstance(roles_value, list): 723 ↛ 726line 723 didn't jump to line 726 because the condition on line 723 was always true

724 groups.extend(roles_value) 

725 

726 return { 

727 "email": email, 

728 "full_name": user_data.get("name") or email, # Fallback to email if name missing 

729 "avatar_url": user_data.get("picture"), 

730 "provider_id": user_data.get("sub") or user_data.get("oid"), 

731 "username": username, 

732 "provider": "entra", 

733 "groups": list(set(groups)), # Deduplicate 

734 } 

735 

736 # Generic OIDC format for all other providers 

737 return { 

738 "email": user_data.get("email"), 

739 "full_name": user_data.get("name"), 

740 "avatar_url": user_data.get("picture"), 

741 "provider_id": user_data.get("sub"), 

742 "username": user_data.get("preferred_username") or user_data.get("email", "").split("@")[0], 

743 "provider": provider.id, 

744 } 

745 

746 async def authenticate_or_create_user(self, user_info: Dict[str, Any]) -> Optional[str]: 

747 """Authenticate existing user or create new user from SSO info. 

748 

749 Args: 

750 user_info: Normalized user info from SSO provider 

751 

752 Returns: 

753 JWT token for authenticated user or None if failed 

754 """ 

755 email = user_info.get("email") 

756 if not email: 

757 return None 

758 

759 # Check if user exists 

760 user = await self.auth_service.get_user_by_email(email) 

761 

762 if user: 

763 # Update user info from SSO 

764 if user_info.get("full_name") and user_info["full_name"] != user.full_name: 

765 user.full_name = user_info["full_name"] 

766 

767 # Update auth provider if changed 

768 if user.auth_provider == "local" or user.auth_provider != user_info.get("provider"): 

769 user.auth_provider = user_info.get("provider", "sso") 

770 

771 # Mark email as verified for SSO users 

772 user.email_verified = True 

773 user.last_login = utc_now() 

774 

775 # Synchronize is_admin status based on current group membership 

776 # Track origin to support both promotion AND demotion for SSO-granted admins 

777 # Manual/API grants are "sticky" - never auto-demoted by SSO 

778 # Only users with admin_origin="sso" can be demoted on login 

779 provider = self.get_provider(user_info.get("provider")) 

780 if provider: 780 ↛ 796line 780 didn't jump to line 796 because the condition on line 780 was always true

781 should_be_admin = self._should_user_be_admin(email, user_info, provider) 

782 if should_be_admin: 

783 # Grant admin access 

784 if not user.is_admin: 

785 logger.info(f"Upgrading is_admin to True for {email} based on SSO admin groups") 

786 user.is_admin = True 

787 # Track that admin was granted via SSO (only set on initial grant) 

788 user.admin_origin = "sso" 

789 # Do NOT change admin_origin if already admin - preserve manual/API grants 

790 elif user.is_admin and user.admin_origin == "sso": 

791 # User was SSO admin but no longer in admin groups - revoke access 

792 logger.info(f"Revoking is_admin for {email} - removed from SSO admin groups") 

793 user.is_admin = False 

794 user.admin_origin = None 

795 

796 self.db.commit() 

797 

798 # Determine if syncing should happen (default True, respect provider-level and Entra setting) 

799 should_sync = True 

800 if provider: 800 ↛ 809line 800 didn't jump to line 809 because the condition on line 800 was always true

801 # Check provider-level sync_roles flag in provider_metadata (allows disabling per-provider) 

802 metadata = provider.provider_metadata or {} 

803 if "sync_roles" in metadata: 

804 should_sync = metadata.get("sync_roles", True) 

805 # Legacy Entra-specific setting (fallback for backwards compatibility) 

806 elif provider.id == "entra" and hasattr(settings, "sso_entra_sync_roles_on_login"): 

807 should_sync = settings.sso_entra_sync_roles_on_login 

808 

809 if provider and should_sync: 809 ↛ 893line 809 didn't jump to line 893 because the condition on line 809 was always true

810 role_assignments = await self._map_groups_to_roles(email, user_info.get("groups", []), provider) 

811 await self._sync_user_roles(email, role_assignments, provider) 

812 else: 

813 # Auto-create user if enabled 

814 provider = self.get_provider(user_info.get("provider")) 

815 if not provider or not provider.auto_create_users: 

816 return None 

817 

818 # Check trusted domains if configured 

819 if provider.trusted_domains: 

820 domain = email.split("@")[1].lower() 

821 if domain not in [d.lower() for d in provider.trusted_domains]: 821 ↛ 825line 821 didn't jump to line 825 because the condition on line 821 was always true

822 return None 

823 

824 # Check if admin approval is required 

825 if settings.sso_require_admin_approval: 

826 # Check if user is already pending approval 

827 

828 pending = self.db.execute(select(PendingUserApproval).where(PendingUserApproval.email == email)).scalar_one_or_none() 

829 

830 if pending: 

831 if pending.status == "pending" and not pending.is_expired(): 

832 return None # Still waiting for approval 

833 if pending.status == "rejected": 

834 return None # User was rejected 

835 if pending.status == "approved": 835 ↛ 856line 835 didn't jump to line 856 because the condition on line 835 was always true

836 # User was approved, create account now 

837 pass # Continue with user creation below 

838 else: 

839 # Create pending approval request 

840 

841 pending = PendingUserApproval( 

842 email=email, 

843 full_name=user_info.get("full_name", email), 

844 auth_provider=user_info.get("provider", "sso"), 

845 sso_metadata=user_info, 

846 expires_at=utc_now() + timedelta(days=30), # 30-day approval window 

847 ) 

848 self.db.add(pending) 

849 self.db.commit() 

850 logger.info(f"Created pending approval request for SSO user: {email}") 

851 return None # No token until approved 

852 

853 # Create new user (either no approval required, or approval already granted) 

854 # Generate a secure random password for SSO users (they won't use it) 

855 

856 random_password = "".join(secrets.choice(string.ascii_letters + string.digits + string.punctuation) for _ in range(32)) 

857 

858 # Determine if user should be admin based on domain/organization 

859 is_admin = self._should_user_be_admin(email, user_info, provider) 

860 

861 user = await self.auth_service.create_user( 

862 email=email, 

863 password=random_password, # Random password for SSO users (not used) 

864 full_name=user_info.get("full_name", email), 

865 is_admin=is_admin, 

866 auth_provider=user_info.get("provider", "sso"), 

867 ) 

868 if not user: 

869 return None 

870 

871 # Assign RBAC roles based on SSO groups (or default role if no groups) 

872 # Check provider-level sync_roles flag in provider_metadata 

873 metadata = provider.provider_metadata or {} 

874 should_sync = metadata.get("sync_roles", True) 

875 # Legacy Entra-specific setting (fallback for backwards compatibility) 

876 if "sync_roles" not in metadata and provider.id == "entra" and hasattr(settings, "sso_entra_sync_roles_on_login"): 876 ↛ 877line 876 didn't jump to line 877 because the condition on line 876 was never true

877 should_sync = settings.sso_entra_sync_roles_on_login 

878 

879 if should_sync: 879 ↛ 885line 879 didn't jump to line 885 because the condition on line 879 was always true

880 role_assignments = await self._map_groups_to_roles(email, user_info.get("groups", []), provider) 

881 if role_assignments: 881 ↛ 882line 881 didn't jump to line 882 because the condition on line 881 was never true

882 await self._sync_user_roles(email, role_assignments, provider) 

883 

884 # If user was created from approved request, mark request as used 

885 if settings.sso_require_admin_approval: 

886 pending = self.db.execute(select(PendingUserApproval).where(and_(PendingUserApproval.email == email, PendingUserApproval.status == "approved"))).scalar_one_or_none() 

887 if pending: 887 ↛ 893line 887 didn't jump to line 893 because the condition on line 887 was always true

888 # Mark as used (we could delete or keep for audit trail) 

889 pending.status = "completed" 

890 self.db.commit() 

891 

892 # Generate JWT token for user — session token (teams resolved server-side) 

893 token_data = { 

894 "sub": user.email, 

895 "email": user.email, 

896 "full_name": user.full_name, 

897 "auth_provider": user.auth_provider, 

898 "iat": int(utc_now().timestamp()), 

899 "user": {"email": user.email, "full_name": user.full_name, "is_admin": user.is_admin, "auth_provider": user.auth_provider}, 

900 "token_use": "session", # nosec B105 - token type marker, not a password 

901 # Scopes 

902 "scopes": {"server_id": None, "permissions": ["*"] if user.is_admin else [], "ip_restrictions": [], "time_restrictions": {}}, 

903 } 

904 

905 # Create JWT token 

906 token = await create_jwt_token(token_data) 

907 return token 

908 

909 def _should_user_be_admin(self, email: str, user_info: Dict[str, Any], provider: SSOProvider) -> bool: 

910 """Determine if SSO user should be granted admin privileges. 

911 

912 Args: 

913 email: User's email address 

914 user_info: Normalized user info from SSO provider 

915 provider: SSO provider configuration 

916 

917 Returns: 

918 True if user should be admin, False otherwise 

919 """ 

920 # Check domain-based admin assignment 

921 domain = email.split("@")[1].lower() 

922 if domain in [d.lower() for d in settings.sso_auto_admin_domains]: 

923 return True 

924 

925 # Check provider-specific admin assignment 

926 if provider.id == "github" and settings.sso_github_admin_orgs: 

927 # For GitHub, we'd need to fetch user's organizations 

928 # This is a placeholder - in production, you'd make API calls to get orgs 

929 github_orgs = user_info.get("organizations", []) 

930 if any(org.lower() in [o.lower() for o in settings.sso_github_admin_orgs] for org in github_orgs): 

931 return True 

932 

933 if provider.id == "google" and settings.sso_google_admin_domains: 

934 # Check if user's domain is in admin domains 

935 if domain in [d.lower() for d in settings.sso_google_admin_domains]: 

936 return True 

937 

938 # Check EntraID admin groups 

939 if provider.id == "entra" and settings.sso_entra_admin_groups: 

940 user_groups = user_info.get("groups", []) 

941 if any(group.lower() in [g.lower() for g in settings.sso_entra_admin_groups] for group in user_groups): 

942 return True 

943 

944 return False 

945 

946 async def _map_groups_to_roles(self, user_email: str, user_groups: List[str], provider: SSOProvider) -> List[Dict[str, Any]]: 

947 """Map SSO groups to Context Forge RBAC roles. 

948 

949 Args: 

950 user_email: User's email address 

951 user_groups: List of groups from SSO provider 

952 provider: SSO provider configuration 

953 

954 Returns: 

955 List of role assignments: [{"role_name": str, "scope": str, "scope_id": Optional[str]}] 

956 """ 

957 # pylint: disable=import-outside-toplevel 

958 # First-Party 

959 from mcpgateway.services.role_service import RoleService 

960 

961 role_assignments = [] 

962 

963 # Generic Role Mapping Logic 

964 metadata = provider.provider_metadata or {} 

965 role_mappings = metadata.get("role_mappings", {}) 

966 

967 # Merge with legacy Entra specific settings if applicable 

968 has_entra_admin_groups = provider.id == "entra" and settings.sso_entra_admin_groups 

969 has_entra_default_role = provider.id == "entra" and settings.sso_entra_default_role 

970 

971 if provider.id == "entra": 

972 # Use generic role_mappings fallback to legacy setting 

973 if not role_mappings and settings.sso_entra_role_mappings: 

974 role_mappings = settings.sso_entra_role_mappings 

975 

976 # Early exit: Skip role mapping if no configuration exists 

977 if not role_mappings and not has_entra_admin_groups and not has_entra_default_role: 

978 logger.debug(f"No role mappings configured for provider {provider.id}, skipping role sync") 

979 return role_assignments 

980 

981 # Handle EntraID admin groups -> platform_admin 

982 if has_entra_admin_groups: 

983 admin_groups_lower = [g.lower() for g in settings.sso_entra_admin_groups] 

984 for group in user_groups: 

985 if group.lower() in admin_groups_lower: 

986 role_assignments.append({"role_name": "platform_admin", "scope": "global", "scope_id": None}) 

987 logger.debug(f"Mapped EntraID admin group to platform_admin role for {user_email}") 

988 break # Only need one admin assignment 

989 

990 # Batch role lookups: collect all role names that need to be looked up 

991 role_names_to_lookup = set() 

992 for group in user_groups: 

993 if group in role_mappings: 

994 role_name = role_mappings[group] 

995 if role_name not in ["admin", "platform_admin"]: 

996 role_names_to_lookup.add(role_name) 

997 

998 # Add default role to lookup if needed 

999 if has_entra_default_role: 

1000 role_names_to_lookup.add(settings.sso_entra_default_role) 

1001 

1002 # Pre-fetch all roles by name in batches (reduces DB round-trips) 

1003 role_service = RoleService(self.db) 

1004 role_cache: Dict[str, Any] = {} 

1005 for role_name in role_names_to_lookup: 

1006 # Try team scope first, then global 

1007 role = await role_service.get_role_by_name(role_name, scope="team") 

1008 if not role: 

1009 role = await role_service.get_role_by_name(role_name, scope="global") 

1010 if role: 

1011 role_cache[role_name] = role 

1012 

1013 # Process role mappings for ALL providers 

1014 for group in user_groups: 

1015 if group in role_mappings: 

1016 role_name = role_mappings[group] 

1017 # Special case for "admin"/"platform_admin" shorthand 

1018 if role_name in ["admin", "platform_admin"]: 

1019 role_assignments.append({"role_name": "platform_admin", "scope": "global", "scope_id": None}) 

1020 logger.debug(f"Mapped group to platform_admin role for {user_email}") 

1021 continue 

1022 

1023 # Use pre-fetched role from cache 

1024 role = role_cache.get(role_name) 

1025 if role: 

1026 # Avoid duplicate assignments 

1027 if not any(r["role_name"] == role.name for r in role_assignments): 1027 ↛ 1014line 1027 didn't jump to line 1014 because the condition on line 1027 was always true

1028 role_assignments.append({"role_name": role.name, "scope": role.scope, "scope_id": None}) 

1029 logger.debug(f"Mapped group to role '{role.name}' for {user_email}") 

1030 else: 

1031 logger.warning(f"Role '{role_name}' not found for group mapping") 

1032 

1033 # Apply default role if no mappings found (Entra legacy fallback) 

1034 if not role_assignments and has_entra_default_role: 

1035 default_role = role_cache.get(settings.sso_entra_default_role) 

1036 if default_role: 1036 ↛ 1040line 1036 didn't jump to line 1040 because the condition on line 1036 was always true

1037 role_assignments.append({"role_name": default_role.name, "scope": default_role.scope, "scope_id": None}) 

1038 logger.info(f"Assigned default role '{default_role.name}' to {user_email}") 

1039 

1040 return role_assignments 

1041 

1042 async def _sync_user_roles(self, user_email: str, role_assignments: List[Dict[str, Any]], _provider: SSOProvider) -> None: 

1043 """Synchronize user's SSO-based role assignments. 

1044 

1045 Args: 

1046 user_email: User's email address 

1047 role_assignments: List of role assignments to apply 

1048 _provider: SSO provider configuration (reserved for future use) 

1049 """ 

1050 # pylint: disable=import-outside-toplevel 

1051 # First-Party 

1052 from mcpgateway.services.role_service import RoleService 

1053 

1054 role_service = RoleService(self.db) 

1055 

1056 # Get current SSO-granted roles (granted_by='sso_system') 

1057 current_roles = await role_service.list_user_roles(user_email, include_expired=False) 

1058 sso_roles = [r for r in current_roles if r.granted_by == "sso_system"] 

1059 

1060 # Build set of desired role assignments 

1061 desired_roles = {(r["role_name"], r["scope"], r.get("scope_id")) for r in role_assignments} 

1062 

1063 # Revoke roles that are no longer in the desired set 

1064 for user_role in sso_roles: 

1065 role_tuple = (user_role.role.name, user_role.scope, user_role.scope_id) 

1066 if role_tuple not in desired_roles: 

1067 await role_service.revoke_role_from_user(user_email=user_email, role_id=user_role.role_id, scope=user_role.scope, scope_id=user_role.scope_id) 

1068 logger.info(f"Revoked SSO role '{user_role.role.name}' from {user_email} (no longer in groups)") 

1069 

1070 # Assign new roles 

1071 for assignment in role_assignments: 

1072 try: 

1073 # Get role by name 

1074 role = await role_service.get_role_by_name(assignment["role_name"], scope=assignment["scope"]) 

1075 if not role: 

1076 logger.warning(f"Role '{assignment['role_name']}' not found, skipping assignment for {user_email}") 

1077 continue 

1078 

1079 # Check if assignment already exists 

1080 existing = await role_service.get_user_role_assignment(user_email=user_email, role_id=role.id, scope=assignment["scope"], scope_id=assignment.get("scope_id")) 

1081 

1082 if not existing or not existing.is_active: 

1083 # Assign role to user 

1084 await role_service.assign_role_to_user(user_email=user_email, role_id=role.id, scope=assignment["scope"], scope_id=assignment.get("scope_id"), granted_by="sso_system") 

1085 logger.info(f"Assigned SSO role '{role.name}' to {user_email}") 

1086 

1087 except Exception as e: 

1088 logger.warning(f"Failed to assign role '{assignment['role_name']}' to {user_email}: {e}")