Coverage for mcpgateway / services / sso_service.py: 99%

1024 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/sso_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Single Sign-On (SSO) authentication service for OAuth2 and OIDC providers. 

8Handles provider management, OAuth flows, and user authentication. 

9""" 

10 

11# Future 

12from __future__ import annotations 

13 

14# Standard 

15import asyncio 

16import base64 

17from dataclasses import dataclass 

18from datetime import timedelta 

19from enum import Enum 

20import hashlib 

21import hmac 

22import logging 

23import secrets 

24import string 

25from time import monotonic 

26from typing import Any, Dict, List, Optional, Tuple, Union 

27import urllib.parse 

28 

29# Third-Party 

30import jwt 

31import orjson 

32from sqlalchemy import and_, select 

33from sqlalchemy.orm import Session 

34 

35# First-Party 

36from mcpgateway.common.validators import SecurityValidator 

37from mcpgateway.config import settings 

38from mcpgateway.db import PendingUserApproval, SSOAuthSession, SSOProvider, utc_now 

39from mcpgateway.services.email_auth_service import EmailAuthService 

40from mcpgateway.services.encryption_service import get_encryption_service 

41from mcpgateway.utils.create_jwt_token import create_jwt_token 

42 

43# Logger 

44logger = logging.getLogger(__name__) 

45 

46# Constants 

47ADFS_PROVIDER_ID = "adfs" 

48 

49 

50class SSOError(Exception): 

51 """Base class for SSO-related errors.""" 

52 

53 

54class SSOAuthenticationError(SSOError): 

55 """Raised when SSO authentication fails.""" 

56 

57 

58class SSOProviderConfigError(SSOError): 

59 """Raised when SSO provider configuration is invalid or incomplete.""" 

60 

61 

62class _Unset(Enum): 

63 """Sentinel: distinguishes 'caller omitted the argument' from 'caller passed None'.""" 

64 

65 UNSET = "UNSET" 

66 

67 

68_UNSET = _Unset.UNSET 

69 

70 

71@dataclass 

72class SSOProviderContext: 

73 """Lightweight context for SSO provider info passed to helper methods.""" 

74 

75 id: Optional[str] 

76 provider_metadata: Dict[str, Any] 

77 

78 

79class SSOService: 

80 """Service for managing SSO authentication flows and providers. 

81 

82 Handles OAuth2/OIDC authentication flows, provider configuration, 

83 and integration with the local user system. 

84 

85 Examples: 

86 Basic construction and helper checks: 

87 >>> from unittest.mock import Mock 

88 >>> service = SSOService(Mock()) 

89 >>> isinstance(service, SSOService) 

90 True 

91 >>> callable(service.list_enabled_providers) 

92 True 

93 """ 

94 

95 _OIDC_METADATA_CACHE_TTL_SECONDS = 300 

96 _oidc_config_cache: Dict[str, Tuple[float, Dict[str, Any]]] = {} 

97 _jwks_client_cache: Dict[str, jwt.PyJWKClient] = {} 

98 _STATE_BINDING_SEPARATOR = "." 

99 _STATE_BINDING_HEX_LEN = 64 

100 

101 def __init__(self, db: Session): 

102 """Initialize SSO service with database session. 

103 

104 Args: 

105 db: SQLAlchemy database session 

106 """ 

107 self.db = db 

108 self.auth_service = EmailAuthService(db) 

109 self._encryption = get_encryption_service(settings.auth_encryption_secret) 

110 

111 async def _encrypt_secret(self, secret: str) -> str: 

112 """Encrypt a client secret for secure storage. 

113 

114 Args: 

115 secret: Plain text client secret 

116 

117 Returns: 

118 Encrypted secret string 

119 """ 

120 return await self._encryption.encrypt_secret_async(secret) 

121 

122 async def _decrypt_secret(self, encrypted_secret: str) -> Optional[str]: 

123 """Decrypt a client secret for use. 

124 

125 Args: 

126 encrypted_secret: Encrypted secret string 

127 

128 Returns: 

129 Plain text client secret 

130 """ 

131 decrypted: str | None = await self._encryption.decrypt_secret_async(encrypted_secret) 

132 if decrypted: 

133 return decrypted 

134 

135 return None 

136 

137 def _decode_jwt_claims(self, token: str) -> Optional[Dict[str, Any]]: 

138 """Decode JWT token payload without verification. 

139 

140 This is used to extract claims from ID tokens where we've already 

141 validated the OAuth flow. The token signature is not verified here 

142 because the token was received directly from the trusted token endpoint. 

143 

144 Args: 

145 token: JWT token string 

146 

147 Returns: 

148 Decoded payload dict or None if decoding fails 

149 

150 Examples: 

151 >>> from unittest.mock import Mock 

152 >>> service = SSOService(Mock()) 

153 >>> # Valid JWT structure (header.payload.signature) 

154 >>> import base64 

155 >>> payload = base64.urlsafe_b64encode(b'{"sub":"123","groups":["admin"]}').decode().rstrip('=') 

156 >>> token = f"eyJhbGciOiJSUzI1NiJ9.{payload}.signature" 

157 >>> claims = service._decode_jwt_claims(token) 

158 >>> claims is not None 

159 True 

160 """ 

161 try: 

162 # JWT format: header.payload.signature 

163 parts = token.split(".") 

164 if len(parts) != 3: 

165 logger.warning("Invalid JWT format: expected 3 parts") 

166 return None 

167 

168 # Decode payload (middle part) - add padding if needed 

169 payload_b64 = parts[1] 

170 # Add padding for base64 decoding 

171 padding = 4 - len(payload_b64) % 4 

172 if padding != 4: 

173 payload_b64 += "=" * padding 

174 

175 payload_bytes = base64.urlsafe_b64decode(payload_b64) 

176 return orjson.loads(payload_bytes) 

177 

178 except (ValueError, orjson.JSONDecodeError, UnicodeDecodeError) as e: 

179 logger.warning(f"Failed to decode JWT claims: {e}") 

180 return None 

181 

182 async def _get_oidc_provider_metadata(self, issuer: str) -> Optional[Dict[str, Any]]: 

183 """Discover and cache OIDC provider metadata. 

184 

185 Args: 

186 issuer: OIDC issuer URL. 

187 

188 Returns: 

189 Provider metadata dict from discovery endpoint, or None on failure. 

190 """ 

191 normalized_issuer = issuer.rstrip("/") 

192 cached = self._oidc_config_cache.get(normalized_issuer) 

193 if cached is not None: 

194 cached_at, cached_metadata = cached 

195 if monotonic() - cached_at < self._OIDC_METADATA_CACHE_TTL_SECONDS: 

196 return cached_metadata 

197 self._oidc_config_cache.pop(normalized_issuer, None) 

198 

199 # First-Party 

200 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

201 

202 discovery_url = f"{normalized_issuer}/.well-known/openid-configuration" 

203 try: 

204 client = await get_http_client() 

205 response = await client.get(discovery_url, timeout=settings.oauth_request_timeout) 

206 if response.status_code != 200: 

207 logger.warning("OIDC discovery failed for issuer %s with HTTP %s", normalized_issuer, response.status_code) 

208 return None 

209 

210 metadata = response.json() 

211 if not isinstance(metadata, dict): 

212 logger.warning("OIDC discovery response for issuer %s is not a JSON object", normalized_issuer) 

213 return None 

214 self._oidc_config_cache[normalized_issuer] = (monotonic(), metadata) 

215 return metadata 

216 except Exception as exc: 

217 logger.warning("OIDC discovery request failed for issuer %s: %s", normalized_issuer, exc) 

218 return None 

219 

220 async def _resolve_oidc_issuer_and_jwks(self, provider: SSOProvider) -> Tuple[Optional[str], Optional[str]]: 

221 """Resolve issuer and JWKS URI for an OIDC provider. 

222 

223 Args: 

224 provider: SSO provider configuration. 

225 

226 Returns: 

227 Tuple of (issuer, jwks_uri), each optional when unavailable. 

228 """ 

229 issuer = provider.issuer.strip() if isinstance(provider.issuer, str) and provider.issuer.strip() else None 

230 jwks_uri = provider.jwks_uri.strip() if isinstance(provider.jwks_uri, str) and provider.jwks_uri.strip() else None 

231 

232 if issuer and not jwks_uri: 

233 metadata = await self._get_oidc_provider_metadata(issuer) 

234 if metadata: 

235 discovered_jwks = metadata.get("jwks_uri") 

236 discovered_issuer = metadata.get("issuer") 

237 if isinstance(discovered_jwks, str) and discovered_jwks.strip(): 

238 jwks_uri = discovered_jwks.strip() 

239 if isinstance(discovered_issuer, str) and discovered_issuer.strip(): 

240 issuer = discovered_issuer.strip() 

241 

242 return issuer, jwks_uri 

243 

244 def _get_jwks_client(self, jwks_uri: str) -> jwt.PyJWKClient: 

245 """Get or create a cached PyJWKClient instance. 

246 

247 Args: 

248 jwks_uri: JWKS endpoint URL. 

249 

250 Returns: 

251 Cached or newly created `PyJWKClient`. 

252 """ 

253 if jwks_uri not in self._jwks_client_cache: 

254 self._jwks_client_cache[jwks_uri] = jwt.PyJWKClient(jwks_uri) 

255 return self._jwks_client_cache[jwks_uri] 

256 

257 async def _verify_oidc_id_token(self, provider: SSOProvider, id_token: str, expected_nonce: Optional[str] = None) -> Optional[Dict[str, Any]]: 

258 """Verify OIDC ID token signature and claims. 

259 

260 Args: 

261 provider: SSO provider configuration. 

262 id_token: Raw OIDC ID token string. 

263 expected_nonce: Expected nonce from auth session, when available. 

264 

265 Returns: 

266 Verified token claims when validation succeeds, otherwise None. 

267 """ 

268 if provider.provider_type != "oidc": 

269 return None 

270 

271 issuer, jwks_uri = await self._resolve_oidc_issuer_and_jwks(provider) 

272 if not jwks_uri: 

273 logger.warning("Skipping id_token claim usage for provider %s: missing jwks_uri", provider.id) 

274 return None 

275 

276 try: 

277 jwks_client = self._get_jwks_client(jwks_uri) 

278 signing_key = await asyncio.to_thread(jwks_client.get_signing_key_from_jwt, id_token) 

279 

280 decode_kwargs: Dict[str, Any] = { 

281 "key": signing_key.key, 

282 "algorithms": ["RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512", "EdDSA"], 

283 "audience": provider.client_id, 

284 "options": { 

285 "verify_signature": True, 

286 "verify_exp": True, 

287 "verify_iat": True, 

288 "verify_aud": True, 

289 "verify_iss": bool(issuer), 

290 }, 

291 } 

292 if issuer: 

293 decode_kwargs["issuer"] = issuer 

294 

295 claims = await asyncio.to_thread(jwt.decode, id_token, **decode_kwargs) 

296 if expected_nonce is not None and claims.get("nonce") != expected_nonce: 

297 logger.warning("OIDC id_token nonce validation failed for provider %s", provider.id) 

298 return None 

299 return claims 

300 except jwt.PyJWTError as exc: 

301 logger.warning("OIDC id_token verification failed for provider %s: %s", provider.id, exc) 

302 return None 

303 except Exception as exc: 

304 logger.warning("Unexpected OIDC id_token verification error for provider %s: %s", provider.id, exc) 

305 return None 

306 

307 def _resolve_entra_graph_fallback_settings(self, provider_metadata: Optional[Dict[str, Any]]) -> Tuple[bool, int, int]: 

308 """Resolve Entra Graph fallback settings with provider metadata override. 

309 

310 Args: 

311 provider_metadata: Optional provider metadata from DB/config. 

312 

313 Returns: 

314 Tuple of (enabled, timeout_seconds, max_groups). 

