Coverage for mcpgateway / services / sso_service.py: 97%
403 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/sso_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Single Sign-On (SSO) authentication service for OAuth2 and OIDC providers.
8Handles provider management, OAuth flows, and user authentication.
9"""
11# Future
12from __future__ import annotations
14# Standard
15import base64
16from datetime import timedelta
17import hashlib
18import logging
19import secrets
20import string
21from typing import Any, Dict, List, Optional, Tuple
22import urllib.parse
24# Third-Party
25import orjson
26from sqlalchemy import and_, select
27from sqlalchemy.orm import Session
29# First-Party
30from mcpgateway.config import settings
31from mcpgateway.db import PendingUserApproval, SSOAuthSession, SSOProvider, utc_now
32from mcpgateway.services.email_auth_service import EmailAuthService
33from mcpgateway.services.encryption_service import get_encryption_service
34from mcpgateway.utils.create_jwt_token import create_jwt_token
36# Logger
37logger = logging.getLogger(__name__)
40class SSOService:
41 """Service for managing SSO authentication flows and providers.
43 Handles OAuth2/OIDC authentication flows, provider configuration,
44 and integration with the local user system.
46 Examples:
47 Basic construction and helper checks:
48 >>> from unittest.mock import Mock
49 >>> service = SSOService(Mock())
50 >>> isinstance(service, SSOService)
51 True
52 >>> callable(service.list_enabled_providers)
53 True
54 """
56 def __init__(self, db: Session):
57 """Initialize SSO service with database session.
59 Args:
60 db: SQLAlchemy database session
61 """
62 self.db = db
63 self.auth_service = EmailAuthService(db)
64 self._encryption = get_encryption_service(settings.auth_encryption_secret)
66 async def _encrypt_secret(self, secret: str) -> str:
67 """Encrypt a client secret for secure storage.
69 Args:
70 secret: Plain text client secret
72 Returns:
73 Encrypted secret string
74 """
75 return await self._encryption.encrypt_secret_async(secret)
77 async def _decrypt_secret(self, encrypted_secret: str) -> Optional[str]:
78 """Decrypt a client secret for use.
80 Args:
81 encrypted_secret: Encrypted secret string
83 Returns:
84 Plain text client secret
85 """
86 decrypted: str | None = await self._encryption.decrypt_secret_async(encrypted_secret)
87 if decrypted:
88 return decrypted
90 return None
92 def _decode_jwt_claims(self, token: str) -> Optional[Dict[str, Any]]:
93 """Decode JWT token payload without verification.
95 This is used to extract claims from ID tokens where we've already
96 validated the OAuth flow. The token signature is not verified here
97 because the token was received directly from the trusted token endpoint.
99 Args:
100 token: JWT token string
102 Returns:
103 Decoded payload dict or None if decoding fails
105 Examples:
106 >>> from unittest.mock import Mock
107 >>> service = SSOService(Mock())
108 >>> # Valid JWT structure (header.payload.signature)
109 >>> import base64
110 >>> payload = base64.urlsafe_b64encode(b'{"sub":"123","groups":["admin"]}').decode().rstrip('=')
111 >>> token = f"eyJhbGciOiJSUzI1NiJ9.{payload}.signature"
112 >>> claims = service._decode_jwt_claims(token)
113 >>> claims is not None
114 True
115 """
116 try:
117 # JWT format: header.payload.signature
118 parts = token.split(".")
119 if len(parts) != 3:
120 logger.warning("Invalid JWT format: expected 3 parts")
121 return None
123 # Decode payload (middle part) - add padding if needed
124 payload_b64 = parts[1]
125 # Add padding for base64 decoding
126 padding = 4 - len(payload_b64) % 4
127 if padding != 4:
128 payload_b64 += "=" * padding
130 payload_bytes = base64.urlsafe_b64decode(payload_b64)
131 return orjson.loads(payload_bytes)
133 except (ValueError, orjson.JSONDecodeError, UnicodeDecodeError) as e:
134 logger.warning(f"Failed to decode JWT claims: {e}")
135 return None
137 def list_enabled_providers(self) -> List[SSOProvider]:
138 """Get list of enabled SSO providers.
140 Returns:
141 List of enabled SSO providers
143 Examples:
144 Returns empty list when DB has no providers:
145 >>> from unittest.mock import MagicMock
146 >>> service = SSOService(MagicMock())
147 >>> service.db.execute.return_value.scalars.return_value.all.return_value = []
148 >>> service.list_enabled_providers()
149 []
150 """
151 stmt = select(SSOProvider).where(SSOProvider.is_enabled.is_(True))
152 result = self.db.execute(stmt)
153 return list(result.scalars().all())
155 def get_provider(self, provider_id: str) -> Optional[SSOProvider]:
156 """Get SSO provider by ID.
158 Args:
159 provider_id: Provider identifier (e.g., 'github', 'google')
161 Returns:
162 SSO provider or None if not found
164 Examples:
165 >>> from unittest.mock import MagicMock
166 >>> service = SSOService(MagicMock())
167 >>> service.db.execute.return_value.scalar_one_or_none.return_value = None
168 >>> service.get_provider('x') is None
169 True
170 """
171 stmt = select(SSOProvider).where(SSOProvider.id == provider_id)
172 result = self.db.execute(stmt)
173 return result.scalar_one_or_none()
175 def get_provider_by_name(self, provider_name: str) -> Optional[SSOProvider]:
176 """Get SSO provider by name.
178 Args:
179 provider_name: Provider name (e.g., 'github', 'google')
181 Returns:
182 SSO provider or None if not found
184 Examples:
185 >>> from unittest.mock import MagicMock
186 >>> service = SSOService(MagicMock())
187 >>> service.db.execute.return_value.scalar_one_or_none.return_value = None
188 >>> service.get_provider_by_name('github') is None
189 True
190 """
191 stmt = select(SSOProvider).where(SSOProvider.name == provider_name)
192 result = self.db.execute(stmt)
193 return result.scalar_one_or_none()
195 async def create_provider(self, provider_data: Dict[str, Any]) -> SSOProvider:
196 """Create new SSO provider configuration.
198 Args:
199 provider_data: Provider configuration data
201 Returns:
202 Created SSO provider
204 Examples:
205 >>> import asyncio
206 >>> from unittest.mock import MagicMock, AsyncMock
207 >>> service = SSOService(MagicMock())
208 >>> service._encrypt_secret = AsyncMock(side_effect=lambda s: 'ENC(' + s + ')')
209 >>> data = {
210 ... 'id': 'github', 'name': 'github', 'display_name': 'GitHub', 'provider_type': 'oauth2',
211 ... 'client_id': 'cid', 'client_secret': 'sec',
212 ... 'authorization_url': 'https://example/auth', 'token_url': 'https://example/token',
213 ... 'userinfo_url': 'https://example/user', 'scope': 'user:email'
214 ... }
215 >>> provider = asyncio.run(service.create_provider(data))
216 >>> hasattr(provider, 'id') and provider.id == 'github'
217 True
218 >>> provider.client_secret_encrypted.startswith('ENC(')
219 True
220 """
221 # Encrypt client secret
222 client_secret = provider_data.pop("client_secret")
223 provider_data["client_secret_encrypted"] = await self._encrypt_secret(client_secret)
225 provider = SSOProvider(**provider_data)
226 self.db.add(provider)
227 self.db.commit()
228 self.db.refresh(provider)
229 return provider
231 async def update_provider(self, provider_id: str, provider_data: Dict[str, Any]) -> Optional[SSOProvider]:
232 """Update existing SSO provider configuration.
234 Args:
235 provider_id: Provider identifier
236 provider_data: Updated provider data
238 Returns:
239 Updated SSO provider or None if not found
241 Examples:
242 >>> import asyncio
243 >>> from types import SimpleNamespace
244 >>> from unittest.mock import MagicMock, AsyncMock
245 >>> svc = SSOService(MagicMock())
246 >>> # Existing provider object
247 >>> existing = SimpleNamespace(id='github', name='github', client_id='old', client_secret_encrypted='X', is_enabled=True)
248 >>> svc.get_provider = lambda _id: existing
249 >>> svc._encrypt_secret = AsyncMock(side_effect=lambda s: 'ENC-' + s)
250 >>> svc.db.commit = lambda: None
251 >>> svc.db.refresh = lambda obj: None
252 >>> updated = asyncio.run(svc.update_provider('github', {'client_id': 'new', 'client_secret': 'sec'}))
253 >>> updated.client_id
254 'new'
255 >>> updated.client_secret_encrypted
256 'ENC-sec'
257 """
258 provider = self.get_provider(provider_id)
259 if not provider:
260 return None
262 # Handle client secret encryption if provided
263 if "client_secret" in provider_data:
264 client_secret = provider_data.pop("client_secret")
265 provider_data["client_secret_encrypted"] = await self._encrypt_secret(client_secret)
267 for key, value in provider_data.items():
268 if hasattr(provider, key): 268 ↛ 267line 268 didn't jump to line 267 because the condition on line 268 was always true
269 setattr(provider, key, value)
271 provider.updated_at = utc_now()
272 self.db.commit()
273 self.db.refresh(provider)
274 return provider
276 def delete_provider(self, provider_id: str) -> bool:
277 """Delete SSO provider configuration.
279 Args:
280 provider_id: Provider identifier
282 Returns:
283 True if deleted, False if not found
285 Examples:
286 >>> from types import SimpleNamespace
287 >>> from unittest.mock import MagicMock
288 >>> svc = SSOService(MagicMock())
289 >>> svc.db.delete = lambda obj: None
290 >>> svc.db.commit = lambda: None
291 >>> svc.get_provider = lambda _id: SimpleNamespace(id='github')
292 >>> svc.delete_provider('github')
293 True
294 >>> svc.get_provider = lambda _id: None
295 >>> svc.delete_provider('missing')
296 False
297 """
298 provider = self.get_provider(provider_id)
299 if not provider:
300 return False
302 self.db.delete(provider)
303 self.db.commit()
304 return True
306 def generate_pkce_challenge(self) -> Tuple[str, str]:
307 """Generate PKCE code verifier and challenge for OAuth 2.1.
309 Returns:
310 Tuple of (code_verifier, code_challenge)
312 Examples:
313 Generate verifier and challenge:
314 >>> from unittest.mock import Mock
315 >>> service = SSOService(Mock())
316 >>> verifier, challenge = service.generate_pkce_challenge()
317 >>> isinstance(verifier, str) and isinstance(challenge, str)
318 True
319 >>> len(verifier) >= 43
320 True
321 >>> len(challenge) >= 43
322 True
323 """
324 # Generate cryptographically random code verifier
325 code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
327 # Generate code challenge using SHA256
328 code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()).decode("utf-8").rstrip("=")
330 return code_verifier, code_challenge
332 def get_authorization_url(self, provider_id: str, redirect_uri: str, scopes: Optional[List[str]] = None) -> Optional[str]:
333 """Generate OAuth authorization URL for provider.
335 Args:
336 provider_id: Provider identifier
337 redirect_uri: Callback URI after authorization
338 scopes: Optional custom scopes (uses provider default if None)
340 Returns:
341 Authorization URL or None if provider not found
343 Examples:
344 >>> from types import SimpleNamespace
345 >>> from unittest.mock import MagicMock
346 >>> service = SSOService(MagicMock())
347 >>> provider = SimpleNamespace(id='github', is_enabled=True, provider_type='oauth2', client_id='cid', authorization_url='https://example/auth', scope='user:email')
348 >>> service.get_provider = lambda _pid: provider
349 >>> service.db.add = lambda x: None
350 >>> service.db.commit = lambda: None
351 >>> url = service.get_authorization_url('github', 'https://app/callback', ['email'])
352 >>> isinstance(url, str) and 'client_id=cid' in url and 'state=' in url
353 True
355 Missing provider returns None:
356 >>> service.get_provider = lambda _pid: None
357 >>> service.get_authorization_url('missing', 'https://app/callback') is None
358 True
359 """
360 provider = self.get_provider(provider_id)
361 if not provider or not provider.is_enabled:
362 return None
364 # Generate PKCE parameters
365 code_verifier, code_challenge = self.generate_pkce_challenge()
367 # Generate CSRF state
368 state = secrets.token_urlsafe(32)
370 # Generate OIDC nonce if applicable
371 nonce = secrets.token_urlsafe(16) if provider.provider_type == "oidc" else None
373 # Create auth session
374 auth_session = SSOAuthSession(provider_id=provider_id, state=state, code_verifier=code_verifier, nonce=nonce, redirect_uri=redirect_uri)
375 self.db.add(auth_session)
376 self.db.commit()
378 # Build authorization URL
379 params = {
380 "client_id": provider.client_id,
381 "response_type": "code",
382 "redirect_uri": redirect_uri,
383 "state": state,
384 "scope": " ".join(scopes) if scopes else provider.scope,
385 "code_challenge": code_challenge,
386 "code_challenge_method": "S256",
387 }
389 if nonce:
390 params["nonce"] = nonce
392 return f"{provider.authorization_url}?{urllib.parse.urlencode(params)}"
394 async def handle_oauth_callback(self, provider_id: str, code: str, state: str) -> Optional[Dict[str, Any]]:
395 """Handle OAuth callback and exchange code for tokens.
397 Args:
398 provider_id: Provider identifier
399 code: Authorization code from callback
400 state: CSRF state parameter
402 Returns:
403 User info dict or None if authentication failed
405 Examples:
406 Happy-path with patched exchanges and user info:
407 >>> import asyncio
408 >>> from types import SimpleNamespace
409 >>> from unittest.mock import MagicMock
410 >>> svc = SSOService(MagicMock())
411 >>> # Mock DB auth session lookup
412 >>> provider = SimpleNamespace(id='github', is_enabled=True, provider_type='oauth2')
413 >>> auth_session = SimpleNamespace(provider_id='github', state='st', provider=provider, is_expired=False)
414 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = auth_session
415 >>> # Patch token exchange and user info retrieval
416 >>> async def _ex(p, sess, c):
417 ... return {'access_token': 'tok', 'id_token': 'id_tok'}
418 >>> async def _ui(p, access, token_data=None):
419 ... return {'email': 'user@example.com'}
420 >>> svc._exchange_code_for_tokens = _ex
421 >>> svc._get_user_info = _ui
422 >>> svc.db.delete = lambda obj: None
423 >>> svc.db.commit = lambda: None
424 >>> out = asyncio.run(svc.handle_oauth_callback('github', 'code', 'st'))
425 >>> out['email']
426 'user@example.com'
428 Early return cases:
429 >>> # No session
430 >>> svc2 = SSOService(MagicMock())
431 >>> svc2.db.execute.return_value.scalar_one_or_none.return_value = None
432 >>> asyncio.run(svc2.handle_oauth_callback('github', 'c', 's')) is None
433 True
434 >>> # Expired session
435 >>> expired = SimpleNamespace(provider_id='github', state='st', provider=SimpleNamespace(is_enabled=True), is_expired=True)
436 >>> svc3 = SSOService(MagicMock())
437 >>> svc3.db.execute.return_value.scalar_one_or_none.return_value = expired
438 >>> asyncio.run(svc3.handle_oauth_callback('github', 'c', 'st')) is None
439 True
440 >>> # Disabled provider
441 >>> disabled = SimpleNamespace(provider_id='github', state='st', provider=SimpleNamespace(is_enabled=False), is_expired=False)
442 >>> svc4 = SSOService(MagicMock())
443 >>> svc4.db.execute.return_value.scalar_one_or_none.return_value = disabled
444 >>> asyncio.run(svc4.handle_oauth_callback('github', 'c', 'st')) is None
445 True
446 """
447 # Validate auth session
448 stmt = select(SSOAuthSession).where(SSOAuthSession.state == state, SSOAuthSession.provider_id == provider_id)
449 auth_session = self.db.execute(stmt).scalar_one_or_none()
451 if not auth_session or auth_session.is_expired:
452 return None
454 provider = auth_session.provider
455 if not provider or not provider.is_enabled:
456 return None
458 try:
459 # Exchange authorization code for tokens
460 logger.info(f"Starting token exchange for provider {provider_id}")
461 token_data = await self._exchange_code_for_tokens(provider, auth_session, code)
462 if not token_data:
463 logger.error(f"Failed to exchange code for tokens for provider {provider_id}")
464 return None
465 logger.info(f"Token exchange successful for provider {provider_id}")
467 # Get user info from provider (pass full token_data for id_token parsing)
468 user_info = await self._get_user_info(provider, token_data["access_token"], token_data)
469 if not user_info:
470 logger.error(f"Failed to get user info for provider {provider_id}")
471 return None
473 # Clean up auth session
474 self.db.delete(auth_session)
475 self.db.commit()
477 return user_info
479 except Exception as e:
480 # Clean up auth session on error
481 logger.error(f"OAuth callback failed for provider {provider_id}: {type(e).__name__}: {str(e)}")
482 logger.exception("Full traceback for OAuth callback failure:")
483 self.db.delete(auth_session)
484 self.db.commit()
485 return None
487 async def _exchange_code_for_tokens(self, provider: SSOProvider, auth_session: SSOAuthSession, code: str) -> Optional[Dict[str, Any]]:
488 """Exchange authorization code for access tokens.
490 Args:
491 provider: SSO provider configuration
492 auth_session: Auth session with PKCE parameters
493 code: Authorization code
495 Returns:
496 Token response dict or None if failed
497 """
498 token_params = {
499 "client_id": provider.client_id,
500 "client_secret": await self._decrypt_secret(provider.client_secret_encrypted),
501 "code": code,
502 "grant_type": "authorization_code",
503 "redirect_uri": auth_session.redirect_uri,
504 "code_verifier": auth_session.code_verifier,
505 }
507 # First-Party
508 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
510 client = await get_http_client()
511 response = await client.post(provider.token_url, data=token_params, headers={"Accept": "application/json"})
513 if response.status_code == 200:
514 return response.json()
515 logger.error(f"Token exchange failed for {provider.name}: HTTP {response.status_code} - {response.text}")
517 return None
519 async def _get_user_info(self, provider: SSOProvider, access_token: str, token_data: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
520 """Get user information from provider using access token.
522 Args:
523 provider: SSO provider configuration
524 access_token: OAuth access token
525 token_data: Optional full token response containing id_token for OIDC providers
527 Returns:
528 User info dict or None if failed
529 """
530 # First-Party
531 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
533 client = await get_http_client()
534 response = await client.get(provider.userinfo_url, headers={"Authorization": f"Bearer {access_token}"})
536 if response.status_code == 200:
537 user_data = response.json()
539 # For GitHub, also fetch organizations if admin assignment is configured
540 if provider.id == "github" and settings.sso_github_admin_orgs:
541 try:
542 orgs_response = await client.get("https://api.github.com/user/orgs", headers={"Authorization": f"Bearer {access_token}"})
543 if orgs_response.status_code == 200:
544 orgs_data = orgs_response.json()
545 user_data["organizations"] = [org["login"] for org in orgs_data]
546 else:
547 logger.warning(f"Failed to fetch GitHub organizations: HTTP {orgs_response.status_code}")
548 user_data["organizations"] = []
549 except Exception as e:
550 logger.warning(f"Error fetching GitHub organizations: {e}")
551 user_data["organizations"] = []
553 # For Entra ID, extract groups/roles from id_token since userinfo doesn't include them
554 # Microsoft's /oidc/userinfo endpoint only returns basic claims (sub, name, email, picture)
555 # Groups and roles are included in the id_token when configured in Azure Portal
556 if provider.id == "entra" and token_data and "id_token" in token_data:
557 id_token_claims = self._decode_jwt_claims(token_data["id_token"])
558 if id_token_claims: 558 ↛ 587line 558 didn't jump to line 587 because the condition on line 558 was always true
559 # Detect group overage - when user has too many groups (>200), EntraID returns
560 # _claim_names/_claim_sources instead of the actual groups array.
561 # See: https://learn.microsoft.com/en-us/entra/identity-platform/id-token-claims-reference
562 claim_names = id_token_claims.get("_claim_names", {})
563 if isinstance(claim_names, dict) and "groups" in claim_names:
564 user_email = user_data.get("email") or user_data.get("preferred_username") or "unknown"
565 logger.warning(
566 f"Group overage detected for user {user_email} - token contains too many groups (>200). "
567 f"Role mapping may be incomplete. Consider using App Roles or Azure group filtering. "
568 f"See docs/docs/manage/sso-entra-role-mapping.md#token-size-considerations"
569 )
571 # Extract groups from id_token (Security Groups as Object IDs)
572 if "groups" in id_token_claims:
573 user_data["groups"] = id_token_claims["groups"]
574 logger.debug(f"Extracted {len(id_token_claims['groups'])} groups from Entra ID token")
576 # Extract roles from id_token (App Roles)
577 if "roles" in id_token_claims:
578 user_data["roles"] = id_token_claims["roles"]
579 logger.debug(f"Extracted {len(id_token_claims['roles'])} roles from Entra ID token")
581 # Also extract any missing basic claims from id_token
582 for claim in ["email", "name", "preferred_username", "oid", "sub"]:
583 if claim not in user_data and claim in id_token_claims:
584 user_data[claim] = id_token_claims[claim]
586 # For Keycloak, also extract groups/roles from id_token if available
587 if provider.id == "keycloak" and token_data and "id_token" in token_data:
588 id_token_claims = self._decode_jwt_claims(token_data["id_token"])
589 if id_token_claims: 589 ↛ 596line 589 didn't jump to line 596 because the condition on line 589 was always true
590 # Keycloak includes realm_access, resource_access, and groups in id_token
591 for claim in ["realm_access", "resource_access", "groups"]:
592 if claim in id_token_claims and claim not in user_data: 592 ↛ 591line 592 didn't jump to line 591 because the condition on line 592 was always true
593 user_data[claim] = id_token_claims[claim]
595 # Normalize user info across providers
596 return self._normalize_user_info(provider, user_data)
597 logger.error(f"User info request failed for {provider.name}: HTTP {response.status_code} - {response.text}")
599 return None
601 def _normalize_user_info(self, provider: SSOProvider, user_data: Dict[str, Any]) -> Dict[str, Any]:
602 """Normalize user info from different providers to common format.
604 Args:
605 provider: SSO provider configuration
606 user_data: Raw user data from provider
608 Returns:
609 Normalized user info dict
610 """
611 # Handle GitHub provider
612 if provider.id == "github":
613 return {
614 "email": user_data.get("email"),
615 "full_name": user_data.get("name") or user_data.get("login"),
616 "avatar_url": user_data.get("avatar_url"),
617 "provider_id": user_data.get("id"),
618 "username": user_data.get("login"),
619 "provider": "github",
620 "organizations": user_data.get("organizations", []),
621 }
623 # Handle Google provider
624 if provider.id == "google":
625 return {
626 "email": user_data.get("email"),
627 "full_name": user_data.get("name"),
628 "avatar_url": user_data.get("picture"),
629 "provider_id": user_data.get("sub"),
630 "username": user_data.get("email", "").split("@")[0],
631 "provider": "google",
632 }
634 # Handle IBM Verify provider
635 if provider.id == "ibm_verify":
636 return {
637 "email": user_data.get("email"),
638 "full_name": user_data.get("name"),
639 "avatar_url": user_data.get("picture"),
640 "provider_id": user_data.get("sub"),
641 "username": user_data.get("preferred_username") or user_data.get("email", "").split("@")[0],
642 "provider": "ibm_verify",
643 }
645 # Handle Okta provider
646 if provider.id == "okta":
647 return {
648 "email": user_data.get("email"),
649 "full_name": user_data.get("name"),
650 "avatar_url": user_data.get("picture"),
651 "provider_id": user_data.get("sub"),
652 "username": user_data.get("preferred_username") or user_data.get("email", "").split("@")[0],
653 "provider": "okta",
654 }
656 # Handle Keycloak provider with role mapping
657 if provider.id == "keycloak":
658 metadata = provider.provider_metadata or {}
659 username_claim = metadata.get("username_claim", "preferred_username")
660 email_claim = metadata.get("email_claim", "email")
661 groups_claim = metadata.get("groups_claim", "groups")
663 groups = []
665 # Extract realm roles
666 if metadata.get("map_realm_roles"):
667 realm_access = user_data.get("realm_access", {})
668 realm_roles = realm_access.get("roles", [])
669 groups.extend(realm_roles)
671 # Extract client roles
672 if metadata.get("map_client_roles"):
673 resource_access = user_data.get("resource_access", {})
674 for client, access in resource_access.items():
675 client_roles = access.get("roles", [])
676 # Prefix with client name to avoid conflicts
677 groups.extend([f"{client}:{role}" for role in client_roles])
679 # Extract groups from custom claim
680 if groups_claim in user_data:
681 custom_groups = user_data.get(groups_claim, [])
682 if isinstance(custom_groups, list): 682 ↛ 685line 682 didn't jump to line 685 because the condition on line 682 was always true
683 groups.extend(custom_groups)
685 return {
686 "email": user_data.get(email_claim),
687 "full_name": user_data.get("name"),
688 "avatar_url": user_data.get("picture"),
689 "provider_id": user_data.get("sub"),
690 "username": user_data.get(username_claim) or user_data.get(email_claim, "").split("@")[0],
691 "provider": "keycloak",
692 "groups": list(set(groups)), # Deduplicate
693 }
695 # Handle Microsoft Entra ID provider with role mapping
696 if provider.id == "entra":
697 metadata = provider.provider_metadata or {}
698 groups_claim = metadata.get("groups_claim", "groups")
700 # Microsoft's userinfo endpoint often omits the email claim
701 # Fallback: preferred_username (UPN) or upn claim
702 email = user_data.get("email") or user_data.get("preferred_username") or user_data.get("upn")
704 # Extract username from email/UPN
705 username = None
706 if user_data.get("preferred_username"):
707 username = user_data.get("preferred_username")
708 elif email:
709 username = email.split("@")[0]
711 # Extract groups from token
712 groups = []
714 # Check configured groups claim (default: 'groups')
715 if groups_claim in user_data:
716 groups_value = user_data.get(groups_claim, [])
717 if isinstance(groups_value, list): 717 ↛ 721line 717 didn't jump to line 721 because the condition on line 717 was always true
718 groups.extend(groups_value)
720 # Also check 'roles' claim for App Role assignments
721 if "roles" in user_data:
722 roles_value = user_data.get("roles", [])
723 if isinstance(roles_value, list): 723 ↛ 726line 723 didn't jump to line 726 because the condition on line 723 was always true
724 groups.extend(roles_value)
726 return {
727 "email": email,
728 "full_name": user_data.get("name") or email, # Fallback to email if name missing
729 "avatar_url": user_data.get("picture"),
730 "provider_id": user_data.get("sub") or user_data.get("oid"),
731 "username": username,
732 "provider": "entra",
733 "groups": list(set(groups)), # Deduplicate
734 }
736 # Generic OIDC format for all other providers
737 return {
738 "email": user_data.get("email"),
739 "full_name": user_data.get("name"),
740 "avatar_url": user_data.get("picture"),
741 "provider_id": user_data.get("sub"),
742 "username": user_data.get("preferred_username") or user_data.get("email", "").split("@")[0],
743 "provider": provider.id,
744 }
746 async def authenticate_or_create_user(self, user_info: Dict[str, Any]) -> Optional[str]:
747 """Authenticate existing user or create new user from SSO info.
749 Args:
750 user_info: Normalized user info from SSO provider
752 Returns:
753 JWT token for authenticated user or None if failed
754 """
755 email = user_info.get("email")
756 if not email:
757 return None
759 # Check if user exists
760 user = await self.auth_service.get_user_by_email(email)
762 if user:
763 # Update user info from SSO
764 if user_info.get("full_name") and user_info["full_name"] != user.full_name:
765 user.full_name = user_info["full_name"]
767 # Update auth provider if changed
768 if user.auth_provider == "local" or user.auth_provider != user_info.get("provider"):
769 user.auth_provider = user_info.get("provider", "sso")
771 # Mark email as verified for SSO users
772 user.email_verified = True
773 user.last_login = utc_now()
775 # Synchronize is_admin status based on current group membership
776 # Track origin to support both promotion AND demotion for SSO-granted admins
777 # Manual/API grants are "sticky" - never auto-demoted by SSO
778 # Only users with admin_origin="sso" can be demoted on login
779 provider = self.get_provider(user_info.get("provider"))
780 if provider: 780 ↛ 796line 780 didn't jump to line 796 because the condition on line 780 was always true
781 should_be_admin = self._should_user_be_admin(email, user_info, provider)
782 if should_be_admin:
783 # Grant admin access
784 if not user.is_admin:
785 logger.info(f"Upgrading is_admin to True for {email} based on SSO admin groups")
786 user.is_admin = True
787 # Track that admin was granted via SSO (only set on initial grant)
788 user.admin_origin = "sso"
789 # Do NOT change admin_origin if already admin - preserve manual/API grants
790 elif user.is_admin and user.admin_origin == "sso":
791 # User was SSO admin but no longer in admin groups - revoke access
792 logger.info(f"Revoking is_admin for {email} - removed from SSO admin groups")
793 user.is_admin = False
794 user.admin_origin = None
796 self.db.commit()
798 # Determine if syncing should happen (default True, respect provider-level and Entra setting)
799 should_sync = True
800 if provider: 800 ↛ 809line 800 didn't jump to line 809 because the condition on line 800 was always true
801 # Check provider-level sync_roles flag in provider_metadata (allows disabling per-provider)
802 metadata = provider.provider_metadata or {}
803 if "sync_roles" in metadata:
804 should_sync = metadata.get("sync_roles", True)
805 # Legacy Entra-specific setting (fallback for backwards compatibility)
806 elif provider.id == "entra" and hasattr(settings, "sso_entra_sync_roles_on_login"):
807 should_sync = settings.sso_entra_sync_roles_on_login
809 if provider and should_sync: 809 ↛ 893line 809 didn't jump to line 893 because the condition on line 809 was always true
810 role_assignments = await self._map_groups_to_roles(email, user_info.get("groups", []), provider)
811 await self._sync_user_roles(email, role_assignments, provider)
812 else:
813 # Auto-create user if enabled
814 provider = self.get_provider(user_info.get("provider"))
815 if not provider or not provider.auto_create_users:
816 return None
818 # Check trusted domains if configured
819 if provider.trusted_domains:
820 domain = email.split("@")[1].lower()
821 if domain not in [d.lower() for d in provider.trusted_domains]: 821 ↛ 825line 821 didn't jump to line 825 because the condition on line 821 was always true
822 return None
824 # Check if admin approval is required
825 if settings.sso_require_admin_approval:
826 # Check if user is already pending approval
828 pending = self.db.execute(select(PendingUserApproval).where(PendingUserApproval.email == email)).scalar_one_or_none()
830 if pending:
831 if pending.status == "pending" and not pending.is_expired():
832 return None # Still waiting for approval
833 if pending.status == "rejected":
834 return None # User was rejected
835 if pending.status == "approved": 835 ↛ 856line 835 didn't jump to line 856 because the condition on line 835 was always true
836 # User was approved, create account now
837 pass # Continue with user creation below
838 else:
839 # Create pending approval request
841 pending = PendingUserApproval(
842 email=email,
843 full_name=user_info.get("full_name", email),
844 auth_provider=user_info.get("provider", "sso"),
845 sso_metadata=user_info,
846 expires_at=utc_now() + timedelta(days=30), # 30-day approval window
847 )
848 self.db.add(pending)
849 self.db.commit()
850 logger.info(f"Created pending approval request for SSO user: {email}")
851 return None # No token until approved
853 # Create new user (either no approval required, or approval already granted)
854 # Generate a secure random password for SSO users (they won't use it)
856 random_password = "".join(secrets.choice(string.ascii_letters + string.digits + string.punctuation) for _ in range(32))
858 # Determine if user should be admin based on domain/organization
859 is_admin = self._should_user_be_admin(email, user_info, provider)
861 user = await self.auth_service.create_user(
862 email=email,
863 password=random_password, # Random password for SSO users (not used)
864 full_name=user_info.get("full_name", email),
865 is_admin=is_admin,
866 auth_provider=user_info.get("provider", "sso"),
867 )
868 if not user:
869 return None
871 # Assign RBAC roles based on SSO groups (or default role if no groups)
872 # Check provider-level sync_roles flag in provider_metadata
873 metadata = provider.provider_metadata or {}
874 should_sync = metadata.get("sync_roles", True)
875 # Legacy Entra-specific setting (fallback for backwards compatibility)
876 if "sync_roles" not in metadata and provider.id == "entra" and hasattr(settings, "sso_entra_sync_roles_on_login"): 876 ↛ 877line 876 didn't jump to line 877 because the condition on line 876 was never true
877 should_sync = settings.sso_entra_sync_roles_on_login
879 if should_sync: 879 ↛ 885line 879 didn't jump to line 885 because the condition on line 879 was always true
880 role_assignments = await self._map_groups_to_roles(email, user_info.get("groups", []), provider)
881 if role_assignments: 881 ↛ 882line 881 didn't jump to line 882 because the condition on line 881 was never true
882 await self._sync_user_roles(email, role_assignments, provider)
884 # If user was created from approved request, mark request as used
885 if settings.sso_require_admin_approval:
886 pending = self.db.execute(select(PendingUserApproval).where(and_(PendingUserApproval.email == email, PendingUserApproval.status == "approved"))).scalar_one_or_none()
887 if pending: 887 ↛ 893line 887 didn't jump to line 893 because the condition on line 887 was always true
888 # Mark as used (we could delete or keep for audit trail)
889 pending.status = "completed"
890 self.db.commit()
892 # Generate JWT token for user — session token (teams resolved server-side)
893 token_data = {
894 "sub": user.email,
895 "email": user.email,
896 "full_name": user.full_name,
897 "auth_provider": user.auth_provider,
898 "iat": int(utc_now().timestamp()),
899 "user": {"email": user.email, "full_name": user.full_name, "is_admin": user.is_admin, "auth_provider": user.auth_provider},
900 "token_use": "session", # nosec B105 - token type marker, not a password
901 # Scopes
902 "scopes": {"server_id": None, "permissions": ["*"] if user.is_admin else [], "ip_restrictions": [], "time_restrictions": {}},
903 }
905 # Create JWT token
906 token = await create_jwt_token(token_data)
907 return token
909 def _should_user_be_admin(self, email: str, user_info: Dict[str, Any], provider: SSOProvider) -> bool:
910 """Determine if SSO user should be granted admin privileges.
912 Args:
913 email: User's email address
914 user_info: Normalized user info from SSO provider
915 provider: SSO provider configuration
917 Returns:
918 True if user should be admin, False otherwise
919 """
920 # Check domain-based admin assignment
921 domain = email.split("@")[1].lower()
922 if domain in [d.lower() for d in settings.sso_auto_admin_domains]:
923 return True
925 # Check provider-specific admin assignment
926 if provider.id == "github" and settings.sso_github_admin_orgs:
927 # For GitHub, we'd need to fetch user's organizations
928 # This is a placeholder - in production, you'd make API calls to get orgs
929 github_orgs = user_info.get("organizations", [])
930 if any(org.lower() in [o.lower() for o in settings.sso_github_admin_orgs] for org in github_orgs):
931 return True
933 if provider.id == "google" and settings.sso_google_admin_domains:
934 # Check if user's domain is in admin domains
935 if domain in [d.lower() for d in settings.sso_google_admin_domains]:
936 return True
938 # Check EntraID admin groups
939 if provider.id == "entra" and settings.sso_entra_admin_groups:
940 user_groups = user_info.get("groups", [])
941 if any(group.lower() in [g.lower() for g in settings.sso_entra_admin_groups] for group in user_groups):
942 return True
944 return False
946 async def _map_groups_to_roles(self, user_email: str, user_groups: List[str], provider: SSOProvider) -> List[Dict[str, Any]]:
947 """Map SSO groups to Context Forge RBAC roles.
949 Args:
950 user_email: User's email address
951 user_groups: List of groups from SSO provider
952 provider: SSO provider configuration
954 Returns:
955 List of role assignments: [{"role_name": str, "scope": str, "scope_id": Optional[str]}]
956 """
957 # pylint: disable=import-outside-toplevel
958 # First-Party
959 from mcpgateway.services.role_service import RoleService
961 role_assignments = []
963 # Generic Role Mapping Logic
964 metadata = provider.provider_metadata or {}
965 role_mappings = metadata.get("role_mappings", {})
967 # Merge with legacy Entra specific settings if applicable
968 has_entra_admin_groups = provider.id == "entra" and settings.sso_entra_admin_groups
969 has_entra_default_role = provider.id == "entra" and settings.sso_entra_default_role
971 if provider.id == "entra":
972 # Use generic role_mappings fallback to legacy setting
973 if not role_mappings and settings.sso_entra_role_mappings:
974 role_mappings = settings.sso_entra_role_mappings
976 # Early exit: Skip role mapping if no configuration exists
977 if not role_mappings and not has_entra_admin_groups and not has_entra_default_role:
978 logger.debug(f"No role mappings configured for provider {provider.id}, skipping role sync")
979 return role_assignments
981 # Handle EntraID admin groups -> platform_admin
982 if has_entra_admin_groups:
983 admin_groups_lower = [g.lower() for g in settings.sso_entra_admin_groups]
984 for group in user_groups:
985 if group.lower() in admin_groups_lower:
986 role_assignments.append({"role_name": "platform_admin", "scope": "global", "scope_id": None})
987 logger.debug(f"Mapped EntraID admin group to platform_admin role for {user_email}")
988 break # Only need one admin assignment
990 # Batch role lookups: collect all role names that need to be looked up
991 role_names_to_lookup = set()
992 for group in user_groups:
993 if group in role_mappings:
994 role_name = role_mappings[group]
995 if role_name not in ["admin", "platform_admin"]:
996 role_names_to_lookup.add(role_name)
998 # Add default role to lookup if needed
999 if has_entra_default_role:
1000 role_names_to_lookup.add(settings.sso_entra_default_role)
1002 # Pre-fetch all roles by name in batches (reduces DB round-trips)
1003 role_service = RoleService(self.db)
1004 role_cache: Dict[str, Any] = {}
1005 for role_name in role_names_to_lookup:
1006 # Try team scope first, then global
1007 role = await role_service.get_role_by_name(role_name, scope="team")
1008 if not role:
1009 role = await role_service.get_role_by_name(role_name, scope="global")
1010 if role:
1011 role_cache[role_name] = role
1013 # Process role mappings for ALL providers
1014 for group in user_groups:
1015 if group in role_mappings:
1016 role_name = role_mappings[group]
1017 # Special case for "admin"/"platform_admin" shorthand
1018 if role_name in ["admin", "platform_admin"]:
1019 role_assignments.append({"role_name": "platform_admin", "scope": "global", "scope_id": None})
1020 logger.debug(f"Mapped group to platform_admin role for {user_email}")
1021 continue
1023 # Use pre-fetched role from cache
1024 role = role_cache.get(role_name)
1025 if role:
1026 # Avoid duplicate assignments
1027 if not any(r["role_name"] == role.name for r in role_assignments): 1027 ↛ 1014line 1027 didn't jump to line 1014 because the condition on line 1027 was always true
1028 role_assignments.append({"role_name": role.name, "scope": role.scope, "scope_id": None})
1029 logger.debug(f"Mapped group to role '{role.name}' for {user_email}")
1030 else:
1031 logger.warning(f"Role '{role_name}' not found for group mapping")
1033 # Apply default role if no mappings found (Entra legacy fallback)
1034 if not role_assignments and has_entra_default_role:
1035 default_role = role_cache.get(settings.sso_entra_default_role)
1036 if default_role: 1036 ↛ 1040line 1036 didn't jump to line 1040 because the condition on line 1036 was always true
1037 role_assignments.append({"role_name": default_role.name, "scope": default_role.scope, "scope_id": None})
1038 logger.info(f"Assigned default role '{default_role.name}' to {user_email}")
1040 return role_assignments
1042 async def _sync_user_roles(self, user_email: str, role_assignments: List[Dict[str, Any]], _provider: SSOProvider) -> None:
1043 """Synchronize user's SSO-based role assignments.
1045 Args:
1046 user_email: User's email address
1047 role_assignments: List of role assignments to apply
1048 _provider: SSO provider configuration (reserved for future use)
1049 """
1050 # pylint: disable=import-outside-toplevel
1051 # First-Party
1052 from mcpgateway.services.role_service import RoleService
1054 role_service = RoleService(self.db)
1056 # Get current SSO-granted roles (granted_by='sso_system')
1057 current_roles = await role_service.list_user_roles(user_email, include_expired=False)
1058 sso_roles = [r for r in current_roles if r.granted_by == "sso_system"]
1060 # Build set of desired role assignments
1061 desired_roles = {(r["role_name"], r["scope"], r.get("scope_id")) for r in role_assignments}
1063 # Revoke roles that are no longer in the desired set
1064 for user_role in sso_roles:
1065 role_tuple = (user_role.role.name, user_role.scope, user_role.scope_id)
1066 if role_tuple not in desired_roles:
1067 await role_service.revoke_role_from_user(user_email=user_email, role_id=user_role.role_id, scope=user_role.scope, scope_id=user_role.scope_id)
1068 logger.info(f"Revoked SSO role '{user_role.role.name}' from {user_email} (no longer in groups)")
1070 # Assign new roles
1071 for assignment in role_assignments:
1072 try:
1073 # Get role by name
1074 role = await role_service.get_role_by_name(assignment["role_name"], scope=assignment["scope"])
1075 if not role:
1076 logger.warning(f"Role '{assignment['role_name']}' not found, skipping assignment for {user_email}")
1077 continue
1079 # Check if assignment already exists
1080 existing = await role_service.get_user_role_assignment(user_email=user_email, role_id=role.id, scope=assignment["scope"], scope_id=assignment.get("scope_id"))
1082 if not existing or not existing.is_active:
1083 # Assign role to user
1084 await role_service.assign_role_to_user(user_email=user_email, role_id=role.id, scope=assignment["scope"], scope_id=assignment.get("scope_id"), granted_by="sso_system")
1085 logger.info(f"Assigned SSO role '{role.name}' to {user_email}")
1087 except Exception as e:
1088 logger.warning(f"Failed to assign role '{assignment['role_name']}' to {user_email}: {e}")