Coverage for mcpgateway / services / sso_service.py: 98%

888 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/sso_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Single Sign-On (SSO) authentication service for OAuth2 and OIDC providers. 

8Handles provider management, OAuth flows, and user authentication. 

9""" 

10 

11# Future 

12from __future__ import annotations 

13 

14# Standard 

15import asyncio 

16import base64 

17from dataclasses import dataclass 

18from datetime import timedelta 

19import hashlib 

20import hmac 

21import logging 

22import secrets 

23import string 

24from time import monotonic 

25from typing import Any, Dict, List, Optional, Tuple 

26import urllib.parse 

27 

28# Third-Party 

29import jwt 

30import orjson 

31from sqlalchemy import and_, select 

32from sqlalchemy.orm import Session 

33 

34# First-Party 

35from mcpgateway.config import settings 

36from mcpgateway.db import PendingUserApproval, SSOAuthSession, SSOProvider, utc_now 

37from mcpgateway.services.email_auth_service import EmailAuthService 

38from mcpgateway.services.encryption_service import get_encryption_service 

39from mcpgateway.utils.create_jwt_token import create_jwt_token 

40 

41# Logger 

42logger = logging.getLogger(__name__) 

43 

44 

45@dataclass 

46class SSOProviderContext: 

47 """Lightweight context for SSO provider info passed to helper methods.""" 

48 

49 id: Optional[str] 

50 provider_metadata: Dict[str, Any] 

51 

52 

53class SSOService: 

54 """Service for managing SSO authentication flows and providers. 

55 

56 Handles OAuth2/OIDC authentication flows, provider configuration, 

57 and integration with the local user system. 

58 

59 Examples: 

60 Basic construction and helper checks: 

61 >>> from unittest.mock import Mock 

62 >>> service = SSOService(Mock()) 

63 >>> isinstance(service, SSOService) 

64 True 

65 >>> callable(service.list_enabled_providers) 

66 True 

67 """ 

68 

69 _OIDC_METADATA_CACHE_TTL_SECONDS = 300 

70 _oidc_config_cache: Dict[str, Tuple[float, Dict[str, Any]]] = {} 

71 _jwks_client_cache: Dict[str, jwt.PyJWKClient] = {} 

72 _STATE_BINDING_SEPARATOR = "." 

73 _STATE_BINDING_HEX_LEN = 64 

74 

75 def __init__(self, db: Session): 

76 """Initialize SSO service with database session. 

77 

78 Args: 

79 db: SQLAlchemy database session 

80 """ 

81 self.db = db 

82 self.auth_service = EmailAuthService(db) 

83 self._encryption = get_encryption_service(settings.auth_encryption_secret) 

84 

85 async def _encrypt_secret(self, secret: str) -> str: 

86 """Encrypt a client secret for secure storage. 

87 

88 Args: 

89 secret: Plain text client secret 

90 

91 Returns: 

92 Encrypted secret string 

93 """ 

94 return await self._encryption.encrypt_secret_async(secret) 

95 

96 async def _decrypt_secret(self, encrypted_secret: str) -> Optional[str]: 

97 """Decrypt a client secret for use. 

98 

99 Args: 

100 encrypted_secret: Encrypted secret string 

101 

102 Returns: 

103 Plain text client secret 

104 """ 

105 decrypted: str | None = await self._encryption.decrypt_secret_async(encrypted_secret) 

106 if decrypted: 

107 return decrypted 

108 

109 return None 

110 

111 def _decode_jwt_claims(self, token: str) -> Optional[Dict[str, Any]]: 

112 """Decode JWT token payload without verification. 

113 

114 This is used to extract claims from ID tokens where we've already 

115 validated the OAuth flow. The token signature is not verified here 

116 because the token was received directly from the trusted token endpoint. 

117 

118 Args: 

119 token: JWT token string 

120 

121 Returns: 

122 Decoded payload dict or None if decoding fails 

123 

124 Examples: 

125 >>> from unittest.mock import Mock 

126 >>> service = SSOService(Mock()) 

127 >>> # Valid JWT structure (header.payload.signature) 

128 >>> import base64 

129 >>> payload = base64.urlsafe_b64encode(b'{"sub":"123","groups":["admin"]}').decode().rstrip('=') 

130 >>> token = f"eyJhbGciOiJSUzI1NiJ9.{payload}.signature" 

131 >>> claims = service._decode_jwt_claims(token) 

132 >>> claims is not None 

133 True 

134 """ 

135 try: 

136 # JWT format: header.payload.signature 

137 parts = token.split(".") 

138 if len(parts) != 3: 

139 logger.warning("Invalid JWT format: expected 3 parts") 

140 return None 

141 

142 # Decode payload (middle part) - add padding if needed 

143 payload_b64 = parts[1] 

144 # Add padding for base64 decoding 

145 padding = 4 - len(payload_b64) % 4 

146 if padding != 4: 

147 payload_b64 += "=" * padding 

148 

149 payload_bytes = base64.urlsafe_b64decode(payload_b64) 

150 return orjson.loads(payload_bytes) 

151 

152 except (ValueError, orjson.JSONDecodeError, UnicodeDecodeError) as e: 

153 logger.warning(f"Failed to decode JWT claims: {e}") 

154 return None 

155 

156 async def _get_oidc_provider_metadata(self, issuer: str) -> Optional[Dict[str, Any]]: 

157 """Discover and cache OIDC provider metadata. 

158 

159 Args: 

160 issuer: OIDC issuer URL. 

161 

162 Returns: 

163 Provider metadata dict from discovery endpoint, or None on failure. 

164 """ 

165 normalized_issuer = issuer.rstrip("/") 

166 cached = self._oidc_config_cache.get(normalized_issuer) 

167 if cached is not None: 

168 cached_at, cached_metadata = cached 

169 if monotonic() - cached_at < self._OIDC_METADATA_CACHE_TTL_SECONDS: 

170 return cached_metadata 

171 self._oidc_config_cache.pop(normalized_issuer, None) 

172 

173 # First-Party 

174 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

175 

176 discovery_url = f"{normalized_issuer}/.well-known/openid-configuration" 

177 try: 

178 client = await get_http_client() 

179 response = await client.get(discovery_url, timeout=settings.oauth_request_timeout) 

180 if response.status_code != 200: 

181 logger.warning("OIDC discovery failed for issuer %s with HTTP %s", normalized_issuer, response.status_code) 

182 return None 

183 

184 metadata = response.json() 

185 if not isinstance(metadata, dict): 

186 logger.warning("OIDC discovery response for issuer %s is not a JSON object", normalized_issuer) 

187 return None 

188 self._oidc_config_cache[normalized_issuer] = (monotonic(), metadata) 

189 return metadata 

190 except Exception as exc: 

191 logger.warning("OIDC discovery request failed for issuer %s: %s", normalized_issuer, exc) 

192 return None 

193 

194 async def _resolve_oidc_issuer_and_jwks(self, provider: SSOProvider) -> Tuple[Optional[str], Optional[str]]: 

195 """Resolve issuer and JWKS URI for an OIDC provider. 

196 

197 Args: 

198 provider: SSO provider configuration. 

199 

200 Returns: 

201 Tuple of (issuer, jwks_uri), each optional when unavailable. 

202 """ 

203 issuer = provider.issuer.strip() if isinstance(provider.issuer, str) and provider.issuer.strip() else None 

204 jwks_uri = provider.jwks_uri.strip() if isinstance(provider.jwks_uri, str) and provider.jwks_uri.strip() else None 

205 

206 if issuer and not jwks_uri: 

207 metadata = await self._get_oidc_provider_metadata(issuer) 

208 if metadata: 

209 discovered_jwks = metadata.get("jwks_uri") 

210 discovered_issuer = metadata.get("issuer") 

211 if isinstance(discovered_jwks, str) and discovered_jwks.strip(): 

212 jwks_uri = discovered_jwks.strip() 

213 if isinstance(discovered_issuer, str) and discovered_issuer.strip(): 

214 issuer = discovered_issuer.strip() 

215 

216 return issuer, jwks_uri 

217 

218 def _get_jwks_client(self, jwks_uri: str) -> jwt.PyJWKClient: 

219 """Get or create a cached PyJWKClient instance. 

220 

221 Args: 

222 jwks_uri: JWKS endpoint URL. 

223 

224 Returns: 

225 Cached or newly created `PyJWKClient`. 

226 """ 

227 if jwks_uri not in self._jwks_client_cache: 

228 self._jwks_client_cache[jwks_uri] = jwt.PyJWKClient(jwks_uri) 

229 return self._jwks_client_cache[jwks_uri] 

230 

231 async def _verify_oidc_id_token(self, provider: SSOProvider, id_token: str, expected_nonce: Optional[str] = None) -> Optional[Dict[str, Any]]: 

232 """Verify OIDC ID token signature and claims. 

233 

234 Args: 

235 provider: SSO provider configuration. 

236 id_token: Raw OIDC ID token string. 

237 expected_nonce: Expected nonce from auth session, when available. 

238 

239 Returns: 

240 Verified token claims when validation succeeds, otherwise None. 

241 """ 

242 if provider.provider_type != "oidc": 

243 return None 

244 

245 issuer, jwks_uri = await self._resolve_oidc_issuer_and_jwks(provider) 

246 if not jwks_uri: 

247 logger.warning("Skipping id_token claim usage for provider %s: missing jwks_uri", provider.id) 

248 return None 

249 

250 try: 

251 jwks_client = self._get_jwks_client(jwks_uri) 

252 signing_key = await asyncio.to_thread(jwks_client.get_signing_key_from_jwt, id_token) 

253 

254 decode_kwargs: Dict[str, Any] = { 

255 "key": signing_key.key, 

256 "algorithms": ["RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512", "EdDSA"], 

257 "audience": provider.client_id, 

258 "options": { 

259 "verify_signature": True, 

260 "verify_exp": True, 

261 "verify_iat": True, 

262 "verify_aud": True, 

263 "verify_iss": bool(issuer), 

264 }, 

265 } 

266 if issuer: 

267 decode_kwargs["issuer"] = issuer 

268 

269 claims = await asyncio.to_thread(jwt.decode, id_token, **decode_kwargs) 

270 if expected_nonce is not None and claims.get("nonce") != expected_nonce: 

271 logger.warning("OIDC id_token nonce validation failed for provider %s", provider.id) 

272 return None 

273 return claims 

274 except jwt.PyJWTError as exc: 

275 logger.warning("OIDC id_token verification failed for provider %s: %s", provider.id, exc) 

276 return None 

277 except Exception as exc: 

278 logger.warning("Unexpected OIDC id_token verification error for provider %s: %s", provider.id, exc) 

279 return None 

280 

281 def _resolve_entra_graph_fallback_settings(self, provider_metadata: Optional[Dict[str, Any]]) -> Tuple[bool, int, int]: 

282 """Resolve Entra Graph fallback settings with provider metadata override. 