315 """ 

316 metadata = provider_metadata or {} 

317 enabled = settings.sso_entra_graph_api_enabled 

318 timeout = settings.sso_entra_graph_api_timeout 

319 max_groups = settings.sso_entra_graph_api_max_groups 

320 

321 if "graph_api_enabled" in metadata: 

322 metadata_enabled = metadata.get("graph_api_enabled") 

323 if isinstance(metadata_enabled, bool): 

324 enabled = metadata_enabled 

325 elif isinstance(metadata_enabled, str): 

326 normalized_enabled = metadata_enabled.strip().lower() 

327 if normalized_enabled in {"1", "true", "yes", "on"}: 

328 enabled = True 

329 elif normalized_enabled in {"0", "false", "no", "off"}: 

330 enabled = False 

331 else: 

332 logger.warning("Invalid provider_metadata.graph_api_enabled=%s; using configured default %s", metadata_enabled, enabled) 

333 elif isinstance(metadata_enabled, int): 

334 enabled = metadata_enabled != 0 

335 else: 

336 logger.warning("Invalid provider_metadata.graph_api_enabled=%s (type %s); using configured default %s", metadata_enabled, type(metadata_enabled).__name__, enabled) 

337 

338 metadata_timeout = metadata.get("graph_api_timeout") 

339 if metadata_timeout is not None: 

340 try: 

341 timeout_candidate = int(metadata_timeout) 

342 if 1 <= timeout_candidate <= 120: 

343 timeout = timeout_candidate 

344 else: 

345 logger.warning("Invalid provider_metadata.graph_api_timeout=%s; using configured default %s", metadata_timeout, timeout) 

346 except (TypeError, ValueError): 

347 logger.warning("Invalid provider_metadata.graph_api_timeout=%s; using configured default %s", metadata_timeout, timeout) 

348 

349 metadata_max_groups = metadata.get("graph_api_max_groups") 

350 if metadata_max_groups is not None: 

351 try: 

352 max_groups_candidate = int(metadata_max_groups) 

353 if max_groups_candidate >= 0: 

354 max_groups = max_groups_candidate 

355 else: 

356 logger.warning("Invalid provider_metadata.graph_api_max_groups=%s; using configured default %s", metadata_max_groups, max_groups) 

357 except (TypeError, ValueError): 

358 logger.warning("Invalid provider_metadata.graph_api_max_groups=%s; using configured default %s", metadata_max_groups, max_groups) 

359 

360 return enabled, timeout, max_groups 

361 

362 async def _fetch_entra_groups_from_graph_api(self, access_token: str, user_email: str, provider_metadata: Optional[Dict[str, Any]] = None) -> Optional[List[str]]: 

363 """Fetch Entra group object IDs from Microsoft Graph for overage tokens. 

364 

365 Args: 

366 access_token: Delegated OAuth access token from Entra. 

367 user_email: User identifier for structured logs. 

368 provider_metadata: Optional provider metadata for runtime overrides. 

369 

370 Returns: 

371 List of group IDs on success, empty list when Graph returns no IDs, 

372 or None if retrieval failed/disabled. 

373 """ 

374 graph_api_enabled, graph_api_timeout, graph_api_max_groups = self._resolve_entra_graph_fallback_settings(provider_metadata) 

375 if not graph_api_enabled: 

376 logger.info("Microsoft Graph fallback for Entra group overage is disabled") 

377 return None 

378 

379 # First-Party 

380 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

381 

382 client = await get_http_client() 

383 try: 

384 response = await client.post( 

385 "https://graph.microsoft.com/v1.0/me/getMemberObjects", 

386 headers={"Authorization": f"Bearer {access_token}"}, 

387 json={"securityEnabledOnly": True}, 

388 timeout=graph_api_timeout, 

389 ) 

390 except Exception as e: 

391 logger.error("Failed to retrieve groups from Graph API for %s: %s", user_email, e) 

392 return None 

393 

394 if response.status_code != 200: 

395 if response.status_code in {401, 403}: 

396 logger.error( 

397 "Failed to retrieve groups from Graph API for %s: HTTP %s. Check Entra delegated permissions and consent (minimum User.Read).", 

398 user_email, 

399 response.status_code, 

400 ) 

401 else: 

402 logger.error("Failed to retrieve groups from Graph API for %s: HTTP %s", user_email, response.status_code) 

403 return None 

404 

405 try: 

406 groups_payload = response.json() 

407 except ValueError as e: 

408 logger.error("Failed to parse Graph API response for %s: %s", user_email, e) 

409 return None 

410 

411 group_values = groups_payload.get("value", []) 

412 if not isinstance(group_values, list): 

413 logger.warning("Unexpected Graph API groups payload for %s: expected list in 'value'", user_email) 

414 return [] 

415 

416 deduped_groups: List[str] = [] 

417 seen_groups: set[str] = set() 

418 for group_id in group_values: 

419 if not isinstance(group_id, str): 

420 continue 

421 normalized_group_id = group_id.strip() 

422 if not normalized_group_id or normalized_group_id in seen_groups: 

423 continue 

424 seen_groups.add(normalized_group_id) 

425 deduped_groups.append(normalized_group_id) 

426 

427 if graph_api_max_groups > 0: 

428 if len(deduped_groups) > graph_api_max_groups: 

429 logger.warning( 

430 "Graph API returned %d groups for %s; applying configured cap (%d)", 

431 len(deduped_groups), 

432 user_email, 

433 graph_api_max_groups, 

434 ) 

435 deduped_groups = deduped_groups[:graph_api_max_groups] 

436 

437 logger.info("Retrieved %d groups from Graph API for %s", len(deduped_groups), user_email) 

438 return deduped_groups 

439 

440 def list_enabled_providers(self) -> List[SSOProvider]: 

441 """Get list of enabled SSO providers. 

442 

443 Returns: 

444 List of enabled SSO providers 

445 

446 Examples: 

447 Returns empty list when DB has no providers: 

448 >>> from unittest.mock import MagicMock 

449 >>> service = SSOService(MagicMock()) 

450 >>> service.db.execute.return_value.scalars.return_value.all.return_value = [] 

451 >>> service.list_enabled_providers() 

452 [] 

453 """ 

454 stmt = select(SSOProvider).where(SSOProvider.is_enabled.is_(True)) 

455 result = self.db.execute(stmt) 

456 return list(result.scalars().all()) 

457 

458 def list_all_providers(self) -> List[SSOProvider]: 

459 """Get list of all SSO providers (enabled and disabled). 

460 

461 Returns: 

462 List of all SSO providers 

463 

464 Examples: 

465 Returns empty list when DB has no providers: 

466 >>> from unittest.mock import MagicMock 

467 >>> service = SSOService(MagicMock()) 

468 >>> service.db.execute.return_value.scalars.return_value.all.return_value = [] 

469 >>> service.list_all_providers() 

470 [] 

471 """ 

472 stmt = select(SSOProvider) 

473 result = self.db.execute(stmt) 

474 return list(result.scalars().all()) 

475 

476 def get_provider(self, provider_id: str) -> Optional[SSOProvider]: 

477 """Get SSO provider by ID. 

478 

479 Args: 

480 provider_id: Provider identifier (e.g., 'github', 'google') 

481 

482 Returns: 

483 SSO provider or None if not found 

484 

485 Examples: 

486 >>> from unittest.mock import MagicMock 

487 >>> service = SSOService(MagicMock()) 

488 >>> service.db.execute.return_value.scalar_one_or_none.return_value = None 

489 >>> service.get_provider('x') is None 

490 True 

491 """ 

492 stmt = select(SSOProvider).where(SSOProvider.id == provider_id) 

493 result = self.db.execute(stmt) 

494 return result.scalar_one_or_none() 

495 

496 def get_provider_by_name(self, provider_name: str) -> Optional[SSOProvider]: 

497 """Get SSO provider by name. 

498 

499 Args: 

500 provider_name: Provider name (e.g., 'github', 'google') 

501 

502 Returns: 

503 SSO provider or None if not found 

504 

505 Examples: 

506 >>> from unittest.mock import MagicMock 

507 >>> service = SSOService(MagicMock()) 

508 >>> service.db.execute.return_value.scalar_one_or_none.return_value = None 

509 >>> service.get_provider_by_name('github') is None 

510 True 

511 """ 

512 stmt = select(SSOProvider).where(SSOProvider.name == provider_name) 

513 result = self.db.execute(stmt) 

514 return result.scalar_one_or_none() 

515 

516 @staticmethod 

517 def _normalize_issuer_url(issuer: str) -> str: 

518 """Normalize issuer URL for allowlist comparisons. 

519 

520 Args: 

521 issuer: Raw issuer URL from provider config or metadata. 

522 

523 Returns: 

524 Lowercased issuer URL without trailing slash. 

525 """ 

526 return issuer.strip().rstrip("/").lower() 

527 

528 def _enforce_allowed_issuer(self, issuer: Optional[str]) -> None: 

529 """Enforce configured issuer allowlist when present. 

530 

531 Args: 

532 issuer: Candidate issuer URL from provider configuration. 

533 

534 Raises: 

535 ValueError: If issuer is set but not in configured allowlist. 

536 """ 

537 allowed_issuers = getattr(settings, "sso_issuers", None) 

538 if not allowed_issuers: 

539 return 

540 

541 if not isinstance(issuer, str) or not issuer.strip(): 

542 logger.warning("SSO provider has blank/empty issuer while SSO_ISSUERS allowlist is configured; issuer enforcement skipped.") 

543 return 

544 

545 normalized_candidate = self._normalize_issuer_url(issuer) 

546 normalized_allowlist = {self._normalize_issuer_url(str(allowed_issuer)) for allowed_issuer in allowed_issuers if isinstance(allowed_issuer, str) and allowed_issuer.strip()} 

547 if normalized_allowlist and normalized_candidate not in normalized_allowlist: 

548 raise ValueError("Issuer is not allowed by SSO_ISSUERS configuration") 

549 

550 @staticmethod 

551 def _resolve_team_mapping_target(mapping_value: Any) -> Tuple[Optional[str], str]: 

552 """Resolve team mapping value into team id and role. 

553 

554 Args: 

555 mapping_value: Team mapping target value from provider config. 

556 

557 Returns: 

558 Tuple of ``(team_id, role)`` where ``team_id`` may be ``None`` and 

559 role defaults to ``member`` when not explicitly valid. 

560 """ 

561 if isinstance(mapping_value, str) and mapping_value.strip(): 

562 return mapping_value.strip(), "member" 

563 

564 if isinstance(mapping_value, dict): 

565 team_id_value = mapping_value.get("team_id") or mapping_value.get("id") 

566 team_id = str(team_id_value).strip() if team_id_value is not None else "" 

567 role_value = str(mapping_value.get("role", "member")).strip().lower() 

568 role = role_value if role_value in {"owner", "member"} else "member" 

569 return (team_id if team_id else None), role 

570 

571 return None, "member" 

572 

573 async def _apply_team_mapping(self, user_email: str, user_info: Dict[str, Any], provider: Optional[SSOProvider]) -> None: 

574 """Apply provider team mappings based on SSO group claims. 

575 

576 Reconciles team memberships: grants new SSO-based memberships and revokes 

577 stale ones when groups are removed from the identity provider. 

578 

579 Args: 

580 user_email: Authenticated user email to map into teams. 

581 user_info: Identity claims payload containing optional group claims. 

582 provider: SSO provider configuration with ``team_mapping`` entries. 

583 

584 Returns: 

585 None. 

