Coverage for mcpgateway / services / sso_service.py: 98%
888 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/sso_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Single Sign-On (SSO) authentication service for OAuth2 and OIDC providers.
8Handles provider management, OAuth flows, and user authentication.
9"""
11# Future
12from __future__ import annotations
14# Standard
15import asyncio
16import base64
17from dataclasses import dataclass
18from datetime import timedelta
19import hashlib
20import hmac
21import logging
22import secrets
23import string
24from time import monotonic
25from typing import Any, Dict, List, Optional, Tuple
26import urllib.parse
28# Third-Party
29import jwt
30import orjson
31from sqlalchemy import and_, select
32from sqlalchemy.orm import Session
34# First-Party
35from mcpgateway.config import settings
36from mcpgateway.db import PendingUserApproval, SSOAuthSession, SSOProvider, utc_now
37from mcpgateway.services.email_auth_service import EmailAuthService
38from mcpgateway.services.encryption_service import get_encryption_service
39from mcpgateway.utils.create_jwt_token import create_jwt_token
41# Logger
42logger = logging.getLogger(__name__)
45@dataclass
46class SSOProviderContext:
47 """Lightweight context for SSO provider info passed to helper methods."""
49 id: Optional[str]
50 provider_metadata: Dict[str, Any]
53class SSOService:
54 """Service for managing SSO authentication flows and providers.
56 Handles OAuth2/OIDC authentication flows, provider configuration,
57 and integration with the local user system.
59 Examples:
60 Basic construction and helper checks:
61 >>> from unittest.mock import Mock
62 >>> service = SSOService(Mock())
63 >>> isinstance(service, SSOService)
64 True
65 >>> callable(service.list_enabled_providers)
66 True
67 """
69 _OIDC_METADATA_CACHE_TTL_SECONDS = 300
70 _oidc_config_cache: Dict[str, Tuple[float, Dict[str, Any]]] = {}
71 _jwks_client_cache: Dict[str, jwt.PyJWKClient] = {}
72 _STATE_BINDING_SEPARATOR = "."
73 _STATE_BINDING_HEX_LEN = 64
75 def __init__(self, db: Session):
76 """Initialize SSO service with database session.
78 Args:
79 db: SQLAlchemy database session
80 """
81 self.db = db
82 self.auth_service = EmailAuthService(db)
83 self._encryption = get_encryption_service(settings.auth_encryption_secret)
85 async def _encrypt_secret(self, secret: str) -> str:
86 """Encrypt a client secret for secure storage.
88 Args:
89 secret: Plain text client secret
91 Returns:
92 Encrypted secret string
93 """
94 return await self._encryption.encrypt_secret_async(secret)
96 async def _decrypt_secret(self, encrypted_secret: str) -> Optional[str]:
97 """Decrypt a client secret for use.
99 Args:
100 encrypted_secret: Encrypted secret string
102 Returns:
103 Plain text client secret
104 """
105 decrypted: str | None = await self._encryption.decrypt_secret_async(encrypted_secret)
106 if decrypted:
107 return decrypted
109 return None
111 def _decode_jwt_claims(self, token: str) -> Optional[Dict[str, Any]]:
112 """Decode JWT token payload without verification.
114 This is used to extract claims from ID tokens where we've already
115 validated the OAuth flow. The token signature is not verified here
116 because the token was received directly from the trusted token endpoint.
118 Args:
119 token: JWT token string
121 Returns:
122 Decoded payload dict or None if decoding fails
124 Examples:
125 >>> from unittest.mock import Mock
126 >>> service = SSOService(Mock())
127 >>> # Valid JWT structure (header.payload.signature)
128 >>> import base64
129 >>> payload = base64.urlsafe_b64encode(b'{"sub":"123","groups":["admin"]}').decode().rstrip('=')
130 >>> token = f"eyJhbGciOiJSUzI1NiJ9.{payload}.signature"
131 >>> claims = service._decode_jwt_claims(token)
132 >>> claims is not None
133 True
134 """
135 try:
136 # JWT format: header.payload.signature
137 parts = token.split(".")
138 if len(parts) != 3:
139 logger.warning("Invalid JWT format: expected 3 parts")
140 return None
142 # Decode payload (middle part) - add padding if needed
143 payload_b64 = parts[1]
144 # Add padding for base64 decoding
145 padding = 4 - len(payload_b64) % 4
146 if padding != 4:
147 payload_b64 += "=" * padding
149 payload_bytes = base64.urlsafe_b64decode(payload_b64)
150 return orjson.loads(payload_bytes)
152 except (ValueError, orjson.JSONDecodeError, UnicodeDecodeError) as e:
153 logger.warning(f"Failed to decode JWT claims: {e}")
154 return None
156 async def _get_oidc_provider_metadata(self, issuer: str) -> Optional[Dict[str, Any]]:
157 """Discover and cache OIDC provider metadata.
159 Args:
160 issuer: OIDC issuer URL.
162 Returns:
163 Provider metadata dict from discovery endpoint, or None on failure.
164 """
165 normalized_issuer = issuer.rstrip("/")
166 cached = self._oidc_config_cache.get(normalized_issuer)
167 if cached is not None:
168 cached_at, cached_metadata = cached
169 if monotonic() - cached_at < self._OIDC_METADATA_CACHE_TTL_SECONDS:
170 return cached_metadata
171 self._oidc_config_cache.pop(normalized_issuer, None)
173 # First-Party
174 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
176 discovery_url = f"{normalized_issuer}/.well-known/openid-configuration"
177 try:
178 client = await get_http_client()
179 response = await client.get(discovery_url, timeout=settings.oauth_request_timeout)
180 if response.status_code != 200:
181 logger.warning("OIDC discovery failed for issuer %s with HTTP %s", normalized_issuer, response.status_code)
182 return None
184 metadata = response.json()
185 if not isinstance(metadata, dict):
186 logger.warning("OIDC discovery response for issuer %s is not a JSON object", normalized_issuer)
187 return None
188 self._oidc_config_cache[normalized_issuer] = (monotonic(), metadata)
189 return metadata
190 except Exception as exc:
191 logger.warning("OIDC discovery request failed for issuer %s: %s", normalized_issuer, exc)
192 return None
194 async def _resolve_oidc_issuer_and_jwks(self, provider: SSOProvider) -> Tuple[Optional[str], Optional[str]]:
195 """Resolve issuer and JWKS URI for an OIDC provider.
197 Args:
198 provider: SSO provider configuration.
200 Returns:
201 Tuple of (issuer, jwks_uri), each optional when unavailable.
202 """
203 issuer = provider.issuer.strip() if isinstance(provider.issuer, str) and provider.issuer.strip() else None
204 jwks_uri = provider.jwks_uri.strip() if isinstance(provider.jwks_uri, str) and provider.jwks_uri.strip() else None
206 if issuer and not jwks_uri:
207 metadata = await self._get_oidc_provider_metadata(issuer)
208 if metadata:
209 discovered_jwks = metadata.get("jwks_uri")
210 discovered_issuer = metadata.get("issuer")
211 if isinstance(discovered_jwks, str) and discovered_jwks.strip():
212 jwks_uri = discovered_jwks.strip()
213 if isinstance(discovered_issuer, str) and discovered_issuer.strip():
214 issuer = discovered_issuer.strip()
216 return issuer, jwks_uri
218 def _get_jwks_client(self, jwks_uri: str) -> jwt.PyJWKClient:
219 """Get or create a cached PyJWKClient instance.
221 Args:
222 jwks_uri: JWKS endpoint URL.
224 Returns:
225 Cached or newly created `PyJWKClient`.
226 """
227 if jwks_uri not in self._jwks_client_cache:
228 self._jwks_client_cache[jwks_uri] = jwt.PyJWKClient(jwks_uri)
229 return self._jwks_client_cache[jwks_uri]
231 async def _verify_oidc_id_token(self, provider: SSOProvider, id_token: str, expected_nonce: Optional[str] = None) -> Optional[Dict[str, Any]]:
232 """Verify OIDC ID token signature and claims.
234 Args:
235 provider: SSO provider configuration.
236 id_token: Raw OIDC ID token string.
237 expected_nonce: Expected nonce from auth session, when available.
239 Returns:
240 Verified token claims when validation succeeds, otherwise None.
241 """
242 if provider.provider_type != "oidc":
243 return None
245 issuer, jwks_uri = await self._resolve_oidc_issuer_and_jwks(provider)
246 if not jwks_uri:
247 logger.warning("Skipping id_token claim usage for provider %s: missing jwks_uri", provider.id)
248 return None
250 try:
251 jwks_client = self._get_jwks_client(jwks_uri)
252 signing_key = await asyncio.to_thread(jwks_client.get_signing_key_from_jwt, id_token)
254 decode_kwargs: Dict[str, Any] = {
255 "key": signing_key.key,
256 "algorithms": ["RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512", "EdDSA"],
257 "audience": provider.client_id,
258 "options": {
259 "verify_signature": True,
260 "verify_exp": True,
261 "verify_iat": True,
262 "verify_aud": True,
263 "verify_iss": bool(issuer),
264 },
265 }
266 if issuer:
267 decode_kwargs["issuer"] = issuer
269 claims = await asyncio.to_thread(jwt.decode, id_token, **decode_kwargs)
270 if expected_nonce is not None and claims.get("nonce") != expected_nonce:
271 logger.warning("OIDC id_token nonce validation failed for provider %s", provider.id)
272 return None
273 return claims
274 except jwt.PyJWTError as exc:
275 logger.warning("OIDC id_token verification failed for provider %s: %s", provider.id, exc)
276 return None
277 except Exception as exc:
278 logger.warning("Unexpected OIDC id_token verification error for provider %s: %s", provider.id, exc)
279 return None
281 def _resolve_entra_graph_fallback_settings(self, provider_metadata: Optional[Dict[str, Any]]) -> Tuple[bool, int, int]:
282 """Resolve Entra Graph fallback settings with provider metadata override.
284 Args:
285 provider_metadata: Optional provider metadata from DB/config.
287 Returns:
288 Tuple of (enabled, timeout_seconds, max_groups).
289 """
290 metadata = provider_metadata or {}
291 enabled = settings.sso_entra_graph_api_enabled
292 timeout = settings.sso_entra_graph_api_timeout
293 max_groups = settings.sso_entra_graph_api_max_groups
295 if "graph_api_enabled" in metadata:
296 metadata_enabled = metadata.get("graph_api_enabled")
297 if isinstance(metadata_enabled, bool):
298 enabled = metadata_enabled
299 elif isinstance(metadata_enabled, str):
300 normalized_enabled = metadata_enabled.strip().lower()
301 if normalized_enabled in {"1", "true", "yes", "on"}:
302 enabled = True
303 elif normalized_enabled in {"0", "false", "no", "off"}:
304 enabled = False
305 else:
306 logger.warning("Invalid provider_metadata.graph_api_enabled=%s; using configured default %s", metadata_enabled, enabled)
307 elif isinstance(metadata_enabled, int):
308 enabled = metadata_enabled != 0
309 else:
310 logger.warning("Invalid provider_metadata.graph_api_enabled=%s (type %s); using configured default %s", metadata_enabled, type(metadata_enabled).__name__, enabled)
312 metadata_timeout = metadata.get("graph_api_timeout")
313 if metadata_timeout is not None:
314 try:
315 timeout_candidate = int(metadata_timeout)
316 if 1 <= timeout_candidate <= 120:
317 timeout = timeout_candidate
318 else:
319 logger.warning("Invalid provider_metadata.graph_api_timeout=%s; using configured default %s", metadata_timeout, timeout)
320 except (TypeError, ValueError):
321 logger.warning("Invalid provider_metadata.graph_api_timeout=%s; using configured default %s", metadata_timeout, timeout)
323 metadata_max_groups = metadata.get("graph_api_max_groups")
324 if metadata_max_groups is not None:
325 try:
326 max_groups_candidate = int(metadata_max_groups)
327 if max_groups_candidate >= 0:
328 max_groups = max_groups_candidate
329 else:
330 logger.warning("Invalid provider_metadata.graph_api_max_groups=%s; using configured default %s", metadata_max_groups, max_groups)
331 except (TypeError, ValueError):
332 logger.warning("Invalid provider_metadata.graph_api_max_groups=%s; using configured default %s", metadata_max_groups, max_groups)
334 return enabled, timeout, max_groups
336 async def _fetch_entra_groups_from_graph_api(self, access_token: str, user_email: str, provider_metadata: Optional[Dict[str, Any]] = None) -> Optional[List[str]]:
337 """Fetch Entra group object IDs from Microsoft Graph for overage tokens.
339 Args:
340 access_token: Delegated OAuth access token from Entra.
341 user_email: User identifier for structured logs.
342 provider_metadata: Optional provider metadata for runtime overrides.
344 Returns:
345 List of group IDs on success, empty list when Graph returns no IDs,
346 or None if retrieval failed/disabled.
347 """
348 graph_api_enabled, graph_api_timeout, graph_api_max_groups = self._resolve_entra_graph_fallback_settings(provider_metadata)
349 if not graph_api_enabled:
350 logger.info("Microsoft Graph fallback for Entra group overage is disabled")
351 return None
353 # First-Party
354 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
356 client = await get_http_client()
357 try:
358 response = await client.post(
359 "https://graph.microsoft.com/v1.0/me/getMemberObjects",
360 headers={"Authorization": f"Bearer {access_token}"},
361 json={"securityEnabledOnly": True},
362 timeout=graph_api_timeout,
363 )
364 except Exception as e:
365 logger.error("Failed to retrieve groups from Graph API for %s: %s", user_email, e)
366 return None
368 if response.status_code != 200:
369 if response.status_code in {401, 403}:
370 logger.error(
371 "Failed to retrieve groups from Graph API for %s: HTTP %s. Check Entra delegated permissions and consent (minimum User.Read).",
372 user_email,
373 response.status_code,
374 )
375 else:
376 logger.error("Failed to retrieve groups from Graph API for %s: HTTP %s", user_email, response.status_code)
377 return None
379 try:
380 groups_payload = response.json()
381 except ValueError as e:
382 logger.error("Failed to parse Graph API response for %s: %s", user_email, e)
383 return None
385 group_values = groups_payload.get("value", [])
386 if not isinstance(group_values, list):
387 logger.warning("Unexpected Graph API groups payload for %s: expected list in 'value'", user_email)
388 return []
390 deduped_groups: List[str] = []
391 seen_groups: set[str] = set()
392 for group_id in group_values:
393 if not isinstance(group_id, str):
394 continue
395 normalized_group_id = group_id.strip()
396 if not normalized_group_id or normalized_group_id in seen_groups:
397 continue
398 seen_groups.add(normalized_group_id)
399 deduped_groups.append(normalized_group_id)
401 if graph_api_max_groups > 0:
402 if len(deduped_groups) > graph_api_max_groups:
403 logger.warning(
404 "Graph API returned %d groups for %s; applying configured cap (%d)",
405 len(deduped_groups),
406 user_email,
407 graph_api_max_groups,
408 )
409 deduped_groups = deduped_groups[:graph_api_max_groups]
411 logger.info("Retrieved %d groups from Graph API for %s", len(deduped_groups), user_email)
412 return deduped_groups
414 def list_enabled_providers(self) -> List[SSOProvider]:
415 """Get list of enabled SSO providers.
417 Returns:
418 List of enabled SSO providers
420 Examples:
421 Returns empty list when DB has no providers:
422 >>> from unittest.mock import MagicMock
423 >>> service = SSOService(MagicMock())
424 >>> service.db.execute.return_value.scalars.return_value.all.return_value = []
425 >>> service.list_enabled_providers()
426 []
427 """
428 stmt = select(SSOProvider).where(SSOProvider.is_enabled.is_(True))
429 result = self.db.execute(stmt)
430 return list(result.scalars().all())
432 def get_provider(self, provider_id: str) -> Optional[SSOProvider]:
433 """Get SSO provider by ID.
435 Args:
436 provider_id: Provider identifier (e.g., 'github', 'google')
438 Returns:
439 SSO provider or None if not found
441 Examples:
442 >>> from unittest.mock import MagicMock
443 >>> service = SSOService(MagicMock())
444 >>> service.db.execute.return_value.scalar_one_or_none.return_value = None
445 >>> service.get_provider('x') is None
446 True
447 """
448 stmt = select(SSOProvider).where(SSOProvider.id == provider_id)
449 result = self.db.execute(stmt)
450 return result.scalar_one_or_none()
452 def get_provider_by_name(self, provider_name: str) -> Optional[SSOProvider]:
453 """Get SSO provider by name.
455 Args:
456 provider_name: Provider name (e.g., 'github', 'google')
458 Returns:
459 SSO provider or None if not found
461 Examples:
462 >>> from unittest.mock import MagicMock
463 >>> service = SSOService(MagicMock())
464 >>> service.db.execute.return_value.scalar_one_or_none.return_value = None
465 >>> service.get_provider_by_name('github') is None
466 True
467 """
468 stmt = select(SSOProvider).where(SSOProvider.name == provider_name)
469 result = self.db.execute(stmt)
470 return result.scalar_one_or_none()
472 @staticmethod
473 def _normalize_issuer_url(issuer: str) -> str:
474 """Normalize issuer URL for allowlist comparisons.
476 Args:
477 issuer: Raw issuer URL from provider config or metadata.
479 Returns:
480 Lowercased issuer URL without trailing slash.
481 """
482 return issuer.strip().rstrip("/").lower()
484 def _enforce_allowed_issuer(self, issuer: Optional[str]) -> None:
485 """Enforce configured issuer allowlist when present.
487 Args:
488 issuer: Candidate issuer URL from provider configuration.
490 Raises:
491 ValueError: If issuer is set but not in configured allowlist.
492 """
493 allowed_issuers = getattr(settings, "sso_issuers", None)
494 if not allowed_issuers:
495 return
497 if not isinstance(issuer, str) or not issuer.strip():
498 logger.warning("SSO provider has blank/empty issuer while SSO_ISSUERS allowlist is configured; issuer enforcement skipped.")
499 return
501 normalized_candidate = self._normalize_issuer_url(issuer)
502 normalized_allowlist = {self._normalize_issuer_url(str(allowed_issuer)) for allowed_issuer in allowed_issuers if isinstance(allowed_issuer, str) and allowed_issuer.strip()}
503 if normalized_allowlist and normalized_candidate not in normalized_allowlist:
504 raise ValueError("Issuer is not allowed by SSO_ISSUERS configuration")
506 @staticmethod
507 def _resolve_team_mapping_target(mapping_value: Any) -> Tuple[Optional[str], str]:
508 """Resolve team mapping value into team id and role.
510 Args:
511 mapping_value: Team mapping target value from provider config.
513 Returns:
514 Tuple of ``(team_id, role)`` where ``team_id`` may be ``None`` and
515 role defaults to ``member`` when not explicitly valid.
516 """
517 if isinstance(mapping_value, str) and mapping_value.strip():
518 return mapping_value.strip(), "member"
520 if isinstance(mapping_value, dict):
521 team_id_value = mapping_value.get("team_id") or mapping_value.get("id")
522 team_id = str(team_id_value).strip() if team_id_value is not None else ""
523 role_value = str(mapping_value.get("role", "member")).strip().lower()
524 role = role_value if role_value in {"owner", "member"} else "member"
525 return (team_id if team_id else None), role
527 return None, "member"
529 async def _apply_team_mapping(self, user_email: str, user_info: Dict[str, Any], provider: Optional[SSOProvider]) -> None:
530 """Apply provider team mappings based on SSO group claims.
532 Args:
533 user_email: Authenticated user email to map into teams.
534 user_info: Identity claims payload containing optional group claims.
535 provider: SSO provider configuration with ``team_mapping`` entries.
537 Returns:
538 None.
539 """
540 if not provider:
541 return
543 mapping = getattr(provider, "team_mapping", None)
544 if not isinstance(mapping, dict) or not mapping:
545 return
547 groups_raw = user_info.get("groups", [])
548 if isinstance(groups_raw, str):
549 groups = [groups_raw]
550 elif isinstance(groups_raw, list):
551 groups = [str(group).strip() for group in groups_raw if str(group).strip()]
552 else:
553 groups = []
555 if not groups:
556 return
558 normalized_groups = {group.lower() for group in groups}
560 # First-Party
561 from mcpgateway.services.team_management_service import MemberAlreadyExistsError, TeamManagementError, TeamManagementService # pylint: disable=import-outside-toplevel
563 team_service = TeamManagementService(self.db)
564 for source_group, target in mapping.items():
565 if not isinstance(source_group, str):
566 continue
567 source_group_normalized = source_group.strip().lower()
568 if not source_group_normalized or source_group_normalized not in normalized_groups:
569 continue
571 team_id, role = self._resolve_team_mapping_target(target)
572 if not team_id:
573 logger.warning("Skipping invalid SSO team_mapping entry for provider %s and group '%s'", provider.id, source_group)
574 continue
576 try:
577 await team_service.add_member_to_team(team_id=team_id, user_email=user_email, role=role, invited_by=user_email)
578 except MemberAlreadyExistsError:
579 logger.debug("SSO team_mapping: user %s already member of team %s", user_email, team_id)
580 except TeamManagementError as exc:
581 logger.warning("SSO team_mapping failed for user %s, group '%s', team '%s': %s", user_email, source_group, team_id, exc)
582 except Exception as exc:
583 logger.warning("Unexpected SSO team_mapping error for user %s and team '%s': %s", user_email, team_id, exc)
585 async def create_provider(self, provider_data: Dict[str, Any]) -> SSOProvider:
586 """Create new SSO provider configuration.
588 Args:
589 provider_data: Provider configuration data
591 Returns:
592 Created SSO provider
594 Examples:
595 >>> import asyncio
596 >>> from unittest.mock import MagicMock, AsyncMock
597 >>> service = SSOService(MagicMock())
598 >>> service._encrypt_secret = AsyncMock(side_effect=lambda s: 'ENC(' + s + ')')
599 >>> data = {
600 ... 'id': 'github', 'name': 'github', 'display_name': 'GitHub', 'provider_type': 'oauth2',
601 ... 'client_id': 'cid', 'client_secret': 'sec',
602 ... 'authorization_url': 'https://example/auth', 'token_url': 'https://example/token',
603 ... 'userinfo_url': 'https://example/user', 'scope': 'user:email'
604 ... }
605 >>> provider = asyncio.run(service.create_provider(data))
606 >>> hasattr(provider, 'id') and provider.id == 'github'
607 True
608 >>> provider.client_secret_encrypted.startswith('ENC(')
609 True
610 """
611 self._enforce_allowed_issuer(provider_data.get("issuer"))
613 # Encrypt client secret
614 client_secret = provider_data.pop("client_secret")
615 provider_data["client_secret_encrypted"] = await self._encrypt_secret(client_secret)
617 # Filter to valid SSOProvider columns to prevent TypeError on unknown keys
618 valid_columns = {c.key for c in SSOProvider.__table__.columns}
619 filtered_data = {k: v for k, v in provider_data.items() if k in valid_columns}
620 skipped = set(provider_data) - set(filtered_data)
621 if skipped:
622 logger.warning("Ignored unknown SSOProvider fields during creation: %s", skipped)
624 provider = SSOProvider(**filtered_data)
625 self.db.add(provider)
626 self.db.commit()
627 self.db.refresh(provider)
628 return provider
630 async def update_provider(self, provider_id: str, provider_data: Dict[str, Any]) -> Optional[SSOProvider]:
631 """Update existing SSO provider configuration.
633 Args:
634 provider_id: Provider identifier
635 provider_data: Updated provider data
637 Returns:
638 Updated SSO provider or None if not found
640 Examples:
641 >>> import asyncio
642 >>> from types import SimpleNamespace
643 >>> from unittest.mock import MagicMock, AsyncMock
644 >>> svc = SSOService(MagicMock())
645 >>> # Existing provider object
646 >>> existing = SimpleNamespace(id='github', name='github', client_id='old', client_secret_encrypted='X', is_enabled=True)
647 >>> svc.get_provider = lambda _id: existing
648 >>> svc._encrypt_secret = AsyncMock(side_effect=lambda s: 'ENC-' + s)
649 >>> svc.db.commit = lambda: None
650 >>> svc.db.refresh = lambda obj: None
651 >>> updated = asyncio.run(svc.update_provider('github', {'client_id': 'new', 'client_secret': 'sec'}))
652 >>> updated.client_id
653 'new'
654 >>> updated.client_secret_encrypted
655 'ENC-sec'
656 """
657 provider = self.get_provider(provider_id)
658 if not provider:
659 return None
661 if "issuer" in provider_data:
662 self._enforce_allowed_issuer(provider_data.get("issuer"))
664 # Handle client secret encryption if provided
665 if "client_secret" in provider_data:
666 client_secret = provider_data.pop("client_secret")
667 provider_data["client_secret_encrypted"] = await self._encrypt_secret(client_secret)
669 for key, value in provider_data.items():
670 if hasattr(provider, key):
671 setattr(provider, key, value)
673 provider.updated_at = utc_now()
674 self.db.commit()
675 self.db.refresh(provider)
676 return provider
678 def delete_provider(self, provider_id: str) -> bool:
679 """Delete SSO provider configuration.
681 Args:
682 provider_id: Provider identifier
684 Returns:
685 True if deleted, False if not found
687 Examples:
688 >>> from types import SimpleNamespace
689 >>> from unittest.mock import MagicMock
690 >>> svc = SSOService(MagicMock())
691 >>> svc.db.delete = lambda obj: None
692 >>> svc.db.commit = lambda: None
693 >>> svc.get_provider = lambda _id: SimpleNamespace(id='github')
694 >>> svc.delete_provider('github')
695 True
696 >>> svc.get_provider = lambda _id: None
697 >>> svc.delete_provider('missing')
698 False
699 """
700 provider = self.get_provider(provider_id)
701 if not provider:
702 return False
704 self.db.delete(provider)
705 self.db.commit()
706 return True
708 def generate_pkce_challenge(self) -> Tuple[str, str]:
709 """Generate PKCE code verifier and challenge for OAuth 2.1.
711 Returns:
712 Tuple of (code_verifier, code_challenge)
714 Examples:
715 Generate verifier and challenge:
716 >>> from unittest.mock import Mock
717 >>> service = SSOService(Mock())
718 >>> verifier, challenge = service.generate_pkce_challenge()
719 >>> isinstance(verifier, str) and isinstance(challenge, str)
720 True
721 >>> len(verifier) >= 43
722 True
723 >>> len(challenge) >= 43
724 True
725 """
726 # Generate cryptographically random code verifier
727 code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
729 # Generate code challenge using SHA256
730 code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()).decode("utf-8").rstrip("=")
732 return code_verifier, code_challenge
734 @staticmethod
735 def _normalize_scope_values(values: Optional[List[str] | str]) -> List[str]:
736 """Normalize scope values to a deduplicated, ordered list.
738 Args:
739 values: Scope input as list or space-delimited string.
741 Returns:
742 Ordered, unique scope values.
743 """
744 if values is None:
745 return []
747 raw_values: List[str]
748 if isinstance(values, str):
749 raw_values = values.split()
750 else:
751 raw_values = [str(v) for v in values if isinstance(v, str) and v.strip()]
753 normalized: List[str] = []
754 seen: set[str] = set()
755 for value in raw_values:
756 scope = value.strip()
757 if not scope or scope in seen:
758 continue
759 normalized.append(scope)
760 seen.add(scope)
761 return normalized
763 def _resolve_login_scopes(self, provider: SSOProvider, requested_scopes: Optional[List[str]]) -> List[str]:
764 """Resolve requested SSO scopes against provider allowlists.
766 Args:
767 provider: SSO provider configuration.
768 requested_scopes: Optional scopes requested by the client.
770 Returns:
771 Final scope list to send to the provider.
773 Raises:
774 ValueError: If provider scope configuration is invalid or request includes disallowed scopes.
775 """
776 configured_scopes = self._normalize_scope_values(getattr(provider, "scope", None))
777 if not configured_scopes:
778 raise ValueError("Provider has no configured scopes")
780 allowed_scopes = configured_scopes
781 provider_metadata = getattr(provider, "provider_metadata", {}) or {}
782 metadata_allowed_raw = provider_metadata.get("allowed_scopes")
783 metadata_allowed = self._normalize_scope_values(metadata_allowed_raw) if metadata_allowed_raw else []
784 if metadata_allowed:
785 metadata_allowed_set = set(metadata_allowed)
786 allowed_scopes = [scope for scope in configured_scopes if scope in metadata_allowed_set]
788 if not allowed_scopes:
789 raise ValueError("No allowed scopes configured for provider")
791 if not requested_scopes:
792 return allowed_scopes
794 normalized_requested = self._normalize_scope_values(requested_scopes)
795 if not normalized_requested:
796 return allowed_scopes
798 allowed_set = set(allowed_scopes)
799 invalid_scopes = [scope for scope in normalized_requested if scope not in allowed_set]
800 if invalid_scopes:
801 invalid_csv = ", ".join(invalid_scopes)
802 raise ValueError(f"Invalid scopes requested: {invalid_csv}")
804 return normalized_requested
806 def _get_state_binding_secret(self) -> bytes:
807 """Resolve secret bytes used for state/session binding HMAC.
809 Returns:
810 Secret bytes used for HMAC signatures.
811 """
812 secret_source = settings.auth_encryption_secret
813 if hasattr(secret_source, "get_secret_value"):
814 return secret_source.get_secret_value().encode("utf-8")
815 return str(secret_source).encode("utf-8")
817 def _generate_session_bound_state(self, provider_id: str, session_binding: str) -> str:
818 """Generate state value bound to browser session context.
820 Args:
821 provider_id: SSO provider identifier.
822 session_binding: Browser-session marker.
824 Returns:
825 Signed state token including nonce and HMAC signature.
826 """
827 state_nonce = secrets.token_urlsafe(24)
828 message = f"{provider_id}:{session_binding}:{state_nonce}".encode("utf-8")
829 signature = hmac.new(self._get_state_binding_secret(), message, hashlib.sha256).hexdigest()
830 return f"{state_nonce}{self._STATE_BINDING_SEPARATOR}{signature}"
832 def _is_session_bound_state(self, state: str) -> bool:
833 """Return whether state appears to carry a session-binding signature.
835 Args:
836 state: State value from OAuth flow.
838 Returns:
839 ``True`` when state has expected nonce/signature format.
840 """
841 if self._STATE_BINDING_SEPARATOR not in state:
842 return False
843 nonce, signature = state.rsplit(self._STATE_BINDING_SEPARATOR, 1)
844 return bool(nonce) and len(signature) == self._STATE_BINDING_HEX_LEN
846 def _verify_session_bound_state(self, provider_id: str, state: str, session_binding: str) -> bool:
847 """Verify state HMAC binding against the current browser session marker.
849 Args:
850 provider_id: SSO provider identifier.
851 state: State value to validate.
852 session_binding: Current browser-session marker.
854 Returns:
855 ``True`` when state signature matches the expected session binding.
856 """
857 if not session_binding or not self._is_session_bound_state(state):
858 return False
859 nonce, signature = state.rsplit(self._STATE_BINDING_SEPARATOR, 1)
860 message = f"{provider_id}:{session_binding}:{nonce}".encode("utf-8")
861 expected_signature = hmac.new(self._get_state_binding_secret(), message, hashlib.sha256).hexdigest()
862 return hmac.compare_digest(signature, expected_signature)
864 @staticmethod
865 def _is_email_verified_claim(user_info: Dict[str, Any]) -> bool:
866 """Evaluate email verification claim when provided by the IdP.
868 Args:
869 user_info: Normalized user-info payload from provider.
871 Returns:
872 ``True`` only when email verification claim is explicitly verified.
873 """
874 if "email_verified" not in user_info:
875 return False
877 claim_value = user_info.get("email_verified")
878 if isinstance(claim_value, bool):
879 return claim_value
880 if isinstance(claim_value, int):
881 return claim_value == 1
882 if isinstance(claim_value, str):
883 return claim_value.strip().lower() in {"1", "true", "yes", "on"}
884 return False
886 def get_authorization_url(
887 self,
888 provider_id: str,
889 redirect_uri: str,
890 scopes: Optional[List[str]] = None,
891 session_binding: Optional[str] = None,
892 ) -> Optional[str]:
893 """Generate OAuth authorization URL for provider.
895 Args:
896 provider_id: Provider identifier
897 redirect_uri: Callback URI after authorization
898 scopes: Optional custom scopes (uses provider default if None)
899 session_binding: Optional browser-session marker for state binding.
901 Returns:
902 Authorization URL or None if provider not found
904 Examples:
905 >>> from types import SimpleNamespace
906 >>> from unittest.mock import MagicMock
907 >>> service = SSOService(MagicMock())
908 >>> provider = SimpleNamespace(id='github', is_enabled=True, provider_type='oauth2', client_id='cid', authorization_url='https://example/auth', scope='user:email')
909 >>> service.get_provider = lambda _pid: provider
910 >>> service.db.add = lambda x: None
911 >>> service.db.commit = lambda: None
912 >>> url = service.get_authorization_url('github', 'https://app/callback', ['user:email'])
913 >>> isinstance(url, str) and 'client_id=cid' in url and 'state=' in url
914 True
916 Missing provider returns None:
917 >>> service.get_provider = lambda _pid: None
918 >>> service.get_authorization_url('missing', 'https://app/callback') is None
919 True
920 """
921 provider = self.get_provider(provider_id)
922 if not provider or not provider.is_enabled:
923 return None
925 # Generate PKCE parameters
926 code_verifier, code_challenge = self.generate_pkce_challenge()
928 resolved_scopes = self._resolve_login_scopes(provider, scopes)
930 # Generate CSRF state (session-bound when browser context is available).
931 state = self._generate_session_bound_state(provider_id, session_binding) if session_binding else secrets.token_urlsafe(32)
933 # Generate OIDC nonce if applicable
934 nonce = secrets.token_urlsafe(16) if provider.provider_type == "oidc" else None
936 # Create auth session
937 auth_session = SSOAuthSession(provider_id=provider_id, state=state, code_verifier=code_verifier, nonce=nonce, redirect_uri=redirect_uri)
938 self.db.add(auth_session)
939 self.db.commit()
941 # Build authorization URL
942 params = {
943 "client_id": provider.client_id,
944 "response_type": "code",
945 "redirect_uri": redirect_uri,
946 "state": state,
947 "scope": " ".join(resolved_scopes),
948 "code_challenge": code_challenge,
949 "code_challenge_method": "S256",
950 }
952 if nonce:
953 params["nonce"] = nonce
955 return f"{provider.authorization_url}?{urllib.parse.urlencode(params)}"
957 async def handle_oauth_callback(self, provider_id: str, code: str, state: str, session_binding: Optional[str] = None) -> Optional[Dict[str, Any]]:
958 """Handle OAuth callback and exchange code for tokens.
960 Args:
961 provider_id: Provider identifier
962 code: Authorization code from callback
963 state: CSRF state parameter
964 session_binding: Optional browser-session marker used to verify bound state.
966 Returns:
967 User info dict or None if authentication failed
969 Examples:
970 Happy-path with patched exchanges and user info:
971 >>> import asyncio
972 >>> from types import SimpleNamespace
973 >>> from unittest.mock import MagicMock
974 >>> svc = SSOService(MagicMock())
975 >>> # Mock DB auth session lookup
976 >>> provider = SimpleNamespace(id='github', is_enabled=True, provider_type='oauth2')
977 >>> auth_session = SimpleNamespace(provider_id='github', state='st', provider=provider, is_expired=False, nonce=None)
978 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = auth_session
979 >>> # Patch token exchange and user info retrieval
980 >>> async def _ex(p, sess, c):
981 ... return {'access_token': 'tok', 'id_token': 'id_tok'}
982 >>> async def _ui(p, access, token_data=None, expected_nonce=None):
983 ... return {'email': 'user@example.com'}
984 >>> svc._exchange_code_for_tokens = _ex
985 >>> svc._get_user_info = _ui
986 >>> svc.db.delete = lambda obj: None
987 >>> svc.db.commit = lambda: None
988 >>> out = asyncio.run(svc.handle_oauth_callback('github', 'code', 'st'))
989 >>> out['email']
990 'user@example.com'
992 Early return cases:
993 >>> # No session
994 >>> svc2 = SSOService(MagicMock())
995 >>> svc2.db.execute.return_value.scalar_one_or_none.return_value = None
996 >>> asyncio.run(svc2.handle_oauth_callback('github', 'c', 's')) is None
997 True
998 >>> # Expired session
999 >>> expired = SimpleNamespace(provider_id='github', state='st', provider=SimpleNamespace(is_enabled=True), is_expired=True)
1000 >>> svc3 = SSOService(MagicMock())
1001 >>> svc3.db.execute.return_value.scalar_one_or_none.return_value = expired
1002 >>> asyncio.run(svc3.handle_oauth_callback('github', 'c', 'st')) is None
1003 True
1004 >>> # Disabled provider
1005 >>> disabled = SimpleNamespace(provider_id='github', state='st', provider=SimpleNamespace(is_enabled=False), is_expired=False)
1006 >>> svc4 = SSOService(MagicMock())
1007 >>> svc4.db.execute.return_value.scalar_one_or_none.return_value = disabled
1008 >>> asyncio.run(svc4.handle_oauth_callback('github', 'c', 'st')) is None
1009 True
1010 """
1011 callback_result = await self.handle_oauth_callback_with_tokens(provider_id, code, state, session_binding=session_binding)
1012 if not callback_result:
1013 return None
1014 user_info, _token_data = callback_result
1015 return user_info
1017 async def handle_oauth_callback_with_tokens(
1018 self,
1019 provider_id: str,
1020 code: str,
1021 state: str,
1022 session_binding: Optional[str] = None,
1023 ) -> Optional[Tuple[Dict[str, Any], Dict[str, Any]]]:
1024 """Handle OAuth callback and return both user info and raw token response.
1026 Args:
1027 provider_id: Provider identifier
1028 code: Authorization code from callback
1029 state: CSRF state parameter
1030 session_binding: Optional browser-session marker used to verify bound state.
1032 Returns:
1033 Tuple of (user_info, token_data) or None if authentication fails
1034 """
1035 # Validate auth session
1036 stmt = select(SSOAuthSession).where(SSOAuthSession.state == state, SSOAuthSession.provider_id == provider_id)
1037 auth_session = self.db.execute(stmt).scalar_one_or_none()
1039 if not auth_session:
1040 logger.warning(f"OAuth callback: no auth session found for state/provider {provider_id}. Possible CSRF or replay.")
1041 return None
1043 if auth_session.is_expired:
1044 logger.warning(f"OAuth callback: auth session expired for provider {provider_id}.")
1045 self.db.delete(auth_session)
1046 self.db.commit()
1047 return None
1049 if self._is_session_bound_state(state):
1050 if not session_binding or not self._verify_session_bound_state(provider_id, state, session_binding):
1051 logger.warning(
1052 "OAuth callback: state/session mismatch for provider %s. Possible CSRF or cross-session replay.",
1053 provider_id,
1054 )
1055 return None
1057 provider = auth_session.provider
1058 if not provider:
1059 logger.error(f"OAuth callback: provider '{provider_id}' not found for auth session.")
1060 return None
1062 if not provider.is_enabled:
1063 logger.warning(f"OAuth callback: provider '{provider_id}' is disabled.")
1064 return None
1066 try:
1067 # Exchange authorization code for tokens
1068 logger.info(f"Starting token exchange for provider {provider_id}")
1069 token_data = await self._exchange_code_for_tokens(provider, auth_session, code)
1070 if not token_data:
1071 logger.error(f"Failed to exchange code for tokens for provider {provider_id}")
1072 return None
1073 logger.info(f"Token exchange successful for provider {provider_id}")
1074 callback_nonce = getattr(auth_session, "nonce", None)
1076 # For OIDC providers, verify id_token before any claim extraction.
1077 if provider.provider_type == "oidc":
1078 if callback_nonce is None:
1079 logger.error("OAuth callback: missing nonce for OIDC provider %s.", provider_id)
1080 return None
1081 id_token = token_data.get("id_token")
1082 if not isinstance(id_token, str):
1083 logger.error("OAuth callback: missing id_token for OIDC provider %s.", provider_id)
1084 return None
1085 verified_claims = await self._verify_oidc_id_token(provider, id_token, expected_nonce=callback_nonce)
1086 if verified_claims is None:
1087 logger.error(f"id_token verification failed for provider {provider_id}")
1088 return None
1089 token_data = dict(token_data)
1090 token_data["_verified_id_token_claims"] = verified_claims
1092 # Get user info from provider (pass full token_data for id_token parsing)
1093 user_info = await self._get_user_info(provider, token_data["access_token"], token_data, expected_nonce=callback_nonce)
1094 if not user_info:
1095 logger.error(f"Failed to get user info for provider {provider_id}")
1096 return None
1098 # Clean up auth session
1099 self.db.delete(auth_session)
1100 self.db.commit()
1102 return user_info, token_data
1104 except Exception as e:
1105 # Clean up auth session on error
1106 logger.error(f"OAuth callback failed for provider {provider_id}: {type(e).__name__}: {str(e)}")
1107 logger.exception("Full traceback for OAuth callback failure:")
1108 self.db.delete(auth_session)
1109 self.db.commit()
1110 return None
1112 async def _exchange_code_for_tokens(self, provider: SSOProvider, auth_session: SSOAuthSession, code: str) -> Optional[Dict[str, Any]]:
1113 """Exchange authorization code for access tokens.
1115 Args:
1116 provider: SSO provider configuration
1117 auth_session: Auth session with PKCE parameters
1118 code: Authorization code
1120 Returns:
1121 Token response dict or None if failed
1122 """
1123 token_params = {
1124 "client_id": provider.client_id,
1125 "client_secret": await self._decrypt_secret(provider.client_secret_encrypted),
1126 "code": code,
1127 "grant_type": "authorization_code",
1128 "redirect_uri": auth_session.redirect_uri,
1129 "code_verifier": auth_session.code_verifier,
1130 }
1132 # First-Party
1133 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
1135 client = await get_http_client()
1136 response = await client.post(provider.token_url, data=token_params, headers={"Accept": "application/json"})
1138 if response.status_code == 200:
1139 return response.json()
1140 logger.error(f"Token exchange failed for {provider.name}: HTTP {response.status_code} - {response.text}")
1142 return None
1144 async def _get_user_info(self, provider: SSOProvider, access_token: str, token_data: Optional[Dict[str, Any]] = None, expected_nonce: Optional[str] = None) -> Optional[Dict[str, Any]]:
1145 """Get user information from provider using access token.
1147 Args:
1148 provider: SSO provider configuration
1149 access_token: OAuth access token
1150 token_data: Optional full token response containing id_token for OIDC providers
1151 expected_nonce: Nonce bound to the current auth session for OIDC id_token verification
1153 Returns:
1154 User info dict or None if failed
1155 """
1156 # First-Party
1157 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
1159 client = await get_http_client()
1160 verified_id_token_claims: Optional[Dict[str, Any]] = None
1161 if token_data and isinstance(token_data.get("_verified_id_token_claims"), dict):
1162 verified_id_token_claims = token_data.get("_verified_id_token_claims")
1163 elif provider.provider_type == "oidc" and token_data and isinstance(token_data.get("id_token"), str):
1164 if expected_nonce is None:
1165 logger.warning("Skipping OIDC id_token fallback verification for provider %s because expected nonce is unavailable.", provider.id)
1166 else:
1167 verified_id_token_claims = await self._verify_oidc_id_token(provider, token_data["id_token"], expected_nonce=expected_nonce)
1169 keycloak_id_token_claims: Optional[Dict[str, Any]] = None
1170 if provider.id == "keycloak" and verified_id_token_claims:
1171 keycloak_id_token_claims = verified_id_token_claims
1173 response = await client.get(provider.userinfo_url, headers={"Authorization": f"Bearer {access_token}"})
1175 if response.status_code == 200:
1176 user_data = response.json()
1178 # For GitHub, also fetch organizations if admin assignment is configured
1179 if provider.id == "github" and settings.sso_github_admin_orgs:
1180 try:
1181 orgs_response = await client.get("https://api.github.com/user/orgs", headers={"Authorization": f"Bearer {access_token}"})
1182 if orgs_response.status_code == 200:
1183 orgs_data = orgs_response.json()
1184 user_data["organizations"] = [org["login"] for org in orgs_data]
1185 else:
1186 logger.warning(f"Failed to fetch GitHub organizations: HTTP {orgs_response.status_code}")
1187 user_data["organizations"] = []
1188 except Exception as e:
1189 logger.warning(f"Error fetching GitHub organizations: {e}")
1190 user_data["organizations"] = []
1192 # For Entra ID, extract groups/roles from id_token since userinfo doesn't include them
1193 # Microsoft's /oidc/userinfo endpoint only returns basic claims (sub, name, email, picture)
1194 # Groups and roles are included in the id_token when configured in Azure Portal
1195 if provider.id == "entra" and verified_id_token_claims:
1196 id_token_claims = verified_id_token_claims
1197 if id_token_claims:
1198 entra_groups_from_graph: Optional[List[str]] = None
1199 # Detect group overage - when user has too many groups (>200), EntraID can return
1200 # overage markers (e.g. _claim_names/_claim_sources, hasgroups, groups:srcN)
1201 # instead of an inline groups array.
1202 # See: https://learn.microsoft.com/en-us/entra/identity-platform/id-token-claims-reference
1203 claim_names = id_token_claims.get("_claim_names", {})
1204 has_groups_src_key = any(isinstance(key, str) and key.startswith("groups:src") for key in id_token_claims)
1205 groups_claim_value = id_token_claims.get("groups")
1206 has_group_overage = (
1207 (isinstance(claim_names, dict) and "groups" in claim_names) or bool(id_token_claims.get("hasgroups")) or has_groups_src_key or isinstance(groups_claim_value, str)
1208 )
1209 if has_group_overage:
1210 user_email = user_data.get("email") or user_data.get("preferred_username") or "unknown"
1211 logger.warning(
1212 "Group overage detected for user %s - token contains too many groups (>200). Attempting Microsoft Graph fallback to resolve complete group membership.",
1213 user_email,
1214 )
1215 entra_groups_from_graph = await self._fetch_entra_groups_from_graph_api(access_token, user_email, provider.provider_metadata)
1216 if entra_groups_from_graph is None:
1217 logger.warning("Proceeding without Graph-resolved Entra groups for user %s", user_email)
1219 # Extract groups from id_token (Security Groups as Object IDs)
1220 if entra_groups_from_graph is not None:
1221 user_data["groups"] = entra_groups_from_graph
1222 elif "groups" in id_token_claims:
1223 user_data["groups"] = id_token_claims["groups"]
1224 logger.debug(f"Extracted {len(id_token_claims['groups'])} groups from Entra ID token")
1226 # Extract roles from id_token (App Roles)
1227 if "roles" in id_token_claims:
1228 user_data["roles"] = id_token_claims["roles"]
1229 logger.debug(f"Extracted {len(id_token_claims['roles'])} roles from Entra ID token")
1231 # Also extract any missing basic claims from id_token
1232 for claim in ["email", "name", "preferred_username", "oid", "sub"]:
1233 if claim not in user_data and claim in id_token_claims:
1234 user_data[claim] = id_token_claims[claim]
1236 # For Keycloak, also extract groups/roles from id_token if available
1237 if provider.id == "keycloak" and keycloak_id_token_claims:
1238 # Keycloak includes realm_access, resource_access, and groups in id_token
1239 for claim in ["realm_access", "resource_access", "groups"]:
1240 if claim in keycloak_id_token_claims and claim not in user_data:
1241 user_data[claim] = keycloak_id_token_claims[claim]
1243 # Normalize user info across providers
1244 return self._normalize_user_info(provider, user_data)
1246 # Keycloak can issue tokens using the browser-facing issuer URL; if userinfo
1247 # is called on a different host/port, Keycloak may reject the token with 401.
1248 # Only fall back to id_token claims for 401 with split-host configuration.
1249 # Other errors (403=revoked, 500=server error) must NOT fall back — the user
1250 # should be denied access, not silently authenticated via stale id_token claims.
1251 if provider.id == "keycloak" and keycloak_id_token_claims and response.status_code == 401:
1252 metadata = provider.provider_metadata or {}
1253 public_base_url = metadata.get("public_base_url")
1254 if public_base_url and public_base_url != metadata.get("base_url"):
1255 logger.warning(
1256 "User info request returned 401 for keycloak with split-host config (public=%s, internal=%s). Falling back to id_token claims.",
1257 public_base_url,
1258 metadata.get("base_url"),
1259 )
1260 return self._normalize_user_info(provider, keycloak_id_token_claims)
1262 logger.error(f"User info request failed for {provider.name}: HTTP {response.status_code} - {response.text}")
1264 return None
1266 def _normalize_user_info(self, provider: SSOProvider, user_data: Dict[str, Any]) -> Dict[str, Any]:
1267 """Normalize user info from different providers to common format.
1269 Args:
1270 provider: SSO provider configuration
1271 user_data: Raw user data from provider
1273 Returns:
1274 Normalized user info dict
1275 """
1276 # Handle GitHub provider
1277 if provider.id == "github":
1278 normalized = {
1279 "email": user_data.get("email"),
1280 "full_name": user_data.get("name") or user_data.get("login"),
1281 "avatar_url": user_data.get("avatar_url"),
1282 "provider_id": user_data.get("id"),
1283 "username": user_data.get("login"),
1284 "provider": "github",
1285 "organizations": user_data.get("organizations", []),
1286 }
1287 # GitHub /user responses do not include an email_verified claim. Preserve
1288 # backward-compatible login behavior by only enforcing verification when
1289 # a concrete claim value is present.
1290 github_email_verified = user_data.get("email_verified")
1291 if github_email_verified is None:
1292 github_email_verified = user_data.get("verified")
1293 if github_email_verified is not None:
1294 normalized["email_verified"] = github_email_verified
1295 return normalized
1297 # Handle Google provider
1298 if provider.id == "google":
1299 return {
1300 "email": user_data.get("email"),
1301 "email_verified": user_data.get("email_verified"),
1302 "full_name": user_data.get("name"),
1303 "avatar_url": user_data.get("picture"),
1304 "provider_id": user_data.get("sub"),
1305 "username": user_data.get("email", "").split("@")[0],
1306 "provider": "google",
1307 }
1309 # Handle IBM Verify provider
1310 if provider.id == "ibm_verify":
1311 return {
1312 "email": user_data.get("email"),
1313 "email_verified": user_data.get("email_verified"),
1314 "full_name": user_data.get("name"),
1315 "avatar_url": user_data.get("picture"),
1316 "provider_id": user_data.get("sub"),
1317 "username": user_data.get("preferred_username") or user_data.get("email", "").split("@")[0],
1318 "provider": "ibm_verify",
1319 }
1321 # Handle Okta provider
1322 if provider.id == "okta":
1323 return {
1324 "email": user_data.get("email"),
1325 "email_verified": user_data.get("email_verified"),
1326 "full_name": user_data.get("name"),
1327 "avatar_url": user_data.get("picture"),
1328 "provider_id": user_data.get("sub"),
1329 "username": user_data.get("preferred_username") or user_data.get("email", "").split("@")[0],
1330 "provider": "okta",
1331 }
1333 # Handle Keycloak provider with role mapping
1334 if provider.id == "keycloak":
1335 metadata = provider.provider_metadata or {}
1336 username_claim = metadata.get("username_claim", "preferred_username")
1337 email_claim = metadata.get("email_claim", "email")
1338 groups_claim = metadata.get("groups_claim", "groups")
1340 groups = []
1342 # Extract realm roles
1343 if metadata.get("map_realm_roles"):
1344 realm_access = user_data.get("realm_access", {})
1345 realm_roles = realm_access.get("roles", [])
1346 groups.extend(realm_roles)
1348 # Extract client roles
1349 if metadata.get("map_client_roles"):
1350 resource_access = user_data.get("resource_access", {})
1351 for client, access in resource_access.items():
1352 client_roles = access.get("roles", [])
1353 # Prefix with client name to avoid conflicts
1354 groups.extend([f"{client}:{role}" for role in client_roles])
1356 # Extract groups from custom claim
1357 if groups_claim in user_data:
1358 custom_groups = user_data.get(groups_claim, [])
1359 if isinstance(custom_groups, list):
1360 groups.extend(custom_groups)
1362 return {
1363 "email": user_data.get(email_claim),
1364 "email_verified": user_data.get("email_verified"),
1365 "full_name": user_data.get("name"),
1366 "avatar_url": user_data.get("picture"),
1367 "provider_id": user_data.get("sub"),
1368 "username": user_data.get(username_claim) or user_data.get(email_claim, "").split("@")[0],
1369 "provider": "keycloak",
1370 "groups": list(set(groups)), # Deduplicate
1371 }
1373 # Handle Microsoft Entra ID provider with role mapping
1374 if provider.id == "entra":
1375 metadata = provider.provider_metadata or {}
1376 groups_claim = metadata.get("groups_claim", "groups")
1378 # Microsoft's userinfo endpoint often omits the email claim
1379 # Fallback: preferred_username (UPN) or upn claim
1380 email = user_data.get("email") or user_data.get("preferred_username") or user_data.get("upn")
1382 # Extract username from email/UPN
1383 username = None
1384 if user_data.get("preferred_username"):
1385 username = user_data.get("preferred_username")
1386 elif email:
1387 username = email.split("@")[0]
1389 # Extract groups from token
1390 groups = []
1392 # Check configured groups claim (default: 'groups')
1393 if groups_claim in user_data:
1394 groups_value = user_data.get(groups_claim, [])
1395 if isinstance(groups_value, list):
1396 groups.extend(groups_value)
1398 # Also check 'roles' claim for App Role assignments
1399 if "roles" in user_data:
1400 roles_value = user_data.get("roles", [])
1401 if isinstance(roles_value, list):
1402 groups.extend(roles_value)
1404 return {
1405 "email": email,
1406 "email_verified": user_data.get("email_verified"),
1407 "full_name": user_data.get("name") or email, # Fallback to email if name missing
1408 "avatar_url": user_data.get("picture"),
1409 "provider_id": user_data.get("sub") or user_data.get("oid"),
1410 "username": username,
1411 "provider": "entra",
1412 "groups": list(set(groups)), # Deduplicate
1413 }
1415 # Generic OIDC format for all other providers
1416 return {
1417 "email": user_data.get("email"),
1418 "email_verified": user_data.get("email_verified"),
1419 "full_name": user_data.get("name"),
1420 "avatar_url": user_data.get("picture"),
1421 "provider_id": user_data.get("sub"),
1422 "username": user_data.get("preferred_username") or user_data.get("email", "").split("@")[0],
1423 "provider": provider.id,
1424 }
1426 async def authenticate_or_create_user(self, user_info: Dict[str, Any]) -> Optional[str]:
1427 """Authenticate existing user or create new user from SSO info.
1429 Args:
1430 user_info: Normalized user info from SSO provider
1432 Returns:
1433 JWT token for authenticated user or None if failed
1434 """
1435 raw_email = user_info.get("email")
1436 if not raw_email:
1437 logger.warning("SSO authenticate_or_create_user: no email in user_info from provider '%s'. User cannot be authenticated without an email address.", user_info.get("provider", "unknown"))
1438 return None
1440 email = str(raw_email).strip().lower()
1441 if not email:
1442 logger.warning("SSO authenticate_or_create_user: email is empty after normalization from provider '%s'.", user_info.get("provider", "unknown"))
1443 return None
1445 if not self._is_email_verified_claim(user_info):
1446 logger.warning(
1447 "SSO authenticate_or_create_user: unverified email claim for provider '%s' and email '%s'.",
1448 user_info.get("provider", "unknown"),
1449 email,
1450 )
1451 return None
1453 incoming_provider = str(user_info.get("provider", "sso")).strip().lower() or "sso"
1454 provider = self.get_provider(incoming_provider)
1456 # Enforce trusted-domain policy consistently for both existing and new users.
1457 trusted_domains = getattr(provider, "trusted_domains", None) if provider else None
1458 if trusted_domains:
1459 domain = email.split("@")[1].lower() if "@" in email else ""
1460 normalized_trusted_domains = [d.lower() for d in trusted_domains if isinstance(d, str)]
1461 if domain not in normalized_trusted_domains:
1462 logger.warning(
1463 "SSO authenticate_or_create_user: email domain '%s' is not allowed for provider '%s'.",
1464 domain,
1465 incoming_provider,
1466 )
1467 return None
1469 # Use stable local values for JWT payload generation to avoid lazy-loading
1470 # expired ORM attributes after commit/flush boundaries.
1471 resolved_email = email
1472 resolved_full_name = user_info.get("full_name", email)
1473 resolved_auth_provider = incoming_provider
1474 resolved_is_admin = False
1476 # Check if user exists
1477 user = await self.auth_service.get_user_by_email(email)
1479 if user:
1480 current_full_name = user.full_name or resolved_full_name
1481 current_auth_provider = str(user.auth_provider or resolved_auth_provider).strip().lower()
1482 current_is_admin = bool(user.is_admin)
1483 current_admin_origin = user.admin_origin
1485 if user.auth_provider and current_auth_provider != incoming_provider:
1486 logger.warning(
1487 "SSO authenticate_or_create_user: account-linking required for email '%s' (existing provider='%s', incoming='%s').",
1488 email,
1489 current_auth_provider,
1490 incoming_provider,
1491 )
1492 return None
1494 provider_id: Optional[str] = None
1495 provider_metadata: Dict[str, Any] = {}
1496 provider_ctx: Optional[Any] = None
1497 if provider:
1498 provider_id = provider.id
1499 provider_metadata = provider.provider_metadata or {}
1500 provider_ctx = SSOProviderContext(id=provider_id, provider_metadata=provider_metadata)
1502 # Update user info from SSO
1503 if user_info.get("full_name") and user_info["full_name"] != current_full_name:
1504 user.full_name = user_info["full_name"]
1505 current_full_name = user_info["full_name"]
1507 # Initialize auth_provider for legacy accounts without provider metadata.
1508 if not user.auth_provider:
1509 user.auth_provider = incoming_provider
1510 current_auth_provider = incoming_provider
1512 # Persist verification status from provider claims.
1513 user.email_verified = self._is_email_verified_claim(user_info)
1514 user.last_login = utc_now()
1516 # Synchronize is_admin status based on current group membership
1517 # Track origin to support both promotion AND demotion for SSO-granted admins
1518 # Manual/API grants are "sticky" - never auto-demoted by SSO
1519 # Only users with admin_origin="sso" can be demoted on login
1520 if provider_ctx:
1521 should_be_admin = self._should_user_be_admin(email, user_info, provider_ctx)
1522 if should_be_admin:
1523 # Grant admin access
1524 if not current_is_admin:
1525 logger.info(f"Upgrading is_admin to True for {email} based on SSO admin groups")
1526 user.is_admin = True
1527 # Track that admin was granted via SSO (only set on initial grant)
1528 user.admin_origin = "sso"
1529 current_is_admin = True
1530 # Do NOT change admin_origin if already admin - preserve manual/API grants
1531 elif current_is_admin and current_admin_origin == "sso":
1532 # User was SSO admin but no longer in admin groups - revoke access
1533 logger.info(f"Revoking is_admin for {email} - removed from SSO admin groups")
1534 user.is_admin = False
1535 user.admin_origin = None
1536 current_is_admin = False
1538 self.db.commit()
1540 # Determine if syncing should happen (default True, respect provider-level and Entra setting)
1541 should_sync = True
1542 if provider_ctx:
1543 # Check provider-level sync_roles flag in provider_metadata (allows disabling per-provider)
1544 if "sync_roles" in provider_metadata:
1545 should_sync = provider_metadata.get("sync_roles", True)
1546 # Legacy Entra-specific setting (fallback for backwards compatibility)
1547 elif provider_id == "entra" and hasattr(settings, "sso_entra_sync_roles_on_login"):
1548 should_sync = settings.sso_entra_sync_roles_on_login
1550 if provider_ctx and should_sync:
1551 role_assignments = await self._map_groups_to_roles(email, user_info.get("groups", []), provider_ctx)
1552 await self._sync_user_roles(email, role_assignments, provider_ctx)
1553 await self._apply_team_mapping(email, user_info, provider)
1555 user_email = getattr(user, "email", None)
1556 if isinstance(user_email, str) and user_email.strip():
1557 resolved_email = user_email.strip().lower()
1558 resolved_full_name = current_full_name
1559 resolved_auth_provider = current_auth_provider
1560 resolved_is_admin = current_is_admin
1561 else:
1562 # Auto-create user if enabled
1563 if not provider or not provider.auto_create_users:
1564 return None
1566 provider_id = provider.id
1567 provider_metadata = provider.provider_metadata or {}
1568 provider_ctx = SSOProviderContext(id=provider_id, provider_metadata=provider_metadata)
1570 # Check if admin approval is required
1571 if settings.sso_require_admin_approval:
1572 # Check if user is already pending approval
1574 pending = self.db.execute(select(PendingUserApproval).where(PendingUserApproval.email == email)).scalar_one_or_none()
1576 if pending:
1577 if pending.status == "pending":
1578 if pending.is_expired():
1579 pending.status = "expired"
1580 self.db.commit()
1582 pending.status = "pending"
1583 pending.requested_at = utc_now()
1584 pending.expires_at = utc_now() + timedelta(days=30)
1585 pending.auth_provider = incoming_provider
1586 pending.sso_metadata = user_info
1587 pending.approved_by = None
1588 pending.approved_at = None
1589 pending.rejection_reason = None
1590 pending.admin_notes = None
1591 self.db.commit()
1592 logger.info(f"Refreshed expired pending approval request for SSO user: {email}")
1593 return None
1594 return None # Still waiting for approval
1595 if pending.status == "rejected":
1596 return None # User was rejected
1597 if pending.status == "approved":
1598 if pending.is_expired():
1599 pending.status = "expired"
1600 self.db.commit()
1601 return None
1602 elif pending.status == "expired":
1603 pending.status = "pending"
1604 pending.requested_at = utc_now()
1605 pending.expires_at = utc_now() + timedelta(days=30)
1606 pending.auth_provider = incoming_provider
1607 pending.sso_metadata = user_info
1608 pending.approved_by = None
1609 pending.approved_at = None
1610 pending.rejection_reason = None
1611 pending.admin_notes = None
1612 self.db.commit()
1613 logger.info(f"Renewed expired pending approval request for SSO user: {email}")
1614 return None
1615 elif pending.status in {"completed"}:
1616 return None
1617 elif pending.status != "approved":
1618 logger.warning(f"Unknown SSO pending approval status '{pending.status}' for user {email}. Denying by default.")
1619 return None
1620 else:
1621 # Create pending approval request
1622 pending = PendingUserApproval(
1623 email=email,
1624 full_name=user_info.get("full_name", email),
1625 auth_provider=incoming_provider,
1626 sso_metadata=user_info,
1627 expires_at=utc_now() + timedelta(days=30), # 30-day approval window
1628 )
1629 self.db.add(pending)
1630 self.db.commit()
1631 logger.info(f"Created pending approval request for SSO user: {email}")
1632 return None # No token until approved
1634 # Create new user (either no approval required, or approval already granted)
1635 # Generate a secure random password for SSO users (they won't use it)
1637 random_password = "".join(secrets.choice(string.ascii_letters + string.digits + string.punctuation) for _ in range(32))
1639 # Determine if user should be admin based on domain/organization
1640 is_admin = self._should_user_be_admin(email, user_info, provider_ctx)
1642 user = await self.auth_service.create_user(
1643 email=email,
1644 password=random_password, # Random password for SSO users (not used)
1645 full_name=user_info.get("full_name", email),
1646 is_admin=is_admin,
1647 auth_provider=incoming_provider,
1648 )
1649 if not user:
1650 return None
1652 user_email = getattr(user, "email", None)
1653 if isinstance(user_email, str) and user_email.strip():
1654 resolved_email = user_email.strip().lower()
1655 resolved_full_name = user_info.get("full_name", email)
1656 resolved_auth_provider = incoming_provider
1657 resolved_is_admin = is_admin
1659 # Assign RBAC roles based on SSO groups (or default role if no groups)
1660 # Check provider-level sync_roles flag in provider_metadata
1661 should_sync = provider_metadata.get("sync_roles", True)
1662 # Legacy Entra-specific setting (fallback for backwards compatibility)
1663 if "sync_roles" not in provider_metadata and provider_id == "entra" and hasattr(settings, "sso_entra_sync_roles_on_login"):
1664 should_sync = settings.sso_entra_sync_roles_on_login
1666 if should_sync:
1667 role_assignments = await self._map_groups_to_roles(email, user_info.get("groups", []), provider_ctx)
1668 if role_assignments:
1669 await self._sync_user_roles(email, role_assignments, provider_ctx)
1670 await self._apply_team_mapping(email, user_info, provider)
1672 # If user was created from approved request, mark request as used
1673 if settings.sso_require_admin_approval:
1674 pending = self.db.execute(select(PendingUserApproval).where(and_(PendingUserApproval.email == email, PendingUserApproval.status == "approved"))).scalar_one_or_none()
1675 if pending:
1676 # Mark as used (we could delete or keep for audit trail)
1677 pending.status = "completed"
1678 self.db.commit()
1680 # Generate JWT token for user — session token (teams resolved server-side)
1681 token_data = {
1682 "sub": resolved_email,
1683 "email": resolved_email,
1684 "full_name": resolved_full_name,
1685 "auth_provider": resolved_auth_provider,
1686 "iat": int(utc_now().timestamp()),
1687 "user": {
1688 "email": resolved_email,
1689 "full_name": resolved_full_name,
1690 "is_admin": resolved_is_admin,
1691 "auth_provider": resolved_auth_provider,
1692 },
1693 "token_use": "session", # nosec B105 - token type marker, not a password
1694 # Scopes
1695 "scopes": {"server_id": None, "permissions": ["*"] if resolved_is_admin else [], "ip_restrictions": [], "time_restrictions": {}},
1696 }
1698 # Create JWT token
1699 token = await create_jwt_token(token_data)
1700 return token
1702 def _should_user_be_admin(self, email: str, user_info: Dict[str, Any], provider: SSOProviderContext) -> bool:
1703 """Determine if SSO user should be granted admin privileges.
1705 Args:
1706 email: User's email address
1707 user_info: Normalized user info from SSO provider
1708 provider: SSO provider configuration
1710 Returns:
1711 True if user should be admin, False otherwise
1712 """
1713 # Check domain-based admin assignment
1714 domain = email.split("@")[1].lower()
1715 if domain in [d.lower() for d in settings.sso_auto_admin_domains]:
1716 return True
1718 # Check provider-specific admin assignment
1719 if provider.id == "github" and settings.sso_github_admin_orgs:
1720 # For GitHub, we'd need to fetch user's organizations
1721 # This is a placeholder - in production, you'd make API calls to get orgs
1722 github_orgs = user_info.get("organizations", [])
1723 if any(org.lower() in [o.lower() for o in settings.sso_github_admin_orgs] for org in github_orgs):
1724 return True
1726 if provider.id == "google" and settings.sso_google_admin_domains:
1727 # Check if user's domain is in admin domains
1728 if domain in [d.lower() for d in settings.sso_google_admin_domains]:
1729 return True
1731 # Check EntraID admin groups
1732 if provider.id == "entra" and settings.sso_entra_admin_groups:
1733 user_groups = user_info.get("groups", [])
1734 if any(group.lower() in [g.lower() for g in settings.sso_entra_admin_groups] for group in user_groups):
1735 return True
1737 return False
1739 async def _map_groups_to_roles(self, user_email: str, user_groups: List[str], provider: SSOProviderContext) -> List[Dict[str, Any]]:
1740 """Map SSO groups to ContextForge RBAC roles.
1742 Args:
1743 user_email: User's email address
1744 user_groups: List of groups from SSO provider
1745 provider: SSO provider configuration
1747 Returns:
1748 List of role assignments: [{"role_name": str, "scope": str, "scope_id": Optional[str]}]
1749 """
1750 # pylint: disable=import-outside-toplevel
1751 # First-Party
1752 from mcpgateway.services.role_service import RoleService
1754 role_assignments = []
1756 # Generic Role Mapping Logic
1757 metadata = provider.provider_metadata or {}
1758 role_mappings = metadata.get("role_mappings", {})
1759 provider_default_role: Optional[str] = metadata.get("default_role")
1760 resolve_team_scope_to_personal_team = bool(metadata.get("resolve_team_scope_to_personal_team", False))
1761 has_provider_default_role = isinstance(provider_default_role, str) and bool(provider_default_role.strip())
1763 # Merge with legacy Entra specific settings if applicable
1764 has_entra_admin_groups = provider.id == "entra" and settings.sso_entra_admin_groups
1766 if provider.id == "entra":
1767 # Use generic role_mappings fallback to legacy setting
1768 if not role_mappings and settings.sso_entra_role_mappings:
1769 role_mappings = settings.sso_entra_role_mappings
1770 # Legacy fallback for default role configuration
1771 if not has_provider_default_role and settings.sso_entra_default_role:
1772 provider_default_role = settings.sso_entra_default_role
1773 has_provider_default_role = True
1775 # Early exit: Skip role mapping if no configuration exists
1776 if not role_mappings and not has_entra_admin_groups and not has_provider_default_role:
1777 logger.debug(f"No role mappings configured for provider {provider.id}, skipping role sync")
1778 return role_assignments
1780 personal_team_id: Optional[str] = None
1781 personal_team_checked = False
1783 async def _resolve_team_scope_id_if_needed(role_scope: str) -> Optional[str]:
1784 """Resolve team scope to personal-team id when provider mapping requires it.
1786 Args:
1787 role_scope: Role scope value from mapping metadata.
1789 Returns:
1790 Optional[str]: Personal team id when resolution succeeds, else None.
1791 """
1792 nonlocal personal_team_id, personal_team_checked
1794 if role_scope != "team" or not resolve_team_scope_to_personal_team:
1795 return None
1797 if personal_team_checked:
1798 return personal_team_id
1800 personal_team_checked = True
1801 try:
1802 # First-Party
1803 from mcpgateway.services.personal_team_service import PersonalTeamService
1805 personal_team = await PersonalTeamService(self.db).get_personal_team(user_email)
1806 personal_team_id = personal_team.id if personal_team else None
1807 if not personal_team_id:
1808 logger.warning(f"Could not resolve personal team for {user_email}; skipping team-scoped SSO role mapping")
1809 except Exception as e:
1810 logger.error(f"Failed to resolve personal team for {user_email}: {e}. All team-scoped SSO role assignments will be skipped for this login.")
1811 personal_team_id = None
1813 return personal_team_id
1815 # Handle EntraID admin groups -> admin role
1816 if has_entra_admin_groups:
1817 admin_groups_lower = [g.lower() for g in settings.sso_entra_admin_groups]
1818 for group in user_groups:
1819 if group.lower() in admin_groups_lower:
1820 role_assignments.append({"role_name": settings.default_admin_role, "scope": "global", "scope_id": None})
1821 logger.debug(f"Mapped EntraID admin group to {settings.default_admin_role} role for {user_email}")
1822 break # Only need one admin assignment
1824 # Batch role lookups: collect all role names that need to be looked up
1825 role_names_to_lookup = set()
1826 for group in user_groups:
1827 if group in role_mappings:
1828 role_name = role_mappings[group]
1829 if role_name not in ["admin", settings.default_admin_role]:
1830 role_names_to_lookup.add(role_name)
1832 # Add default role to lookup if needed
1833 if has_provider_default_role and provider_default_role:
1834 role_names_to_lookup.add(provider_default_role)
1836 # Pre-fetch all roles by name in batches (reduces DB round-trips)
1837 role_service = RoleService(self.db)
1838 role_cache: Dict[str, Any] = {}
1839 for role_name in role_names_to_lookup:
1840 # Try team scope first, then global
1841 role = await role_service.get_role_by_name(role_name, scope="team")
1842 if not role:
1843 role = await role_service.get_role_by_name(role_name, scope="global")
1844 if role:
1845 role_cache[role_name] = role
1847 # Process role mappings for ALL providers
1848 for group in user_groups:
1849 if group in role_mappings:
1850 role_name = role_mappings[group]
1851 # Special case for "admin" shorthand or configured admin role name
1852 if role_name in ["admin", settings.default_admin_role]:
1853 role_assignments.append({"role_name": settings.default_admin_role, "scope": "global", "scope_id": None})
1854 logger.debug(f"Mapped group to {settings.default_admin_role} role for {user_email}")
1855 continue
1857 # Use pre-fetched role from cache
1858 role = role_cache.get(role_name)
1859 if role:
1860 scope_id = await _resolve_team_scope_id_if_needed(role.scope)
1861 if role.scope == "team" and resolve_team_scope_to_personal_team and not scope_id:
1862 continue
1863 # Avoid duplicate assignments
1864 if not any(r["role_name"] == role.name and r["scope"] == role.scope and r.get("scope_id") == scope_id for r in role_assignments):
1865 role_assignments.append({"role_name": role.name, "scope": role.scope, "scope_id": scope_id})
1866 logger.debug(f"Mapped group to role '{role.name}' for {user_email}")
1867 else:
1868 logger.warning(f"Role '{role_name}' not found for group mapping")
1870 # Apply default role if no mappings found
1871 if not role_assignments and has_provider_default_role and provider_default_role:
1872 default_role = role_cache.get(provider_default_role)
1873 if default_role:
1874 scope_id = await _resolve_team_scope_id_if_needed(default_role.scope)
1875 if default_role.scope == "team" and resolve_team_scope_to_personal_team and not scope_id:
1876 return role_assignments
1877 role_assignments.append({"role_name": default_role.name, "scope": default_role.scope, "scope_id": scope_id})
1878 logger.info(f"Assigned default role '{default_role.name}' to {user_email}")
1880 return role_assignments
1882 async def _sync_user_roles(self, user_email: str, role_assignments: List[Dict[str, Any]], _provider: SSOProviderContext) -> None:
1883 """Synchronize user's SSO-based role assignments.
1885 Args:
1886 user_email: User's email address
1887 role_assignments: List of role assignments to apply
1888 _provider: SSO provider configuration (reserved for future use)
1889 """
1890 # pylint: disable=import-outside-toplevel
1891 # First-Party
1892 from mcpgateway.services.role_service import RoleService
1894 role_service = RoleService(self.db)
1896 # Get current SSO-granted roles
1897 current_roles = await role_service.list_user_roles(user_email, include_expired=False)
1898 sso_roles = [r for r in current_roles if getattr(r, "grant_source", None) == "sso"]
1900 # Build set of desired role assignments
1901 desired_roles = {(r["role_name"], r["scope"], r.get("scope_id")) for r in role_assignments}
1903 # Revoke roles that are no longer in the desired set
1904 for user_role in sso_roles:
1905 role_tuple = (user_role.role.name, user_role.scope, user_role.scope_id)
1906 if role_tuple not in desired_roles:
1907 await role_service.revoke_role_from_user(user_email=user_email, role_id=user_role.role_id, scope=user_role.scope, scope_id=user_role.scope_id)
1908 logger.info(f"Revoked SSO role '{user_role.role.name}' from {user_email} (no longer in groups)")
1910 # Assign new roles
1911 for assignment in role_assignments:
1912 try:
1913 # Get role by name
1914 role = await role_service.get_role_by_name(assignment["role_name"], scope=assignment["scope"])
1915 if not role:
1916 logger.warning(f"Role '{assignment['role_name']}' not found, skipping assignment for {user_email}")
1917 continue
1919 # Check if assignment already exists
1920 existing = await role_service.get_user_role_assignment(user_email=user_email, role_id=role.id, scope=assignment["scope"], scope_id=assignment.get("scope_id"))
1922 if not existing or not existing.is_active:
1923 # Assign role to user
1924 await role_service.assign_role_to_user(
1925 user_email=user_email, role_id=role.id, scope=assignment["scope"], scope_id=assignment.get("scope_id"), granted_by=user_email, grant_source="sso"
1926 )
1927 logger.info(f"Assigned SSO role '{role.name}' to {user_email}")
1929 except Exception as e:
1930 logger.warning(f"Failed to assign role '{assignment['role_name']}' to {user_email}: {e}", exc_info=True)
1931 try:
1932 self.db.rollback()
1933 except Exception as rollback_error:
1934 logger.error(f"Database rollback failed after role assignment error for {user_email}: {rollback_error}. Aborting remaining role assignments.")
1935 break