283 

284 Args: 

285 provider_metadata: Optional provider metadata from DB/config. 

286 

287 Returns: 

288 Tuple of (enabled, timeout_seconds, max_groups). 

289 """ 

290 metadata = provider_metadata or {} 

291 enabled = settings.sso_entra_graph_api_enabled 

292 timeout = settings.sso_entra_graph_api_timeout 

293 max_groups = settings.sso_entra_graph_api_max_groups 

294 

295 if "graph_api_enabled" in metadata: 

296 metadata_enabled = metadata.get("graph_api_enabled") 

297 if isinstance(metadata_enabled, bool): 

298 enabled = metadata_enabled 

299 elif isinstance(metadata_enabled, str): 

300 normalized_enabled = metadata_enabled.strip().lower() 

301 if normalized_enabled in {"1", "true", "yes", "on"}: 

302 enabled = True 

303 elif normalized_enabled in {"0", "false", "no", "off"}: 

304 enabled = False 

305 else: 

306 logger.warning("Invalid provider_metadata.graph_api_enabled=%s; using configured default %s", metadata_enabled, enabled) 

307 elif isinstance(metadata_enabled, int): 

308 enabled = metadata_enabled != 0 

309 else: 

310 logger.warning("Invalid provider_metadata.graph_api_enabled=%s (type %s); using configured default %s", metadata_enabled, type(metadata_enabled).__name__, enabled) 

311 

312 metadata_timeout = metadata.get("graph_api_timeout") 

313 if metadata_timeout is not None: 

314 try: 

315 timeout_candidate = int(metadata_timeout) 

316 if 1 <= timeout_candidate <= 120: 

317 timeout = timeout_candidate 

318 else: 

319 logger.warning("Invalid provider_metadata.graph_api_timeout=%s; using configured default %s", metadata_timeout, timeout) 

320 except (TypeError, ValueError): 

321 logger.warning("Invalid provider_metadata.graph_api_timeout=%s; using configured default %s", metadata_timeout, timeout) 

322 

323 metadata_max_groups = metadata.get("graph_api_max_groups") 

324 if metadata_max_groups is not None: 

325 try: 

326 max_groups_candidate = int(metadata_max_groups) 

327 if max_groups_candidate >= 0: 

328 max_groups = max_groups_candidate 

329 else: 

330 logger.warning("Invalid provider_metadata.graph_api_max_groups=%s; using configured default %s", metadata_max_groups, max_groups) 

331 except (TypeError, ValueError): 

332 logger.warning("Invalid provider_metadata.graph_api_max_groups=%s; using configured default %s", metadata_max_groups, max_groups) 

333 

334 return enabled, timeout, max_groups 

335 

336 async def _fetch_entra_groups_from_graph_api(self, access_token: str, user_email: str, provider_metadata: Optional[Dict[str, Any]] = None) -> Optional[List[str]]: 

337 """Fetch Entra group object IDs from Microsoft Graph for overage tokens. 

338 

339 Args: 

340 access_token: Delegated OAuth access token from Entra. 

341 user_email: User identifier for structured logs. 

342 provider_metadata: Optional provider metadata for runtime overrides. 

343 

344 Returns: 

345 List of group IDs on success, empty list when Graph returns no IDs, 

346 or None if retrieval failed/disabled. 

347 """ 

348 graph_api_enabled, graph_api_timeout, graph_api_max_groups = self._resolve_entra_graph_fallback_settings(provider_metadata) 

349 if not graph_api_enabled: 

350 logger.info("Microsoft Graph fallback for Entra group overage is disabled") 

351 return None 

352 

353 # First-Party 

354 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

355 

356 client = await get_http_client() 

357 try: 

358 response = await client.post( 

359 "https://graph.microsoft.com/v1.0/me/getMemberObjects", 

360 headers={"Authorization": f"Bearer {access_token}"}, 

361 json={"securityEnabledOnly": True}, 

362 timeout=graph_api_timeout, 

363 ) 

364 except Exception as e: 

365 logger.error("Failed to retrieve groups from Graph API for %s: %s", user_email, e) 

366 return None 

367 

368 if response.status_code != 200: 

369 if response.status_code in {401, 403}: 

370 logger.error( 

371 "Failed to retrieve groups from Graph API for %s: HTTP %s. Check Entra delegated permissions and consent (minimum User.Read).", 

372 user_email, 

373 response.status_code, 

374 ) 

375 else: 

376 logger.error("Failed to retrieve groups from Graph API for %s: HTTP %s", user_email, response.status_code) 

377 return None 

378 

379 try: 

380 groups_payload = response.json() 

381 except ValueError as e: 

382 logger.error("Failed to parse Graph API response for %s: %s", user_email, e) 

383 return None 

384 

385 group_values = groups_payload.get("value", []) 

386 if not isinstance(group_values, list): 

387 logger.warning("Unexpected Graph API groups payload for %s: expected list in 'value'", user_email) 

388 return [] 

389 

390 deduped_groups: List[str] = [] 

391 seen_groups: set[str] = set() 

392 for group_id in group_values: 

393 if not isinstance(group_id, str): 

394 continue 

395 normalized_group_id = group_id.strip() 

396 if not normalized_group_id or normalized_group_id in seen_groups: 

397 continue 

398 seen_groups.add(normalized_group_id) 

399 deduped_groups.append(normalized_group_id) 

400 

401 if graph_api_max_groups > 0: 

402 if len(deduped_groups) > graph_api_max_groups: 

403 logger.warning( 

404 "Graph API returned %d groups for %s; applying configured cap (%d)", 

405 len(deduped_groups), 

406 user_email, 

407 graph_api_max_groups, 

408 ) 

409 deduped_groups = deduped_groups[:graph_api_max_groups] 

410 

411 logger.info("Retrieved %d groups from Graph API for %s", len(deduped_groups), user_email) 

412 return deduped_groups 

413 

414 def list_enabled_providers(self) -> List[SSOProvider]: 

415 """Get list of enabled SSO providers. 

416 

417 Returns: 

418 List of enabled SSO providers 

419 

420 Examples: 

421 Returns empty list when DB has no providers: 

422 >>> from unittest.mock import MagicMock 

423 >>> service = SSOService(MagicMock()) 

424 >>> service.db.execute.return_value.scalars.return_value.all.return_value = [] 

425 >>> service.list_enabled_providers() 

426 [] 

427 """ 

428 stmt = select(SSOProvider).where(SSOProvider.is_enabled.is_(True)) 

429 result = self.db.execute(stmt) 

430 return list(result.scalars().all()) 

431 

432 def get_provider(self, provider_id: str) -> Optional[SSOProvider]: 

433 """Get SSO provider by ID. 

434 

435 Args: 

436 provider_id: Provider identifier (e.g., 'github', 'google') 

437 

438 Returns: 

439 SSO provider or None if not found 

440 

441 Examples: 

442 >>> from unittest.mock import MagicMock 

443 >>> service = SSOService(MagicMock()) 

444 >>> service.db.execute.return_value.scalar_one_or_none.return_value = None 

445 >>> service.get_provider('x') is None 

446 True 

447 """ 

448 stmt = select(SSOProvider).where(SSOProvider.id == provider_id) 

449 result = self.db.execute(stmt) 

450 return result.scalar_one_or_none() 

451 

452 def get_provider_by_name(self, provider_name: str) -> Optional[SSOProvider]: 

453 """Get SSO provider by name. 

454 

455 Args: 

456 provider_name: Provider name (e.g., 'github', 'google') 

457 

458 Returns: 

459 SSO provider or None if not found 

460 

461 Examples: 

462 >>> from unittest.mock import MagicMock 

463 >>> service = SSOService(MagicMock()) 

464 >>> service.db.execute.return_value.scalar_one_or_none.return_value = None 

465 >>> service.get_provider_by_name('github') is None 

466 True 

467 """ 

468 stmt = select(SSOProvider).where(SSOProvider.name == provider_name) 

469 result = self.db.execute(stmt) 

470 return result.scalar_one_or_none() 

471 

472 @staticmethod 

473 def _normalize_issuer_url(issuer: str) -> str: 

474 """Normalize issuer URL for allowlist comparisons. 

475 

476 Args: 

477 issuer: Raw issuer URL from provider config or metadata. 

478 

479 Returns: 

480 Lowercased issuer URL without trailing slash. 

481 """ 

482 return issuer.strip().rstrip("/").lower() 

483 

484 def _enforce_allowed_issuer(self, issuer: Optional[str]) -> None: 

485 """Enforce configured issuer allowlist when present. 

486 

487 Args: 

488 issuer: Candidate issuer URL from provider configuration. 

489 

490 Raises: 

491 ValueError: If issuer is set but not in configured allowlist. 

492 """ 

493 allowed_issuers = getattr(settings, "sso_issuers", None) 

494 if not allowed_issuers: 

495 return 

496 

497 if not isinstance(issuer, str) or not issuer.strip(): 

498 logger.warning("SSO provider has blank/empty issuer while SSO_ISSUERS allowlist is configured; issuer enforcement skipped.") 

499 return 

500 

501 normalized_candidate = self._normalize_issuer_url(issuer) 

502 normalized_allowlist = {self._normalize_issuer_url(str(allowed_issuer)) for allowed_issuer in allowed_issuers if isinstance(allowed_issuer, str) and allowed_issuer.strip()} 

503 if normalized_allowlist and normalized_candidate not in normalized_allowlist: 

504 raise ValueError("Issuer is not allowed by SSO_ISSUERS configuration") 

505 

506 @staticmethod 

507 def _resolve_team_mapping_target(mapping_value: Any) -> Tuple[Optional[str], str]: 

508 """Resolve team mapping value into team id and role. 

509 

510 Args: 

511 mapping_value: Team mapping target value from provider config. 

512 

513 Returns: 

514 Tuple of ``(team_id, role)`` where ``team_id`` may be ``None`` and 

515 role defaults to ``member`` when not explicitly valid. 

516 """ 

517 if isinstance(mapping_value, str) and mapping_value.strip(): 

518 return mapping_value.strip(), "member" 

519 

520 if isinstance(mapping_value, dict): 

521 team_id_value = mapping_value.get("team_id") or mapping_value.get("id") 

522 team_id = str(team_id_value).strip() if team_id_value is not None else "" 

523 role_value = str(mapping_value.get("role", "member")).strip().lower() 

524 role = role_value if role_value in {"owner", "member"} else "member" 

525 return (team_id if team_id else None), role 

526 

527 return None, "member" 

528 

529 async def _apply_team_mapping(self, user_email: str, user_info: Dict[str, Any], provider: Optional[SSOProvider]) -> None: 

530 """Apply provider team mappings based on SSO group claims. 