586 """ 

587 if not provider: 

588 return 

589 

590 mapping = getattr(provider, "team_mapping", None) 

591 if not isinstance(mapping, dict) or not mapping: 

592 return 

593 

594 groups_raw = user_info.get("groups", []) 

595 if isinstance(groups_raw, str): 

596 groups = [groups_raw] 

597 elif isinstance(groups_raw, list): 

598 groups = [str(group).strip() for group in groups_raw if str(group).strip()] 

599 else: 

600 groups = [] 

601 

602 normalized_groups = {group.lower() for group in groups} 

603 

604 # First-Party 

605 from mcpgateway.services.team_management_service import ( # pylint: disable=import-outside-toplevel 

606 MemberAlreadyExistsError, 

607 TeamManagementError, 

608 TeamManagementService, 

609 ) 

610 

611 team_service = TeamManagementService(self.db) 

612 

613 # Get current SSO-granted team memberships for this user 

614 # First-Party 

615 from mcpgateway.db import EmailTeamMember # pylint: disable=import-outside-toplevel 

616 

617 stmt = select(EmailTeamMember).where( 

618 EmailTeamMember.user_email == user_email, 

619 EmailTeamMember.grant_source == "sso", 

620 EmailTeamMember.is_active.is_(True), 

621 ) 

622 result = self.db.execute(stmt) 

623 current_sso_memberships = result.scalars().all() 

624 

625 # Build set of desired team IDs from current groups + team_mapping 

626 desired_team_ids = set() 

627 for source_group, target in mapping.items(): 

628 if not isinstance(source_group, str): 

629 continue 

630 source_group_normalized = source_group.strip().lower() 

631 if not source_group_normalized: 

632 continue 

633 

634 if source_group_normalized in normalized_groups: 

635 team_id, _ = self._resolve_team_mapping_target(target) 

636 if team_id: 

637 desired_team_ids.add(team_id) 

638 

639 # Revoke SSO memberships that are no longer in desired set 

640 for membership in current_sso_memberships: 

641 if membership.team_id not in desired_team_ids: 

642 try: 

643 await team_service.remove_member_from_team( 

644 team_id=membership.team_id, 

645 user_email=user_email, 

646 ) 

647 logger.info( 

648 "Revoked SSO team membership for %s from team %s (group no longer in claims)", 

649 SecurityValidator.sanitize_log_message(user_email), 

650 membership.team_id, 

651 ) 

652 except Exception as exc: 

653 logger.warning( 

654 "Failed to revoke SSO team membership for %s from team %s: %s", 

655 SecurityValidator.sanitize_log_message(user_email), 

656 membership.team_id, 

657 exc, 

658 ) 

659 

660 # Grant new SSO memberships 

661 for source_group, target in mapping.items(): 

662 if not isinstance(source_group, str): 

663 continue 

664 source_group_normalized = source_group.strip().lower() 

665 if not source_group_normalized or source_group_normalized not in normalized_groups: 

666 continue 

667 

668 team_id, role = self._resolve_team_mapping_target(target) 

669 if not team_id: 

670 logger.warning( 

671 "Skipping invalid SSO team_mapping entry for provider %s and group '%s'", 

672 provider.id, 

673 source_group, 

674 ) 

675 continue 

676 

677 try: 

678 await team_service.add_member_to_team( 

679 team_id=team_id, 

680 user_email=user_email, 

681 role=role, 

682 invited_by=user_email, 

683 grant_source="sso", 

684 ) 

685 logger.info( 

686 "Granted SSO team membership for %s to team %s via group '%s'", 

687 SecurityValidator.sanitize_log_message(user_email), 

688 team_id, 

689 source_group, 

690 ) 

691 except MemberAlreadyExistsError: 

692 logger.debug( 

693 "SSO team_mapping: user %s already member of team %s", 

694 SecurityValidator.sanitize_log_message(user_email), 

695 team_id, 

696 ) 

697 except TeamManagementError as exc: 

698 logger.warning( 

699 "SSO team_mapping failed for user %s, group '%s', team '%s': %s", 

700 SecurityValidator.sanitize_log_message(user_email), 

701 source_group, 

702 team_id, 

703 exc, 

704 ) 

705 except Exception as exc: 

706 logger.error( 

707 "Unexpected error in SSO team_mapping for user %s, group '%s': %s", 

708 SecurityValidator.sanitize_log_message(user_email), 

709 source_group, 

710 exc, 

711 exc_info=True, 

712 ) 

713 

714 async def create_provider(self, provider_data: Dict[str, Any]) -> SSOProvider: 

715 """Create new SSO provider configuration. 

716 

717 Args: 

718 provider_data: Provider configuration data 

719 

720 Returns: 

721 Created SSO provider 

722 

723 Examples: 

724 >>> import asyncio 

725 >>> from unittest.mock import MagicMock, AsyncMock 

726 >>> service = SSOService(MagicMock()) 

727 >>> service._encrypt_secret = AsyncMock(side_effect=lambda s: 'ENC(' + s + ')') 

728 >>> data = { 

729 ... 'id': 'github', 'name': 'github', 'display_name': 'GitHub', 'provider_type': 'oauth2', 

730 ... 'client_id': 'cid', 'client_secret': 'sec', 

731 ... 'authorization_url': 'https://example/auth', 'token_url': 'https://example/token', 

732 ... 'userinfo_url': 'https://example/user', 'scope': 'user:email' 

733 ... } 

734 >>> provider = asyncio.run(service.create_provider(data)) 

735 >>> hasattr(provider, 'id') and provider.id == 'github' 

736 True 

737 >>> provider.client_secret_encrypted.startswith('ENC(') 

738 True 

739 """ 

740 self._enforce_allowed_issuer(provider_data.get("issuer")) 

741 

742 # Encrypt client secret 

743 client_secret = provider_data.pop("client_secret") 

744 provider_data["client_secret_encrypted"] = await self._encrypt_secret(client_secret) 

745 

746 # Filter to valid SSOProvider columns to prevent TypeError on unknown keys 

747 valid_columns = {c.key for c in SSOProvider.__table__.columns} 

748 filtered_data = {k: v for k, v in provider_data.items() if k in valid_columns} 

749 skipped = set(provider_data) - set(filtered_data) 

750 if skipped: 

751 logger.warning("Ignored unknown SSOProvider fields during creation: %s", skipped) 

752 

753 provider = SSOProvider(**filtered_data) 

754 self.db.add(provider) 

755 self.db.commit() 

756 self.db.refresh(provider) 

757 return provider 

758 

759 async def update_provider(self, provider_id: str, provider_data: Dict[str, Any]) -> Optional[SSOProvider]: 

760 """Update existing SSO provider configuration. 

761 

762 Args: 

763 provider_id: Provider identifier 

764 provider_data: Updated provider data 

765 

766 Returns: 

767 Updated SSO provider or None if not found 

768 

769 Examples: 

770 >>> import asyncio 

771 >>> from types import SimpleNamespace 

772 >>> from unittest.mock import MagicMock, AsyncMock 

773 >>> svc = SSOService(MagicMock()) 

774 >>> # Existing provider object 

775 >>> existing = SimpleNamespace(id='github', name='github', client_id='old', client_secret_encrypted='X', is_enabled=True) 

776 >>> svc.get_provider = lambda _id: existing 

777 >>> svc._encrypt_secret = AsyncMock(side_effect=lambda s: 'ENC-' + s) 

778 >>> svc.db.commit = lambda: None 

779 >>> svc.db.refresh = lambda obj: None 

780 >>> updated = asyncio.run(svc.update_provider('github', {'client_id': 'new', 'client_secret': 'sec'})) 

781 >>> updated.client_id 

782 'new' 

783 >>> updated.client_secret_encrypted 

784 'ENC-sec' 

785 """ 

786 provider = self.get_provider(provider_id) 

787 if not provider: 

788 return None 

789 

790 if "issuer" in provider_data: 

791 self._enforce_allowed_issuer(provider_data.get("issuer")) 

792 

793 # Handle client secret encryption if provided 

794 if "client_secret" in provider_data: 

795 client_secret = provider_data.pop("client_secret") 

796 provider_data["client_secret_encrypted"] = await self._encrypt_secret(client_secret) 

797 

798 for key, value in provider_data.items(): 

799 if hasattr(provider, key): 

800 setattr(provider, key, value) 

801 

802 provider.updated_at = utc_now() 

803 self.db.commit() 

804 self.db.refresh(provider) 

805 return provider 

806 

807 def delete_provider(self, provider_id: str) -> bool: 

808 """Delete SSO provider configuration. 

809 

810 Args: 

811 provider_id: Provider identifier 

812 

813 Returns: 

814 True if deleted, False if not found 

815 

816 Examples: 

817 >>> from types import SimpleNamespace 

818 >>> from unittest.mock import MagicMock 

819 >>> svc = SSOService(MagicMock()) 

820 >>> svc.db.delete = lambda obj: None 

821 >>> svc.db.commit = lambda: None 

822 >>> svc.get_provider = lambda _id: SimpleNamespace(id='github') 

823 >>> svc.delete_provider('github') 

824 True 

825 >>> svc.get_provider = lambda _id: None 

826 >>> svc.delete_provider('missing') 

827 False 

828 """ 

829 provider = self.get_provider(provider_id) 

830 if not provider: 

831 return False 

832 

833 self.db.delete(provider) 

834 self.db.commit() 

835 return True 

836 

837 def generate_pkce_challenge(self) -> Tuple[str, str]: 

838 """Generate PKCE code verifier and challenge for OAuth 2.1. 

839 

840 Returns: 

841 Tuple of (code_verifier, code_challenge) 

842 

843 Examples: 

844 Generate verifier and challenge: 

845 >>> from unittest.mock import Mock 

846 >>> service = SSOService(Mock()) 

847 >>> verifier, challenge = service.generate_pkce_challenge() 

848 >>> isinstance(verifier, str) and isinstance(challenge, str) 

849 True 

850 >>> len(verifier) >= 43 

851 True 

852 >>> len(challenge) >= 43 

853 True 

854 """ 

855 # Generate cryptographically random code verifier 

856 code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=") 

857 

858 # Generate code challenge using SHA256 

859 code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()).decode("utf-8").rstrip("=") 

860 

861 return code_verifier, code_challenge 

862 

863 @staticmethod 

864 def _normalize_scope_values(values: Optional[List[str] | str]) -> List[str]: 

865 """Normalize scope values to a deduplicated, ordered list. 

866 

867 Args: 

868 values: Scope input as list or space-delimited string. 

869 

870 Returns: 

871 Ordered, unique scope values. 

872 """ 

873 if values is None: 

874 return [] 

875 

876 raw_values: List[str] 

877 if isinstance(values, str): 

878 raw_values = values.split() 

879 else: 

880 raw_values = [str(v) for v in values if isinstance(v, str) and v.strip()] 

881 

882 normalized: List[str] = [] 

883 seen: set[str] = set() 

884 for value in raw_values: 

885 scope = value.strip() 

886 if not scope or scope in seen: 

887 continue 

888 normalized.append(scope) 

889 seen.add(scope) 

890 return normalized 

891 

892 def _resolve_login_scopes(self, provider: SSOProvider, requested_scopes: Optional[List[str]]) -> List[str]: 

893 """Resolve requested SSO scopes against provider allowlists. 

894 

895 Args: 

896 provider: SSO provider configuration. 

897 requested_scopes: Optional scopes requested by the client. 

898 

899 Returns: 

900 Final scope list to send to the provider. 

901 

902 Raises: 

903 ValueError: If provider scope configuration is invalid or request includes disallowed scopes. 

904 """ 

905 configured_scopes = self._normalize_scope_values(getattr(provider, "scope", None)) 

906 if not configured_scopes: 

907 raise ValueError("Provider has no configured scopes") 

908 

909 allowed_scopes = configured_scopes 

910 provider_metadata = getattr(provider, "provider_metadata", {}) or {} 

911 metadata_allowed_raw = provider_metadata.get("allowed_scopes") 

912 metadata_allowed = self._normalize_scope_values(metadata_allowed_raw) if metadata_allowed_raw else [] 

913 if metadata_allowed: 

914 metadata_allowed_set = set(metadata_allowed) 

915 allowed_scopes = [scope for scope in configured_scopes if scope in metadata_allowed_set] 

916 

917 if not allowed_scopes: 

918 raise ValueError("No allowed scopes configured for provider") 

919 

920 if not requested_scopes: 

921 return allowed_scopes 

922 

923 normalized_requested = self._normalize_scope_values(requested_scopes) 

924 if not normalized_requested: 

925 return allowed_scopes 

926 

927 allowed_set = set(allowed_scopes) 

928 invalid_scopes = [scope for scope in normalized_requested if scope not in allowed_set] 

929 if invalid_scopes: 

930 invalid_csv = ", ".join(invalid_scopes) 

931 raise ValueError(f"Invalid scopes requested: {invalid_csv}") 

932 

933 return normalized_requested 

934 

935 def _get_state_binding_secret(self) -> bytes: 

936 """Resolve secret bytes used for state/session binding HMAC. 

937 

938 Returns: 

939 Secret bytes used for HMAC signatures. 

940 """ 

941 secret_source = settings.auth_encryption_secret 

942 if hasattr(secret_source, "get_secret_value"): 

943 return secret_source.get_secret_value().encode("utf-8") 

944 return str(secret_source).encode("utf-8") 

945 

946 def _generate_session_bound_state(self, provider_id: str, session_binding: str) -> str: 

947 """Generate state value bound to browser session context. 

948 

949 Args: 

950 provider_id: SSO provider identifier. 

951 session_binding: Browser-session marker. 

952 

953 Returns: 

954 Signed state token including nonce and HMAC signature. 

955 """ 

956 state_nonce = secrets.token_urlsafe(24) 

957 message = f"{provider_id}:{session_binding}:{state_nonce}".encode("utf-8") 

958 signature = hmac.new(self._get_state_binding_secret(), message, hashlib.sha256).hexdigest() 

959 return f"{state_nonce}{self._STATE_BINDING_SEPARATOR}{signature}" 

960 

961 def _is_session_bound_state(self, state: str) -> bool: 

962 """Return whether state appears to carry a session-binding signature. 

963 

964 Args: 

965 state: State value from OAuth flow. 

966 

967 Returns: 

968 ``True`` when state has expected nonce/signature format. 

969 """ 

970 if self._STATE_BINDING_SEPARATOR not in state: 

971 return False 

972 nonce, signature = state.rsplit(self._STATE_BINDING_SEPARATOR, 1) 

973 return bool(nonce) and len(signature) == self._STATE_BINDING_HEX_LEN 

974 

975 def _verify_session_bound_state(self, provider_id: str, state: str, session_binding: str) -> bool: 

976 """Verify state HMAC binding against the current browser session marker. 

977 

978 Args: 

979 provider_id: SSO provider identifier. 

980 state: State value to validate. 

