Coverage for mcpgateway / services / sso_service.py: 99%
1024 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/sso_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Single Sign-On (SSO) authentication service for OAuth2 and OIDC providers.
8Handles provider management, OAuth flows, and user authentication.
9"""
11# Future
12from __future__ import annotations
14# Standard
15import asyncio
16import base64
17from dataclasses import dataclass
18from datetime import timedelta
19from enum import Enum
20import hashlib
21import hmac
22import logging
23import secrets
24import string
25from time import monotonic
26from typing import Any, Dict, List, Optional, Tuple, Union
27import urllib.parse
29# Third-Party
30import jwt
31import orjson
32from sqlalchemy import and_, select
33from sqlalchemy.orm import Session
35# First-Party
36from mcpgateway.common.validators import SecurityValidator
37from mcpgateway.config import settings
38from mcpgateway.db import PendingUserApproval, SSOAuthSession, SSOProvider, utc_now
39from mcpgateway.services.email_auth_service import EmailAuthService
40from mcpgateway.services.encryption_service import get_encryption_service
41from mcpgateway.utils.create_jwt_token import create_jwt_token
43# Logger
44logger = logging.getLogger(__name__)
46# Constants
47ADFS_PROVIDER_ID = "adfs"
50class SSOError(Exception):
51 """Base class for SSO-related errors."""
54class SSOAuthenticationError(SSOError):
55 """Raised when SSO authentication fails."""
58class SSOProviderConfigError(SSOError):
59 """Raised when SSO provider configuration is invalid or incomplete."""
62class _Unset(Enum):
63 """Sentinel: distinguishes 'caller omitted the argument' from 'caller passed None'."""
65 UNSET = "UNSET"
68_UNSET = _Unset.UNSET
71@dataclass
72class SSOProviderContext:
73 """Lightweight context for SSO provider info passed to helper methods."""
75 id: Optional[str]
76 provider_metadata: Dict[str, Any]
79class SSOService:
80 """Service for managing SSO authentication flows and providers.
82 Handles OAuth2/OIDC authentication flows, provider configuration,
83 and integration with the local user system.
85 Examples:
86 Basic construction and helper checks:
87 >>> from unittest.mock import Mock
88 >>> service = SSOService(Mock())
89 >>> isinstance(service, SSOService)
90 True
91 >>> callable(service.list_enabled_providers)
92 True
93 """
95 _OIDC_METADATA_CACHE_TTL_SECONDS = 300
96 _oidc_config_cache: Dict[str, Tuple[float, Dict[str, Any]]] = {}
97 _jwks_client_cache: Dict[str, jwt.PyJWKClient] = {}
98 _STATE_BINDING_SEPARATOR = "."
99 _STATE_BINDING_HEX_LEN = 64
101 def __init__(self, db: Session):
102 """Initialize SSO service with database session.
104 Args:
105 db: SQLAlchemy database session
106 """
107 self.db = db
108 self.auth_service = EmailAuthService(db)
109 self._encryption = get_encryption_service(settings.auth_encryption_secret)
111 async def _encrypt_secret(self, secret: str) -> str:
112 """Encrypt a client secret for secure storage.
114 Args:
115 secret: Plain text client secret
117 Returns:
118 Encrypted secret string
119 """
120 return await self._encryption.encrypt_secret_async(secret)
122 async def _decrypt_secret(self, encrypted_secret: str) -> Optional[str]:
123 """Decrypt a client secret for use.
125 Args:
126 encrypted_secret: Encrypted secret string
128 Returns:
129 Plain text client secret
130 """
131 decrypted: str | None = await self._encryption.decrypt_secret_async(encrypted_secret)
132 if decrypted:
133 return decrypted
135 return None
137 def _decode_jwt_claims(self, token: str) -> Optional[Dict[str, Any]]:
138 """Decode JWT token payload without verification.
140 This is used to extract claims from ID tokens where we've already
141 validated the OAuth flow. The token signature is not verified here
142 because the token was received directly from the trusted token endpoint.
144 Args:
145 token: JWT token string
147 Returns:
148 Decoded payload dict or None if decoding fails
150 Examples:
151 >>> from unittest.mock import Mock
152 >>> service = SSOService(Mock())
153 >>> # Valid JWT structure (header.payload.signature)
154 >>> import base64
155 >>> payload = base64.urlsafe_b64encode(b'{"sub":"123","groups":["admin"]}').decode().rstrip('=')
156 >>> token = f"eyJhbGciOiJSUzI1NiJ9.{payload}.signature"
157 >>> claims = service._decode_jwt_claims(token)
158 >>> claims is not None
159 True
160 """
161 try:
162 # JWT format: header.payload.signature
163 parts = token.split(".")
164 if len(parts) != 3:
165 logger.warning("Invalid JWT format: expected 3 parts")
166 return None
168 # Decode payload (middle part) - add padding if needed
169 payload_b64 = parts[1]
170 # Add padding for base64 decoding
171 padding = 4 - len(payload_b64) % 4
172 if padding != 4:
173 payload_b64 += "=" * padding
175 payload_bytes = base64.urlsafe_b64decode(payload_b64)
176 return orjson.loads(payload_bytes)
178 except (ValueError, orjson.JSONDecodeError, UnicodeDecodeError) as e:
179 logger.warning(f"Failed to decode JWT claims: {e}")
180 return None
182 async def _get_oidc_provider_metadata(self, issuer: str) -> Optional[Dict[str, Any]]:
183 """Discover and cache OIDC provider metadata.
185 Args:
186 issuer: OIDC issuer URL.
188 Returns:
189 Provider metadata dict from discovery endpoint, or None on failure.
190 """
191 normalized_issuer = issuer.rstrip("/")
192 cached = self._oidc_config_cache.get(normalized_issuer)
193 if cached is not None:
194 cached_at, cached_metadata = cached
195 if monotonic() - cached_at < self._OIDC_METADATA_CACHE_TTL_SECONDS:
196 return cached_metadata
197 self._oidc_config_cache.pop(normalized_issuer, None)
199 # First-Party
200 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
202 discovery_url = f"{normalized_issuer}/.well-known/openid-configuration"
203 try:
204 client = await get_http_client()
205 response = await client.get(discovery_url, timeout=settings.oauth_request_timeout)
206 if response.status_code != 200:
207 logger.warning("OIDC discovery failed for issuer %s with HTTP %s", normalized_issuer, response.status_code)
208 return None
210 metadata = response.json()
211 if not isinstance(metadata, dict):
212 logger.warning("OIDC discovery response for issuer %s is not a JSON object", normalized_issuer)
213 return None
214 self._oidc_config_cache[normalized_issuer] = (monotonic(), metadata)
215 return metadata
216 except Exception as exc:
217 logger.warning("OIDC discovery request failed for issuer %s: %s", normalized_issuer, exc)
218 return None
220 async def _resolve_oidc_issuer_and_jwks(self, provider: SSOProvider) -> Tuple[Optional[str], Optional[str]]:
221 """Resolve issuer and JWKS URI for an OIDC provider.
223 Args:
224 provider: SSO provider configuration.
226 Returns:
227 Tuple of (issuer, jwks_uri), each optional when unavailable.
228 """
229 issuer = provider.issuer.strip() if isinstance(provider.issuer, str) and provider.issuer.strip() else None
230 jwks_uri = provider.jwks_uri.strip() if isinstance(provider.jwks_uri, str) and provider.jwks_uri.strip() else None
232 if issuer and not jwks_uri:
233 metadata = await self._get_oidc_provider_metadata(issuer)
234 if metadata:
235 discovered_jwks = metadata.get("jwks_uri")
236 discovered_issuer = metadata.get("issuer")
237 if isinstance(discovered_jwks, str) and discovered_jwks.strip():
238 jwks_uri = discovered_jwks.strip()
239 if isinstance(discovered_issuer, str) and discovered_issuer.strip():
240 issuer = discovered_issuer.strip()
242 return issuer, jwks_uri
244 def _get_jwks_client(self, jwks_uri: str) -> jwt.PyJWKClient:
245 """Get or create a cached PyJWKClient instance.
247 Args:
248 jwks_uri: JWKS endpoint URL.
250 Returns:
251 Cached or newly created `PyJWKClient`.
252 """
253 if jwks_uri not in self._jwks_client_cache:
254 self._jwks_client_cache[jwks_uri] = jwt.PyJWKClient(jwks_uri)
255 return self._jwks_client_cache[jwks_uri]
257 async def _verify_oidc_id_token(self, provider: SSOProvider, id_token: str, expected_nonce: Optional[str] = None) -> Optional[Dict[str, Any]]:
258 """Verify OIDC ID token signature and claims.
260 Args:
261 provider: SSO provider configuration.
262 id_token: Raw OIDC ID token string.
263 expected_nonce: Expected nonce from auth session, when available.
265 Returns:
266 Verified token claims when validation succeeds, otherwise None.
267 """
268 if provider.provider_type != "oidc":
269 return None
271 issuer, jwks_uri = await self._resolve_oidc_issuer_and_jwks(provider)
272 if not jwks_uri:
273 logger.warning("Skipping id_token claim usage for provider %s: missing jwks_uri", provider.id)
274 return None
276 try:
277 jwks_client = self._get_jwks_client(jwks_uri)
278 signing_key = await asyncio.to_thread(jwks_client.get_signing_key_from_jwt, id_token)
280 decode_kwargs: Dict[str, Any] = {
281 "key": signing_key.key,
282 "algorithms": ["RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512", "EdDSA"],
283 "audience": provider.client_id,
284 "options": {
285 "verify_signature": True,
286 "verify_exp": True,
287 "verify_iat": True,
288 "verify_aud": True,
289 "verify_iss": bool(issuer),
290 },
291 }
292 if issuer:
293 decode_kwargs["issuer"] = issuer
295 claims = await asyncio.to_thread(jwt.decode, id_token, **decode_kwargs)
296 if expected_nonce is not None and claims.get("nonce") != expected_nonce:
297 logger.warning("OIDC id_token nonce validation failed for provider %s", provider.id)
298 return None
299 return claims
300 except jwt.PyJWTError as exc:
301 logger.warning("OIDC id_token verification failed for provider %s: %s", provider.id, exc)
302 return None
303 except Exception as exc:
304 logger.warning("Unexpected OIDC id_token verification error for provider %s: %s", provider.id, exc)
305 return None
307 def _resolve_entra_graph_fallback_settings(self, provider_metadata: Optional[Dict[str, Any]]) -> Tuple[bool, int, int]:
308 """Resolve Entra Graph fallback settings with provider metadata override.
310 Args:
311 provider_metadata: Optional provider metadata from DB/config.
313 Returns:
314 Tuple of (enabled, timeout_seconds, max_groups).
315 """
316 metadata = provider_metadata or {}
317 enabled = settings.sso_entra_graph_api_enabled
318 timeout = settings.sso_entra_graph_api_timeout
319 max_groups = settings.sso_entra_graph_api_max_groups
321 if "graph_api_enabled" in metadata:
322 metadata_enabled = metadata.get("graph_api_enabled")
323 if isinstance(metadata_enabled, bool):
324 enabled = metadata_enabled
325 elif isinstance(metadata_enabled, str):
326 normalized_enabled = metadata_enabled.strip().lower()
327 if normalized_enabled in {"1", "true", "yes", "on"}:
328 enabled = True
329 elif normalized_enabled in {"0", "false", "no", "off"}:
330 enabled = False
331 else:
332 logger.warning("Invalid provider_metadata.graph_api_enabled=%s; using configured default %s", metadata_enabled, enabled)
333 elif isinstance(metadata_enabled, int):
334 enabled = metadata_enabled != 0
335 else:
336 logger.warning("Invalid provider_metadata.graph_api_enabled=%s (type %s); using configured default %s", metadata_enabled, type(metadata_enabled).__name__, enabled)
338 metadata_timeout = metadata.get("graph_api_timeout")
339 if metadata_timeout is not None:
340 try:
341 timeout_candidate = int(metadata_timeout)
342 if 1 <= timeout_candidate <= 120:
343 timeout = timeout_candidate
344 else:
345 logger.warning("Invalid provider_metadata.graph_api_timeout=%s; using configured default %s", metadata_timeout, timeout)
346 except (TypeError, ValueError):
347 logger.warning("Invalid provider_metadata.graph_api_timeout=%s; using configured default %s", metadata_timeout, timeout)
349 metadata_max_groups = metadata.get("graph_api_max_groups")
350 if metadata_max_groups is not None:
351 try:
352 max_groups_candidate = int(metadata_max_groups)
353 if max_groups_candidate >= 0:
354 max_groups = max_groups_candidate
355 else:
356 logger.warning("Invalid provider_metadata.graph_api_max_groups=%s; using configured default %s", metadata_max_groups, max_groups)
357 except (TypeError, ValueError):
358 logger.warning("Invalid provider_metadata.graph_api_max_groups=%s; using configured default %s", metadata_max_groups, max_groups)
360 return enabled, timeout, max_groups
362 async def _fetch_entra_groups_from_graph_api(self, access_token: str, user_email: str, provider_metadata: Optional[Dict[str, Any]] = None) -> Optional[List[str]]:
363 """Fetch Entra group object IDs from Microsoft Graph for overage tokens.
365 Args:
366 access_token: Delegated OAuth access token from Entra.
367 user_email: User identifier for structured logs.
368 provider_metadata: Optional provider metadata for runtime overrides.
370 Returns:
371 List of group IDs on success, empty list when Graph returns no IDs,
372 or None if retrieval failed/disabled.
373 """
374 graph_api_enabled, graph_api_timeout, graph_api_max_groups = self._resolve_entra_graph_fallback_settings(provider_metadata)
375 if not graph_api_enabled:
376 logger.info("Microsoft Graph fallback for Entra group overage is disabled")
377 return None
379 # First-Party
380 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
382 client = await get_http_client()
383 try:
384 response = await client.post(
385 "https://graph.microsoft.com/v1.0/me/getMemberObjects",
386 headers={"Authorization": f"Bearer {access_token}"},
387 json={"securityEnabledOnly": True},
388 timeout=graph_api_timeout,
389 )
390 except Exception as e:
391 logger.error("Failed to retrieve groups from Graph API for %s: %s", user_email, e)
392 return None
394 if response.status_code != 200:
395 if response.status_code in {401, 403}:
396 logger.error(
397 "Failed to retrieve groups from Graph API for %s: HTTP %s. Check Entra delegated permissions and consent (minimum User.Read).",
398 user_email,
399 response.status_code,
400 )
401 else:
402 logger.error("Failed to retrieve groups from Graph API for %s: HTTP %s", user_email, response.status_code)
403 return None
405 try:
406 groups_payload = response.json()
407 except ValueError as e:
408 logger.error("Failed to parse Graph API response for %s: %s", user_email, e)
409 return None
411 group_values = groups_payload.get("value", [])
412 if not isinstance(group_values, list):
413 logger.warning("Unexpected Graph API groups payload for %s: expected list in 'value'", user_email)
414 return []
416 deduped_groups: List[str] = []
417 seen_groups: set[str] = set()
418 for group_id in group_values:
419 if not isinstance(group_id, str):
420 continue
421 normalized_group_id = group_id.strip()
422 if not normalized_group_id or normalized_group_id in seen_groups:
423 continue
424 seen_groups.add(normalized_group_id)
425 deduped_groups.append(normalized_group_id)
427 if graph_api_max_groups > 0:
428 if len(deduped_groups) > graph_api_max_groups:
429 logger.warning(
430 "Graph API returned %d groups for %s; applying configured cap (%d)",
431 len(deduped_groups),
432 user_email,
433 graph_api_max_groups,
434 )
435 deduped_groups = deduped_groups[:graph_api_max_groups]
437 logger.info("Retrieved %d groups from Graph API for %s", len(deduped_groups), user_email)
438 return deduped_groups
440 def list_enabled_providers(self) -> List[SSOProvider]:
441 """Get list of enabled SSO providers.
443 Returns:
444 List of enabled SSO providers
446 Examples:
447 Returns empty list when DB has no providers:
448 >>> from unittest.mock import MagicMock
449 >>> service = SSOService(MagicMock())
450 >>> service.db.execute.return_value.scalars.return_value.all.return_value = []
451 >>> service.list_enabled_providers()
452 []
453 """
454 stmt = select(SSOProvider).where(SSOProvider.is_enabled.is_(True))
455 result = self.db.execute(stmt)
456 return list(result.scalars().all())
458 def list_all_providers(self) -> List[SSOProvider]:
459 """Get list of all SSO providers (enabled and disabled).
461 Returns:
462 List of all SSO providers
464 Examples:
465 Returns empty list when DB has no providers:
466 >>> from unittest.mock import MagicMock
467 >>> service = SSOService(MagicMock())
468 >>> service.db.execute.return_value.scalars.return_value.all.return_value = []
469 >>> service.list_all_providers()
470 []
471 """
472 stmt = select(SSOProvider)
473 result = self.db.execute(stmt)
474 return list(result.scalars().all())
476 def get_provider(self, provider_id: str) -> Optional[SSOProvider]:
477 """Get SSO provider by ID.
479 Args:
480 provider_id: Provider identifier (e.g., 'github', 'google')
482 Returns:
483 SSO provider or None if not found
485 Examples:
486 >>> from unittest.mock import MagicMock
487 >>> service = SSOService(MagicMock())
488 >>> service.db.execute.return_value.scalar_one_or_none.return_value = None
489 >>> service.get_provider('x') is None
490 True
491 """
492 stmt = select(SSOProvider).where(SSOProvider.id == provider_id)
493 result = self.db.execute(stmt)
494 return result.scalar_one_or_none()
496 def get_provider_by_name(self, provider_name: str) -> Optional[SSOProvider]:
497 """Get SSO provider by name.
499 Args:
500 provider_name: Provider name (e.g., 'github', 'google')
502 Returns:
503 SSO provider or None if not found
505 Examples:
506 >>> from unittest.mock import MagicMock
507 >>> service = SSOService(MagicMock())
508 >>> service.db.execute.return_value.scalar_one_or_none.return_value = None
509 >>> service.get_provider_by_name('github') is None
510 True
511 """
512 stmt = select(SSOProvider).where(SSOProvider.name == provider_name)
513 result = self.db.execute(stmt)
514 return result.scalar_one_or_none()
516 @staticmethod
517 def _normalize_issuer_url(issuer: str) -> str:
518 """Normalize issuer URL for allowlist comparisons.
520 Args:
521 issuer: Raw issuer URL from provider config or metadata.
523 Returns:
524 Lowercased issuer URL without trailing slash.
525 """
526 return issuer.strip().rstrip("/").lower()
528 def _enforce_allowed_issuer(self, issuer: Optional[str]) -> None:
529 """Enforce configured issuer allowlist when present.
531 Args:
532 issuer: Candidate issuer URL from provider configuration.
534 Raises:
535 ValueError: If issuer is set but not in configured allowlist.
536 """
537 allowed_issuers = getattr(settings, "sso_issuers", None)
538 if not allowed_issuers:
539 return
541 if not isinstance(issuer, str) or not issuer.strip():
542 logger.warning("SSO provider has blank/empty issuer while SSO_ISSUERS allowlist is configured; issuer enforcement skipped.")
543 return
545 normalized_candidate = self._normalize_issuer_url(issuer)
546 normalized_allowlist = {self._normalize_issuer_url(str(allowed_issuer)) for allowed_issuer in allowed_issuers if isinstance(allowed_issuer, str) and allowed_issuer.strip()}
547 if normalized_allowlist and normalized_candidate not in normalized_allowlist:
548 raise ValueError("Issuer is not allowed by SSO_ISSUERS configuration")
550 @staticmethod
551 def _resolve_team_mapping_target(mapping_value: Any) -> Tuple[Optional[str], str]:
552 """Resolve team mapping value into team id and role.
554 Args:
555 mapping_value: Team mapping target value from provider config.
557 Returns:
558 Tuple of ``(team_id, role)`` where ``team_id`` may be ``None`` and
559 role defaults to ``member`` when not explicitly valid.
560 """
561 if isinstance(mapping_value, str) and mapping_value.strip():
562 return mapping_value.strip(), "member"
564 if isinstance(mapping_value, dict):
565 team_id_value = mapping_value.get("team_id") or mapping_value.get("id")
566 team_id = str(team_id_value).strip() if team_id_value is not None else ""
567 role_value = str(mapping_value.get("role", "member")).strip().lower()
568 role = role_value if role_value in {"owner", "member"} else "member"
569 return (team_id if team_id else None), role
571 return None, "member"
573 async def _apply_team_mapping(self, user_email: str, user_info: Dict[str, Any], provider: Optional[SSOProvider]) -> None:
574 """Apply provider team mappings based on SSO group claims.
576 Reconciles team memberships: grants new SSO-based memberships and revokes
577 stale ones when groups are removed from the identity provider.
579 Args:
580 user_email: Authenticated user email to map into teams.
581 user_info: Identity claims payload containing optional group claims.
582 provider: SSO provider configuration with ``team_mapping`` entries.
584 Returns:
585 None.
586 """
587 if not provider:
588 return
590 mapping = getattr(provider, "team_mapping", None)
591 if not isinstance(mapping, dict) or not mapping:
592 return
594 groups_raw = user_info.get("groups", [])
595 if isinstance(groups_raw, str):
596 groups = [groups_raw]
597 elif isinstance(groups_raw, list):
598 groups = [str(group).strip() for group in groups_raw if str(group).strip()]
599 else:
600 groups = []
602 normalized_groups = {group.lower() for group in groups}
604 # First-Party
605 from mcpgateway.services.team_management_service import ( # pylint: disable=import-outside-toplevel
606 MemberAlreadyExistsError,
607 TeamManagementError,
608 TeamManagementService,
609 )
611 team_service = TeamManagementService(self.db)
613 # Get current SSO-granted team memberships for this user
614 # First-Party
615 from mcpgateway.db import EmailTeamMember # pylint: disable=import-outside-toplevel
617 stmt = select(EmailTeamMember).where(
618 EmailTeamMember.user_email == user_email,
619 EmailTeamMember.grant_source == "sso",
620 EmailTeamMember.is_active.is_(True),
621 )
622 result = self.db.execute(stmt)
623 current_sso_memberships = result.scalars().all()
625 # Build set of desired team IDs from current groups + team_mapping
626 desired_team_ids = set()
627 for source_group, target in mapping.items():
628 if not isinstance(source_group, str):
629 continue
630 source_group_normalized = source_group.strip().lower()
631 if not source_group_normalized:
632 continue
634 if source_group_normalized in normalized_groups:
635 team_id, _ = self._resolve_team_mapping_target(target)
636 if team_id:
637 desired_team_ids.add(team_id)
639 # Revoke SSO memberships that are no longer in desired set
640 for membership in current_sso_memberships:
641 if membership.team_id not in desired_team_ids:
642 try:
643 await team_service.remove_member_from_team(
644 team_id=membership.team_id,
645 user_email=user_email,
646 )
647 logger.info(
648 "Revoked SSO team membership for %s from team %s (group no longer in claims)",
649 SecurityValidator.sanitize_log_message(user_email),
650 membership.team_id,
651 )
652 except Exception as exc:
653 logger.warning(
654 "Failed to revoke SSO team membership for %s from team %s: %s",
655 SecurityValidator.sanitize_log_message(user_email),
656 membership.team_id,
657 exc,
658 )
660 # Grant new SSO memberships
661 for source_group, target in mapping.items():
662 if not isinstance(source_group, str):
663 continue
664 source_group_normalized = source_group.strip().lower()
665 if not source_group_normalized or source_group_normalized not in normalized_groups:
666 continue
668 team_id, role = self._resolve_team_mapping_target(target)
669 if not team_id:
670 logger.warning(
671 "Skipping invalid SSO team_mapping entry for provider %s and group '%s'",
672 provider.id,
673 source_group,
674 )
675 continue
677 try:
678 await team_service.add_member_to_team(
679 team_id=team_id,
680 user_email=user_email,
681 role=role,
682 invited_by=user_email,
683 grant_source="sso",
684 )
685 logger.info(
686 "Granted SSO team membership for %s to team %s via group '%s'",
687 SecurityValidator.sanitize_log_message(user_email),
688 team_id,
689 source_group,
690 )
691 except MemberAlreadyExistsError:
692 logger.debug(
693 "SSO team_mapping: user %s already member of team %s",
694 SecurityValidator.sanitize_log_message(user_email),
695 team_id,
696 )
697 except TeamManagementError as exc:
698 logger.warning(
699 "SSO team_mapping failed for user %s, group '%s', team '%s': %s",
700 SecurityValidator.sanitize_log_message(user_email),
701 source_group,
702 team_id,
703 exc,
704 )
705 except Exception as exc:
706 logger.error(
707 "Unexpected error in SSO team_mapping for user %s, group '%s': %s",
708 SecurityValidator.sanitize_log_message(user_email),
709 source_group,
710 exc,
711 exc_info=True,
712 )
714 async def create_provider(self, provider_data: Dict[str, Any]) -> SSOProvider:
715 """Create new SSO provider configuration.
717 Args:
718 provider_data: Provider configuration data
720 Returns:
721 Created SSO provider
723 Examples:
724 >>> import asyncio
725 >>> from unittest.mock import MagicMock, AsyncMock
726 >>> service = SSOService(MagicMock())
727 >>> service._encrypt_secret = AsyncMock(side_effect=lambda s: 'ENC(' + s + ')')
728 >>> data = {
729 ... 'id': 'github', 'name': 'github', 'display_name': 'GitHub', 'provider_type': 'oauth2',
730 ... 'client_id': 'cid', 'client_secret': 'sec',
731 ... 'authorization_url': 'https://example/auth', 'token_url': 'https://example/token',
732 ... 'userinfo_url': 'https://example/user', 'scope': 'user:email'
733 ... }
734 >>> provider = asyncio.run(service.create_provider(data))
735 >>> hasattr(provider, 'id') and provider.id == 'github'
736 True
737 >>> provider.client_secret_encrypted.startswith('ENC(')
738 True
739 """
740 self._enforce_allowed_issuer(provider_data.get("issuer"))
742 # Encrypt client secret
743 client_secret = provider_data.pop("client_secret")
744 provider_data["client_secret_encrypted"] = await self._encrypt_secret(client_secret)
746 # Filter to valid SSOProvider columns to prevent TypeError on unknown keys
747 valid_columns = {c.key for c in SSOProvider.__table__.columns}
748 filtered_data = {k: v for k, v in provider_data.items() if k in valid_columns}
749 skipped = set(provider_data) - set(filtered_data)
750 if skipped:
751 logger.warning("Ignored unknown SSOProvider fields during creation: %s", skipped)
753 provider = SSOProvider(**filtered_data)
754 self.db.add(provider)
755 self.db.commit()
756 self.db.refresh(provider)
757 return provider
759 async def update_provider(self, provider_id: str, provider_data: Dict[str, Any]) -> Optional[SSOProvider]:
760 """Update existing SSO provider configuration.
762 Args:
763 provider_id: Provider identifier
764 provider_data: Updated provider data
766 Returns:
767 Updated SSO provider or None if not found
769 Examples:
770 >>> import asyncio
771 >>> from types import SimpleNamespace
772 >>> from unittest.mock import MagicMock, AsyncMock
773 >>> svc = SSOService(MagicMock())
774 >>> # Existing provider object
775 >>> existing = SimpleNamespace(id='github', name='github', client_id='old', client_secret_encrypted='X', is_enabled=True)
776 >>> svc.get_provider = lambda _id: existing
777 >>> svc._encrypt_secret = AsyncMock(side_effect=lambda s: 'ENC-' + s)
778 >>> svc.db.commit = lambda: None
779 >>> svc.db.refresh = lambda obj: None
780 >>> updated = asyncio.run(svc.update_provider('github', {'client_id': 'new', 'client_secret': 'sec'}))
781 >>> updated.client_id
782 'new'
783 >>> updated.client_secret_encrypted
784 'ENC-sec'
785 """
786 provider = self.get_provider(provider_id)
787 if not provider:
788 return None
790 if "issuer" in provider_data:
791 self._enforce_allowed_issuer(provider_data.get("issuer"))
793 # Handle client secret encryption if provided
794 if "client_secret" in provider_data:
795 client_secret = provider_data.pop("client_secret")
796 provider_data["client_secret_encrypted"] = await self._encrypt_secret(client_secret)
798 for key, value in provider_data.items():
799 if hasattr(provider, key):
800 setattr(provider, key, value)
802 provider.updated_at = utc_now()
803 self.db.commit()
804 self.db.refresh(provider)
805 return provider
807 def delete_provider(self, provider_id: str) -> bool:
808 """Delete SSO provider configuration.
810 Args:
811 provider_id: Provider identifier
813 Returns:
814 True if deleted, False if not found
816 Examples:
817 >>> from types import SimpleNamespace
818 >>> from unittest.mock import MagicMock
819 >>> svc = SSOService(MagicMock())
820 >>> svc.db.delete = lambda obj: None
821 >>> svc.db.commit = lambda: None
822 >>> svc.get_provider = lambda _id: SimpleNamespace(id='github')
823 >>> svc.delete_provider('github')
824 True
825 >>> svc.get_provider = lambda _id: None
826 >>> svc.delete_provider('missing')
827 False
828 """
829 provider = self.get_provider(provider_id)
830 if not provider:
831 return False
833 self.db.delete(provider)
834 self.db.commit()
835 return True
837 def generate_pkce_challenge(self) -> Tuple[str, str]:
838 """Generate PKCE code verifier and challenge for OAuth 2.1.
840 Returns:
841 Tuple of (code_verifier, code_challenge)
843 Examples:
844 Generate verifier and challenge:
845 >>> from unittest.mock import Mock
846 >>> service = SSOService(Mock())
847 >>> verifier, challenge = service.generate_pkce_challenge()
848 >>> isinstance(verifier, str) and isinstance(challenge, str)
849 True
850 >>> len(verifier) >= 43
851 True
852 >>> len(challenge) >= 43
853 True
854 """
855 # Generate cryptographically random code verifier
856 code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
858 # Generate code challenge using SHA256
859 code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()).decode("utf-8").rstrip("=")
861 return code_verifier, code_challenge
863 @staticmethod
864 def _normalize_scope_values(values: Optional[List[str] | str]) -> List[str]:
865 """Normalize scope values to a deduplicated, ordered list.
867 Args:
868 values: Scope input as list or space-delimited string.
870 Returns:
871 Ordered, unique scope values.
872 """
873 if values is None:
874 return []
876 raw_values: List[str]
877 if isinstance(values, str):
878 raw_values = values.split()
879 else:
880 raw_values = [str(v) for v in values if isinstance(v, str) and v.strip()]
882 normalized: List[str] = []
883 seen: set[str] = set()
884 for value in raw_values:
885 scope = value.strip()
886 if not scope or scope in seen:
887 continue
888 normalized.append(scope)
889 seen.add(scope)
890 return normalized
892 def _resolve_login_scopes(self, provider: SSOProvider, requested_scopes: Optional[List[str]]) -> List[str]:
893 """Resolve requested SSO scopes against provider allowlists.
895 Args:
896 provider: SSO provider configuration.
897 requested_scopes: Optional scopes requested by the client.
899 Returns:
900 Final scope list to send to the provider.
902 Raises:
903 ValueError: If provider scope configuration is invalid or request includes disallowed scopes.
904 """
905 configured_scopes = self._normalize_scope_values(getattr(provider, "scope", None))
906 if not configured_scopes:
907 raise ValueError("Provider has no configured scopes")
909 allowed_scopes = configured_scopes
910 provider_metadata = getattr(provider, "provider_metadata", {}) or {}
911 metadata_allowed_raw = provider_metadata.get("allowed_scopes")
912 metadata_allowed = self._normalize_scope_values(metadata_allowed_raw) if metadata_allowed_raw else []
913 if metadata_allowed:
914 metadata_allowed_set = set(metadata_allowed)
915 allowed_scopes = [scope for scope in configured_scopes if scope in metadata_allowed_set]
917 if not allowed_scopes:
918 raise ValueError("No allowed scopes configured for provider")
920 if not requested_scopes:
921 return allowed_scopes
923 normalized_requested = self._normalize_scope_values(requested_scopes)
924 if not normalized_requested:
925 return allowed_scopes
927 allowed_set = set(allowed_scopes)
928 invalid_scopes = [scope for scope in normalized_requested if scope not in allowed_set]
929 if invalid_scopes:
930 invalid_csv = ", ".join(invalid_scopes)
931 raise ValueError(f"Invalid scopes requested: {invalid_csv}")
933 return normalized_requested
935 def _get_state_binding_secret(self) -> bytes:
936 """Resolve secret bytes used for state/session binding HMAC.
938 Returns:
939 Secret bytes used for HMAC signatures.
940 """
941 secret_source = settings.auth_encryption_secret
942 if hasattr(secret_source, "get_secret_value"):
943 return secret_source.get_secret_value().encode("utf-8")
944 return str(secret_source).encode("utf-8")
946 def _generate_session_bound_state(self, provider_id: str, session_binding: str) -> str:
947 """Generate state value bound to browser session context.
949 Args:
950 provider_id: SSO provider identifier.
951 session_binding: Browser-session marker.
953 Returns:
954 Signed state token including nonce and HMAC signature.
955 """
956 state_nonce = secrets.token_urlsafe(24)
957 message = f"{provider_id}:{session_binding}:{state_nonce}".encode("utf-8")
958 signature = hmac.new(self._get_state_binding_secret(), message, hashlib.sha256).hexdigest()
959 return f"{state_nonce}{self._STATE_BINDING_SEPARATOR}{signature}"
961 def _is_session_bound_state(self, state: str) -> bool:
962 """Return whether state appears to carry a session-binding signature.
964 Args:
965 state: State value from OAuth flow.
967 Returns:
968 ``True`` when state has expected nonce/signature format.
969 """
970 if self._STATE_BINDING_SEPARATOR not in state:
971 return False
972 nonce, signature = state.rsplit(self._STATE_BINDING_SEPARATOR, 1)
973 return bool(nonce) and len(signature) == self._STATE_BINDING_HEX_LEN
975 def _verify_session_bound_state(self, provider_id: str, state: str, session_binding: str) -> bool:
976 """Verify state HMAC binding against the current browser session marker.
978 Args:
979 provider_id: SSO provider identifier.
980 state: State value to validate.
981 session_binding: Current browser-session marker.
983 Returns:
984 ``True`` when state signature matches the expected session binding.
985 """
986 if not session_binding or not self._is_session_bound_state(state):
987 return False
988 nonce, signature = state.rsplit(self._STATE_BINDING_SEPARATOR, 1)
989 message = f"{provider_id}:{session_binding}:{nonce}".encode("utf-8")
990 expected_signature = hmac.new(self._get_state_binding_secret(), message, hashlib.sha256).hexdigest()
991 return hmac.compare_digest(signature, expected_signature)
993 @staticmethod
994 def _is_email_verified_claim(user_info: Dict[str, Any]) -> bool:
995 """Evaluate email verification claim when provided by the IdP.
997 When the ``email_verified`` claim is **absent** from ``user_info`` (e.g.
998 Microsoft Entra ID and GitHub do not include it for work / school
999 accounts) the function returns ``True`` so that those providers are not
1000 incorrectly blocked. The check is only enforced when the IdP
1001 *explicitly* supplies the claim — a ``False``/``0``/``"false"`` value
1002 means the provider has flagged the address as unverified and the user
1003 should be rejected.
1005 Args:
1006 user_info: Normalized user-info payload from provider.
1008 Returns:
1009 ``True`` when the claim is absent (provider does not restrict) or
1010 when it is explicitly set to a truthy value; ``False`` only when the
1011 provider explicitly indicates the address is *not* verified.
1012 """
1013 if "email_verified" not in user_info:
1014 # Claim not provided by IdP — treat as no restriction (pass through).
1015 return True
1017 claim_value = user_info.get("email_verified")
1018 if isinstance(claim_value, bool):
1019 return claim_value
1020 if isinstance(claim_value, int):
1021 return claim_value == 1
1022 if isinstance(claim_value, str):
1023 return claim_value.strip().lower() in {"1", "true", "yes", "on"}
1024 return False
1026 def get_authorization_url(
1027 self,
1028 provider_id: str,
1029 redirect_uri: str,
1030 scopes: Optional[List[str]] = None,
1031 session_binding: Optional[str] = None,
1032 ) -> Optional[str]:
1033 """Generate OAuth authorization URL for provider.
1035 Args:
1036 provider_id: Provider identifier
1037 redirect_uri: Callback URI after authorization
1038 scopes: Optional custom scopes (uses provider default if None)
1039 session_binding: Optional browser-session marker for state binding.
1041 Returns:
1042 Authorization URL or None if provider not found
1044 Examples:
1045 >>> from types import SimpleNamespace
1046 >>> from unittest.mock import MagicMock
1047 >>> service = SSOService(MagicMock())
1048 >>> provider = SimpleNamespace(id='github', is_enabled=True, provider_type='oauth2', client_id='cid', authorization_url='https://example/auth', scope='user:email')
1049 >>> service.get_provider = lambda _pid: provider
1050 >>> service.db.add = lambda x: None
1051 >>> service.db.commit = lambda: None
1052 >>> url = service.get_authorization_url('github', 'https://app/callback', ['user:email'])
1053 >>> isinstance(url, str) and 'client_id=cid' in url and 'state=' in url
1054 True
1056 Missing provider returns None:
1057 >>> service.get_provider = lambda _pid: None
1058 >>> service.get_authorization_url('missing', 'https://app/callback') is None
1059 True
1060 """
1061 provider = self.get_provider(provider_id)
1062 if not provider or not provider.is_enabled:
1063 return None
1065 # Generate PKCE parameters
1066 code_verifier, code_challenge = self.generate_pkce_challenge()
1068 resolved_scopes = self._resolve_login_scopes(provider, scopes)
1070 # Generate CSRF state (session-bound when browser context is available).
1071 state = self._generate_session_bound_state(provider_id, session_binding) if session_binding else secrets.token_urlsafe(32)
1073 # Generate OIDC nonce if applicable
1074 nonce = secrets.token_urlsafe(16) if provider.provider_type == "oidc" else None
1076 # Create auth session
1077 auth_session = SSOAuthSession(provider_id=provider_id, state=state, code_verifier=code_verifier, nonce=nonce, redirect_uri=redirect_uri)
1078 self.db.add(auth_session)
1079 self.db.commit()
1081 # Build authorization URL
1082 params = {
1083 "client_id": provider.client_id,
1084 "response_type": "code",
1085 "redirect_uri": redirect_uri,
1086 "state": state,
1087 "scope": " ".join(resolved_scopes),
1088 "code_challenge": code_challenge,
1089 "code_challenge_method": "S256",
1090 }
1092 if nonce:
1093 params["nonce"] = nonce
1095 return f"{provider.authorization_url}?{urllib.parse.urlencode(params)}"
1097 async def handle_oauth_callback(self, provider_id: str, code: str, state: str, session_binding: Optional[str] = None) -> Optional[Dict[str, Any]]:
1098 """Handle OAuth callback and exchange code for tokens.
1100 Args:
1101 provider_id: Provider identifier
1102 code: Authorization code from callback
1103 state: CSRF state parameter
1104 session_binding: Optional browser-session marker used to verify bound state.
1106 Returns:
1107 User info dict or None if authentication failed
1109 Examples:
1110 Happy-path with patched exchanges and user info:
1111 >>> import asyncio
1112 >>> from types import SimpleNamespace
1113 >>> from unittest.mock import MagicMock
1114 >>> svc = SSOService(MagicMock())
1115 >>> # Mock DB auth session lookup
1116 >>> provider = SimpleNamespace(id='github', is_enabled=True, provider_type='oauth2')
1117 >>> auth_session = SimpleNamespace(provider_id='github', state='st', provider=provider, is_expired=False, nonce=None)
1118 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = auth_session
1119 >>> # Patch token exchange and user info retrieval
1120 >>> async def _ex(p, sess, c):
1121 ... return {'access_token': 'tok', 'id_token': 'id_tok'}
1122 >>> async def _ui(p, access, token_data=None, expected_nonce=None):
1123 ... return {'email': 'user@example.com'}
1124 >>> svc._exchange_code_for_tokens = _ex
1125 >>> svc._get_user_info = _ui
1126 >>> svc.db.delete = lambda obj: None
1127 >>> svc.db.commit = lambda: None
1128 >>> out = asyncio.run(svc.handle_oauth_callback('github', 'code', 'st'))
1129 >>> out['email']
1130 'user@example.com'
1132 Early return cases:
1133 >>> # No session
1134 >>> svc2 = SSOService(MagicMock())
1135 >>> svc2.db.execute.return_value.scalar_one_or_none.return_value = None
1136 >>> asyncio.run(svc2.handle_oauth_callback('github', 'c', 's')) is None
1137 True
1138 >>> # Expired session
1139 >>> expired = SimpleNamespace(provider_id='github', state='st', provider=SimpleNamespace(is_enabled=True), is_expired=True)
1140 >>> svc3 = SSOService(MagicMock())
1141 >>> svc3.db.execute.return_value.scalar_one_or_none.return_value = expired
1142 >>> asyncio.run(svc3.handle_oauth_callback('github', 'c', 'st')) is None
1143 True
1144 >>> # Disabled provider
1145 >>> disabled = SimpleNamespace(provider_id='github', state='st', provider=SimpleNamespace(is_enabled=False), is_expired=False)
1146 >>> svc4 = SSOService(MagicMock())
1147 >>> svc4.db.execute.return_value.scalar_one_or_none.return_value = disabled
1148 >>> asyncio.run(svc4.handle_oauth_callback('github', 'c', 'st')) is None
1149 True
1150 """
1151 callback_result = await self.handle_oauth_callback_with_tokens(provider_id, code, state, session_binding=session_binding)
1152 if not callback_result:
1153 return None
1154 user_info, _token_data = callback_result
1155 return user_info
1157 async def handle_oauth_callback_with_tokens(
1158 self,
1159 provider_id: str,
1160 code: str,
1161 state: str,
1162 session_binding: Optional[str] = None,
1163 ) -> Optional[Tuple[Dict[str, Any], Dict[str, Any]]]:
1164 """Handle OAuth callback and return both user info and raw token response.
1166 Args:
1167 provider_id: Provider identifier
1168 code: Authorization code from callback
1169 state: CSRF state parameter
1170 session_binding: Optional browser-session marker used to verify bound state.
1172 Returns:
1173 Tuple of (user_info, token_data) or None if authentication fails
1174 """
1175 # Validate auth session
1176 stmt = select(SSOAuthSession).where(SSOAuthSession.state == state, SSOAuthSession.provider_id == provider_id)
1177 auth_session = self.db.execute(stmt).scalar_one_or_none()
1179 if not auth_session:
1180 logger.warning(f"OAuth callback: no auth session found for state/provider {provider_id}. Possible CSRF or replay.")
1181 return None
1183 if auth_session.is_expired:
1184 logger.warning(f"OAuth callback: auth session expired for provider {provider_id}.")
1185 self.db.delete(auth_session)
1186 self.db.commit()
1187 return None
1189 if self._is_session_bound_state(state):
1190 if not session_binding or not self._verify_session_bound_state(provider_id, state, session_binding):
1191 logger.warning(
1192 "OAuth callback: state/session mismatch for provider %s. Possible CSRF or cross-session replay.",
1193 provider_id,
1194 )
1195 return None
1197 provider = auth_session.provider
1198 if not provider:
1199 logger.error(f"OAuth callback: provider '{provider_id}' not found for auth session.")
1200 return None
1202 if not provider.is_enabled:
1203 logger.warning(f"OAuth callback: provider '{provider_id}' is disabled.")
1204 return None
1206 try:
1207 # Exchange authorization code for tokens
1208 logger.info(f"Starting token exchange for provider {provider_id}")
1209 token_data = await self._exchange_code_for_tokens(provider, auth_session, code)
1210 if not token_data:
1211 logger.error(f"Failed to exchange code for tokens for provider {provider_id}")
1212 return None
1213 logger.info(f"Token exchange successful for provider {provider_id}")
1214 callback_nonce = getattr(auth_session, "nonce", None)
1216 # For OIDC providers, verify id_token before any claim extraction.
1217 if provider.provider_type == "oidc":
1218 if callback_nonce is None:
1219 logger.error("OAuth callback: missing nonce for OIDC provider %s.", provider_id)
1220 return None
1221 id_token = token_data.get("id_token")
1222 if not isinstance(id_token, str):
1223 logger.error("OAuth callback: missing id_token for OIDC provider %s.", provider_id)
1224 return None
1225 verified_claims = await self._verify_oidc_id_token(provider, id_token, expected_nonce=callback_nonce)
1226 if verified_claims is None:
1227 # ADFS: on-prem deployments often lack a discoverable JWKS endpoint,
1228 # so verification may fail. Fall through to _get_user_info which
1229 # decodes the id_token received over the direct TLS channel.
1230 if provider.id == ADFS_PROVIDER_ID:
1231 logger.warning("OIDC id_token verification failed for ADFS provider %s; falling back to TLS-trust decode.", provider_id)
1232 # Mark that verification was attempted so _get_user_info
1233 # does not redundantly re-attempt it.
1234 token_data = dict(token_data)
1235 token_data["_adfs_verification_attempted"] = True
1236 else:
1237 logger.error("id_token verification failed for provider %s", provider_id)
1238 return None
1239 else:
1240 token_data = dict(token_data)
1241 token_data["_verified_id_token_claims"] = verified_claims
1243 # Get user info from provider (pass full token_data for id_token parsing)
1244 user_info = await self._get_user_info(provider, token_data["access_token"], token_data, expected_nonce=callback_nonce)
1245 if not user_info:
1246 logger.error(f"Failed to get user info for provider {provider_id}")
1247 return None
1249 # Clean up auth session
1250 self.db.delete(auth_session)
1251 self.db.commit()
1253 return user_info, token_data
1255 except Exception as e:
1256 # Clean up auth session on error
1257 logger.error(f"OAuth callback failed for provider {provider_id}: {type(e).__name__}: {str(e)}")
1258 logger.exception("Full traceback for OAuth callback failure:")
1259 self.db.delete(auth_session)
1260 self.db.commit()
1261 return None
1263 async def _exchange_code_for_tokens(self, provider: SSOProvider, auth_session: SSOAuthSession, code: str) -> Optional[Dict[str, Any]]:
1264 """Exchange authorization code for access tokens.
1266 Args:
1267 provider: SSO provider configuration
1268 auth_session: Auth session with PKCE parameters
1269 code: Authorization code
1271 Returns:
1272 Token response dict or None if failed
1273 """
1274 token_params = {
1275 "client_id": provider.client_id,
1276 "client_secret": await self._decrypt_secret(provider.client_secret_encrypted),
1277 "code": code,
1278 "grant_type": "authorization_code",
1279 "redirect_uri": auth_session.redirect_uri,
1280 "code_verifier": auth_session.code_verifier,
1281 }
1283 # First-Party
1284 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
1286 client = await get_http_client()
1287 response = await client.post(provider.token_url, data=token_params, headers={"Accept": "application/json"})
1289 if response.status_code == 200:
1290 return response.json()
1291 logger.error(f"Token exchange failed for {provider.name}: HTTP {response.status_code} - {response.text}")
1293 return None
1295 async def _enrich_user_data_from_claims(
1296 self,
1297 provider: SSOProvider,
1298 user_data: Dict[str, Any],
1299 access_token: str,
1300 verified_id_token_claims: Optional[Dict[str, Any]],
1301 ) -> None:
1302 """Enrich userinfo response with provider-specific claims.
1304 Mutates ``user_data`` in place by merging groups, roles, and other
1305 claims from the id_token or external APIs (GitHub orgs, Entra Graph)
1306 that the userinfo endpoint does not include on its own.
1308 Args:
1309 provider: SSO provider configuration.
1310 user_data: Mutable userinfo dict from the provider endpoint.
1311 access_token: OAuth access token for follow-up API calls.
1312 verified_id_token_claims: Verified id_token claims, if available.
1313 """
1314 # GitHub: fetch organizations for admin assignment
1315 if provider.id == "github" and settings.sso_github_admin_orgs:
1316 # First-Party
1317 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
1319 client = await get_http_client()
1320 try:
1321 orgs_response = await client.get("https://api.github.com/user/orgs", headers={"Authorization": f"Bearer {access_token}"})
1322 if orgs_response.status_code == 200:
1323 user_data["organizations"] = [org["login"] for org in orgs_response.json()]
1324 else:
1325 logger.warning(f"Failed to fetch GitHub organizations: HTTP {orgs_response.status_code}")
1326 user_data["organizations"] = []
1327 except Exception as e:
1328 logger.warning(f"Error fetching GitHub organizations: {e}")
1329 user_data["organizations"] = []
1330 return
1332 # Entra ID: extract groups/roles from id_token since userinfo doesn't include them.
1333 # Microsoft's /oidc/userinfo endpoint only returns basic claims (sub, name, email, picture).
1334 # Groups and roles are included in the id_token when configured in Azure Portal.
1335 if provider.id == "entra" and verified_id_token_claims:
1336 entra_groups_from_graph: Optional[List[str]] = None
1337 # Detect group overage — when user has too many groups (>200), EntraID can return
1338 # overage markers (e.g. _claim_names/_claim_sources, hasgroups, groups:srcN)
1339 # instead of an inline groups array.
1340 # See: https://learn.microsoft.com/en-us/entra/identity-platform/id-token-claims-reference
1341 claim_names = verified_id_token_claims.get("_claim_names", {})
1342 has_groups_src_key = any(isinstance(key, str) and key.startswith("groups:src") for key in verified_id_token_claims)
1343 groups_claim_value = verified_id_token_claims.get("groups")
1344 has_group_overage = (
1345 (isinstance(claim_names, dict) and "groups" in claim_names) or bool(verified_id_token_claims.get("hasgroups")) or has_groups_src_key or isinstance(groups_claim_value, str)
1346 )
1347 if has_group_overage:
1348 user_email = user_data.get("email") or user_data.get("preferred_username") or "unknown"
1349 logger.warning(
1350 "Group overage detected for user %s - token contains too many groups (>200). Attempting Microsoft Graph fallback to resolve complete group membership.",
1351 user_email,
1352 )
1353 entra_groups_from_graph = await self._fetch_entra_groups_from_graph_api(access_token, user_email, provider.provider_metadata)
1354 if entra_groups_from_graph is None:
1355 logger.warning("Proceeding without Graph-resolved Entra groups for user %s", user_email)
1357 # Extract groups from id_token (Security Groups as Object IDs)
1358 if entra_groups_from_graph is not None:
1359 user_data["groups"] = entra_groups_from_graph
1360 elif "groups" in verified_id_token_claims:
1361 user_data["groups"] = verified_id_token_claims["groups"]
1362 logger.debug(f"Extracted {len(verified_id_token_claims['groups'])} groups from Entra ID token")
1364 # Extract roles from id_token (App Roles)
1365 if "roles" in verified_id_token_claims:
1366 user_data["roles"] = verified_id_token_claims["roles"]
1367 logger.debug(f"Extracted {len(verified_id_token_claims['roles'])} roles from Entra ID token")
1369 # Also extract any missing basic claims from id_token
1370 for claim in ["email", "name", "preferred_username", "oid", "sub"]:
1371 if claim not in user_data and claim in verified_id_token_claims:
1372 user_data[claim] = verified_id_token_claims[claim]
1373 return
1375 # Keycloak: merge realm_access, resource_access, and groups from id_token
1376 if provider.id == "keycloak" and verified_id_token_claims:
1377 for claim in ["realm_access", "resource_access", "groups"]:
1378 if claim in verified_id_token_claims and claim not in user_data:
1379 user_data[claim] = verified_id_token_claims[claim]
1380 return
1382 # Generic OIDC (including Okta, IBM Verify, and any custom provider):
1383 # merge groups and roles claims from the verified id_token when the
1384 # userinfo response does not already contain them.
1385 if provider.id not in ("github", "google") and verified_id_token_claims:
1386 metadata = provider.provider_metadata or {}
1387 groups_claim = metadata.get("groups_claim", "groups")
1388 for claim in {groups_claim, "roles"}:
1389 if claim in verified_id_token_claims and claim not in user_data:
1390 user_data[claim] = verified_id_token_claims[claim]
1392 async def _get_user_info(self, provider: SSOProvider, access_token: str, token_data: Optional[Dict[str, Any]] = None, expected_nonce: Optional[str] = None) -> Optional[Dict[str, Any]]:
1393 """Get user information from provider using access token.
1395 Args:
1396 provider: SSO provider configuration
1397 access_token: OAuth access token
1398 token_data: Optional full token response containing id_token for OIDC providers
1399 expected_nonce: Nonce bound to the current auth session for OIDC id_token verification
1401 Returns:
1402 User info dict or None if failed
1404 Raises:
1405 SSOProviderConfigError: If the ADFS provider is missing the required id_token.
1406 """
1407 # First-Party
1408 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
1410 client = await get_http_client()
1411 verified_id_token_claims: Optional[Dict[str, Any]] = None
1412 if token_data and isinstance(token_data.get("_verified_id_token_claims"), dict):
1413 verified_id_token_claims = token_data.get("_verified_id_token_claims")
1414 elif token_data and token_data.get("_adfs_verification_attempted"):
1415 # ADFS verification was already attempted upstream in handle_oauth_callback_with_tokens
1416 # and failed (on-prem ADFS without JWKS). Skip redundant re-attempt.
1417 pass
1418 elif provider.provider_type == "oidc" and token_data and isinstance(token_data.get("id_token"), str):
1419 if expected_nonce is None:
1420 logger.warning("Skipping OIDC id_token fallback verification for provider %s because expected nonce is unavailable.", provider.id)
1421 else:
1422 verified_id_token_claims = await self._verify_oidc_id_token(provider, token_data["id_token"], expected_nonce=expected_nonce)
1424 # ADFS does not support GET on the userinfo endpoint.
1425 # Extract user info directly from the ID token instead.
1426 # Prefer verified claims when available (ADFS with discoverable JWKS);
1427 # fall back to unverified decode for on-prem ADFS without JWKS.
1428 if provider.id == ADFS_PROVIDER_ID:
1429 # Use verified claims if OIDC verification succeeded upstream
1430 if verified_id_token_claims:
1431 logger.debug("ADFS: using verified id_token claims (keys: %s)", list(verified_id_token_claims.keys()))
1432 return self._normalize_user_info(provider, verified_id_token_claims)
1434 # Fall back to unverified decode (on-prem ADFS without JWKS).
1435 # The id_token was received server-to-server over TLS from the token
1436 # endpoint, so we trust the transport. We still validate aud, iss,
1437 # exp, and nonce as defense-in-depth against token confusion.
1438 if token_data and isinstance(token_data.get("id_token"), str):
1439 id_token_claims = self._decode_jwt_claims(token_data["id_token"])
1440 if not id_token_claims:
1441 logger.error("Failed to decode ADFS ID token claims")
1442 return None
1444 # Validate audience — must match our client_id
1445 token_aud = id_token_claims.get("aud")
1446 if isinstance(token_aud, list):
1447 aud_match = provider.client_id in token_aud
1448 else:
1449 aud_match = token_aud == provider.client_id
1450 if not aud_match:
1451 logger.error("ADFS id_token audience mismatch: expected %s, got %s", provider.client_id, token_aud)
1452 return None
1454 # Validate issuer — must match configured issuer
1455 if provider.issuer and id_token_claims.get("iss") != provider.issuer:
1456 logger.error("ADFS id_token issuer mismatch: expected %s, got %s", provider.issuer, id_token_claims.get("iss"))
1457 return None
1459 # Validate expiration — reject missing or non-numeric exp
1460 # Standard
1461 import time # pylint: disable=import-outside-toplevel
1463 exp = id_token_claims.get("exp")
1464 if not isinstance(exp, (int, float)):
1465 logger.error("ADFS id_token missing or malformed exp claim: %r", exp)
1466 return None
1467 if exp < time.time():
1468 logger.error("ADFS id_token has expired (exp=%s)", exp)
1469 return None
1471 # Validate nonce — prevents replay attacks
1472 if expected_nonce and id_token_claims.get("nonce") != expected_nonce:
1473 logger.error("ADFS id_token nonce mismatch")
1474 return None
1476 logger.debug("ADFS: using decoded id_token claims with validated aud/iss/exp/nonce (keys: %s)", list(id_token_claims.keys()))
1477 return self._normalize_user_info(provider, id_token_claims)
1479 logger.error("ADFS provider requires id_token but none was provided in token_data")
1480 raise SSOProviderConfigError("ADFS provider requires id_token in token response")
1482 response = await client.get(provider.userinfo_url, headers={"Authorization": f"Bearer {access_token}"})
1484 if response.status_code == 200:
1485 user_data = response.json()
1486 await self._enrich_user_data_from_claims(provider, user_data, access_token, verified_id_token_claims)
1487 return self._normalize_user_info(provider, user_data)
1489 # Keycloak can issue tokens using the browser-facing issuer URL; if userinfo
1490 # is called on a different host/port, Keycloak may reject the token with 401.
1491 # Only fall back to id_token claims for 401 with split-host configuration.
1492 # Other errors (403=revoked, 500=server error) must NOT fall back — the user
1493 # should be denied access, not silently authenticated via stale id_token claims.
1494 if provider.id == "keycloak" and verified_id_token_claims and response.status_code == 401:
1495 metadata = provider.provider_metadata or {}
1496 public_base_url = metadata.get("public_base_url")
1497 if public_base_url and public_base_url != metadata.get("base_url"):
1498 logger.warning(
1499 "User info request returned 401 for keycloak with split-host config (public=%s, internal=%s). Falling back to id_token claims.",
1500 public_base_url,
1501 metadata.get("base_url"),
1502 )
1503 return self._normalize_user_info(provider, verified_id_token_claims)
1505 logger.error(f"User info request failed for {provider.name}: HTTP {response.status_code} - {response.text}")
1507 return None
1509 def _get_default_email_domain(self, provider: SSOProvider) -> Optional[str]:
1510 """Get default email domain from provider metadata or global settings.
1512 Args:
1513 provider: SSO provider instance
1515 Returns:
1516 Default email domain string, or None if not configured
1517 """
1518 metadata = provider.provider_metadata or {}
1519 default_domain = metadata.get("default_email_domain")
1520 if not default_domain:
1521 default_domain = settings.sso_adfs_default_email_domain
1522 return default_domain
1524 def _normalize_adfs_email(self, raw_email: str, default_domain: Optional[str]) -> Optional[str]:
1525 """Normalize ADFS email/UPN to standard email format.
1527 ADFS may return UPN in various formats:
1528 - user@domain.com (already valid email)
1529 - DOMAIN\\username (Windows domain format)
1530 - username (plain username without domain)
1532 Args:
1533 raw_email: Raw email/UPN string from ADFS claims
1534 default_domain: Default email domain to append if needed
1536 Returns:
1537 Normalized email address, or None if normalization fails
1538 """
1539 if not raw_email:
1540 return None
1542 raw_email_str = str(raw_email).strip()
1544 # Already a valid email format
1545 if "@" in raw_email_str and "." in raw_email_str.split("@")[-1]:
1546 return raw_email_str
1548 # Handle DOMAIN\username format
1549 if "\\" in raw_email_str:
1550 username_part = raw_email_str.split("\\")[-1]
1551 if default_domain:
1552 return f"{username_part}@{default_domain}"
1553 logger.warning("ADFS UPN in DOMAIN\\username format but no default_email_domain configured")
1554 return None
1556 # Handle plain username without domain
1557 if "@" not in raw_email_str:
1558 if default_domain:
1559 return f"{raw_email_str}@{default_domain}"
1560 logger.warning("ADFS UPN is plain username but no default_email_domain configured")
1561 return None
1563 return raw_email_str
1565 @staticmethod
1566 def _extract_groups_and_roles(user_data: Dict[str, Any], groups_claim: str = "groups") -> list[str]:
1567 """Extract groups and roles from user data into a unified list.
1569 Args:
1570 user_data: Raw user data from provider.
1571 groups_claim: Claim key for groups (default: ``"groups"``).
1573 Returns:
1574 Combined list of group and role strings.
1575 """
1576 groups: list[str] = []
1578 if groups_claim in user_data:
1579 groups_value = user_data.get(groups_claim, [])
1580 if isinstance(groups_value, list):
1581 groups.extend(g for g in groups_value if isinstance(g, str))
1582 elif isinstance(groups_value, str):
1583 groups.append(groups_value)
1585 if "roles" in user_data:
1586 roles_value = user_data.get("roles", [])
1587 if isinstance(roles_value, list):
1588 groups.extend(r for r in roles_value if isinstance(r, str))
1589 elif isinstance(roles_value, str):
1590 groups.append(roles_value)
1592 return groups
1594 @staticmethod
1595 def _build_normalized_user_info(
1596 user_data: Dict[str, Any],
1597 provider_name: str,
1598 groups: list[str],
1599 *,
1600 email: Union[Optional[str], _Unset] = _UNSET,
1601 full_name: Union[Optional[str], _Unset] = _UNSET,
1602 avatar_url: Union[Optional[str], _Unset] = _UNSET,
1603 provider_id: Union[Optional[Any], _Unset] = _UNSET,
1604 username: Union[Optional[str], _Unset] = _UNSET,
1605 extra: Optional[Dict[str, Any]] = None,
1606 ) -> Dict[str, Any]:
1607 """Build a normalized user-info dict with common fields.
1609 Provider-specific branches call this with overrides only for fields
1610 that deviate from the standard OIDC claim mapping. The
1611 ``email_verified`` claim is propagated only when the IdP explicitly
1612 includes it so that ``_is_email_verified_claim``'s
1613 absent-means-pass-through logic applies correctly.
1615 Pass ``None`` explicitly to force a field to ``None``; omit the
1616 argument (default ``_UNSET``) to fall back to the standard claim.
1618 Args:
1619 user_data: Raw user data from the provider.
1620 provider_name: Provider identifier for the ``"provider"`` field.
1621 groups: Pre-extracted groups list (will be deduplicated).
1622 email: Override for ``user_data["email"]``.
1623 full_name: Override for ``user_data["name"]``.
1624 avatar_url: Override for ``user_data["picture"]``.
1625 provider_id: Override for ``user_data["sub"]``.
1626 username: Override for the computed username.
1627 extra: Additional keys merged into the result.
1629 Returns:
1630 Normalized user-info dict.
1631 """
1632 normalized: Dict[str, Any] = {
1633 "email": email if email is not _UNSET else user_data.get("email"),
1634 "full_name": full_name if full_name is not _UNSET else user_data.get("name"),
1635 "avatar_url": avatar_url if avatar_url is not _UNSET else user_data.get("picture"),
1636 "provider_id": provider_id if provider_id is not _UNSET else user_data.get("sub"),
1637 "username": username if username is not _UNSET else (user_data.get("preferred_username") or user_data.get("email", "").split("@")[0]),
1638 "provider": provider_name,
1639 "groups": list(set(groups)),
1640 }
1641 if "email_verified" in user_data:
1642 normalized["email_verified"] = user_data["email_verified"]
1643 if extra:
1644 normalized.update(extra)
1645 return normalized
1647 def _normalize_user_info(self, provider: SSOProvider, user_data: Dict[str, Any]) -> Dict[str, Any]:
1648 """Normalize user info from different providers to common format.
1650 Args:
1651 provider: SSO provider configuration
1652 user_data: Raw user data from provider
1654 Returns:
1655 Normalized user info dict
1656 """
1657 # Handle GitHub provider
1658 if provider.id == "github":
1659 normalized: Dict[str, Any] = {
1660 "email": user_data.get("email"),
1661 "full_name": user_data.get("name") or user_data.get("login"),
1662 "avatar_url": user_data.get("avatar_url"),
1663 "provider_id": user_data.get("id"),
1664 "username": user_data.get("login"),
1665 "provider": "github",
1666 "organizations": user_data.get("organizations", []),
1667 }
1668 # GitHub /user responses do not include an email_verified claim. Preserve
1669 # backward-compatible login behavior by only enforcing verification when
1670 # a concrete claim value is present.
1671 github_email_verified = user_data.get("email_verified")
1672 if github_email_verified is None:
1673 github_email_verified = user_data.get("verified")
1674 if github_email_verified is not None:
1675 normalized["email_verified"] = github_email_verified
1676 return normalized
1678 # Handle Google provider
1679 if provider.id == "google":
1680 return self._build_normalized_user_info(
1681 user_data,
1682 "google",
1683 [],
1684 username=user_data.get("email", "").split("@")[0],
1685 )
1687 metadata = provider.provider_metadata or {}
1688 groups_claim = metadata.get("groups_claim", "groups")
1690 # Handle IBM Verify provider
1691 if provider.id == "ibm_verify":
1692 groups = self._extract_groups_and_roles(user_data, groups_claim)
1693 return self._build_normalized_user_info(user_data, "ibm_verify", groups)
1695 # Handle Okta provider
1696 if provider.id == "okta":
1697 groups = self._extract_groups_and_roles(user_data, groups_claim)
1698 return self._build_normalized_user_info(user_data, "okta", groups)
1700 # Handle Keycloak provider with role mapping
1701 if provider.id == "keycloak":
1702 username_claim = metadata.get("username_claim", "preferred_username")
1703 email_claim = metadata.get("email_claim", "email")
1705 groups: list[str] = []
1707 # Extract realm roles
1708 if metadata.get("map_realm_roles"):
1709 realm_access = user_data.get("realm_access", {})
1710 groups.extend(realm_access.get("roles", []))
1712 # Extract client roles
1713 if metadata.get("map_client_roles"):
1714 resource_access = user_data.get("resource_access", {})
1715 for client, access in resource_access.items():
1716 groups.extend(f"{client}:{role}" for role in access.get("roles", []))
1718 # Extract groups from custom claim
1719 if groups_claim in user_data:
1720 custom_groups = user_data.get(groups_claim, [])
1721 if isinstance(custom_groups, list):
1722 groups.extend(custom_groups)
1724 return self._build_normalized_user_info(
1725 user_data,
1726 "keycloak",
1727 groups,
1728 email=user_data.get(email_claim),
1729 username=user_data.get(username_claim) or user_data.get(email_claim, "").split("@")[0],
1730 )
1732 # Handle Microsoft Entra ID provider with role mapping
1733 if provider.id == "entra":
1734 # Microsoft's userinfo endpoint often omits the email claim
1735 # Fallback: preferred_username (UPN) or upn claim
1736 email = user_data.get("email") or user_data.get("preferred_username") or user_data.get("upn")
1737 username = user_data.get("preferred_username") or (email.split("@")[0] if email else None)
1739 groups = self._extract_groups_and_roles(user_data, groups_claim)
1740 return self._build_normalized_user_info(
1741 user_data,
1742 "entra",
1743 groups,
1744 email=email,
1745 full_name=user_data.get("name") or email,
1746 provider_id=user_data.get("sub") or user_data.get("oid"),
1747 username=username,
1748 )
1750 # Handle ADFS provider
1751 if provider.id == ADFS_PROVIDER_ID:
1752 # ADFS uses UPN (User Principal Name) as the primary identifier.
1753 # Claim priority: email > preferred_username > upn > unique_name
1754 raw_email = user_data.get("email") or user_data.get("preferred_username") or user_data.get("upn") or user_data.get("unique_name")
1756 email = None
1757 if raw_email:
1758 default_domain = self._get_default_email_domain(provider)
1759 email = self._normalize_adfs_email(str(raw_email), default_domain)
1761 username = None
1762 if email:
1763 username = email.split("@")[0]
1764 elif raw_email:
1765 raw_str = str(raw_email).strip()
1766 if "\\" in raw_str:
1767 username = raw_str.split("\\")[-1]
1768 elif "@" in raw_str:
1769 username = raw_str.split("@")[0]
1770 else:
1771 username = raw_str
1773 full_name = user_data.get("name")
1774 if not full_name and user_data.get("given_name") and user_data.get("family_name"):
1775 full_name = f"{user_data.get('given_name')} {user_data.get('family_name')}"
1777 adfs_groups = self._extract_groups_and_roles(user_data, groups_claim)
1779 return self._build_normalized_user_info(
1780 user_data,
1781 ADFS_PROVIDER_ID,
1782 adfs_groups,
1783 email=email,
1784 full_name=full_name or email or username,
1785 provider_id=user_data.get("sub") or user_data.get("oid") or email or username,
1786 username=username or email,
1787 extra={"email_verified": True}, # ADFS tokens are trusted after successful authentication
1788 )
1790 # Generic OIDC format for all other providers.
1791 groups = self._extract_groups_and_roles(user_data, groups_claim)
1792 return self._build_normalized_user_info(user_data, provider.id, groups)
1794 def _reset_pending_approval(self, pending: PendingUserApproval, incoming_provider: str, user_info: Dict[str, Any]) -> None:
1795 """Reset a pending approval request to pending state with fresh metadata.
1797 Args:
1798 pending: Existing pending approval record.
1799 incoming_provider: SSO provider identifier.
1800 user_info: Normalized user info from SSO provider.
1801 """
1802 pending.status = "pending"
1803 pending.requested_at = utc_now()
1804 pending.expires_at = utc_now() + timedelta(days=30)
1805 pending.auth_provider = incoming_provider
1806 pending.sso_metadata = user_info
1807 pending.approved_by = None
1808 pending.approved_at = None
1809 pending.rejection_reason = None
1810 pending.admin_notes = None
1811 self.db.commit()
1813 @staticmethod
1814 def _should_sync_roles(provider_id: Optional[str], provider_metadata: Dict[str, Any]) -> bool:
1815 """Determine whether RBAC role sync should run for a login.
1817 Checks the provider-level ``sync_roles`` flag in ``provider_metadata``
1818 first, then falls back to the legacy Entra-specific setting.
1820 Args:
1821 provider_id: SSO provider identifier (e.g. ``"entra"``).
1822 provider_metadata: Provider metadata dict from DB/config.
1824 Returns:
1825 True if role sync should proceed, False otherwise.
1826 """
1827 if "sync_roles" in provider_metadata:
1828 return bool(provider_metadata.get("sync_roles", True))
1829 if provider_id == "entra" and hasattr(settings, "sso_entra_sync_roles_on_login"):
1830 return bool(settings.sso_entra_sync_roles_on_login)
1831 return True
1833 def _check_pending_approval(self, email: str, incoming_provider: str, user_info: Dict[str, Any]) -> bool:
1834 """Check admin approval state for a new SSO user.
1836 Returns ``True`` only when the user has an active, non-expired
1837 ``"approved"`` record and may proceed to account creation. All
1838 other states return ``False`` (caller should deny access).
1840 Args:
1841 email: Normalized user email.
1842 incoming_provider: SSO provider identifier.
1843 user_info: Normalized user info from SSO provider.
1845 Returns:
1846 True when user is approved for creation, False otherwise.
1847 """
1848 pending = self.db.execute(select(PendingUserApproval).where(PendingUserApproval.email == email)).scalar_one_or_none()
1850 if not pending:
1851 pending = PendingUserApproval(
1852 email=email,
1853 full_name=user_info.get("full_name", email),
1854 auth_provider=incoming_provider,
1855 sso_metadata=user_info,
1856 expires_at=utc_now() + timedelta(days=30),
1857 )
1858 self.db.add(pending)
1859 self.db.commit()
1860 logger.info(f"Created pending approval request for SSO user: {SecurityValidator.sanitize_log_message(email)}")
1861 return False
1863 if pending.status == "pending":
1864 if pending.is_expired():
1865 pending.status = "expired"
1866 self.db.commit()
1867 self._reset_pending_approval(pending, incoming_provider, user_info)
1868 logger.info(f"Refreshed expired pending approval request for SSO user: {SecurityValidator.sanitize_log_message(email)}")
1869 return False
1871 if pending.status == "rejected":
1872 return False
1874 if pending.status == "approved":
1875 if pending.is_expired():
1876 pending.status = "expired"
1877 self.db.commit()
1878 return False
1879 return True
1881 if pending.status == "expired":
1882 self._reset_pending_approval(pending, incoming_provider, user_info)
1883 logger.info(f"Renewed expired pending approval request for SSO user: {SecurityValidator.sanitize_log_message(email)}")
1884 return False
1886 if pending.status == "completed":
1887 return False
1889 logger.warning(f"Unknown SSO pending approval status '{pending.status}' for user {SecurityValidator.sanitize_log_message(email)}. Denying by default.")
1890 return False
1892 async def authenticate_or_create_user(self, user_info: Dict[str, Any]) -> Optional[str]:
1893 """Authenticate existing user or create new user from SSO info.
1895 Args:
1896 user_info: Normalized user info from SSO provider
1898 Returns:
1899 JWT token for authenticated user or None if failed
1900 """
1901 raw_email = user_info.get("email")
1902 if not raw_email:
1903 logger.warning("SSO authenticate_or_create_user: no email in user_info from provider '%s'. User cannot be authenticated without an email address.", user_info.get("provider", "unknown"))
1904 return None
1906 email = str(raw_email).strip().lower()
1907 if not email:
1908 logger.warning("SSO authenticate_or_create_user: email is empty after normalization from provider '%s'.", user_info.get("provider", "unknown"))
1909 return None
1911 if not self._is_email_verified_claim(user_info):
1912 logger.warning(
1913 "SSO authenticate_or_create_user: unverified email claim for provider '%s' and email '%s'.",
1914 user_info.get("provider", "unknown"),
1915 email,
1916 )
1917 return None
1919 incoming_provider = str(user_info.get("provider", "sso")).strip().lower() or "sso"
1920 provider = self.get_provider(incoming_provider)
1922 # Enforce trusted-domain policy consistently for both existing and new users.
1923 trusted_domains = getattr(provider, "trusted_domains", None) if provider else None
1924 if trusted_domains:
1925 domain = email.split("@")[1].lower() if "@" in email else ""
1926 normalized_trusted_domains = [d.lower() for d in trusted_domains if isinstance(d, str)]
1927 if domain not in normalized_trusted_domains:
1928 logger.warning(
1929 "SSO authenticate_or_create_user: email domain '%s' is not allowed for provider '%s'.",
1930 domain,
1931 incoming_provider,
1932 )
1933 return None
1935 # Use stable local values for JWT payload generation to avoid lazy-loading
1936 # expired ORM attributes after commit/flush boundaries.
1937 resolved_email = email
1938 resolved_full_name = user_info.get("full_name", email)
1939 resolved_auth_provider = incoming_provider
1940 resolved_is_admin = False
1942 # Check if user exists
1943 user = await self.auth_service.get_user_by_email(email)
1945 if user:
1946 current_full_name = user.full_name or resolved_full_name
1947 current_auth_provider = str(user.auth_provider or resolved_auth_provider).strip().lower()
1948 current_is_admin = bool(user.is_admin)
1949 current_admin_origin = user.admin_origin
1951 if user.auth_provider and current_auth_provider != incoming_provider:
1952 logger.warning(
1953 "SSO authenticate_or_create_user: account-linking required for email '%s' (existing provider='%s', incoming='%s').",
1954 email,
1955 current_auth_provider,
1956 incoming_provider,
1957 )
1958 return None
1960 provider_id: Optional[str] = None
1961 provider_metadata: Dict[str, Any] = {}
1962 provider_ctx: Optional[Any] = None
1963 if provider:
1964 provider_id = provider.id
1965 provider_metadata = provider.provider_metadata or {}
1966 provider_ctx = SSOProviderContext(id=provider_id, provider_metadata=provider_metadata)
1968 # Update user info from SSO
1969 if user_info.get("full_name") and user_info["full_name"] != current_full_name:
1970 user.full_name = user_info["full_name"]
1971 current_full_name = user_info["full_name"]
1973 # Initialize auth_provider for legacy accounts without provider metadata.
1974 if not user.auth_provider:
1975 user.auth_provider = incoming_provider
1976 current_auth_provider = incoming_provider
1978 # Persist verification status from provider claims.
1979 user.email_verified = self._is_email_verified_claim(user_info)
1980 user.last_login = utc_now()
1982 # Synchronize is_admin status based on current group membership
1983 # Track origin to support both promotion AND demotion for SSO-granted admins
1984 # Manual/API grants are "sticky" - never auto-demoted by SSO
1985 # Only users with admin_origin="sso" can be demoted on login
1986 if provider_ctx:
1987 should_be_admin = self._should_user_be_admin(email, user_info, provider_ctx)
1988 if should_be_admin:
1989 # Grant admin access
1990 if not current_is_admin:
1991 logger.info(f"Upgrading is_admin to True for {SecurityValidator.sanitize_log_message(email)} based on SSO admin groups")
1992 user.is_admin = True
1993 # Track that admin was granted via SSO (only set on initial grant)
1994 user.admin_origin = "sso"
1995 current_is_admin = True
1996 # Do NOT change admin_origin if already admin - preserve manual/API grants
1997 elif current_is_admin and current_admin_origin == "sso":
1998 # User was SSO admin but no longer in admin groups - revoke access
1999 logger.info(f"Revoking is_admin for {SecurityValidator.sanitize_log_message(email)} - removed from SSO admin groups")
2000 user.is_admin = False
2001 user.admin_origin = None
2002 current_is_admin = False
2004 self.db.commit()
2006 if provider_ctx and self._should_sync_roles(provider_id, provider_metadata):
2007 role_assignments = await self._map_groups_to_roles(email, user_info.get("groups", []), provider_ctx)
2008 await self._sync_user_roles(email, role_assignments, provider_ctx)
2009 await self._apply_team_mapping(email, user_info, provider)
2011 user_email = getattr(user, "email", None)
2012 if isinstance(user_email, str) and user_email.strip():
2013 resolved_email = user_email.strip().lower()
2014 resolved_full_name = current_full_name
2015 resolved_auth_provider = current_auth_provider
2016 resolved_is_admin = current_is_admin
2017 else:
2018 # Auto-create user if enabled
2019 if not provider or not provider.auto_create_users:
2020 return None
2022 provider_id = provider.id
2023 provider_metadata = provider.provider_metadata or {}
2024 provider_ctx = SSOProviderContext(id=provider_id, provider_metadata=provider_metadata)
2026 # Check if admin approval is required
2027 if settings.sso_require_admin_approval:
2028 approval_result = self._check_pending_approval(email, incoming_provider, user_info)
2029 if approval_result is not True:
2030 return None # Blocked by approval workflow
2032 # Create new user (either no approval required, or approval already granted)
2033 # Generate a secure random password for SSO users (they won't use it)
2035 random_password = "".join(secrets.choice(string.ascii_letters + string.digits + string.punctuation) for _ in range(32))
2037 # Determine if user should be admin based on domain/organization
2038 is_admin = self._should_user_be_admin(email, user_info, provider_ctx)
2040 user = await self.auth_service.create_user(
2041 email=email,
2042 password=random_password, # Random password for SSO users (not used)
2043 full_name=user_info.get("full_name", email),
2044 is_admin=is_admin,
2045 auth_provider=incoming_provider,
2046 )
2047 if not user:
2048 return None
2050 user_email = getattr(user, "email", None)
2051 if isinstance(user_email, str) and user_email.strip():
2052 resolved_email = user_email.strip().lower()
2053 resolved_full_name = user_info.get("full_name", email)
2054 resolved_auth_provider = incoming_provider
2055 resolved_is_admin = is_admin
2057 if self._should_sync_roles(provider_id, provider_metadata):
2058 role_assignments = await self._map_groups_to_roles(email, user_info.get("groups", []), provider_ctx)
2059 if role_assignments:
2060 await self._sync_user_roles(email, role_assignments, provider_ctx)
2061 await self._apply_team_mapping(email, user_info, provider)
2063 # If user was created from approved request, mark request as used
2064 if settings.sso_require_admin_approval:
2065 pending = self.db.execute(select(PendingUserApproval).where(and_(PendingUserApproval.email == email, PendingUserApproval.status == "approved"))).scalar_one_or_none()
2066 if pending:
2067 # Mark as used (we could delete or keep for audit trail)
2068 pending.status = "completed"
2069 self.db.commit()
2071 # Generate JWT token for user — session token (teams resolved server-side)
2072 token_data = {
2073 "sub": resolved_email,
2074 "email": resolved_email,
2075 "full_name": resolved_full_name,
2076 "auth_provider": resolved_auth_provider,
2077 "iat": int(utc_now().timestamp()),
2078 "user": {
2079 "email": resolved_email,
2080 "full_name": resolved_full_name,
2081 "is_admin": resolved_is_admin,
2082 "auth_provider": resolved_auth_provider,
2083 },
2084 "token_use": "session", # nosec B105 - token type marker, not a password
2085 # Scopes
2086 "scopes": {"server_id": None, "permissions": ["*"] if resolved_is_admin else [], "ip_restrictions": [], "time_restrictions": {}},
2087 }
2089 # Create JWT token
2090 token = await create_jwt_token(token_data)
2091 return token
2093 def _should_user_be_admin(self, email: str, user_info: Dict[str, Any], provider: SSOProviderContext) -> bool:
2094 """Determine if SSO user should be granted admin privileges.
2096 Args:
2097 email: User's email address
2098 user_info: Normalized user info from SSO provider
2099 provider: SSO provider configuration
2101 Returns:
2102 True if user should be admin, False otherwise
2103 """
2104 # Validate email format — reject admin checks for invalid emails
2105 if not email or "@" not in email:
2106 logger.warning("Invalid email format for admin check: %r. Rejecting admin privilege.", email)
2107 return False
2109 # Check domain-based admin assignment
2110 domain = email.split("@")[1].lower()
2111 if domain in {d.lower() for d in settings.sso_auto_admin_domains}:
2112 return True
2114 # Check provider-specific admin assignment
2115 if provider.id == "github" and settings.sso_github_admin_orgs:
2116 github_admin_orgs = {o.lower() for o in settings.sso_github_admin_orgs}
2117 github_orgs = user_info.get("organizations", [])
2118 if any(org.lower() in github_admin_orgs for org in github_orgs):
2119 return True
2121 if provider.id == "google" and settings.sso_google_admin_domains:
2122 if domain in {d.lower() for d in settings.sso_google_admin_domains}:
2123 return True
2125 # Check EntraID admin groups
2126 if provider.id == "entra" and settings.sso_entra_admin_groups:
2127 entra_admin_groups = {g.lower() for g in settings.sso_entra_admin_groups}
2128 user_groups = user_info.get("groups", [])
2129 if any(group.lower() in entra_admin_groups for group in user_groups):
2130 return True
2132 return False
2134 async def _map_groups_to_roles(self, user_email: str, user_groups: List[str], provider: SSOProviderContext) -> List[Dict[str, Any]]:
2135 """Map SSO groups to ContextForge RBAC roles.
2137 Args:
2138 user_email: User's email address
2139 user_groups: List of groups from SSO provider
2140 provider: SSO provider configuration
2142 Returns:
2143 List of role assignments: [{"role_name": str, "scope": str, "scope_id": Optional[str]}]
2144 """
2145 # pylint: disable=import-outside-toplevel
2146 # First-Party
2147 from mcpgateway.services.role_service import RoleService
2149 role_assignments = []
2151 # Generic Role Mapping Logic
2152 metadata = provider.provider_metadata or {}
2153 role_mappings = metadata.get("role_mappings", {})
2154 provider_default_role: Optional[str] = metadata.get("default_role")
2155 resolve_team_scope_to_personal_team = bool(metadata.get("resolve_team_scope_to_personal_team", False))
2156 has_provider_default_role = isinstance(provider_default_role, str) and bool(provider_default_role.strip())
2158 # Merge with legacy Entra specific settings if applicable
2159 has_entra_admin_groups = provider.id == "entra" and settings.sso_entra_admin_groups
2161 if provider.id == "entra":
2162 # Use generic role_mappings fallback to legacy setting
2163 if not role_mappings and settings.sso_entra_role_mappings:
2164 role_mappings = settings.sso_entra_role_mappings
2165 # Legacy fallback for default role configuration
2166 if not has_provider_default_role and settings.sso_entra_default_role:
2167 provider_default_role = settings.sso_entra_default_role
2168 has_provider_default_role = True
2170 # Early exit: Skip role mapping if no configuration exists
2171 if not role_mappings and not has_entra_admin_groups and not has_provider_default_role:
2172 logger.debug(f"No role mappings configured for provider {provider.id}, skipping role sync")
2173 return role_assignments
2175 personal_team_id: Optional[str] = None
2176 personal_team_checked = False
2178 async def _resolve_team_scope_id_if_needed(role_scope: str) -> Optional[str]:
2179 """Resolve team scope to personal-team id when provider mapping requires it.
2181 Args:
2182 role_scope: Role scope value from mapping metadata.
2184 Returns:
2185 Optional[str]: Personal team id when resolution succeeds, else None.
2186 """
2187 nonlocal personal_team_id, personal_team_checked
2189 if role_scope != "team" or not resolve_team_scope_to_personal_team:
2190 return None
2192 if personal_team_checked:
2193 return personal_team_id
2195 personal_team_checked = True
2196 try:
2197 # First-Party
2198 from mcpgateway.services.personal_team_service import PersonalTeamService
2200 personal_team = await PersonalTeamService(self.db).get_personal_team(user_email)
2201 personal_team_id = personal_team.id if personal_team else None
2202 if not personal_team_id:
2203 logger.warning(f"Could not resolve personal team for {SecurityValidator.sanitize_log_message(user_email)}; skipping team-scoped SSO role mapping")
2204 except Exception as e:
2205 logger.error(f"Failed to resolve personal team for {SecurityValidator.sanitize_log_message(user_email)}: {e}. All team-scoped SSO role assignments will be skipped for this login.")
2206 personal_team_id = None
2208 return personal_team_id
2210 # Handle EntraID admin groups -> admin role
2211 if has_entra_admin_groups:
2212 admin_groups_lower = [g.lower() for g in settings.sso_entra_admin_groups]
2213 for group in user_groups:
2214 if group.lower() in admin_groups_lower:
2215 role_assignments.append({"role_name": settings.default_admin_role, "scope": "global", "scope_id": None})
2216 logger.debug(f"Mapped EntraID admin group to {settings.default_admin_role} role for {SecurityValidator.sanitize_log_message(user_email)}")
2217 break # Only need one admin assignment
2219 # Batch role lookups: collect all role names that need to be looked up
2220 role_names_to_lookup = set()
2221 for group in user_groups:
2222 if group in role_mappings:
2223 role_name = role_mappings[group]
2224 if role_name not in ["admin", settings.default_admin_role]:
2225 role_names_to_lookup.add(role_name)
2227 # Add default role to lookup if needed
2228 if has_provider_default_role and provider_default_role:
2229 role_names_to_lookup.add(provider_default_role)
2231 # Pre-fetch all roles by name in batches (reduces DB round-trips)
2232 role_service = RoleService(self.db)
2233 role_cache: Dict[str, Any] = {}
2234 for role_name in role_names_to_lookup:
2235 # Try team scope first, then global
2236 role = await role_service.get_role_by_name(role_name, scope="team")
2237 if not role:
2238 role = await role_service.get_role_by_name(role_name, scope="global")
2239 if role:
2240 role_cache[role_name] = role
2242 # Process role mappings for ALL providers
2243 for group in user_groups:
2244 if group in role_mappings:
2245 role_name = role_mappings[group]
2246 # Special case for "admin" shorthand or configured admin role name
2247 if role_name in ["admin", settings.default_admin_role]:
2248 role_assignments.append({"role_name": settings.default_admin_role, "scope": "global", "scope_id": None})
2249 logger.debug(f"Mapped group to {settings.default_admin_role} role for {SecurityValidator.sanitize_log_message(user_email)}")
2250 continue
2252 # Use pre-fetched role from cache
2253 role = role_cache.get(role_name)
2254 if role:
2255 scope_id = await _resolve_team_scope_id_if_needed(role.scope)
2256 if role.scope == "team" and resolve_team_scope_to_personal_team and not scope_id:
2257 continue
2258 # Avoid duplicate assignments
2259 if not any(r["role_name"] == role.name and r["scope"] == role.scope and r.get("scope_id") == scope_id for r in role_assignments):
2260 role_assignments.append({"role_name": role.name, "scope": role.scope, "scope_id": scope_id})
2261 logger.debug(f"Mapped group to role '{role.name}' for {SecurityValidator.sanitize_log_message(user_email)}")
2262 else:
2263 logger.warning(f"Role '{role_name}' not found for group mapping")
2265 # Apply default role if no mappings found
2266 if not role_assignments and has_provider_default_role and provider_default_role:
2267 default_role = role_cache.get(provider_default_role)
2268 if default_role:
2269 scope_id = await _resolve_team_scope_id_if_needed(default_role.scope)
2270 if default_role.scope == "team" and resolve_team_scope_to_personal_team and not scope_id:
2271 return role_assignments
2272 role_assignments.append({"role_name": default_role.name, "scope": default_role.scope, "scope_id": scope_id})
2273 logger.info(f"Assigned default role '{default_role.name}' to {SecurityValidator.sanitize_log_message(user_email)}")
2275 return role_assignments
2277 async def _sync_user_roles(self, user_email: str, role_assignments: List[Dict[str, Any]], _provider: SSOProviderContext) -> None:
2278 """Synchronize user's SSO-based role assignments.
2280 Args:
2281 user_email: User's email address
2282 role_assignments: List of role assignments to apply
2283 _provider: SSO provider configuration (reserved for future use)
2284 """
2285 # pylint: disable=import-outside-toplevel
2286 # First-Party
2287 from mcpgateway.services.role_service import RoleService
2289 role_service = RoleService(self.db)
2291 # Get current SSO-granted roles
2292 current_roles = await role_service.list_user_roles(user_email, include_expired=False)
2293 sso_roles = [r for r in current_roles if getattr(r, "grant_source", None) == "sso"]
2295 # Build set of desired role assignments
2296 desired_roles = {(r["role_name"], r["scope"], r.get("scope_id")) for r in role_assignments}
2298 # Revoke roles that are no longer in the desired set
2299 for user_role in sso_roles:
2300 role_tuple = (user_role.role.name, user_role.scope, user_role.scope_id)
2301 if role_tuple not in desired_roles:
2302 await role_service.revoke_role_from_user(user_email=user_email, role_id=user_role.role_id, scope=user_role.scope, scope_id=user_role.scope_id)
2303 logger.info(f"Revoked SSO role '{user_role.role.name}' from {SecurityValidator.sanitize_log_message(user_email)} (no longer in groups)")
2305 # Assign new roles
2306 for assignment in role_assignments:
2307 try:
2308 # Get role by name
2309 role = await role_service.get_role_by_name(assignment["role_name"], scope=assignment["scope"])
2310 if not role:
2311 logger.warning(f"Role '{assignment['role_name']}' not found, skipping assignment for {SecurityValidator.sanitize_log_message(user_email)}")
2312 continue
2314 # Check if assignment already exists
2315 existing = await role_service.get_user_role_assignment(user_email=user_email, role_id=role.id, scope=assignment["scope"], scope_id=assignment.get("scope_id"))
2317 if not existing or not existing.is_active:
2318 # Assign role to user
2319 await role_service.assign_role_to_user(
2320 user_email=user_email, role_id=role.id, scope=assignment["scope"], scope_id=assignment.get("scope_id"), granted_by=user_email, grant_source="sso"
2321 )
2322 logger.info(f"Assigned SSO role '{role.name}' to {SecurityValidator.sanitize_log_message(user_email)}")
2324 except Exception as e:
2325 logger.warning(f"Failed to assign role '{assignment['role_name']}' to {SecurityValidator.sanitize_log_message(user_email)}: {e}", exc_info=True)
2326 try:
2327 self.db.rollback()
2328 except Exception as rollback_error:
2329 logger.error(
2330 f"Database rollback failed after role assignment error for {SecurityValidator.sanitize_log_message(user_email)}: {rollback_error}. Aborting remaining role assignments."
2331 )
2332 break