531 

532 Args: 

533 user_email: Authenticated user email to map into teams. 

534 user_info: Identity claims payload containing optional group claims. 

535 provider: SSO provider configuration with ``team_mapping`` entries. 

536 

537 Returns: 

538 None. 

539 """ 

540 if not provider: 

541 return 

542 

543 mapping = getattr(provider, "team_mapping", None) 

544 if not isinstance(mapping, dict) or not mapping: 

545 return 

546 

547 groups_raw = user_info.get("groups", []) 

548 if isinstance(groups_raw, str): 

549 groups = [groups_raw] 

550 elif isinstance(groups_raw, list): 

551 groups = [str(group).strip() for group in groups_raw if str(group).strip()] 

552 else: 

553 groups = [] 

554 

555 if not groups: 

556 return 

557 

558 normalized_groups = {group.lower() for group in groups} 

559 

560 # First-Party 

561 from mcpgateway.services.team_management_service import MemberAlreadyExistsError, TeamManagementError, TeamManagementService # pylint: disable=import-outside-toplevel 

562 

563 team_service = TeamManagementService(self.db) 

564 for source_group, target in mapping.items(): 

565 if not isinstance(source_group, str): 

566 continue 

567 source_group_normalized = source_group.strip().lower() 

568 if not source_group_normalized or source_group_normalized not in normalized_groups: 

569 continue 

570 

571 team_id, role = self._resolve_team_mapping_target(target) 

572 if not team_id: 

573 logger.warning("Skipping invalid SSO team_mapping entry for provider %s and group '%s'", provider.id, source_group) 

574 continue 

575 

576 try: 

577 await team_service.add_member_to_team(team_id=team_id, user_email=user_email, role=role, invited_by=user_email) 

578 except MemberAlreadyExistsError: 

579 logger.debug("SSO team_mapping: user %s already member of team %s", user_email, team_id) 

580 except TeamManagementError as exc: 

581 logger.warning("SSO team_mapping failed for user %s, group '%s', team '%s': %s", user_email, source_group, team_id, exc) 

582 except Exception as exc: 

583 logger.warning("Unexpected SSO team_mapping error for user %s and team '%s': %s", user_email, team_id, exc) 

584 

585 async def create_provider(self, provider_data: Dict[str, Any]) -> SSOProvider: 

586 """Create new SSO provider configuration. 

587 

588 Args: 

589 provider_data: Provider configuration data 

590 

591 Returns: 

592 Created SSO provider 

593 

594 Examples: 

595 >>> import asyncio 

596 >>> from unittest.mock import MagicMock, AsyncMock 

597 >>> service = SSOService(MagicMock()) 

598 >>> service._encrypt_secret = AsyncMock(side_effect=lambda s: 'ENC(' + s + ')') 

599 >>> data = { 

600 ... 'id': 'github', 'name': 'github', 'display_name': 'GitHub', 'provider_type': 'oauth2', 

601 ... 'client_id': 'cid', 'client_secret': 'sec', 

602 ... 'authorization_url': 'https://example/auth', 'token_url': 'https://example/token', 

603 ... 'userinfo_url': 'https://example/user', 'scope': 'user:email' 

604 ... } 

605 >>> provider = asyncio.run(service.create_provider(data)) 

606 >>> hasattr(provider, 'id') and provider.id == 'github' 

607 True 

608 >>> provider.client_secret_encrypted.startswith('ENC(') 

609 True 

610 """ 

611 self._enforce_allowed_issuer(provider_data.get("issuer")) 

612 

613 # Encrypt client secret 

614 client_secret = provider_data.pop("client_secret") 

615 provider_data["client_secret_encrypted"] = await self._encrypt_secret(client_secret) 

616 

617 # Filter to valid SSOProvider columns to prevent TypeError on unknown keys 

618 valid_columns = {c.key for c in SSOProvider.__table__.columns} 

619 filtered_data = {k: v for k, v in provider_data.items() if k in valid_columns} 

620 skipped = set(provider_data) - set(filtered_data) 

621 if skipped: 

622 logger.warning("Ignored unknown SSOProvider fields during creation: %s", skipped) 

623 

624 provider = SSOProvider(**filtered_data) 

625 self.db.add(provider) 

626 self.db.commit() 

627 self.db.refresh(provider) 

628 return provider 

629 

630 async def update_provider(self, provider_id: str, provider_data: Dict[str, Any]) -> Optional[SSOProvider]: 

631 """Update existing SSO provider configuration. 

632 

633 Args: 

634 provider_id: Provider identifier 

635 provider_data: Updated provider data 

636 

637 Returns: 

638 Updated SSO provider or None if not found 

639 

640 Examples: 

641 >>> import asyncio 

642 >>> from types import SimpleNamespace 

643 >>> from unittest.mock import MagicMock, AsyncMock 

644 >>> svc = SSOService(MagicMock()) 

645 >>> # Existing provider object 

646 >>> existing = SimpleNamespace(id='github', name='github', client_id='old', client_secret_encrypted='X', is_enabled=True) 

647 >>> svc.get_provider = lambda _id: existing 

648 >>> svc._encrypt_secret = AsyncMock(side_effect=lambda s: 'ENC-' + s) 

649 >>> svc.db.commit = lambda: None 

650 >>> svc.db.refresh = lambda obj: None 

651 >>> updated = asyncio.run(svc.update_provider('github', {'client_id': 'new', 'client_secret': 'sec'})) 

652 >>> updated.client_id 

653 'new' 

654 >>> updated.client_secret_encrypted 

655 'ENC-sec' 

656 """ 

657 provider = self.get_provider(provider_id) 

658 if not provider: 

659 return None 

660 

661 if "issuer" in provider_data: 

662 self._enforce_allowed_issuer(provider_data.get("issuer")) 

663 

664 # Handle client secret encryption if provided 

665 if "client_secret" in provider_data: 

666 client_secret = provider_data.pop("client_secret") 

667 provider_data["client_secret_encrypted"] = await self._encrypt_secret(client_secret) 

668 

669 for key, value in provider_data.items(): 

670 if hasattr(provider, key): 

671 setattr(provider, key, value) 

672 

673 provider.updated_at = utc_now() 

674 self.db.commit() 

675 self.db.refresh(provider) 

676 return provider 

677 

678 def delete_provider(self, provider_id: str) -> bool: 

679 """Delete SSO provider configuration. 

680 

681 Args: 

682 provider_id: Provider identifier 

683 

684 Returns: 

685 True if deleted, False if not found 

686 

687 Examples: 

688 >>> from types import SimpleNamespace 

689 >>> from unittest.mock import MagicMock 

690 >>> svc = SSOService(MagicMock()) 

691 >>> svc.db.delete = lambda obj: None 

692 >>> svc.db.commit = lambda: None 

693 >>> svc.get_provider = lambda _id: SimpleNamespace(id='github') 

694 >>> svc.delete_provider('github') 

695 True 

696 >>> svc.get_provider = lambda _id: None 

697 >>> svc.delete_provider('missing') 

698 False 

699 """ 

700 provider = self.get_provider(provider_id) 

701 if not provider: 

702 return False 

703 

704 self.db.delete(provider) 

705 self.db.commit() 

706 return True 

707 

708 def generate_pkce_challenge(self) -> Tuple[str, str]: 

709 """Generate PKCE code verifier and challenge for OAuth 2.1. 

710 

711 Returns: 

712 Tuple of (code_verifier, code_challenge) 

713 

714 Examples: 

715 Generate verifier and challenge: 

716 >>> from unittest.mock import Mock 

717 >>> service = SSOService(Mock()) 

718 >>> verifier, challenge = service.generate_pkce_challenge() 

719 >>> isinstance(verifier, str) and isinstance(challenge, str) 

720 True 

721 >>> len(verifier) >= 43 

722 True 

723 >>> len(challenge) >= 43 

724 True 

725 """ 

726 # Generate cryptographically random code verifier 

727 code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=") 

728 

729 # Generate code challenge using SHA256 

730 code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()).decode("utf-8").rstrip("=") 

731 

732 return code_verifier, code_challenge 

733 

734 @staticmethod 

735 def _normalize_scope_values(values: Optional[List[str] | str]) -> List[str]: 

736 """Normalize scope values to a deduplicated, ordered list. 

737 

738 Args: 

739 values: Scope input as list or space-delimited string. 

740 

741 Returns: 

742 Ordered, unique scope values. 

743 """ 

744 if values is None: 

745 return [] 

746 

747 raw_values: List[str] 

748 if isinstance(values, str): 

749 raw_values = values.split() 

750 else: 

751 raw_values = [str(v) for v in values if isinstance(v, str) and v.strip()] 

752 

753 normalized: List[str] = [] 

754 seen: set[str] = set() 

755 for value in raw_values: 

756 scope = value.strip() 

757 if not scope or scope in seen: 

758 continue 

759 normalized.append(scope) 

760 seen.add(scope) 

761 return normalized 

762 

763 def _resolve_login_scopes(self, provider: SSOProvider, requested_scopes: Optional[List[str]]) -> List[str]: 

764 """Resolve requested SSO scopes against provider allowlists. 

765 

766 Args: 

767 provider: SSO provider configuration. 

768 requested_scopes: Optional scopes requested by the client. 

769 

770 Returns: 

771 Final scope list to send to the provider. 

772 

773 Raises: 

774 ValueError: If provider scope configuration is invalid or request includes disallowed scopes. 

775 """ 

776 configured_scopes = self._normalize_scope_values(getattr(provider, "scope", None)) 

777 if not configured_scopes: 

778 raise ValueError("Provider has no configured scopes") 

779 

780 allowed_scopes = configured_scopes 

781 provider_metadata = getattr(provider, "provider_metadata", {}) or {} 

782 metadata_allowed_raw = provider_metadata.get("allowed_scopes") 

783 metadata_allowed = self._normalize_scope_values(metadata_allowed_raw) if metadata_allowed_raw else [] 

784 if metadata_allowed: 

785 metadata_allowed_set = set(metadata_allowed) 

786 allowed_scopes = [scope for scope in configured_scopes if scope in metadata_allowed_set] 

787 

788 if not allowed_scopes: 

789 raise ValueError("No allowed scopes configured for provider") 

790 

791 if not requested_scopes: 

792 return allowed_scopes 

793 

794 normalized_requested = self._normalize_scope_values(requested_scopes) 

795 if not normalized_requested: 

796 return allowed_scopes 

797 

798 allowed_set = set(allowed_scopes) 

799 invalid_scopes = [scope for scope in normalized_requested if scope not in allowed_set] 

800 if invalid_scopes: 

801 invalid_csv = ", ".join(invalid_scopes) 

802 raise ValueError(f"Invalid scopes requested: {invalid_csv}") 

803 

804 return normalized_requested 

805 

806 def _get_state_binding_secret(self) -> bytes: 

807 """Resolve secret bytes used for state/session binding HMAC. 