981 session_binding: Current browser-session marker. 

982 

983 Returns: 

984 ``True`` when state signature matches the expected session binding. 

985 """ 

986 if not session_binding or not self._is_session_bound_state(state): 

987 return False 

988 nonce, signature = state.rsplit(self._STATE_BINDING_SEPARATOR, 1) 

989 message = f"{provider_id}:{session_binding}:{nonce}".encode("utf-8") 

990 expected_signature = hmac.new(self._get_state_binding_secret(), message, hashlib.sha256).hexdigest() 

991 return hmac.compare_digest(signature, expected_signature) 

992 

993 @staticmethod 

994 def _is_email_verified_claim(user_info: Dict[str, Any]) -> bool: 

995 """Evaluate email verification claim when provided by the IdP. 

996 

997 When the ``email_verified`` claim is **absent** from ``user_info`` (e.g. 

998 Microsoft Entra ID and GitHub do not include it for work / school 

999 accounts) the function returns ``True`` so that those providers are not 

1000 incorrectly blocked. The check is only enforced when the IdP 

1001 *explicitly* supplies the claim — a ``False``/``0``/``"false"`` value 

1002 means the provider has flagged the address as unverified and the user 

1003 should be rejected. 

1004 

1005 Args: 

1006 user_info: Normalized user-info payload from provider. 

1007 

1008 Returns: 

1009 ``True`` when the claim is absent (provider does not restrict) or 

1010 when it is explicitly set to a truthy value; ``False`` only when the 

1011 provider explicitly indicates the address is *not* verified. 

1012 """ 

1013 if "email_verified" not in user_info: 

1014 # Claim not provided by IdP — treat as no restriction (pass through). 

1015 return True 

1016 

1017 claim_value = user_info.get("email_verified") 

1018 if isinstance(claim_value, bool): 

1019 return claim_value 

1020 if isinstance(claim_value, int): 

1021 return claim_value == 1 

1022 if isinstance(claim_value, str): 

1023 return claim_value.strip().lower() in {"1", "true", "yes", "on"} 

1024 return False 

1025 

1026 def get_authorization_url( 

1027 self, 

1028 provider_id: str, 

1029 redirect_uri: str, 

1030 scopes: Optional[List[str]] = None, 

1031 session_binding: Optional[str] = None, 

1032 ) -> Optional[str]: 

1033 """Generate OAuth authorization URL for provider. 

1034 

1035 Args: 

1036 provider_id: Provider identifier 

1037 redirect_uri: Callback URI after authorization 

1038 scopes: Optional custom scopes (uses provider default if None) 

1039 session_binding: Optional browser-session marker for state binding. 

1040 

1041 Returns: 

1042 Authorization URL or None if provider not found 

1043 

1044 Examples: 

1045 >>> from types import SimpleNamespace 

1046 >>> from unittest.mock import MagicMock 

1047 >>> service = SSOService(MagicMock()) 

1048 >>> provider = SimpleNamespace(id='github', is_enabled=True, provider_type='oauth2', client_id='cid', authorization_url='https://example/auth', scope='user:email') 

1049 >>> service.get_provider = lambda _pid: provider 

1050 >>> service.db.add = lambda x: None 

1051 >>> service.db.commit = lambda: None 

1052 >>> url = service.get_authorization_url('github', 'https://app/callback', ['user:email']) 

1053 >>> isinstance(url, str) and 'client_id=cid' in url and 'state=' in url 

1054 True 

1055 

1056 Missing provider returns None: 

1057 >>> service.get_provider = lambda _pid: None 

1058 >>> service.get_authorization_url('missing', 'https://app/callback') is None 

1059 True 

1060 """ 

1061 provider = self.get_provider(provider_id) 

1062 if not provider or not provider.is_enabled: 

1063 return None 

1064 

1065 # Generate PKCE parameters 

1066 code_verifier, code_challenge = self.generate_pkce_challenge() 

1067 

1068 resolved_scopes = self._resolve_login_scopes(provider, scopes) 

1069 

1070 # Generate CSRF state (session-bound when browser context is available). 

1071 state = self._generate_session_bound_state(provider_id, session_binding) if session_binding else secrets.token_urlsafe(32) 

1072 

1073 # Generate OIDC nonce if applicable 

1074 nonce = secrets.token_urlsafe(16) if provider.provider_type == "oidc" else None 

1075 

1076 # Create auth session 

1077 auth_session = SSOAuthSession(provider_id=provider_id, state=state, code_verifier=code_verifier, nonce=nonce, redirect_uri=redirect_uri) 

1078 self.db.add(auth_session) 

1079 self.db.commit() 

1080 

1081 # Build authorization URL 

1082 params = { 

1083 "client_id": provider.client_id, 

1084 "response_type": "code", 

1085 "redirect_uri": redirect_uri, 

1086 "state": state, 

1087 "scope": " ".join(resolved_scopes), 

1088 "code_challenge": code_challenge, 

1089 "code_challenge_method": "S256", 

1090 } 

1091 

1092 if nonce: 

1093 params["nonce"] = nonce 

1094 

1095 return f"{provider.authorization_url}?{urllib.parse.urlencode(params)}" 

1096 

1097 async def handle_oauth_callback(self, provider_id: str, code: str, state: str, session_binding: Optional[str] = None) -> Optional[Dict[str, Any]]: 

1098 """Handle OAuth callback and exchange code for tokens. 

1099 

1100 Args: 

1101 provider_id: Provider identifier 

1102 code: Authorization code from callback 

1103 state: CSRF state parameter 

1104 session_binding: Optional browser-session marker used to verify bound state. 

1105 

1106 Returns: 

1107 User info dict or None if authentication failed 

1108 

1109 Examples: 

1110 Happy-path with patched exchanges and user info: 

1111 >>> import asyncio 

1112 >>> from types import SimpleNamespace 

1113 >>> from unittest.mock import MagicMock 

1114 >>> svc = SSOService(MagicMock()) 

1115 >>> # Mock DB auth session lookup 

1116 >>> provider = SimpleNamespace(id='github', is_enabled=True, provider_type='oauth2') 

1117 >>> auth_session = SimpleNamespace(provider_id='github', state='st', provider=provider, is_expired=False, nonce=None) 

1118 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = auth_session 

1119 >>> # Patch token exchange and user info retrieval 

1120 >>> async def _ex(p, sess, c): 

1121 ... return {'access_token': 'tok', 'id_token': 'id_tok'} 

1122 >>> async def _ui(p, access, token_data=None, expected_nonce=None): 

1123 ... return {'email': 'user@example.com'} 

1124 >>> svc._exchange_code_for_tokens = _ex 

1125 >>> svc._get_user_info = _ui 

1126 >>> svc.db.delete = lambda obj: None 

1127 >>> svc.db.commit = lambda: None 

1128 >>> out = asyncio.run(svc.handle_oauth_callback('github', 'code', 'st')) 

1129 >>> out['email'] 

1130 'user@example.com' 

1131 

1132 Early return cases: 

1133 >>> # No session 

1134 >>> svc2 = SSOService(MagicMock()) 

1135 >>> svc2.db.execute.return_value.scalar_one_or_none.return_value = None 

1136 >>> asyncio.run(svc2.handle_oauth_callback('github', 'c', 's')) is None 

1137 True 

1138 >>> # Expired session 

1139 >>> expired = SimpleNamespace(provider_id='github', state='st', provider=SimpleNamespace(is_enabled=True), is_expired=True) 

1140 >>> svc3 = SSOService(MagicMock()) 

1141 >>> svc3.db.execute.return_value.scalar_one_or_none.return_value = expired 

1142 >>> asyncio.run(svc3.handle_oauth_callback('github', 'c', 'st')) is None 

1143 True 

1144 >>> # Disabled provider 

1145 >>> disabled = SimpleNamespace(provider_id='github', state='st', provider=SimpleNamespace(is_enabled=False), is_expired=False) 

1146 >>> svc4 = SSOService(MagicMock()) 

1147 >>> svc4.db.execute.return_value.scalar_one_or_none.return_value = disabled 

1148 >>> asyncio.run(svc4.handle_oauth_callback('github', 'c', 'st')) is None 

1149 True 

1150 """ 

1151 callback_result = await self.handle_oauth_callback_with_tokens(provider_id, code, state, session_binding=session_binding) 

1152 if not callback_result: 

1153 return None 

1154 user_info, _token_data = callback_result 

1155 return user_info 

1156 

1157 async def handle_oauth_callback_with_tokens( 

1158 self, 

1159 provider_id: str, 

1160 code: str, 

1161 state: str, 

1162 session_binding: Optional[str] = None, 

1163 ) -> Optional[Tuple[Dict[str, Any], Dict[str, Any]]]: 

1164 """Handle OAuth callback and return both user info and raw token response. 

1165 

1166 Args: 

1167 provider_id: Provider identifier 

1168 code: Authorization code from callback 

1169 state: CSRF state parameter 

1170 session_binding: Optional browser-session marker used to verify bound state. 

1171 

1172 Returns: 

1173 Tuple of (user_info, token_data) or None if authentication fails 

1174 """ 

1175 # Validate auth session 

1176 stmt = select(SSOAuthSession).where(SSOAuthSession.state == state, SSOAuthSession.provider_id == provider_id) 

1177 auth_session = self.db.execute(stmt).scalar_one_or_none() 

1178 

1179 if not auth_session: 

1180 logger.warning(f"OAuth callback: no auth session found for state/provider {provider_id}. Possible CSRF or replay.") 

1181 return None 

1182 

1183 if auth_session.is_expired: 

1184 logger.warning(f"OAuth callback: auth session expired for provider {provider_id}.") 

1185 self.db.delete(auth_session) 

1186 self.db.commit() 

1187 return None 

1188 

1189 if self._is_session_bound_state(state): 

1190 if not session_binding or not self._verify_session_bound_state(provider_id, state, session_binding): 

1191 logger.warning( 

1192 "OAuth callback: state/session mismatch for provider %s. Possible CSRF or cross-session replay.", 

1193 provider_id, 

1194 ) 

1195 return None 

1196 

1197 provider = auth_session.provider 

1198 if not provider: 

1199 logger.error(f"OAuth callback: provider '{provider_id}' not found for auth session.") 

1200 return None 

1201 

1202 if not provider.is_enabled: 

1203 logger.warning(f"OAuth callback: provider '{provider_id}' is disabled.") 

1204 return None 

1205 

1206 try: 

1207 # Exchange authorization code for tokens 

1208 logger.info(f"Starting token exchange for provider {provider_id}") 

1209 token_data = await self._exchange_code_for_tokens(provider, auth_session, code) 

1210 if not token_data: 

1211 logger.error(f"Failed to exchange code for tokens for provider {provider_id}") 

1212 return None 

1213 logger.info(f"Token exchange successful for provider {provider_id}") 

1214 callback_nonce = getattr(auth_session, "nonce", None) 

1215 

1216 # For OIDC providers, verify id_token before any claim extraction. 

1217 if provider.provider_type == "oidc": 

1218 if callback_nonce is None: 

1219 logger.error("OAuth callback: missing nonce for OIDC provider %s.", provider_id) 

1220 return None 

1221 id_token = token_data.get("id_token") 

1222 if not isinstance(id_token, str): 

1223 logger.error("OAuth callback: missing id_token for OIDC provider %s.", provider_id) 

1224 return None 

1225 verified_claims = await self._verify_oidc_id_token(provider, id_token, expected_nonce=callback_nonce) 

1226 if verified_claims is None: 

1227 # ADFS: on-prem deployments often lack a discoverable JWKS endpoint, 

1228 # so verification may fail. Fall through to _get_user_info which 

1229 # decodes the id_token received over the direct TLS channel. 

1230 if provider.id == ADFS_PROVIDER_ID: 

1231 logger.warning("OIDC id_token verification failed for ADFS provider %s; falling back to TLS-trust decode.", provider_id) 

1232 # Mark that verification was attempted so _get_user_info 

1233 # does not redundantly re-attempt it. 

1234 token_data = dict(token_data) 

1235 token_data["_adfs_verification_attempted"] = True 

1236 else: 

1237 logger.error("id_token verification failed for provider %s", provider_id) 

1238 return None 

1239 else: 

1240 token_data = dict(token_data) 

1241 token_data["_verified_id_token_claims"] = verified_claims 

1242 

1243 # Get user info from provider (pass full token_data for id_token parsing) 

1244 user_info = await self._get_user_info(provider, token_data["access_token"], token_data, expected_nonce=callback_nonce) 

1245 if not user_info: 

1246 logger.error(f"Failed to get user info for provider {provider_id}") 

1247 return None 

1248 

1249 # Clean up auth session 

1250 self.db.delete(auth_session) 

1251 self.db.commit() 

1252 

1253 return user_info, token_data 

1254 

1255 except Exception as e: 

1256 # Clean up auth session on error 

1257 logger.error(f"OAuth callback failed for provider {provider_id}: {type(e).__name__}: {str(e)}") 

1258 logger.exception("Full traceback for OAuth callback failure:") 

1259 self.db.delete(auth_session) 

1260 self.db.commit() 

1261 return None 

1262 

1263 async def _exchange_code_for_tokens(self, provider: SSOProvider, auth_session: SSOAuthSession, code: str) -> Optional[Dict[str, Any]]: 

1264 """Exchange authorization code for access tokens. 

1265 

1266 Args: 

1267 provider: SSO provider configuration 

1268 auth_session: Auth session with PKCE parameters 

1269 code: Authorization code 

1270 

1271 Returns: 

1272 Token response dict or None if failed 

1273 """ 

1274 token_params = { 

1275 "client_id": provider.client_id, 

1276 "client_secret": await self._decrypt_secret(provider.client_secret_encrypted), 

1277 "code": code, 

1278 "grant_type": "authorization_code", 

1279 "redirect_uri": auth_session.redirect_uri, 

1280 "code_verifier": auth_session.code_verifier, 

1281 } 

1282 

1283 # First-Party 

1284 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

1285 

1286 client = await get_http_client() 

1287 response = await client.post(provider.token_url, data=token_params, headers={"Accept": "application/json"}) 

1288 

1289 if response.status_code == 200: 

1290 return response.json() 

1291 logger.error(f"Token exchange failed for {provider.name}: HTTP {response.status_code} - {response.text}") 

1292 

1293 return None 

1294 

1295 async def _enrich_user_data_from_claims( 

1296 self, 

1297 provider: SSOProvider, 

1298 user_data: Dict[str, Any], 

1299 access_token: str, 

1300 verified_id_token_claims: Optional[Dict[str, Any]], 

1301 ) -> None: 

1302 """Enrich userinfo response with provider-specific claims. 

1303 

1304 Mutates ``user_data`` in place by merging groups, roles, and other 

1305 claims from the id_token or external APIs (GitHub orgs, Entra Graph) 

1306 that the userinfo endpoint does not include on its own. 

1307 

1308 Args: 

1309 provider: SSO provider configuration. 

1310 user_data: Mutable userinfo dict from the provider endpoint. 

1311 access_token: OAuth access token for follow-up API calls. 

1312 verified_id_token_claims: Verified id_token claims, if available. 

1313 """ 

1314 # GitHub: fetch organizations for admin assignment 

1315 if provider.id == "github" and settings.sso_github_admin_orgs: 

1316 # First-Party 

1317 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

1318 

1319 client = await get_http_client() 

1320 try: 

1321 orgs_response = await client.get("https://api.github.com/user/orgs", headers={"Authorization": f"Bearer {access_token}"}) 

1322 if orgs_response.status_code == 200: 

1323 user_data["organizations"] = [org["login"] for org in orgs_response.json()] 

1324 else: 

1325 logger.warning(f"Failed to fetch GitHub organizations: HTTP {orgs_response.status_code}") 

1326 user_data["organizations"] = [] 

1327 except Exception as e: 

1328 logger.warning(f"Error fetching GitHub organizations: {e}") 

1329 user_data["organizations"] = [] 

1330 return 

1331 

1332 # Entra ID: extract groups/roles from id_token since userinfo doesn't include them. 

1333 # Microsoft's /oidc/userinfo endpoint only returns basic claims (sub, name, email, picture). 

1334 # Groups and roles are included in the id_token when configured in Azure Portal. 

1335 if provider.id == "entra" and verified_id_token_claims: 

1336 entra_groups_from_graph: Optional[List[str]] = None 

1337 # Detect group overage — when user has too many groups (>200), EntraID can return 

1338 # overage markers (e.g. _claim_names/_claim_sources, hasgroups, groups:srcN) 

1339 # instead of an inline groups array. 

1340 # See: https://learn.microsoft.com/en-us/entra/identity-platform/id-token-claims-reference 

1341 claim_names = verified_id_token_claims.get("_claim_names", {}) 

1342 has_groups_src_key = any(isinstance(key, str) and key.startswith("groups:src") for key in verified_id_token_claims) 

1343 groups_claim_value = verified_id_token_claims.get("groups") 

1344 has_group_overage = ( 

1345 (isinstance(claim_names, dict) and "groups" in claim_names) or bool(verified_id_token_claims.get("hasgroups")) or has_groups_src_key or isinstance(groups_claim_value, str) 

1346 ) 

1347 if has_group_overage: 

1348 user_email = user_data.get("email") or user_data.get("preferred_username") or "unknown" 

1349 logger.warning( 

1350 "Group overage detected for user %s - token contains too many groups (>200). Attempting Microsoft Graph fallback to resolve complete group membership.", 

1351 user_email, 

1352 ) 

1353 entra_groups_from_graph = await self._fetch_entra_groups_from_graph_api(access_token, user_email, provider.provider_metadata) 

1354 if entra_groups_from_graph is None: 

1355 logger.warning("Proceeding without Graph-resolved Entra groups for user %s", user_email) 

1356 

1357 # Extract groups from id_token (Security Groups as Object IDs) 

1358 if entra_groups_from_graph is not None: 

1359 user_data["groups"] = entra_groups_from_graph 

1360 elif "groups" in verified_id_token_claims: 

1361 user_data["groups"] = verified_id_token_claims["groups"] 

1362 logger.debug(f"Extracted {len(verified_id_token_claims['groups'])} groups from Entra ID token") 

1363 

1364 # Extract roles from id_token (App Roles) 

1365 if "roles" in verified_id_token_claims: 

1366 user_data["roles"] = verified_id_token_claims["roles"] 

1367 logger.debug(f"Extracted {len(verified_id_token_claims['roles'])} roles from Entra ID token") 

1368 

1369 # Also extract any missing basic claims from id_token 

1370 for claim in ["email", "name", "preferred_username", "oid", "sub"]: 

1371 if claim not in user_data and claim in verified_id_token_claims: 

1372 user_data[claim] = verified_id_token_claims[claim] 

1373 return 

1374 

1375 # Keycloak: merge realm_access, resource_access, and groups from id_token 

1376 if provider.id == "keycloak" and verified_id_token_claims: 

1377 for claim in ["realm_access", "resource_access", "groups"]: 

1378 if claim in verified_id_token_claims and claim not in user_data: 

1379 user_data[claim] = verified_id_token_claims[claim] 

1380 return 

1381 

1382 # Generic OIDC (including Okta, IBM Verify, and any custom provider): 

1383 # merge groups and roles claims from the verified id_token when the 

1384 # userinfo response does not already contain them. 

1385 if provider.id not in ("github", "google") and verified_id_token_claims: 

1386 metadata = provider.provider_metadata or {} 

1387 groups_claim = metadata.get("groups_claim", "groups") 

1388 for claim in {groups_claim, "roles"}: 

1389 if claim in verified_id_token_claims and claim not in user_data: 

1390 user_data[claim] = verified_id_token_claims[claim] 

1391 

1392 async def _get_user_info(self, provider: SSOProvider, access_token: str, token_data: Optional[Dict[str, Any]] = None, expected_nonce: Optional[str] = None) -> Optional[Dict[str, Any]]: 

1393 """Get user information from provider using access token. 

1394 

1395 Args: 

1396 provider: SSO provider configuration 

1397 access_token: OAuth access token 

1398 token_data: Optional full token response containing id_token for OIDC providers 

1399 expected_nonce: Nonce bound to the current auth session for OIDC id_token verification 

1400 

1401 Returns: 

1402 User info dict or None if failed 

1403 

1404 Raises: 

1405 SSOProviderConfigError: If the ADFS provider is missing the required id_token. 

1406 """ 

1407 # First-Party 

1408 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

1409 

1410 client = await get_http_client() 

1411 verified_id_token_claims: Optional[Dict[str, Any]] = None 

1412 if token_data and isinstance(token_data.get("_verified_id_token_claims"), dict): 

1413 verified_id_token_claims = token_data.get("_verified_id_token_claims") 

1414 elif token_data and token_data.get("_adfs_verification_attempted"): 

1415 # ADFS verification was already attempted upstream in handle_oauth_callback_with_tokens 

1416 # and failed (on-prem ADFS without JWKS). Skip redundant re-attempt. 

1417 pass 

1418 elif provider.provider_type == "oidc" and token_data and isinstance(token_data.get("id_token"), str): 

1419 if expected_nonce is None: 

1420 logger.warning("Skipping OIDC id_token fallback verification for provider %s because expected nonce is unavailable.", provider.id) 

1421 else: 

1422 verified_id_token_claims = await self._verify_oidc_id_token(provider, token_data["id_token"], expected_nonce=expected_nonce) 

1423 

1424 # ADFS does not support GET on the userinfo endpoint. 

1425 # Extract user info directly from the ID token instead. 

1426 # Prefer verified claims when available (ADFS with discoverable JWKS); 

1427 # fall back to unverified decode for on-prem ADFS without JWKS. 

1428 if provider.id == ADFS_PROVIDER_ID: 

1429 # Use verified claims if OIDC verification succeeded upstream 

1430 if verified_id_token_claims: 

1431 logger.debug("ADFS: using verified id_token claims (keys: %s)", list(verified_id_token_claims.keys())) 

1432 return self._normalize_user_info(provider, verified_id_token_claims) 

1433 

1434 # Fall back to unverified decode (on-prem ADFS without JWKS). 

1435 # The id_token was received server-to-server over TLS from the token 

1436 # endpoint, so we trust the transport. We still validate aud, iss, 

1437 # exp, and nonce as defense-in-depth against token confusion. 

1438 if token_data and isinstance(token_data.get("id_token"), str): 

1439 id_token_claims = self._decode_jwt_claims(token_data["id_token"]) 

1440 if not id_token_claims: 

1441 logger.error("Failed to decode ADFS ID token claims") 

1442 return None 

1443 

1444 # Validate audience — must match our client_id 

1445 token_aud = id_token_claims.get("aud") 

1446 if isinstance(token_aud, list): 

1447 aud_match = provider.client_id in token_aud 

1448 else: 

1449 aud_match = token_aud == provider.client_id 

1450 if not aud_match: 

1451 logger.error("ADFS id_token audience mismatch: expected %s, got %s", provider.client_id, token_aud) 

1452 return None 

1453 

1454 # Validate issuer — must match configured issuer 

1455 if provider.issuer and id_token_claims.get("iss") != provider.issuer: 

1456 logger.error("ADFS id_token issuer mismatch: expected %s, got %s", provider.issuer, id_token_claims.get("iss")) 

1457 return None 

1458 

1459 # Validate expiration — reject missing or non-numeric exp 

1460 # Standard 

1461 import time # pylint: disable=import-outside-toplevel 

1462 

1463 exp = id_token_claims.get("exp") 

1464 if not isinstance(exp, (int, float)): 

1465 logger.error("ADFS id_token missing or malformed exp claim: %r", exp) 

1466 return None 

1467 if exp < time.time(): 

1468 logger.error("ADFS id_token has expired (exp=%s)", exp) 

1469 return None 

1470 

1471 # Validate nonce — prevents replay attacks 

1472 if expected_nonce and id_token_claims.get("nonce") != expected_nonce: 

1473 logger.error("ADFS id_token nonce mismatch") 

1474 return None 

1475 

1476 logger.debug("ADFS: using decoded id_token claims with validated aud/iss/exp/nonce (keys: %s)", list(id_token_claims.keys())) 

1477 return self._normalize_user_info(provider, id_token_claims) 

1478 

1479 logger.error("ADFS provider requires id_token but none was provided in token_data") 

1480 raise SSOProviderConfigError("ADFS provider requires id_token in token response") 

1481 

1482 response = await client.get(provider.userinfo_url, headers={"Authorization": f"Bearer {access_token}"}) 

1483 

1484 if response.status_code == 200: 

1485 user_data = response.json() 

1486 await self._enrich_user_data_from_claims(provider, user_data, access_token, verified_id_token_claims) 

1487 return self._normalize_user_info(provider, user_data) 

1488 

1489 # Keycloak can issue tokens using the browser-facing issuer URL; if userinfo 

1490 # is called on a different host/port, Keycloak may reject the token with 401. 

1491 # Only fall back to id_token claims for 401 with split-host configuration. 

1492 # Other errors (403=revoked, 500=server error) must NOT fall back — the user 

1493 # should be denied access, not silently authenticated via stale id_token claims. 

1494 if provider.id == "keycloak" and verified_id_token_claims and response.status_code == 401: 

1495 metadata = provider.provider_metadata or {} 

1496 public_base_url = metadata.get("public_base_url") 

1497 if public_base_url and public_base_url != metadata.get("base_url"): 

1498 logger.warning( 

1499 "User info request returned 401 for keycloak with split-host config (public=%s, internal=%s). Falling back to id_token claims.", 

1500 public_base_url, 

1501 metadata.get("base_url"), 

1502 ) 

1503 return self._normalize_user_info(provider, verified_id_token_claims) 

1504 

1505 logger.error(f"User info request failed for {provider.name}: HTTP {response.status_code} - {response.text}") 

1506 

1507 return None 

1508 

1509 def _get_default_email_domain(self, provider: SSOProvider) -> Optional[str]: 

1510 """Get default email domain from provider metadata or global settings. 