808 

809 Returns: 

810 Secret bytes used for HMAC signatures. 

811 """ 

812 secret_source = settings.auth_encryption_secret 

813 if hasattr(secret_source, "get_secret_value"): 

814 return secret_source.get_secret_value().encode("utf-8") 

815 return str(secret_source).encode("utf-8") 

816 

817 def _generate_session_bound_state(self, provider_id: str, session_binding: str) -> str: 

818 """Generate state value bound to browser session context. 

819 

820 Args: 

821 provider_id: SSO provider identifier. 

822 session_binding: Browser-session marker. 

823 

824 Returns: 

825 Signed state token including nonce and HMAC signature. 

826 """ 

827 state_nonce = secrets.token_urlsafe(24) 

828 message = f"{provider_id}:{session_binding}:{state_nonce}".encode("utf-8") 

829 signature = hmac.new(self._get_state_binding_secret(), message, hashlib.sha256).hexdigest() 

830 return f"{state_nonce}{self._STATE_BINDING_SEPARATOR}{signature}" 

831 

832 def _is_session_bound_state(self, state: str) -> bool: 

833 """Return whether state appears to carry a session-binding signature. 

834 

835 Args: 

836 state: State value from OAuth flow. 

837 

838 Returns: 

839 ``True`` when state has expected nonce/signature format. 

840 """ 

841 if self._STATE_BINDING_SEPARATOR not in state: 

842 return False 

843 nonce, signature = state.rsplit(self._STATE_BINDING_SEPARATOR, 1) 

844 return bool(nonce) and len(signature) == self._STATE_BINDING_HEX_LEN 

845 

846 def _verify_session_bound_state(self, provider_id: str, state: str, session_binding: str) -> bool: 

847 """Verify state HMAC binding against the current browser session marker. 

848 

849 Args: 

850 provider_id: SSO provider identifier. 

851 state: State value to validate. 

852 session_binding: Current browser-session marker. 

853 

854 Returns: 

855 ``True`` when state signature matches the expected session binding. 

856 """ 

857 if not session_binding or not self._is_session_bound_state(state): 

858 return False 

859 nonce, signature = state.rsplit(self._STATE_BINDING_SEPARATOR, 1) 

860 message = f"{provider_id}:{session_binding}:{nonce}".encode("utf-8") 

861 expected_signature = hmac.new(self._get_state_binding_secret(), message, hashlib.sha256).hexdigest() 

862 return hmac.compare_digest(signature, expected_signature) 

863 

864 @staticmethod 

865 def _is_email_verified_claim(user_info: Dict[str, Any]) -> bool: 

866 """Evaluate email verification claim when provided by the IdP. 

867 

868 Args: 

869 user_info: Normalized user-info payload from provider. 

870 

871 Returns: 

872 ``True`` only when email verification claim is explicitly verified. 

873 """ 

874 if "email_verified" not in user_info: 

875 return False 

876 

877 claim_value = user_info.get("email_verified") 

878 if isinstance(claim_value, bool): 

879 return claim_value 

880 if isinstance(claim_value, int): 

881 return claim_value == 1 

882 if isinstance(claim_value, str): 

883 return claim_value.strip().lower() in {"1", "true", "yes", "on"} 

884 return False 

885 

886 def get_authorization_url( 

887 self, 

888 provider_id: str, 

889 redirect_uri: str, 

890 scopes: Optional[List[str]] = None, 

891 session_binding: Optional[str] = None, 

892 ) -> Optional[str]: 

893 """Generate OAuth authorization URL for provider. 

894 

895 Args: 

896 provider_id: Provider identifier 

897 redirect_uri: Callback URI after authorization 

898 scopes: Optional custom scopes (uses provider default if None) 

899 session_binding: Optional browser-session marker for state binding. 

900 

901 Returns: 

902 Authorization URL or None if provider not found 

903 

904 Examples: 

905 >>> from types import SimpleNamespace 

906 >>> from unittest.mock import MagicMock 

907 >>> service = SSOService(MagicMock()) 

908 >>> provider = SimpleNamespace(id='github', is_enabled=True, provider_type='oauth2', client_id='cid', authorization_url='https://example/auth', scope='user:email') 

909 >>> service.get_provider = lambda _pid: provider 

910 >>> service.db.add = lambda x: None 

911 >>> service.db.commit = lambda: None 

912 >>> url = service.get_authorization_url('github', 'https://app/callback', ['user:email']) 

913 >>> isinstance(url, str) and 'client_id=cid' in url and 'state=' in url 

914 True 

915 

916 Missing provider returns None: 

917 >>> service.get_provider = lambda _pid: None 

918 >>> service.get_authorization_url('missing', 'https://app/callback') is None 

919 True 

920 """ 

921 provider = self.get_provider(provider_id) 

922 if not provider or not provider.is_enabled: 

923 return None 

924 

925 # Generate PKCE parameters 

926 code_verifier, code_challenge = self.generate_pkce_challenge() 

927 

928 resolved_scopes = self._resolve_login_scopes(provider, scopes) 

929 

930 # Generate CSRF state (session-bound when browser context is available). 

931 state = self._generate_session_bound_state(provider_id, session_binding) if session_binding else secrets.token_urlsafe(32) 

932 

933 # Generate OIDC nonce if applicable 

934 nonce = secrets.token_urlsafe(16) if provider.provider_type == "oidc" else None 

935 

936 # Create auth session 

937 auth_session = SSOAuthSession(provider_id=provider_id, state=state, code_verifier=code_verifier, nonce=nonce, redirect_uri=redirect_uri) 

938 self.db.add(auth_session) 

939 self.db.commit() 

940 

941 # Build authorization URL 

942 params = { 

943 "client_id": provider.client_id, 

944 "response_type": "code", 

945 "redirect_uri": redirect_uri, 

946 "state": state, 

947 "scope": " ".join(resolved_scopes), 

948 "code_challenge": code_challenge, 

949 "code_challenge_method": "S256", 

950 } 

951 

952 if nonce: 

953 params["nonce"] = nonce 

954 

955 return f"{provider.authorization_url}?{urllib.parse.urlencode(params)}" 

956 

957 async def handle_oauth_callback(self, provider_id: str, code: str, state: str, session_binding: Optional[str] = None) -> Optional[Dict[str, Any]]: 

958 """Handle OAuth callback and exchange code for tokens. 

959 

960 Args: 

961 provider_id: Provider identifier 

962 code: Authorization code from callback 

963 state: CSRF state parameter 

964 session_binding: Optional browser-session marker used to verify bound state. 

965 

966 Returns: 

967 User info dict or None if authentication failed 

968 

969 Examples: 

970 Happy-path with patched exchanges and user info: 

971 >>> import asyncio 

972 >>> from types import SimpleNamespace 

973 >>> from unittest.mock import MagicMock 

974 >>> svc = SSOService(MagicMock()) 

975 >>> # Mock DB auth session lookup 

976 >>> provider = SimpleNamespace(id='github', is_enabled=True, provider_type='oauth2') 

977 >>> auth_session = SimpleNamespace(provider_id='github', state='st', provider=provider, is_expired=False, nonce=None) 

978 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = auth_session 

979 >>> # Patch token exchange and user info retrieval 

980 >>> async def _ex(p, sess, c): 

981 ... return {'access_token': 'tok', 'id_token': 'id_tok'} 

982 >>> async def _ui(p, access, token_data=None, expected_nonce=None): 

983 ... return {'email': 'user@example.com'} 

984 >>> svc._exchange_code_for_tokens = _ex 

985 >>> svc._get_user_info = _ui 

986 >>> svc.db.delete = lambda obj: None 

987 >>> svc.db.commit = lambda: None 

988 >>> out = asyncio.run(svc.handle_oauth_callback('github', 'code', 'st')) 

989 >>> out['email'] 

990 'user@example.com' 

991 

992 Early return cases: 

993 >>> # No session 

994 >>> svc2 = SSOService(MagicMock()) 

995 >>> svc2.db.execute.return_value.scalar_one_or_none.return_value = None 

996 >>> asyncio.run(svc2.handle_oauth_callback('github', 'c', 's')) is None 

997 True 

998 >>> # Expired session 

999 >>> expired = SimpleNamespace(provider_id='github', state='st', provider=SimpleNamespace(is_enabled=True), is_expired=True) 

1000 >>> svc3 = SSOService(MagicMock()) 

1001 >>> svc3.db.execute.return_value.scalar_one_or_none.return_value = expired 

1002 >>> asyncio.run(svc3.handle_oauth_callback('github', 'c', 'st')) is None 

1003 True 

1004 >>> # Disabled provider 

1005 >>> disabled = SimpleNamespace(provider_id='github', state='st', provider=SimpleNamespace(is_enabled=False), is_expired=False) 

1006 >>> svc4 = SSOService(MagicMock()) 

1007 >>> svc4.db.execute.return_value.scalar_one_or_none.return_value = disabled 

1008 >>> asyncio.run(svc4.handle_oauth_callback('github', 'c', 'st')) is None 

1009 True 

1010 """ 

1011 callback_result = await self.handle_oauth_callback_with_tokens(provider_id, code, state, session_binding=session_binding) 

1012 if not callback_result: 

1013 return None 

1014 user_info, _token_data = callback_result 

1015 return user_info 

1016 

1017 async def handle_oauth_callback_with_tokens( 

1018 self, 

1019 provider_id: str, 

1020 code: str, 

1021 state: str, 

1022 session_binding: Optional[str] = None, 

1023 ) -> Optional[Tuple[Dict[str, Any], Dict[str, Any]]]: 

1024 """Handle OAuth callback and return both user info and raw token response. 

1025 

1026 Args: 

1027 provider_id: Provider identifier 

1028 code: Authorization code from callback 

1029 state: CSRF state parameter 

1030 session_binding: Optional browser-session marker used to verify bound state. 

1031 

1032 Returns: 

1033 Tuple of (user_info, token_data) or None if authentication fails 