1511 

1512 Args: 

1513 provider: SSO provider instance 

1514 

1515 Returns: 

1516 Default email domain string, or None if not configured 

1517 """ 

1518 metadata = provider.provider_metadata or {} 

1519 default_domain = metadata.get("default_email_domain") 

1520 if not default_domain: 

1521 default_domain = settings.sso_adfs_default_email_domain 

1522 return default_domain 

1523 

1524 def _normalize_adfs_email(self, raw_email: str, default_domain: Optional[str]) -> Optional[str]: 

1525 """Normalize ADFS email/UPN to standard email format. 

1526 

1527 ADFS may return UPN in various formats: 

1528 - user@domain.com (already valid email) 

1529 - DOMAIN\\username (Windows domain format) 

1530 - username (plain username without domain) 

1531 

1532 Args: 

1533 raw_email: Raw email/UPN string from ADFS claims 

1534 default_domain: Default email domain to append if needed 

1535 

1536 Returns: 

1537 Normalized email address, or None if normalization fails 

1538 """ 

1539 if not raw_email: 

1540 return None 

1541 

1542 raw_email_str = str(raw_email).strip() 

1543 

1544 # Already a valid email format 

1545 if "@" in raw_email_str and "." in raw_email_str.split("@")[-1]: 

1546 return raw_email_str 

1547 

1548 # Handle DOMAIN\username format 

1549 if "\\" in raw_email_str: 

1550 username_part = raw_email_str.split("\\")[-1] 

1551 if default_domain: 

1552 return f"{username_part}@{default_domain}" 

1553 logger.warning("ADFS UPN in DOMAIN\\username format but no default_email_domain configured") 

1554 return None 

1555 

1556 # Handle plain username without domain 

1557 if "@" not in raw_email_str: 

1558 if default_domain: 

1559 return f"{raw_email_str}@{default_domain}" 

1560 logger.warning("ADFS UPN is plain username but no default_email_domain configured") 

1561 return None 

1562 

1563 return raw_email_str 

1564 

1565 @staticmethod 

1566 def _extract_groups_and_roles(user_data: Dict[str, Any], groups_claim: str = "groups") -> list[str]: 

1567 """Extract groups and roles from user data into a unified list. 

1568 

1569 Args: 

1570 user_data: Raw user data from provider. 

1571 groups_claim: Claim key for groups (default: ``"groups"``). 

1572 

1573 Returns: 

1574 Combined list of group and role strings. 

1575 """ 

1576 groups: list[str] = [] 

1577 

1578 if groups_claim in user_data: 

1579 groups_value = user_data.get(groups_claim, []) 

1580 if isinstance(groups_value, list): 

1581 groups.extend(g for g in groups_value if isinstance(g, str)) 

1582 elif isinstance(groups_value, str): 

1583 groups.append(groups_value) 

1584 

1585 if "roles" in user_data: 

1586 roles_value = user_data.get("roles", []) 

1587 if isinstance(roles_value, list): 

1588 groups.extend(r for r in roles_value if isinstance(r, str)) 

1589 elif isinstance(roles_value, str): 

1590 groups.append(roles_value) 

1591 

1592 return groups 

1593 

1594 @staticmethod 

1595 def _build_normalized_user_info( 

1596 user_data: Dict[str, Any], 

1597 provider_name: str, 

1598 groups: list[str], 

1599 *, 

1600 email: Union[Optional[str], _Unset] = _UNSET, 

1601 full_name: Union[Optional[str], _Unset] = _UNSET, 

1602 avatar_url: Union[Optional[str], _Unset] = _UNSET, 

1603 provider_id: Union[Optional[Any], _Unset] = _UNSET, 

1604 username: Union[Optional[str], _Unset] = _UNSET, 

1605 extra: Optional[Dict[str, Any]] = None, 

1606 ) -> Dict[str, Any]: 

1607 """Build a normalized user-info dict with common fields. 

1608 

1609 Provider-specific branches call this with overrides only for fields 

1610 that deviate from the standard OIDC claim mapping. The 

1611 ``email_verified`` claim is propagated only when the IdP explicitly 

1612 includes it so that ``_is_email_verified_claim``'s 

1613 absent-means-pass-through logic applies correctly. 

1614 

1615 Pass ``None`` explicitly to force a field to ``None``; omit the 

1616 argument (default ``_UNSET``) to fall back to the standard claim. 

1617 

1618 Args: 

1619 user_data: Raw user data from the provider. 

1620 provider_name: Provider identifier for the ``"provider"`` field. 

1621 groups: Pre-extracted groups list (will be deduplicated). 

1622 email: Override for ``user_data["email"]``. 

1623 full_name: Override for ``user_data["name"]``. 

1624 avatar_url: Override for ``user_data["picture"]``. 

1625 provider_id: Override for ``user_data["sub"]``. 

1626 username: Override for the computed username. 

1627 extra: Additional keys merged into the result. 

1628 

1629 Returns: 

1630 Normalized user-info dict. 

1631 """ 

1632 normalized: Dict[str, Any] = { 

1633 "email": email if email is not _UNSET else user_data.get("email"), 

1634 "full_name": full_name if full_name is not _UNSET else user_data.get("name"), 

1635 "avatar_url": avatar_url if avatar_url is not _UNSET else user_data.get("picture"), 

1636 "provider_id": provider_id if provider_id is not _UNSET else user_data.get("sub"), 

1637 "username": username if username is not _UNSET else (user_data.get("preferred_username") or user_data.get("email", "").split("@")[0]), 

1638 "provider": provider_name, 

1639 "groups": list(set(groups)), 

1640 } 

1641 if "email_verified" in user_data: 

1642 normalized["email_verified"] = user_data["email_verified"] 

1643 if extra: 

1644 normalized.update(extra) 

1645 return normalized 

1646 

1647 def _normalize_user_info(self, provider: SSOProvider, user_data: Dict[str, Any]) -> Dict[str, Any]: 

1648 """Normalize user info from different providers to common format. 

1649 

1650 Args: 

1651 provider: SSO provider configuration 

1652 user_data: Raw user data from provider 

1653 

1654 Returns: 

1655 Normalized user info dict 

1656 """ 

1657 # Handle GitHub provider 

1658 if provider.id == "github": 

1659 normalized: Dict[str, Any] = { 

1660 "email": user_data.get("email"), 

1661 "full_name": user_data.get("name") or user_data.get("login"), 

1662 "avatar_url": user_data.get("avatar_url"), 

1663 "provider_id": user_data.get("id"), 

1664 "username": user_data.get("login"), 

1665 "provider": "github", 

1666 "organizations": user_data.get("organizations", []), 

1667 } 

1668 # GitHub /user responses do not include an email_verified claim. Preserve 

1669 # backward-compatible login behavior by only enforcing verification when 

1670 # a concrete claim value is present. 

1671 github_email_verified = user_data.get("email_verified") 

1672 if github_email_verified is None: 

1673 github_email_verified = user_data.get("verified") 

1674 if github_email_verified is not None: 

1675 normalized["email_verified"] = github_email_verified 

1676 return normalized 

1677 

1678 # Handle Google provider 

1679 if provider.id == "google": 

1680 return self._build_normalized_user_info( 

1681 user_data, 

1682 "google", 

1683 [], 

1684 username=user_data.get("email", "").split("@")[0], 

1685 ) 

1686 

1687 metadata = provider.provider_metadata or {} 

1688 groups_claim = metadata.get("groups_claim", "groups") 

1689 

1690 # Handle IBM Verify provider 

1691 if provider.id == "ibm_verify": 

1692 groups = self._extract_groups_and_roles(user_data, groups_claim) 

1693 return self._build_normalized_user_info(user_data, "ibm_verify", groups) 

1694 

1695 # Handle Okta provider 

1696 if provider.id == "okta": 

1697 groups = self._extract_groups_and_roles(user_data, groups_claim) 

1698 return self._build_normalized_user_info(user_data, "okta", groups) 

1699 

1700 # Handle Keycloak provider with role mapping 

1701 if provider.id == "keycloak": 

1702 username_claim = metadata.get("username_claim", "preferred_username") 

1703 email_claim = metadata.get("email_claim", "email") 

1704 

1705 groups: list[str] = [] 

1706 

1707 # Extract realm roles 

1708 if metadata.get("map_realm_roles"): 

1709 realm_access = user_data.get("realm_access", {}) 

1710 groups.extend(realm_access.get("roles", [])) 

1711 

1712 # Extract client roles 

1713 if metadata.get("map_client_roles"): 

1714 resource_access = user_data.get("resource_access", {}) 

1715 for client, access in resource_access.items(): 

1716 groups.extend(f"{client}:{role}" for role in access.get("roles", [])) 

1717 

1718 # Extract groups from custom claim 

1719 if groups_claim in user_data: 

1720 custom_groups = user_data.get(groups_claim, []) 

1721 if isinstance(custom_groups, list): 

1722 groups.extend(custom_groups) 

1723 

1724 return self._build_normalized_user_info( 

1725 user_data, 

1726 "keycloak", 

1727 groups, 

1728 email=user_data.get(email_claim), 

1729 username=user_data.get(username_claim) or user_data.get(email_claim, "").split("@")[0], 

1730 ) 

1731 

1732 # Handle Microsoft Entra ID provider with role mapping 

1733 if provider.id == "entra": 

1734 # Microsoft's userinfo endpoint often omits the email claim 

1735 # Fallback: preferred_username (UPN) or upn claim 

1736 email = user_data.get("email") or user_data.get("preferred_username") or user_data.get("upn") 

1737 username = user_data.get("preferred_username") or (email.split("@")[0] if email else None) 

1738 

1739 groups = self._extract_groups_and_roles(user_data, groups_claim) 

1740 return self._build_normalized_user_info( 

1741 user_data, 

1742 "entra", 

1743 groups, 

1744 email=email, 

1745 full_name=user_data.get("name") or email, 

1746 provider_id=user_data.get("sub") or user_data.get("oid"), 

1747 username=username, 

1748 ) 

1749 

1750 # Handle ADFS provider 

1751 if provider.id == ADFS_PROVIDER_ID: 

1752 # ADFS uses UPN (User Principal Name) as the primary identifier. 

1753 # Claim priority: email > preferred_username > upn > unique_name 

1754 raw_email = user_data.get("email") or user_data.get("preferred_username") or user_data.get("upn") or user_data.get("unique_name") 

1755 

1756 email = None 

1757 if raw_email: 

1758 default_domain = self._get_default_email_domain(provider) 

1759 email = self._normalize_adfs_email(str(raw_email), default_domain) 

1760 

1761 username = None 

1762 if email: 

1763 username = email.split("@")[0] 

1764 elif raw_email: 

1765 raw_str = str(raw_email).strip() 

1766 if "\\" in raw_str: 

1767 username = raw_str.split("\\")[-1] 

1768 elif "@" in raw_str: 

1769 username = raw_str.split("@")[0] 

1770 else: 

1771 username = raw_str 

1772 

1773 full_name = user_data.get("name") 

1774 if not full_name and user_data.get("given_name") and user_data.get("family_name"): 

1775 full_name = f"{user_data.get('given_name')} {user_data.get('family_name')}" 

1776 

1777 adfs_groups = self._extract_groups_and_roles(user_data, groups_claim) 

1778 

1779 return self._build_normalized_user_info( 

1780 user_data, 

1781 ADFS_PROVIDER_ID, 

1782 adfs_groups, 

1783 email=email, 

1784 full_name=full_name or email or username, 

1785 provider_id=user_data.get("sub") or user_data.get("oid") or email or username, 

1786 username=username or email, 

1787 extra={"email_verified": True}, # ADFS tokens are trusted after successful authentication 

1788 ) 

1789 

1790 # Generic OIDC format for all other providers. 

1791 groups = self._extract_groups_and_roles(user_data, groups_claim) 

1792 return self._build_normalized_user_info(user_data, provider.id, groups) 

1793 

1794 def _reset_pending_approval(self, pending: PendingUserApproval, incoming_provider: str, user_info: Dict[str, Any]) -> None: 

1795 """Reset a pending approval request to pending state with fresh metadata. 

1796 

1797 Args: 

1798 pending: Existing pending approval record. 

1799 incoming_provider: SSO provider identifier. 

1800 user_info: Normalized user info from SSO provider. 

1801 """ 

1802 pending.status = "pending" 

1803 pending.requested_at = utc_now() 

1804 pending.expires_at = utc_now() + timedelta(days=30) 

1805 pending.auth_provider = incoming_provider 

1806 pending.sso_metadata = user_info 

1807 pending.approved_by = None 

1808 pending.approved_at = None 

1809 pending.rejection_reason = None 

1810 pending.admin_notes = None 

1811 self.db.commit() 

1812 

1813 @staticmethod 

1814 def _should_sync_roles(provider_id: Optional[str], provider_metadata: Dict[str, Any]) -> bool: 

1815 """Determine whether RBAC role sync should run for a login. 

1816 

1817 Checks the provider-level ``sync_roles`` flag in ``provider_metadata`` 

1818 first, then falls back to the legacy Entra-specific setting. 

1819 

1820 Args: 

1821 provider_id: SSO provider identifier (e.g. ``"entra"``). 

1822 provider_metadata: Provider metadata dict from DB/config. 

1823 

1824 Returns: 

1825 True if role sync should proceed, False otherwise. 

1826 """ 

1827 if "sync_roles" in provider_metadata: 

1828 return bool(provider_metadata.get("sync_roles", True)) 

1829 if provider_id == "entra" and hasattr(settings, "sso_entra_sync_roles_on_login"): 

1830 return bool(settings.sso_entra_sync_roles_on_login) 

1831 return True 

1832 

1833 def _check_pending_approval(self, email: str, incoming_provider: str, user_info: Dict[str, Any]) -> bool: 

1834 """Check admin approval state for a new SSO user. 

1835 

1836 Returns ``True`` only when the user has an active, non-expired 

1837 ``"approved"`` record and may proceed to account creation. All 

1838 other states return ``False`` (caller should deny access). 

1839 

1840 Args: 

1841 email: Normalized user email. 

1842 incoming_provider: SSO provider identifier. 

1843 user_info: Normalized user info from SSO provider. 

1844 

1845 Returns: 

1846 True when user is approved for creation, False otherwise. 

1847 """ 

1848 pending = self.db.execute(select(PendingUserApproval).where(PendingUserApproval.email == email)).scalar_one_or_none() 

1849 

1850 if not pending: 

1851 pending = PendingUserApproval( 

1852 email=email, 

1853 full_name=user_info.get("full_name", email), 

1854 auth_provider=incoming_provider, 

1855 sso_metadata=user_info, 

1856 expires_at=utc_now() + timedelta(days=30), 

1857 ) 

1858 self.db.add(pending) 

1859 self.db.commit() 

1860 logger.info(f"Created pending approval request for SSO user: {SecurityValidator.sanitize_log_message(email)}") 

1861 return False 

1862 

1863 if pending.status == "pending": 

1864 if pending.is_expired(): 

1865 pending.status = "expired" 

1866 self.db.commit() 

1867 self._reset_pending_approval(pending, incoming_provider, user_info) 

1868 logger.info(f"Refreshed expired pending approval request for SSO user: {SecurityValidator.sanitize_log_message(email)}") 

1869 return False 

1870 

1871 if pending.status == "rejected": 

1872 return False 

1873 

1874 if pending.status == "approved": 

1875 if pending.is_expired(): 

1876 pending.status = "expired" 

1877 self.db.commit() 

1878 return False 

1879 return True 

1880 

1881 if pending.status == "expired": 

1882 self._reset_pending_approval(pending, incoming_provider, user_info) 

1883 logger.info(f"Renewed expired pending approval request for SSO user: {SecurityValidator.sanitize_log_message(email)}") 

1884 return False 

1885 

1886 if pending.status == "completed": 

1887 return False 

1888 

1889 logger.warning(f"Unknown SSO pending approval status '{pending.status}' for user {SecurityValidator.sanitize_log_message(email)}. Denying by default.") 

1890 return False 

1891 

1892 async def authenticate_or_create_user(self, user_info: Dict[str, Any]) -> Optional[str]: 

1893 """Authenticate existing user or create new user from SSO info. 

1894 

1895 Args: 

1896 user_info: Normalized user info from SSO provider 

1897 

1898 Returns: 

1899 JWT token for authenticated user or None if failed 

1900 """ 

1901 raw_email = user_info.get("email") 

1902 if not raw_email: 

1903 logger.warning("SSO authenticate_or_create_user: no email in user_info from provider '%s'. User cannot be authenticated without an email address.", user_info.get("provider", "unknown")) 

1904 return None 

1905 

1906 email = str(raw_email).strip().lower() 

1907 if not email: 

1908 logger.warning("SSO authenticate_or_create_user: email is empty after normalization from provider '%s'.", user_info.get("provider", "unknown")) 

1909 return None 

1910 

1911 if not self._is_email_verified_claim(user_info): 

1912 logger.warning( 

1913 "SSO authenticate_or_create_user: unverified email claim for provider '%s' and email '%s'.", 

1914 user_info.get("provider", "unknown"), 

1915 email, 

1916 ) 

1917 return None 

1918 

1919 incoming_provider = str(user_info.get("provider", "sso")).strip().lower() or "sso" 

1920 provider = self.get_provider(incoming_provider) 

1921 

1922 # Enforce trusted-domain policy consistently for both existing and new users. 

1923 trusted_domains = getattr(provider, "trusted_domains", None) if provider else None 

1924 if trusted_domains: 

1925 domain = email.split("@")[1].lower() if "@" in email else "" 

1926 normalized_trusted_domains = [d.lower() for d in trusted_domains if isinstance(d, str)] 

1927 if domain not in normalized_trusted_domains: 

1928 logger.warning( 

1929 "SSO authenticate_or_create_user: email domain '%s' is not allowed for provider '%s'.", 

1930 domain, 

1931 incoming_provider, 

1932 ) 

1933 return None 

1934 

1935 # Use stable local values for JWT payload generation to avoid lazy-loading 

1936 # expired ORM attributes after commit/flush boundaries. 

1937 resolved_email = email 

1938 resolved_full_name = user_info.get("full_name", email) 

1939 resolved_auth_provider = incoming_provider 

1940 resolved_is_admin = False 

1941 

1942 # Check if user exists 

1943 user = await self.auth_service.get_user_by_email(email) 

1944 

1945 if user: 

1946 current_full_name = user.full_name or resolved_full_name 

1947 current_auth_provider = str(user.auth_provider or resolved_auth_provider).strip().lower() 

1948 current_is_admin = bool(user.is_admin) 

1949 current_admin_origin = user.admin_origin 

1950 

1951 if user.auth_provider and current_auth_provider != incoming_provider: 

1952 logger.warning( 

1953 "SSO authenticate_or_create_user: account-linking required for email '%s' (existing provider='%s', incoming='%s').", 

1954 email, 

1955 current_auth_provider, 

1956 incoming_provider, 

1957 ) 

1958 return None 

1959 

1960 provider_id: Optional[str] = None 

1961 provider_metadata: Dict[str, Any] = {} 

1962 provider_ctx: Optional[Any] = None 

1963 if provider: 

1964 provider_id = provider.id 

1965 provider_metadata = provider.provider_metadata or {} 

1966 provider_ctx = SSOProviderContext(id=provider_id, provider_metadata=provider_metadata) 

1967 

1968 # Update user info from SSO 

1969 if user_info.get("full_name") and user_info["full_name"] != current_full_name: 

1970 user.full_name = user_info["full_name"] 

1971 current_full_name = user_info["full_name"] 

1972 

1973 # Initialize auth_provider for legacy accounts without provider metadata. 

1974 if not user.auth_provider: 

1975 user.auth_provider = incoming_provider 

1976 current_auth_provider = incoming_provider 

1977 

1978 # Persist verification status from provider claims. 

1979 user.email_verified = self._is_email_verified_claim(user_info) 

1980 user.last_login = utc_now() 

1981 

1982 # Synchronize is_admin status based on current group membership 

1983 # Track origin to support both promotion AND demotion for SSO-granted admins 

1984 # Manual/API grants are "sticky" - never auto-demoted by SSO 

1985 # Only users with admin_origin="sso" can be demoted on login 

1986 if provider_ctx: 

1987 should_be_admin = self._should_user_be_admin(email, user_info, provider_ctx) 

1988 if should_be_admin: 

1989 # Grant admin access 

1990 if not current_is_admin: 

1991 logger.info(f"Upgrading is_admin to True for {SecurityValidator.sanitize_log_message(email)} based on SSO admin groups") 

1992 user.is_admin = True 

1993 # Track that admin was granted via SSO (only set on initial grant) 

1994 user.admin_origin = "sso" 

1995 current_is_admin = True 

1996 # Do NOT change admin_origin if already admin - preserve manual/API grants 

1997 elif current_is_admin and current_admin_origin == "sso": 

1998 # User was SSO admin but no longer in admin groups - revoke access 

1999 logger.info(f"Revoking is_admin for {SecurityValidator.sanitize_log_message(email)} - removed from SSO admin groups") 

2000 user.is_admin = False 

2001 user.admin_origin = None 

2002 current_is_admin = False 

2003 

2004 self.db.commit() 

2005 

2006 if provider_ctx and self._should_sync_roles(provider_id, provider_metadata): 

2007 role_assignments = await self._map_groups_to_roles(email, user_info.get("groups", []), provider_ctx) 

2008 await self._sync_user_roles(email, role_assignments, provider_ctx) 

2009 await self._apply_team_mapping(email, user_info, provider) 

2010 

2011 user_email = getattr(user, "email", None) 

2012 if isinstance(user_email, str) and user_email.strip(): 

2013 resolved_email = user_email.strip().lower() 

2014 resolved_full_name = current_full_name 

2015 resolved_auth_provider = current_auth_provider 

2016 resolved_is_admin = current_is_admin 

2017 else: 

2018 # Auto-create user if enabled 

2019 if not provider or not provider.auto_create_users: 

2020 return None 

2021 

2022 provider_id = provider.id 

2023 provider_metadata = provider.provider_metadata or {} 

2024 provider_ctx = SSOProviderContext(id=provider_id, provider_metadata=provider_metadata) 

2025 

2026 # Check if admin approval is required 

2027 if settings.sso_require_admin_approval: 

2028 approval_result = self._check_pending_approval(email, incoming_provider, user_info) 

2029 if approval_result is not True: 

2030 return None # Blocked by approval workflow 

2031 

2032 # Create new user (either no approval required, or approval already granted) 

2033 # Generate a secure random password for SSO users (they won't use it) 

2034 

2035 random_password = "".join(secrets.choice(string.ascii_letters + string.digits + string.punctuation) for _ in range(32)) 

2036 

2037 # Determine if user should be admin based on domain/organization 

2038 is_admin = self._should_user_be_admin(email, user_info, provider_ctx) 

2039 

2040 user = await self.auth_service.create_user( 

2041 email=email, 

2042 password=random_password, # Random password for SSO users (not used) 

2043 full_name=user_info.get("full_name", email), 

2044 is_admin=is_admin, 

2045 auth_provider=incoming_provider, 

2046 ) 

2047 if not user: 

2048 return None 

2049 

2050 user_email = getattr(user, "email", None) 

2051 if isinstance(user_email, str) and user_email.strip(): 

2052 resolved_email = user_email.strip().lower() 

2053 resolved_full_name = user_info.get("full_name", email) 

2054 resolved_auth_provider = incoming_provider 

2055 resolved_is_admin = is_admin 

2056 

2057 if self._should_sync_roles(provider_id, provider_metadata): 

2058 role_assignments = await self._map_groups_to_roles(email, user_info.get("groups", []), provider_ctx) 

2059 if role_assignments: 

2060 await self._sync_user_roles(email, role_assignments, provider_ctx) 

2061 await self._apply_team_mapping(email, user_info, provider) 

2062 

2063 # If user was created from approved request, mark request as used 

2064 if settings.sso_require_admin_approval: 

2065 pending = self.db.execute(select(PendingUserApproval).where(and_(PendingUserApproval.email == email, PendingUserApproval.status == "approved"))).scalar_one_or_none() 

2066 if pending: 

2067 # Mark as used (we could delete or keep for audit trail) 

2068 pending.status = "completed" 

2069 self.db.commit() 

2070 

2071 # Generate JWT token for user — session token (teams resolved server-side) 

2072 token_data = { 

2073 "sub": resolved_email, 

2074 "email": resolved_email, 

2075 "full_name": resolved_full_name, 

2076 "auth_provider": resolved_auth_provider, 

2077 "iat": int(utc_now().timestamp()), 

2078 "user": { 

2079 "email": resolved_email, 

2080 "full_name": resolved_full_name, 

2081 "is_admin": resolved_is_admin, 

2082 "auth_provider": resolved_auth_provider, 

2083 }, 

2084 "token_use": "session", # nosec B105 - token type marker, not a password 

2085 # Scopes 

2086 "scopes": {"server_id": None, "permissions": ["*"] if resolved_is_admin else [], "ip_restrictions": [], "time_restrictions": {}}, 

2087 } 

2088 

2089 # Create JWT token 

2090 token = await create_jwt_token(token_data) 

2091 return token 

2092 

2093 def _should_user_be_admin(self, email: str, user_info: Dict[str, Any], provider: SSOProviderContext) -> bool: 

2094 """Determine if SSO user should be granted admin privileges. 