1034 """ 

1035 # Validate auth session 

1036 stmt = select(SSOAuthSession).where(SSOAuthSession.state == state, SSOAuthSession.provider_id == provider_id) 

1037 auth_session = self.db.execute(stmt).scalar_one_or_none() 

1038 

1039 if not auth_session: 

1040 logger.warning(f"OAuth callback: no auth session found for state/provider {provider_id}. Possible CSRF or replay.") 

1041 return None 

1042 

1043 if auth_session.is_expired: 

1044 logger.warning(f"OAuth callback: auth session expired for provider {provider_id}.") 

1045 self.db.delete(auth_session) 

1046 self.db.commit() 

1047 return None 

1048 

1049 if self._is_session_bound_state(state): 

1050 if not session_binding or not self._verify_session_bound_state(provider_id, state, session_binding): 

1051 logger.warning( 

1052 "OAuth callback: state/session mismatch for provider %s. Possible CSRF or cross-session replay.", 

1053 provider_id, 

1054 ) 

1055 return None 

1056 

1057 provider = auth_session.provider 

1058 if not provider: 

1059 logger.error(f"OAuth callback: provider '{provider_id}' not found for auth session.") 

1060 return None 

1061 

1062 if not provider.is_enabled: 

1063 logger.warning(f"OAuth callback: provider '{provider_id}' is disabled.") 

1064 return None 

1065 

1066 try: 

1067 # Exchange authorization code for tokens 

1068 logger.info(f"Starting token exchange for provider {provider_id}") 

1069 token_data = await self._exchange_code_for_tokens(provider, auth_session, code) 

1070 if not token_data: 

1071 logger.error(f"Failed to exchange code for tokens for provider {provider_id}") 

1072 return None 

1073 logger.info(f"Token exchange successful for provider {provider_id}") 

1074 callback_nonce = getattr(auth_session, "nonce", None) 

1075 

1076 # For OIDC providers, verify id_token before any claim extraction. 

1077 if provider.provider_type == "oidc": 

1078 if callback_nonce is None: 

1079 logger.error("OAuth callback: missing nonce for OIDC provider %s.", provider_id) 

1080 return None 

1081 id_token = token_data.get("id_token") 

1082 if not isinstance(id_token, str): 

1083 logger.error("OAuth callback: missing id_token for OIDC provider %s.", provider_id) 

1084 return None 

1085 verified_claims = await self._verify_oidc_id_token(provider, id_token, expected_nonce=callback_nonce) 

1086 if verified_claims is None: 

1087 logger.error(f"id_token verification failed for provider {provider_id}") 

1088 return None 

1089 token_data = dict(token_data) 

1090 token_data["_verified_id_token_claims"] = verified_claims 

1091 

1092 # Get user info from provider (pass full token_data for id_token parsing) 

1093 user_info = await self._get_user_info(provider, token_data["access_token"], token_data, expected_nonce=callback_nonce) 

1094 if not user_info: 

1095 logger.error(f"Failed to get user info for provider {provider_id}") 

1096 return None 

1097 

1098 # Clean up auth session 

1099 self.db.delete(auth_session) 

1100 self.db.commit() 

1101 

1102 return user_info, token_data 

1103 

1104 except Exception as e: 

1105 # Clean up auth session on error 

1106 logger.error(f"OAuth callback failed for provider {provider_id}: {type(e).__name__}: {str(e)}") 

1107 logger.exception("Full traceback for OAuth callback failure:") 

1108 self.db.delete(auth_session) 

1109 self.db.commit() 

1110 return None 

1111 

1112 async def _exchange_code_for_tokens(self, provider: SSOProvider, auth_session: SSOAuthSession, code: str) -> Optional[Dict[str, Any]]: 

1113 """Exchange authorization code for access tokens. 

1114 

1115 Args: 

1116 provider: SSO provider configuration 

1117 auth_session: Auth session with PKCE parameters 

1118 code: Authorization code 

1119 

1120 Returns: 

1121 Token response dict or None if failed 

1122 """ 

1123 token_params = { 

1124 "client_id": provider.client_id, 

1125 "client_secret": await self._decrypt_secret(provider.client_secret_encrypted), 

1126 "code": code, 

1127 "grant_type": "authorization_code", 

1128 "redirect_uri": auth_session.redirect_uri, 

1129 "code_verifier": auth_session.code_verifier, 

1130 } 

1131 

1132 # First-Party 

1133 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

1134 

1135 client = await get_http_client() 

1136 response = await client.post(provider.token_url, data=token_params, headers={"Accept": "application/json"}) 

1137 

1138 if response.status_code == 200: 

1139 return response.json() 

1140 logger.error(f"Token exchange failed for {provider.name}: HTTP {response.status_code} - {response.text}") 

1141 

1142 return None 

1143 

1144 async def _get_user_info(self, provider: SSOProvider, access_token: str, token_data: Optional[Dict[str, Any]] = None, expected_nonce: Optional[str] = None) -> Optional[Dict[str, Any]]: 

1145 """Get user information from provider using access token. 

1146 

1147 Args: 

1148 provider: SSO provider configuration 

1149 access_token: OAuth access token 

1150 token_data: Optional full token response containing id_token for OIDC providers 

1151 expected_nonce: Nonce bound to the current auth session for OIDC id_token verification 

1152 

1153 Returns: 

1154 User info dict or None if failed 

1155 """ 

1156 # First-Party 

1157 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

1158 

1159 client = await get_http_client() 

1160 verified_id_token_claims: Optional[Dict[str, Any]] = None 

1161 if token_data and isinstance(token_data.get("_verified_id_token_claims"), dict): 

1162 verified_id_token_claims = token_data.get("_verified_id_token_claims") 

1163 elif provider.provider_type == "oidc" and token_data and isinstance(token_data.get("id_token"), str): 

1164 if expected_nonce is None: 

1165 logger.warning("Skipping OIDC id_token fallback verification for provider %s because expected nonce is unavailable.", provider.id) 

1166 else: 

1167 verified_id_token_claims = await self._verify_oidc_id_token(provider, token_data["id_token"], expected_nonce=expected_nonce) 

1168 

1169 keycloak_id_token_claims: Optional[Dict[str, Any]] = None 

1170 if provider.id == "keycloak" and verified_id_token_claims: 

1171 keycloak_id_token_claims = verified_id_token_claims 

1172 

1173 response = await client.get(provider.userinfo_url, headers={"Authorization": f"Bearer {access_token}"}) 

1174 

1175 if response.status_code == 200: 

1176 user_data = response.json() 

1177 

1178 # For GitHub, also fetch organizations if admin assignment is configured 

1179 if provider.id == "github" and settings.sso_github_admin_orgs: 

1180 try: 

1181 orgs_response = await client.get("https://api.github.com/user/orgs", headers={"Authorization": f"Bearer {access_token}"}) 

1182 if orgs_response.status_code == 200: 

1183 orgs_data = orgs_response.json() 

1184 user_data["organizations"] = [org["login"] for org in orgs_data] 

1185 else: 

1186 logger.warning(f"Failed to fetch GitHub organizations: HTTP {orgs_response.status_code}") 

1187 user_data["organizations"] = [] 

1188 except Exception as e: 

1189 logger.warning(f"Error fetching GitHub organizations: {e}") 

1190 user_data["organizations"] = [] 

1191 

1192 # For Entra ID, extract groups/roles from id_token since userinfo doesn't include them 

1193 # Microsoft's /oidc/userinfo endpoint only returns basic claims (sub, name, email, picture) 

1194 # Groups and roles are included in the id_token when configured in Azure Portal 

1195 if provider.id == "entra" and verified_id_token_claims: 

1196 id_token_claims = verified_id_token_claims 

1197 if id_token_claims: 

1198 entra_groups_from_graph: Optional[List[str]] = None 

1199 # Detect group overage - when user has too many groups (>200), EntraID can return 

1200 # overage markers (e.g. _claim_names/_claim_sources, hasgroups, groups:srcN) 

1201 # instead of an inline groups array. 

1202 # See: https://learn.microsoft.com/en-us/entra/identity-platform/id-token-claims-reference 

1203 claim_names = id_token_claims.get("_claim_names", {}) 

1204 has_groups_src_key = any(isinstance(key, str) and key.startswith("groups:src") for key in id_token_claims) 

1205 groups_claim_value = id_token_claims.get("groups") 

1206 has_group_overage = ( 

1207 (isinstance(claim_names, dict) and "groups" in claim_names) or bool(id_token_claims.get("hasgroups")) or has_groups_src_key or isinstance(groups_claim_value, str) 

1208 ) 

1209 if has_group_overage: 

1210 user_email = user_data.get("email") or user_data.get("preferred_username") or "unknown" 

1211 logger.warning( 

1212 "Group overage detected for user %s - token contains too many groups (>200). Attempting Microsoft Graph fallback to resolve complete group membership.", 

1213 user_email, 

1214 ) 

1215 entra_groups_from_graph = await self._fetch_entra_groups_from_graph_api(access_token, user_email, provider.provider_metadata) 

1216 if entra_groups_from_graph is None: 

1217 logger.warning("Proceeding without Graph-resolved Entra groups for user %s", user_email) 

1218 

1219 # Extract groups from id_token (Security Groups as Object IDs) 

1220 if entra_groups_from_graph is not None: 

1221 user_data["groups"] = entra_groups_from_graph 

1222 elif "groups" in id_token_claims: 

1223 user_data["groups"] = id_token_claims["groups"] 

1224 logger.debug(f"Extracted {len(id_token_claims['groups'])} groups from Entra ID token") 

1225 

1226 # Extract roles from id_token (App Roles) 

1227 if "roles" in id_token_claims: 

1228 user_data["roles"] = id_token_claims["roles"] 

1229 logger.debug(f"Extracted {len(id_token_claims['roles'])} roles from Entra ID token") 

1230 

1231 # Also extract any missing basic claims from id_token 

1232 for claim in ["email", "name", "preferred_username", "oid", "sub"]: 

1233 if claim not in user_data and claim in id_token_claims: 

1234 user_data[claim] = id_token_claims[claim] 

1235 

1236 # For Keycloak, also extract groups/roles from id_token if available 

1237 if provider.id == "keycloak" and keycloak_id_token_claims: 

1238 # Keycloak includes realm_access, resource_access, and groups in id_token 

1239 for claim in ["realm_access", "resource_access", "groups"]: 

1240 if claim in keycloak_id_token_claims and claim not in user_data: 

1241 user_data[claim] = keycloak_id_token_claims[claim] 

1242 

1243 # Normalize user info across providers 

1244 return self._normalize_user_info(provider, user_data) 

1245 

1246 # Keycloak can issue tokens using the browser-facing issuer URL; if userinfo 

1247 # is called on a different host/port, Keycloak may reject the token with 401. 

1248 # Only fall back to id_token claims for 401 with split-host configuration. 

1249 # Other errors (403=revoked, 500=server error) must NOT fall back — the user 

1250 # should be denied access, not silently authenticated via stale id_token claims. 

1251 if provider.id == "keycloak" and keycloak_id_token_claims and response.status_code == 401: 

1252 metadata = provider.provider_metadata or {} 

1253 public_base_url = metadata.get("public_base_url") 

1254 if public_base_url and public_base_url != metadata.get("base_url"): 

1255 logger.warning( 

1256 "User info request returned 401 for keycloak with split-host config (public=%s, internal=%s). Falling back to id_token claims.", 

1257 public_base_url, 

1258 metadata.get("base_url"), 

1259 ) 

1260 return self._normalize_user_info(provider, keycloak_id_token_claims) 

1261 

1262 logger.error(f"User info request failed for {provider.name}: HTTP {response.status_code} - {response.text}") 

1263 

1264 return None 

1265 

1266 def _normalize_user_info(self, provider: SSOProvider, user_data: Dict[str, Any]) -> Dict[str, Any]: 

1267 """Normalize user info from different providers to common format. 

1268 

1269 Args: 

1270 provider: SSO provider configuration 

1271 user_data: Raw user data from provider 

1272 

1273 Returns: 

1274 Normalized user info dict 

1275 """ 

1276 # Handle GitHub provider 

1277 if provider.id == "github": 

1278 normalized = { 

1279 "email": user_data.get("email"), 

1280 "full_name": user_data.get("name") or user_data.get("login"), 

1281 "avatar_url": user_data.get("avatar_url"), 

1282 "provider_id": user_data.get("id"), 

1283 "username": user_data.get("login"), 

1284 "provider": "github", 

1285 "organizations": user_data.get("organizations", []), 

1286 } 

1287 # GitHub /user responses do not include an email_verified claim. Preserve 

1288 # backward-compatible login behavior by only enforcing verification when 

1289 # a concrete claim value is present. 

1290 github_email_verified = user_data.get("email_verified") 

1291 if github_email_verified is None: 

1292 github_email_verified = user_data.get("verified") 

1293 if github_email_verified is not None: 

1294 normalized["email_verified"] = github_email_verified 

1295 return normalized 

1296 

1297 # Handle Google provider 

1298 if provider.id == "google": 

1299 return { 

1300 "email": user_data.get("email"), 

1301 "email_verified": user_data.get("email_verified"), 

1302 "full_name": user_data.get("name"), 

1303 "avatar_url": user_data.get("picture"), 

1304 "provider_id": user_data.get("sub"), 

1305 "username": user_data.get("email", "").split("@")[0], 

1306 "provider": "google", 

1307 } 

1308 

1309 # Handle IBM Verify provider 

1310 if provider.id == "ibm_verify": 

1311 return { 

1312 "email": user_data.get("email"), 

1313 "email_verified": user_data.get("email_verified"), 

1314 "full_name": user_data.get("name"), 

1315 "avatar_url": user_data.get("picture"), 

1316 "provider_id": user_data.get("sub"), 

1317 "username": user_data.get("preferred_username") or user_data.get("email", "").split("@")[0], 

1318 "provider": "ibm_verify", 

1319 } 

1320 

1321 # Handle Okta provider 

1322 if provider.id == "okta": 

1323 return { 

1324 "email": user_data.get("email"), 

1325 "email_verified": user_data.get("email_verified"), 

1326 "full_name": user_data.get("name"), 

1327 "avatar_url": user_data.get("picture"), 

1328 "provider_id": user_data.get("sub"), 

1329 "username": user_data.get("preferred_username") or user_data.get("email", "").split("@")[0], 

1330 "provider": "okta", 

1331 } 

1332 

1333 # Handle Keycloak provider with role mapping 

1334 if provider.id == "keycloak": 

1335 metadata = provider.provider_metadata or {} 

1336 username_claim = metadata.get("username_claim", "preferred_username") 

1337 email_claim = metadata.get("email_claim", "email") 

1338 groups_claim = metadata.get("groups_claim", "groups") 

1339 

1340 groups = [] 

1341 

1342 # Extract realm roles 

1343 if metadata.get("map_realm_roles"): 

1344 realm_access = user_data.get("realm_access", {}) 

1345 realm_roles = realm_access.get("roles", []) 

1346 groups.extend(realm_roles) 

1347 

1348 # Extract client roles 

1349 if metadata.get("map_client_roles"): 

1350 resource_access = user_data.get("resource_access", {}) 

1351 for client, access in resource_access.items(): 

1352 client_roles = access.get("roles", []) 

1353 # Prefix with client name to avoid conflicts 

1354 groups.extend([f"{client}:{role}" for role in client_roles]) 

1355 

1356 # Extract groups from custom claim 

1357 if groups_claim in user_data: 

1358 custom_groups = user_data.get(groups_claim, []) 

1359 if isinstance(custom_groups, list): 

1360 groups.extend(custom_groups) 

1361 

1362 return { 

1363 "email": user_data.get(email_claim), 

1364 "email_verified": user_data.get("email_verified"), 

1365 "full_name": user_data.get("name"), 

1366 "avatar_url": user_data.get("picture"), 

1367 "provider_id": user_data.get("sub"), 

1368 "username": user_data.get(username_claim) or user_data.get(email_claim, "").split("@")[0], 

1369 "provider": "keycloak", 

1370 "groups": list(set(groups)), # Deduplicate 

1371 } 

1372 

1373 # Handle Microsoft Entra ID provider with role mapping 

1374 if provider.id == "entra": 

1375 metadata = provider.provider_metadata or {} 

1376 groups_claim = metadata.get("groups_claim", "groups") 

1377 

1378 # Microsoft's userinfo endpoint often omits the email claim 

1379 # Fallback: preferred_username (UPN) or upn claim 

1380 email = user_data.get("email") or user_data.get("preferred_username") or user_data.get("upn") 

1381 

1382 # Extract username from email/UPN 

1383 username = None 

1384 if user_data.get("preferred_username"): 

1385 username = user_data.get("preferred_username") 

1386 elif email: 

1387 username = email.split("@")[0] 

1388 

1389 # Extract groups from token 

1390 groups = [] 

1391 

1392 # Check configured groups claim (default: 'groups') 

1393 if groups_claim in user_data: 

1394 groups_value = user_data.get(groups_claim, []) 

1395 if isinstance(groups_value, list): 

1396 groups.extend(groups_value) 

1397 

1398 # Also check 'roles' claim for App Role assignments 

1399 if "roles" in user_data: 

1400 roles_value = user_data.get("roles", []) 

1401 if isinstance(roles_value, list): 

1402 groups.extend(roles_value) 

1403 

1404 return { 

1405 "email": email, 

1406 "email_verified": user_data.get("email_verified"), 

1407 "full_name": user_data.get("name") or email, # Fallback to email if name missing 

1408 "avatar_url": user_data.get("picture"), 

1409 "provider_id": user_data.get("sub") or user_data.get("oid"), 

1410 "username": username, 

1411 "provider": "entra", 

1412 "groups": list(set(groups)), # Deduplicate 

1413 } 

1414 

1415 # Generic OIDC format for all other providers 

1416 return { 

1417 "email": user_data.get("email"), 

1418 "email_verified": user_data.get("email_verified"), 

1419 "full_name": user_data.get("name"), 

1420 "avatar_url": user_data.get("picture"), 

1421 "provider_id": user_data.get("sub"), 

1422 "username": user_data.get("preferred_username") or user_data.get("email", "").split("@")[0], 

1423 "provider": provider.id, 

1424 } 

1425 

1426 async def authenticate_or_create_user(self, user_info: Dict[str, Any]) -> Optional[str]: 

1427 """Authenticate existing user or create new user from SSO info. 

1428 

1429 Args: 

1430 user_info: Normalized user info from SSO provider 

1431 

1432 Returns: 

1433 JWT token for authenticated user or None if failed 