2095 

2096 Args: 

2097 email: User's email address 

2098 user_info: Normalized user info from SSO provider 

2099 provider: SSO provider configuration 

2100 

2101 Returns: 

2102 True if user should be admin, False otherwise 

2103 """ 

2104 # Validate email format — reject admin checks for invalid emails 

2105 if not email or "@" not in email: 

2106 logger.warning("Invalid email format for admin check: %r. Rejecting admin privilege.", email) 

2107 return False 

2108 

2109 # Check domain-based admin assignment 

2110 domain = email.split("@")[1].lower() 

2111 if domain in {d.lower() for d in settings.sso_auto_admin_domains}: 

2112 return True 

2113 

2114 # Check provider-specific admin assignment 

2115 if provider.id == "github" and settings.sso_github_admin_orgs: 

2116 github_admin_orgs = {o.lower() for o in settings.sso_github_admin_orgs} 

2117 github_orgs = user_info.get("organizations", []) 

2118 if any(org.lower() in github_admin_orgs for org in github_orgs): 

2119 return True 

2120 

2121 if provider.id == "google" and settings.sso_google_admin_domains: 

2122 if domain in {d.lower() for d in settings.sso_google_admin_domains}: 

2123 return True 

2124 

2125 # Check EntraID admin groups 

2126 if provider.id == "entra" and settings.sso_entra_admin_groups: 

2127 entra_admin_groups = {g.lower() for g in settings.sso_entra_admin_groups} 

2128 user_groups = user_info.get("groups", []) 

2129 if any(group.lower() in entra_admin_groups for group in user_groups): 

2130 return True 

2131 

2132 return False 

2133 

2134 async def _map_groups_to_roles(self, user_email: str, user_groups: List[str], provider: SSOProviderContext) -> List[Dict[str, Any]]: 

2135 """Map SSO groups to ContextForge RBAC roles. 

2136 

2137 Args: 

2138 user_email: User's email address 

2139 user_groups: List of groups from SSO provider 

2140 provider: SSO provider configuration 

2141 

2142 Returns: 

2143 List of role assignments: [{"role_name": str, "scope": str, "scope_id": Optional[str]}] 

2144 """ 

2145 # pylint: disable=import-outside-toplevel 

2146 # First-Party 

2147 from mcpgateway.services.role_service import RoleService 

2148 

2149 role_assignments = [] 

2150 

2151 # Generic Role Mapping Logic 

2152 metadata = provider.provider_metadata or {} 

2153 role_mappings = metadata.get("role_mappings", {}) 

2154 provider_default_role: Optional[str] = metadata.get("default_role") 

2155 resolve_team_scope_to_personal_team = bool(metadata.get("resolve_team_scope_to_personal_team", False)) 

2156 has_provider_default_role = isinstance(provider_default_role, str) and bool(provider_default_role.strip()) 

2157 

2158 # Merge with legacy Entra specific settings if applicable 

2159 has_entra_admin_groups = provider.id == "entra" and settings.sso_entra_admin_groups 

2160 

2161 if provider.id == "entra": 

2162 # Use generic role_mappings fallback to legacy setting 

2163 if not role_mappings and settings.sso_entra_role_mappings: 

2164 role_mappings = settings.sso_entra_role_mappings 

2165 # Legacy fallback for default role configuration 

2166 if not has_provider_default_role and settings.sso_entra_default_role: 

2167 provider_default_role = settings.sso_entra_default_role 

2168 has_provider_default_role = True 

2169 

2170 # Early exit: Skip role mapping if no configuration exists 

2171 if not role_mappings and not has_entra_admin_groups and not has_provider_default_role: 

2172 logger.debug(f"No role mappings configured for provider {provider.id}, skipping role sync") 

2173 return role_assignments 

2174 

2175 personal_team_id: Optional[str] = None 

2176 personal_team_checked = False 

2177 

2178 async def _resolve_team_scope_id_if_needed(role_scope: str) -> Optional[str]: 

2179 """Resolve team scope to personal-team id when provider mapping requires it. 

2180 

2181 Args: 

2182 role_scope: Role scope value from mapping metadata. 

2183 

2184 Returns: 

2185 Optional[str]: Personal team id when resolution succeeds, else None. 

2186 """ 

2187 nonlocal personal_team_id, personal_team_checked 

2188 

2189 if role_scope != "team" or not resolve_team_scope_to_personal_team: 

2190 return None 

2191 

2192 if personal_team_checked: 

2193 return personal_team_id 

2194 

2195 personal_team_checked = True 

2196 try: 

2197 # First-Party 

2198 from mcpgateway.services.personal_team_service import PersonalTeamService 

2199 

2200 personal_team = await PersonalTeamService(self.db).get_personal_team(user_email) 

2201 personal_team_id = personal_team.id if personal_team else None 

2202 if not personal_team_id: 

2203 logger.warning(f"Could not resolve personal team for {SecurityValidator.sanitize_log_message(user_email)}; skipping team-scoped SSO role mapping") 

2204 except Exception as e: 

2205 logger.error(f"Failed to resolve personal team for {SecurityValidator.sanitize_log_message(user_email)}: {e}. All team-scoped SSO role assignments will be skipped for this login.") 

2206 personal_team_id = None 

2207 

2208 return personal_team_id 

2209 

2210 # Handle EntraID admin groups -> admin role 

2211 if has_entra_admin_groups: 

2212 admin_groups_lower = [g.lower() for g in settings.sso_entra_admin_groups] 

2213 for group in user_groups: 

2214 if group.lower() in admin_groups_lower: 

2215 role_assignments.append({"role_name": settings.default_admin_role, "scope": "global", "scope_id": None}) 

2216 logger.debug(f"Mapped EntraID admin group to {settings.default_admin_role} role for {SecurityValidator.sanitize_log_message(user_email)}") 

2217 break # Only need one admin assignment 

2218 

2219 # Batch role lookups: collect all role names that need to be looked up 

2220 role_names_to_lookup = set() 

2221 for group in user_groups: 

2222 if group in role_mappings: 

2223 role_name = role_mappings[group] 

2224 if role_name not in ["admin", settings.default_admin_role]: 

2225 role_names_to_lookup.add(role_name) 

2226 

2227 # Add default role to lookup if needed 

2228 if has_provider_default_role and provider_default_role: 

2229 role_names_to_lookup.add(provider_default_role) 

2230 

2231 # Pre-fetch all roles by name in batches (reduces DB round-trips) 

2232 role_service = RoleService(self.db) 

2233 role_cache: Dict[str, Any] = {} 

2234 for role_name in role_names_to_lookup: 

2235 # Try team scope first, then global 

2236 role = await role_service.get_role_by_name(role_name, scope="team") 

2237 if not role: 

2238 role = await role_service.get_role_by_name(role_name, scope="global") 

2239 if role: 

2240 role_cache[role_name] = role 

2241 

2242 # Process role mappings for ALL providers 

2243 for group in user_groups: 

2244 if group in role_mappings: 

2245 role_name = role_mappings[group] 

2246 # Special case for "admin" shorthand or configured admin role name 

2247 if role_name in ["admin", settings.default_admin_role]: 

2248 role_assignments.append({"role_name": settings.default_admin_role, "scope": "global", "scope_id": None}) 

2249 logger.debug(f"Mapped group to {settings.default_admin_role} role for {SecurityValidator.sanitize_log_message(user_email)}") 

2250 continue 

2251 

2252 # Use pre-fetched role from cache 

2253 role = role_cache.get(role_name) 

2254 if role: 

2255 scope_id = await _resolve_team_scope_id_if_needed(role.scope) 

2256 if role.scope == "team" and resolve_team_scope_to_personal_team and not scope_id: 

2257 continue 

2258 # Avoid duplicate assignments 

2259 if not any(r["role_name"] == role.name and r["scope"] == role.scope and r.get("scope_id") == scope_id for r in role_assignments): 

2260 role_assignments.append({"role_name": role.name, "scope": role.scope, "scope_id": scope_id}) 

2261 logger.debug(f"Mapped group to role '{role.name}' for {SecurityValidator.sanitize_log_message(user_email)}") 

2262 else: 

2263 logger.warning(f"Role '{role_name}' not found for group mapping") 

2264 

2265 # Apply default role if no mappings found 

2266 if not role_assignments and has_provider_default_role and provider_default_role: 

2267 default_role = role_cache.get(provider_default_role) 

2268 if default_role: 

2269 scope_id = await _resolve_team_scope_id_if_needed(default_role.scope) 

2270 if default_role.scope == "team" and resolve_team_scope_to_personal_team and not scope_id: 

2271 return role_assignments 

2272 role_assignments.append({"role_name": default_role.name, "scope": default_role.scope, "scope_id": scope_id}) 

2273 logger.info(f"Assigned default role '{default_role.name}' to {SecurityValidator.sanitize_log_message(user_email)}") 

2274 

2275 return role_assignments 

2276 

2277 async def _sync_user_roles(self, user_email: str, role_assignments: List[Dict[str, Any]], _provider: SSOProviderContext) -> None: 

2278 """Synchronize user's SSO-based role assignments. 

2279 

2280 Args: 

2281 user_email: User's email address 

2282 role_assignments: List of role assignments to apply 

2283 _provider: SSO provider configuration (reserved for future use) 

2284 """ 

2285 # pylint: disable=import-outside-toplevel 

2286 # First-Party 

2287 from mcpgateway.services.role_service import RoleService 

2288 

2289 role_service = RoleService(self.db) 

2290 

2291 # Get current SSO-granted roles 

2292 current_roles = await role_service.list_user_roles(user_email, include_expired=False) 

2293 sso_roles = [r for r in current_roles if getattr(r, "grant_source", None) == "sso"] 

2294 

2295 # Build set of desired role assignments 

2296 desired_roles = {(r["role_name"], r["scope"], r.get("scope_id")) for r in role_assignments} 

2297 

2298 # Revoke roles that are no longer in the desired set 

2299 for user_role in sso_roles: 

2300 role_tuple = (user_role.role.name, user_role.scope, user_role.scope_id) 

2301 if role_tuple not in desired_roles: 

2302 await role_service.revoke_role_from_user(user_email=user_email, role_id=user_role.role_id, scope=user_role.scope, scope_id=user_role.scope_id) 

2303 logger.info(f"Revoked SSO role '{user_role.role.name}' from {SecurityValidator.sanitize_log_message(user_email)} (no longer in groups)") 

2304 

2305 # Assign new roles 

2306 for assignment in role_assignments: 

2307 try: 

2308 # Get role by name 

2309 role = await role_service.get_role_by_name(assignment["role_name"], scope=assignment["scope"]) 

2310 if not role: 

2311 logger.warning(f"Role '{assignment['role_name']}' not found, skipping assignment for {SecurityValidator.sanitize_log_message(user_email)}") 

2312 continue 

2313 

2314 # Check if assignment already exists 

2315 existing = await role_service.get_user_role_assignment(user_email=user_email, role_id=role.id, scope=assignment["scope"], scope_id=assignment.get("scope_id")) 

2316 

2317 if not existing or not existing.is_active: 

2318 # Assign role to user 

2319 await role_service.assign_role_to_user( 

2320 user_email=user_email, role_id=role.id, scope=assignment["scope"], scope_id=assignment.get("scope_id"), granted_by=user_email, grant_source="sso" 

2321 ) 

2322 logger.info(f"Assigned SSO role '{role.name}' to {SecurityValidator.sanitize_log_message(user_email)}") 

2323 

2324 except Exception as e: 

2325 logger.warning(f"Failed to assign role '{assignment['role_name']}' to {SecurityValidator.sanitize_log_message(user_email)}: {e}", exc_info=True) 

2326 try: 

2327 self.db.rollback() 

2328 except Exception as rollback_error: 

2329 logger.error( 

2330 f"Database rollback failed after role assignment error for {SecurityValidator.sanitize_log_message(user_email)}: {rollback_error}. Aborting remaining role assignments." 

2331 ) 

2332 break