1434 """ 

1435 raw_email = user_info.get("email") 

1436 if not raw_email: 

1437 logger.warning("SSO authenticate_or_create_user: no email in user_info from provider '%s'. User cannot be authenticated without an email address.", user_info.get("provider", "unknown")) 

1438 return None 

1439 

1440 email = str(raw_email).strip().lower() 

1441 if not email: 

1442 logger.warning("SSO authenticate_or_create_user: email is empty after normalization from provider '%s'.", user_info.get("provider", "unknown")) 

1443 return None 

1444 

1445 if not self._is_email_verified_claim(user_info): 

1446 logger.warning( 

1447 "SSO authenticate_or_create_user: unverified email claim for provider '%s' and email '%s'.", 

1448 user_info.get("provider", "unknown"), 

1449 email, 

1450 ) 

1451 return None 

1452 

1453 incoming_provider = str(user_info.get("provider", "sso")).strip().lower() or "sso" 

1454 provider = self.get_provider(incoming_provider) 

1455 

1456 # Enforce trusted-domain policy consistently for both existing and new users. 

1457 trusted_domains = getattr(provider, "trusted_domains", None) if provider else None 

1458 if trusted_domains: 

1459 domain = email.split("@")[1].lower() if "@" in email else "" 

1460 normalized_trusted_domains = [d.lower() for d in trusted_domains if isinstance(d, str)] 

1461 if domain not in normalized_trusted_domains: 

1462 logger.warning( 

1463 "SSO authenticate_or_create_user: email domain '%s' is not allowed for provider '%s'.", 

1464 domain, 

1465 incoming_provider, 

1466 ) 

1467 return None 

1468 

1469 # Use stable local values for JWT payload generation to avoid lazy-loading 

1470 # expired ORM attributes after commit/flush boundaries. 

1471 resolved_email = email 

1472 resolved_full_name = user_info.get("full_name", email) 

1473 resolved_auth_provider = incoming_provider 

1474 resolved_is_admin = False 

1475 

1476 # Check if user exists 

1477 user = await self.auth_service.get_user_by_email(email) 

1478 

1479 if user: 

1480 current_full_name = user.full_name or resolved_full_name 

1481 current_auth_provider = str(user.auth_provider or resolved_auth_provider).strip().lower() 

1482 current_is_admin = bool(user.is_admin) 

1483 current_admin_origin = user.admin_origin 

1484 

1485 if user.auth_provider and current_auth_provider != incoming_provider: 

1486 logger.warning( 

1487 "SSO authenticate_or_create_user: account-linking required for email '%s' (existing provider='%s', incoming='%s').", 

1488 email, 

1489 current_auth_provider, 

1490 incoming_provider, 

1491 ) 

1492 return None 

1493 

1494 provider_id: Optional[str] = None 

1495 provider_metadata: Dict[str, Any] = {} 

1496 provider_ctx: Optional[Any] = None 

1497 if provider: 

1498 provider_id = provider.id 

1499 provider_metadata = provider.provider_metadata or {} 

1500 provider_ctx = SSOProviderContext(id=provider_id, provider_metadata=provider_metadata) 

1501 

1502 # Update user info from SSO 

1503 if user_info.get("full_name") and user_info["full_name"] != current_full_name: 

1504 user.full_name = user_info["full_name"] 

1505 current_full_name = user_info["full_name"] 

1506 

1507 # Initialize auth_provider for legacy accounts without provider metadata. 

1508 if not user.auth_provider: 

1509 user.auth_provider = incoming_provider 

1510 current_auth_provider = incoming_provider 

1511 

1512 # Persist verification status from provider claims. 

1513 user.email_verified = self._is_email_verified_claim(user_info) 

1514 user.last_login = utc_now() 

1515 

1516 # Synchronize is_admin status based on current group membership 

1517 # Track origin to support both promotion AND demotion for SSO-granted admins 

1518 # Manual/API grants are "sticky" - never auto-demoted by SSO 

1519 # Only users with admin_origin="sso" can be demoted on login 

1520 if provider_ctx: 

1521 should_be_admin = self._should_user_be_admin(email, user_info, provider_ctx) 

1522 if should_be_admin: 

1523 # Grant admin access 

1524 if not current_is_admin: 

1525 logger.info(f"Upgrading is_admin to True for {email} based on SSO admin groups") 

1526 user.is_admin = True 

1527 # Track that admin was granted via SSO (only set on initial grant) 

1528 user.admin_origin = "sso" 

1529 current_is_admin = True 

1530 # Do NOT change admin_origin if already admin - preserve manual/API grants 

1531 elif current_is_admin and current_admin_origin == "sso": 

1532 # User was SSO admin but no longer in admin groups - revoke access 

1533 logger.info(f"Revoking is_admin for {email} - removed from SSO admin groups") 

1534 user.is_admin = False 

1535 user.admin_origin = None 

1536 current_is_admin = False 

1537 

1538 self.db.commit() 

1539 

1540 # Determine if syncing should happen (default True, respect provider-level and Entra setting) 

1541 should_sync = True 

1542 if provider_ctx: 

1543 # Check provider-level sync_roles flag in provider_metadata (allows disabling per-provider) 

1544 if "sync_roles" in provider_metadata: 

1545 should_sync = provider_metadata.get("sync_roles", True) 

1546 # Legacy Entra-specific setting (fallback for backwards compatibility) 

1547 elif provider_id == "entra" and hasattr(settings, "sso_entra_sync_roles_on_login"): 

1548 should_sync = settings.sso_entra_sync_roles_on_login 

1549 

1550 if provider_ctx and should_sync: 

1551 role_assignments = await self._map_groups_to_roles(email, user_info.get("groups", []), provider_ctx) 

1552 await self._sync_user_roles(email, role_assignments, provider_ctx) 

1553 await self._apply_team_mapping(email, user_info, provider) 

1554 

1555 user_email = getattr(user, "email", None) 

1556 if isinstance(user_email, str) and user_email.strip(): 

1557 resolved_email = user_email.strip().lower() 

1558 resolved_full_name = current_full_name 

1559 resolved_auth_provider = current_auth_provider 

1560 resolved_is_admin = current_is_admin 

1561 else: 

1562 # Auto-create user if enabled 

1563 if not provider or not provider.auto_create_users: 

1564 return None 

1565 

1566 provider_id = provider.id 

1567 provider_metadata = provider.provider_metadata or {} 

1568 provider_ctx = SSOProviderContext(id=provider_id, provider_metadata=provider_metadata) 

1569 

1570 # Check if admin approval is required 

1571 if settings.sso_require_admin_approval: 

1572 # Check if user is already pending approval 

1573 

1574 pending = self.db.execute(select(PendingUserApproval).where(PendingUserApproval.email == email)).scalar_one_or_none() 

1575 

1576 if pending: 

1577 if pending.status == "pending": 

1578 if pending.is_expired(): 

1579 pending.status = "expired" 

1580 self.db.commit() 

1581 

1582 pending.status = "pending" 

1583 pending.requested_at = utc_now() 

1584 pending.expires_at = utc_now() + timedelta(days=30) 

1585 pending.auth_provider = incoming_provider 

1586 pending.sso_metadata = user_info 

1587 pending.approved_by = None 

1588 pending.approved_at = None 

1589 pending.rejection_reason = None 

1590 pending.admin_notes = None 

1591 self.db.commit() 

1592 logger.info(f"Refreshed expired pending approval request for SSO user: {email}") 

1593 return None 

1594 return None # Still waiting for approval 

1595 if pending.status == "rejected": 

1596 return None # User was rejected 

1597 if pending.status == "approved": 

1598 if pending.is_expired(): 

1599 pending.status = "expired" 

1600 self.db.commit() 

1601 return None 

1602 elif pending.status == "expired": 

1603 pending.status = "pending" 

1604 pending.requested_at = utc_now() 

1605 pending.expires_at = utc_now() + timedelta(days=30) 

1606 pending.auth_provider = incoming_provider 

1607 pending.sso_metadata = user_info 

1608 pending.approved_by = None 

1609 pending.approved_at = None 

1610 pending.rejection_reason = None 

1611 pending.admin_notes = None 

1612 self.db.commit() 

1613 logger.info(f"Renewed expired pending approval request for SSO user: {email}") 

1614 return None 

1615 elif pending.status in {"completed"}: 

1616 return None 

1617 elif pending.status != "approved": 

1618 logger.warning(f"Unknown SSO pending approval status '{pending.status}' for user {email}. Denying by default.") 

1619 return None 

1620 else: 

1621 # Create pending approval request 

1622 pending = PendingUserApproval( 

1623 email=email, 

1624 full_name=user_info.get("full_name", email), 

1625 auth_provider=incoming_provider, 

1626 sso_metadata=user_info, 

1627 expires_at=utc_now() + timedelta(days=30), # 30-day approval window 

1628 ) 

1629 self.db.add(pending) 

1630 self.db.commit() 

1631 logger.info(f"Created pending approval request for SSO user: {email}") 

1632 return None # No token until approved 

1633 

1634 # Create new user (either no approval required, or approval already granted) 

1635 # Generate a secure random password for SSO users (they won't use it) 

1636 

1637 random_password = "".join(secrets.choice(string.ascii_letters + string.digits + string.punctuation) for _ in range(32)) 

1638 

1639 # Determine if user should be admin based on domain/organization 

1640 is_admin = self._should_user_be_admin(email, user_info, provider_ctx) 

1641 

1642 user = await self.auth_service.create_user( 

1643 email=email, 

1644 password=random_password, # Random password for SSO users (not used) 

1645 full_name=user_info.get("full_name", email), 

1646 is_admin=is_admin, 

1647 auth_provider=incoming_provider, 

1648 ) 

1649 if not user: 

1650 return None 

1651 

1652 user_email = getattr(user, "email", None) 

1653 if isinstance(user_email, str) and user_email.strip(): 

1654 resolved_email = user_email.strip().lower() 

1655 resolved_full_name = user_info.get("full_name", email) 

1656 resolved_auth_provider = incoming_provider 

1657 resolved_is_admin = is_admin 

1658 

1659 # Assign RBAC roles based on SSO groups (or default role if no groups) 

1660 # Check provider-level sync_roles flag in provider_metadata 

1661 should_sync = provider_metadata.get("sync_roles", True) 

1662 # Legacy Entra-specific setting (fallback for backwards compatibility) 

1663 if "sync_roles" not in provider_metadata and provider_id == "entra" and hasattr(settings, "sso_entra_sync_roles_on_login"): 

1664 should_sync = settings.sso_entra_sync_roles_on_login 

1665 

1666 if should_sync: 

1667 role_assignments = await self._map_groups_to_roles(email, user_info.get("groups", []), provider_ctx) 

1668 if role_assignments: 

1669 await self._sync_user_roles(email, role_assignments, provider_ctx) 

1670 await self._apply_team_mapping(email, user_info, provider) 

1671 

1672 # If user was created from approved request, mark request as used 

1673 if settings.sso_require_admin_approval: 

1674 pending = self.db.execute(select(PendingUserApproval).where(and_(PendingUserApproval.email == email, PendingUserApproval.status == "approved"))).scalar_one_or_none() 

1675 if pending: 

1676 # Mark as used (we could delete or keep for audit trail) 

1677 pending.status = "completed" 

1678 self.db.commit() 

1679 

1680 # Generate JWT token for user — session token (teams resolved server-side) 

1681 token_data = { 

1682 "sub": resolved_email, 

1683 "email": resolved_email, 

1684 "full_name": resolved_full_name, 

1685 "auth_provider": resolved_auth_provider, 

1686 "iat": int(utc_now().timestamp()), 

1687 "user": { 

1688 "email": resolved_email, 

1689 "full_name": resolved_full_name, 

1690 "is_admin": resolved_is_admin, 

1691 "auth_provider": resolved_auth_provider, 

1692 }, 

1693 "token_use": "session", # nosec B105 - token type marker, not a password 

1694 # Scopes 

1695 "scopes": {"server_id": None, "permissions": ["*"] if resolved_is_admin else [], "ip_restrictions": [], "time_restrictions": {}}, 

1696 } 

1697 

1698 # Create JWT token 

1699 token = await create_jwt_token(token_data) 

1700 return token 

1701 

1702 def _should_user_be_admin(self, email: str, user_info: Dict[str, Any], provider: SSOProviderContext) -> bool: 

1703 """Determine if SSO user should be granted admin privileges. 

1704 

1705 Args: 

1706 email: User's email address 

1707 user_info: Normalized user info from SSO provider 

1708 provider: SSO provider configuration 

1709 

1710 Returns: 

1711 True if user should be admin, False otherwise 

1712 """ 

1713 # Check domain-based admin assignment 

1714 domain = email.split("@")[1].lower() 

1715 if domain in [d.lower() for d in settings.sso_auto_admin_domains]: 

1716 return True 

1717 

1718 # Check provider-specific admin assignment 

1719 if provider.id == "github" and settings.sso_github_admin_orgs: 

1720 # For GitHub, we'd need to fetch user's organizations 

1721 # This is a placeholder - in production, you'd make API calls to get orgs 

1722 github_orgs = user_info.get("organizations", []) 

1723 if any(org.lower() in [o.lower() for o in settings.sso_github_admin_orgs] for org in github_orgs): 

1724 return True 

1725 

1726 if provider.id == "google" and settings.sso_google_admin_domains: 

1727 # Check if user's domain is in admin domains 

1728 if domain in [d.lower() for d in settings.sso_google_admin_domains]: 

1729 return True 

1730 

1731 # Check EntraID admin groups 

1732 if provider.id == "entra" and settings.sso_entra_admin_groups: 

1733 user_groups = user_info.get("groups", []) 

1734 if any(group.lower() in [g.lower() for g in settings.sso_entra_admin_groups] for group in user_groups): 

1735 return True 

1736 

1737 return False 

1738 

1739 async def _map_groups_to_roles(self, user_email: str, user_groups: List[str], provider: SSOProviderContext) -> List[Dict[str, Any]]: 

1740 """Map SSO groups to ContextForge RBAC roles. 

1741 

1742 Args: 

1743 user_email: User's email address 

1744 user_groups: List of groups from SSO provider 

1745 provider: SSO provider configuration 

1746 

1747 Returns: 

1748 List of role assignments: [{"role_name": str, "scope": str, "scope_id": Optional[str]}] 

1749 """ 

1750 # pylint: disable=import-outside-toplevel 

1751 # First-Party 

1752 from mcpgateway.services.role_service import RoleService 

1753 

1754 role_assignments = [] 

1755 

1756 # Generic Role Mapping Logic 

1757 metadata = provider.provider_metadata or {} 

1758 role_mappings = metadata.get("role_mappings", {}) 

1759 provider_default_role: Optional[str] = metadata.get("default_role") 

1760 resolve_team_scope_to_personal_team = bool(metadata.get("resolve_team_scope_to_personal_team", False)) 

1761 has_provider_default_role = isinstance(provider_default_role, str) and bool(provider_default_role.strip()) 

1762 

1763 # Merge with legacy Entra specific settings if applicable 

1764 has_entra_admin_groups = provider.id == "entra" and settings.sso_entra_admin_groups 

1765 

1766 if provider.id == "entra": 

1767 # Use generic role_mappings fallback to legacy setting 

1768 if not role_mappings and settings.sso_entra_role_mappings: 

1769 role_mappings = settings.sso_entra_role_mappings 

1770 # Legacy fallback for default role configuration 

1771 if not has_provider_default_role and settings.sso_entra_default_role: 

1772 provider_default_role = settings.sso_entra_default_role 

1773 has_provider_default_role = True 

1774 

1775 # Early exit: Skip role mapping if no configuration exists 

1776 if not role_mappings and not has_entra_admin_groups and not has_provider_default_role: 

1777 logger.debug(f"No role mappings configured for provider {provider.id}, skipping role sync") 

1778 return role_assignments 

1779 

1780 personal_team_id: Optional[str] = None 

1781 personal_team_checked = False 

1782 

1783 async def _resolve_team_scope_id_if_needed(role_scope: str) -> Optional[str]: 

1784 """Resolve team scope to personal-team id when provider mapping requires it. 

1785 

1786 Args: 

1787 role_scope: Role scope value from mapping metadata. 

1788 

1789 Returns: 

1790 Optional[str]: Personal team id when resolution succeeds, else None. 

1791 """ 

1792 nonlocal personal_team_id, personal_team_checked 

1793 

1794 if role_scope != "team" or not resolve_team_scope_to_personal_team: 

1795 return None 

1796 

1797 if personal_team_checked: 

1798 return personal_team_id 

1799 

1800 personal_team_checked = True 

1801 try: 

1802 # First-Party 

1803 from mcpgateway.services.personal_team_service import PersonalTeamService 

1804 

1805 personal_team = await PersonalTeamService(self.db).get_personal_team(user_email) 

1806 personal_team_id = personal_team.id if personal_team else None 

1807 if not personal_team_id: 

1808 logger.warning(f"Could not resolve personal team for {user_email}; skipping team-scoped SSO role mapping") 

1809 except Exception as e: 

1810 logger.error(f"Failed to resolve personal team for {user_email}: {e}. All team-scoped SSO role assignments will be skipped for this login.") 

1811 personal_team_id = None 

1812 

1813 return personal_team_id 

1814 

1815 # Handle EntraID admin groups -> admin role 

1816 if has_entra_admin_groups: 

1817 admin_groups_lower = [g.lower() for g in settings.sso_entra_admin_groups] 

1818 for group in user_groups: 

1819 if group.lower() in admin_groups_lower: 

1820 role_assignments.append({"role_name": settings.default_admin_role, "scope": "global", "scope_id": None}) 

1821 logger.debug(f"Mapped EntraID admin group to {settings.default_admin_role} role for {user_email}") 

1822 break # Only need one admin assignment 

1823 

1824 # Batch role lookups: collect all role names that need to be looked up 

1825 role_names_to_lookup = set() 

1826 for group in user_groups: 

1827 if group in role_mappings: 

1828 role_name = role_mappings[group] 

1829 if role_name not in ["admin", settings.default_admin_role]: 

1830 role_names_to_lookup.add(role_name) 

1831 

1832 # Add default role to lookup if needed 

1833 if has_provider_default_role and provider_default_role: 

1834 role_names_to_lookup.add(provider_default_role) 

1835 

1836 # Pre-fetch all roles by name in batches (reduces DB round-trips) 

1837 role_service = RoleService(self.db) 

1838 role_cache: Dict[str, Any] = {} 

1839 for role_name in role_names_to_lookup: 

1840 # Try team scope first, then global 

1841 role = await role_service.get_role_by_name(role_name, scope="team") 

1842 if not role: 

1843 role = await role_service.get_role_by_name(role_name, scope="global") 

1844 if role: 

1845 role_cache[role_name] = role 

1846 

1847 # Process role mappings for ALL providers 

1848 for group in user_groups: 

1849 if group in role_mappings: 

1850 role_name = role_mappings[group] 

1851 # Special case for "admin" shorthand or configured admin role name 

1852 if role_name in ["admin", settings.default_admin_role]: 

1853 role_assignments.append({"role_name": settings.default_admin_role, "scope": "global", "scope_id": None}) 

1854 logger.debug(f"Mapped group to {settings.default_admin_role} role for {user_email}") 

1855 continue 

1856 

1857 # Use pre-fetched role from cache 

1858 role = role_cache.get(role_name) 

1859 if role: 

1860 scope_id = await _resolve_team_scope_id_if_needed(role.scope) 

1861 if role.scope == "team" and resolve_team_scope_to_personal_team and not scope_id: 

1862 continue 

1863 # Avoid duplicate assignments 

1864 if not any(r["role_name"] == role.name and r["scope"] == role.scope and r.get("scope_id") == scope_id for r in role_assignments): 

1865 role_assignments.append({"role_name": role.name, "scope": role.scope, "scope_id": scope_id}) 

1866 logger.debug(f"Mapped group to role '{role.name}' for {user_email}") 

1867 else: 

1868 logger.warning(f"Role '{role_name}' not found for group mapping") 

1869 

1870 # Apply default role if no mappings found 

1871 if not role_assignments and has_provider_default_role and provider_default_role: 

1872 default_role = role_cache.get(provider_default_role) 

1873 if default_role: 

1874 scope_id = await _resolve_team_scope_id_if_needed(default_role.scope) 

1875 if default_role.scope == "team" and resolve_team_scope_to_personal_team and not scope_id: 

1876 return role_assignments 

1877 role_assignments.append({"role_name": default_role.name, "scope": default_role.scope, "scope_id": scope_id}) 

1878 logger.info(f"Assigned default role '{default_role.name}' to {user_email}") 

1879 

1880 return role_assignments 

1881 

1882 async def _sync_user_roles(self, user_email: str, role_assignments: List[Dict[str, Any]], _provider: SSOProviderContext) -> None: 

1883 """Synchronize user's SSO-based role assignments. 

1884 

1885 Args: 

1886 user_email: User's email address 

1887 role_assignments: List of role assignments to apply 

1888 _provider: SSO provider configuration (reserved for future use) 

1889 """ 

1890 # pylint: disable=import-outside-toplevel 

1891 # First-Party 

1892 from mcpgateway.services.role_service import RoleService 

1893 

1894 role_service = RoleService(self.db) 

1895 

1896 # Get current SSO-granted roles 

1897 current_roles = await role_service.list_user_roles(user_email, include_expired=False) 

1898 sso_roles = [r for r in current_roles if getattr(r, "grant_source", None) == "sso"] 

1899 

1900 # Build set of desired role assignments 

1901 desired_roles = {(r["role_name"], r["scope"], r.get("scope_id")) for r in role_assignments} 

1902 

1903 # Revoke roles that are no longer in the desired set 

1904 for user_role in sso_roles: 

1905 role_tuple = (user_role.role.name, user_role.scope, user_role.scope_id) 

1906 if role_tuple not in desired_roles: 

1907 await role_service.revoke_role_from_user(user_email=user_email, role_id=user_role.role_id, scope=user_role.scope, scope_id=user_role.scope_id) 

1908 logger.info(f"Revoked SSO role '{user_role.role.name}' from {user_email} (no longer in groups)") 

1909 

1910 # Assign new roles 

1911 for assignment in role_assignments: 

1912 try: 

1913 # Get role by name 

1914 role = await role_service.get_role_by_name(assignment["role_name"], scope=assignment["scope"]) 

1915 if not role: 

1916 logger.warning(f"Role '{assignment['role_name']}' not found, skipping assignment for {user_email}") 

1917 continue 

1918 

1919 # Check if assignment already exists 

1920 existing = await role_service.get_user_role_assignment(user_email=user_email, role_id=role.id, scope=assignment["scope"], scope_id=assignment.get("scope_id")) 

1921 

1922 if not existing or not existing.is_active: 

1923 # Assign role to user 

1924 await role_service.assign_role_to_user( 

1925 user_email=user_email, role_id=role.id, scope=assignment["scope"], scope_id=assignment.get("scope_id"), granted_by=user_email, grant_source="sso" 

1926 ) 

1927 logger.info(f"Assigned SSO role '{role.name}' to {user_email}") 

1928 

1929 except Exception as e: 

1930 logger.warning(f"Failed to assign role '{assignment['role_name']}' to {user_email}: {e}", exc_info=True) 

1931 try: 

1932 self.db.rollback() 

1933 except Exception as rollback_error: 

1934 logger.error(f"Database rollback failed after role assignment error for {user_email}: {rollback_error}. Aborting remaining role assignments.") 

1935 break