Coverage for mcpgateway / routers / sso.py: 98%
278 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/routers/sso.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Single Sign-On (SSO) authentication routes for OAuth2/OIDC providers.
8Handles SSO login flows, provider configuration, and callback handling.
9"""
11# Standard
12import secrets
13from typing import Dict, List, Optional
14from urllib.parse import urlparse
16# Third-Party
17from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status
18from pydantic import BaseModel
19from sqlalchemy.orm import Session
21# First-Party
22from mcpgateway.config import settings
23from mcpgateway.db import get_db
24from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission
25from mcpgateway.services.logging_service import LoggingService
26from mcpgateway.services.sso_service import SSOService
27from mcpgateway.utils.log_sanitizer import sanitize_for_log
29# Initialize logging
30logging_service = LoggingService()
31logger = logging_service.get_logger("mcpgateway.routers.sso")
34class SSOProviderCreateRequest(BaseModel):
35 """Request to create SSO provider."""
37 id: str
38 name: str
39 display_name: str
40 provider_type: str # oauth2, oidc
41 client_id: str
42 client_secret: str
43 authorization_url: str
44 token_url: str
45 userinfo_url: str
46 issuer: Optional[str] = None
47 jwks_uri: Optional[str] = None
48 scope: str = "openid profile email"
49 trusted_domains: List[str] = []
50 auto_create_users: bool = True
51 team_mapping: Dict = {}
52 provider_metadata: Dict = {} # Role mappings, groups_claim config, etc.
55class SSOProviderUpdateRequest(BaseModel):
56 """Request to update SSO provider."""
58 name: Optional[str] = None
59 display_name: Optional[str] = None
60 provider_type: Optional[str] = None
61 client_id: Optional[str] = None
62 client_secret: Optional[str] = None
63 authorization_url: Optional[str] = None
64 token_url: Optional[str] = None
65 userinfo_url: Optional[str] = None
66 issuer: Optional[str] = None
67 jwks_uri: Optional[str] = None
68 scope: Optional[str] = None
69 trusted_domains: Optional[List[str]] = None
70 auto_create_users: Optional[bool] = None
71 team_mapping: Optional[Dict] = None
72 provider_metadata: Optional[Dict] = None # Role mappings, groups_claim config, etc.
73 is_enabled: Optional[bool] = None
76# Create router
77sso_router = APIRouter(prefix="/auth/sso", tags=["SSO Authentication"])
80class SSOProviderResponse(BaseModel):
81 """SSO provider information for client."""
83 id: str
84 name: str
85 display_name: str
86 authorization_url: Optional[str] = None # Only provided when initiating login
89class SSOLoginResponse(BaseModel):
90 """SSO login initiation response."""
92 authorization_url: str
93 state: str
96class SSOCallbackResponse(BaseModel):
97 """SSO authentication callback response."""
99 access_token: str
100 token_type: str = "bearer"
101 expires_in: int
102 user: Dict
105@sso_router.get("/providers", response_model=List[SSOProviderResponse])
106async def list_sso_providers(
107 db: Session = Depends(get_db),
108) -> List[SSOProviderResponse]:
109 """List available SSO providers for login.
111 Args:
112 db: Database session
114 Returns:
115 List of enabled SSO providers with basic information.
117 Raises:
118 HTTPException: If SSO authentication is disabled
120 Examples:
121 >>> import asyncio
122 >>> asyncio.iscoroutinefunction(list_sso_providers)
123 True
124 """
125 if not settings.sso_enabled:
126 raise HTTPException(status_code=404, detail="SSO authentication is disabled")
128 sso_service = SSOService(db)
129 providers = sso_service.list_enabled_providers()
131 return [SSOProviderResponse(id=provider.id, name=provider.name, display_name=provider.display_name) for provider in providers]
134def _normalize_origin(scheme: str, host: str, port: int | None) -> str:
135 """Normalize an origin to scheme://host:port format.
137 Args:
138 scheme: URL scheme (http/https)
139 host: Hostname
140 port: Port number (None uses default for scheme)
142 Returns:
143 Normalized origin string
144 """
145 # Use default ports for scheme if not specified
146 default_ports = {"http": 80, "https": 443}
147 if port is None or port == default_ports.get(scheme):
148 return f"{scheme}://{host}"
149 return f"{scheme}://{host}:{port}"
152def _validate_redirect_uri(redirect_uri: str, request: Request | None = None) -> bool:
153 """Validate redirect_uri to prevent open redirect attacks.
155 Validates against a server-side allowlist (settings.allowed_origins and settings.app_domain).
156 Does NOT trust the Host header to prevent spoofing attacks.
158 Allows:
159 - Relative URIs (no scheme/host)
160 - URIs matching configured allowed_origins (full origin including scheme and port)
161 - URIs matching app_domain (if configured)
163 Args:
164 redirect_uri: The redirect URI to validate
165 request: The FastAPI request object (unused, kept for API compatibility)
167 Returns:
168 True if the redirect_uri is safe, False otherwise
169 """
170 parsed = urlparse(redirect_uri)
172 # Allow relative URIs (no scheme and no netloc)
173 if not parsed.scheme and not parsed.netloc:
174 return True
176 # For absolute URIs, validate against server-side allowlist only
177 # Extract full origin components from redirect_uri
178 redirect_scheme = parsed.scheme.lower()
179 redirect_host = parsed.hostname.lower() if parsed.hostname else ""
180 redirect_port = parsed.port
182 # Normalize the redirect origin
183 redirect_origin = _normalize_origin(redirect_scheme, redirect_host, redirect_port)
185 # Check against app_domain (if configured)
186 if hasattr(settings, "app_domain") and settings.app_domain:
187 # app_domain is an HttpUrl - extract the hostname for comparison
188 app_domain_host = urlparse(str(settings.app_domain)).hostname or ""
189 app_domain_host = app_domain_host.lower()
190 if redirect_host == app_domain_host:
191 # Only allow HTTPS in production, or HTTP for localhost
192 if redirect_scheme == "https" or (redirect_scheme == "http" and app_domain_host in ("localhost", "127.0.0.1")):
193 return True
195 # Check against allowed_origins (full origin match including scheme and port)
196 if hasattr(settings, "allowed_origins") and settings.allowed_origins:
197 for origin in settings.allowed_origins:
198 origin = origin.strip()
199 if not origin:
200 continue
202 # Parse the allowed origin
203 origin_parsed = urlparse(origin if "://" in origin else f"https://{origin}")
204 origin_scheme = origin_parsed.scheme.lower() if origin_parsed.scheme else "https"
205 origin_host = origin_parsed.hostname.lower() if origin_parsed.hostname else origin.lower()
206 origin_port = origin_parsed.port
208 # Normalize and compare full origins
209 allowed_origin = _normalize_origin(origin_scheme, origin_host, origin_port)
210 if redirect_origin == allowed_origin:
211 return True
213 return False
216@sso_router.get("/login/{provider_id}", response_model=SSOLoginResponse)
217async def initiate_sso_login(
218 provider_id: str,
219 request: Request,
220 response: Response,
221 redirect_uri: str = Query(..., description="Callback URI after authentication"),
222 scopes: Optional[str] = Query(None, description="Space-separated OAuth scopes"),
223 db: Session = Depends(get_db),
224) -> SSOLoginResponse:
225 """Initiate SSO authentication flow.
227 Validates the redirect_uri against a server-side allowlist to prevent open redirect attacks.
228 Only allows relative URIs, URIs matching app_domain, or URIs from configured allowed_origins.
229 Does NOT trust the Host header for validation.
231 Args:
232 provider_id: SSO provider identifier (e.g., 'github', 'google')
233 request: FastAPI request object
234 response: FastAPI response object used to set session-binding cookie
235 redirect_uri: Callback URI after successful authentication
236 scopes: Optional custom OAuth scopes (space-separated)
237 db: Database session
239 Returns:
240 Authorization URL and state parameter for redirect.
242 Raises:
243 HTTPException: If SSO is disabled, provider not found, or redirect_uri is invalid
245 Examples:
246 >>> import asyncio
247 >>> asyncio.iscoroutinefunction(initiate_sso_login)
248 True
249 """
250 if not settings.sso_enabled:
251 raise HTTPException(status_code=404, detail="SSO authentication is disabled")
253 # Validate redirect_uri to prevent open redirect attacks
254 # Uses server-side allowlist (allowed_origins, app_domain) - does NOT trust Host header
255 if not _validate_redirect_uri(redirect_uri, request):
256 # Sanitize untrusted redirect_uri before logging to prevent log injection
257 logger.warning(f"SSO login rejected - invalid redirect_uri: {sanitize_for_log(redirect_uri)}")
258 raise HTTPException(
259 status_code=status.HTTP_400_BAD_REQUEST,
260 detail="Invalid redirect_uri. Must be a relative path or URL matching allowed origins.",
261 )
263 sso_service = SSOService(db)
264 scope_list = scopes.split() if scopes else None
265 browser_session_binding = secrets.token_urlsafe(32)
267 try:
268 auth_url = sso_service.get_authorization_url(provider_id, redirect_uri, scope_list, session_binding=browser_session_binding)
269 except ValueError as exc:
270 raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
272 if not auth_url:
273 raise HTTPException(status_code=404, detail=f"SSO provider '{provider_id}' not found or disabled")
275 # Extract state from URL for client reference
276 # Standard
277 import urllib.parse
279 parsed = urllib.parse.urlparse(auth_url)
280 params = urllib.parse.parse_qs(parsed.query)
281 state = params.get("state", [""])[0]
283 use_secure = (settings.environment == "production") or settings.secure_cookies
284 response.set_cookie(
285 key="sso_session_id",
286 value=browser_session_binding,
287 httponly=True,
288 secure=use_secure,
289 samesite=settings.cookie_samesite,
290 path=settings.app_root_path or "/",
291 )
293 return SSOLoginResponse(authorization_url=auth_url, state=state)
296@sso_router.get("/callback/{provider_id}")
297async def handle_sso_callback(
298 provider_id: str,
299 code: str = Query(..., description="Authorization code from SSO provider"),
300 state: str = Query(..., description="CSRF state parameter"),
301 request: Request = None,
302 response: Response = None,
303 db: Session = Depends(get_db),
304):
305 """Handle SSO authentication callback.
307 Args:
308 provider_id: SSO provider identifier
309 code: Authorization code from provider
310 state: CSRF state parameter for validation
311 request: FastAPI request object
312 response: FastAPI response object
313 db: Database session
315 Returns:
316 JWT access token and user information.
318 Raises:
319 HTTPException: If SSO is disabled or authentication fails
321 Examples:
322 >>> import asyncio
323 >>> asyncio.iscoroutinefunction(handle_sso_callback)
324 True
325 """
326 if not settings.sso_enabled:
327 raise HTTPException(status_code=404, detail="SSO authentication is disabled")
329 # Get root path for URL construction
330 root_path = request.scope.get("root_path", "") if request else ""
332 sso_service = SSOService(db)
334 # Handle OAuth callback — returns (user_info, token_data) or None
335 user_info: Optional[Dict[str, object]] = None
336 token_data: Dict[str, object] = {}
338 browser_session_binding = request.cookies.get("sso_session_id") if request else None
339 if not browser_session_binding:
340 # Third-Party
341 from fastapi.responses import RedirectResponse
343 return RedirectResponse(url=f"{root_path}/admin/login?error=sso_failed", status_code=302)
345 callback_result = await sso_service.handle_oauth_callback_with_tokens(provider_id, code, state, session_binding=browser_session_binding)
346 if callback_result:
347 user_info, token_data = callback_result
349 if not user_info:
350 # Redirect back to login with error
351 # Third-Party
352 from fastapi.responses import RedirectResponse
354 return RedirectResponse(url=f"{root_path}/admin/login?error=sso_failed", status_code=302)
356 # Authenticate or create user
357 access_token = await sso_service.authenticate_or_create_user(user_info)
358 if not access_token:
359 # Redirect back to login with error
360 # Third-Party
361 from fastapi.responses import RedirectResponse
363 return RedirectResponse(url=f"{root_path}/admin/login?error=user_creation_failed", status_code=302)
365 # Create redirect response
366 # Third-Party
367 from fastapi.responses import RedirectResponse
369 redirect_response = RedirectResponse(url=f"{root_path}/admin", status_code=302)
371 # Set secure HTTP-only cookie using the same method as email auth
372 # First-Party
373 from mcpgateway.utils.security_cookies import CookieTooLargeError, set_auth_cookie
375 try:
376 set_auth_cookie(redirect_response, access_token, remember_me=False)
377 except CookieTooLargeError:
378 redirect_response = RedirectResponse(
379 url=f"{root_path}/admin/login?error=token_too_large",
380 status_code=302,
381 )
382 return redirect_response
384 # Persist Keycloak ID token as short-lived, HTTP-only hint for RP-initiated logout.
385 # Without id_token_hint, some Keycloak versions show confirmation and may preserve SSO.
386 id_token = token_data.get("id_token")
387 if provider_id == "keycloak" and isinstance(id_token, str) and id_token:
388 if len(id_token) > 3800: # Leave room for cookie metadata within browser 4KB limit
389 logger.warning("Keycloak id_token too large for cookie storage. RP-initiated logout will not include id_token_hint.")
390 else:
391 use_secure = (settings.environment == "production") or settings.secure_cookies
392 redirect_response.set_cookie(
393 key="sso_id_token_hint",
394 value=id_token,
395 max_age=settings.token_expiry * 60, # match session token lifetime
396 httponly=True,
397 secure=use_secure,
398 samesite=settings.cookie_samesite,
399 path=settings.app_root_path or "/",
400 )
402 return redirect_response
405# Admin endpoints for SSO provider management
406@sso_router.post("/admin/providers", response_model=Dict)
407@require_permission("admin.sso_providers:create")
408async def create_sso_provider(
409 provider_data: SSOProviderCreateRequest,
410 db: Session = Depends(get_db),
411 user=Depends(get_current_user_with_permissions),
412) -> Dict:
413 """Create new SSO provider configuration (Admin only).
415 Args:
416 provider_data: SSO provider configuration
417 db: Database session
418 user: Current authenticated user
420 Returns:
421 Created provider information.
423 Raises:
424 HTTPException: If provider already exists or creation fails
425 """
426 sso_service = SSOService(db)
428 # Check if provider already exists
429 existing = sso_service.get_provider(provider_data.id)
430 if existing:
431 raise HTTPException(status_code=409, detail=f"SSO provider '{provider_data.id}' already exists")
433 try:
434 provider = await sso_service.create_provider(provider_data.model_dump())
435 except ValueError as exc:
436 raise HTTPException(status_code=400, detail=str(exc)) from exc
438 result = {
439 "id": provider.id,
440 "name": provider.name,
441 "display_name": provider.display_name,
442 "provider_type": provider.provider_type,
443 "is_enabled": provider.is_enabled,
444 "created_at": provider.created_at,
445 }
446 db.commit()
447 db.close()
448 return result
451@sso_router.get("/admin/providers", response_model=List[Dict])
452@require_permission("admin.sso_providers:read")
453async def list_all_sso_providers(
454 db: Session = Depends(get_db),
455 user=Depends(get_current_user_with_permissions),
456) -> List[Dict]:
457 """List all SSO providers including disabled ones (Admin only).
459 Args:
460 db: Database session
461 user: Current authenticated user
463 Returns:
464 List of all SSO providers with configuration details.
465 """
466 # Third-Party
467 from sqlalchemy import select
469 # First-Party
470 from mcpgateway.db import SSOProvider
472 stmt = select(SSOProvider)
473 result = db.execute(stmt)
474 providers = result.scalars().all()
476 result = [
477 {
478 "id": provider.id,
479 "name": provider.name,
480 "display_name": provider.display_name,
481 "provider_type": provider.provider_type,
482 "is_enabled": provider.is_enabled,
483 "trusted_domains": provider.trusted_domains,
484 "auto_create_users": provider.auto_create_users,
485 "created_at": provider.created_at,
486 "updated_at": provider.updated_at,
487 }
488 for provider in providers
489 ]
490 db.commit()
491 db.close()
492 return result
495@sso_router.get("/admin/providers/{provider_id}", response_model=Dict)
496@require_permission("admin.sso_providers:read")
497async def get_sso_provider(
498 provider_id: str,
499 db: Session = Depends(get_db),
500 user=Depends(get_current_user_with_permissions),
501) -> Dict:
502 """Get SSO provider details (Admin only).
504 Args:
505 provider_id: Provider identifier
506 db: Database session
507 user: Current authenticated user
509 Returns:
510 Provider configuration details.
512 Raises:
513 HTTPException: If provider not found
514 """
515 sso_service = SSOService(db)
516 provider = sso_service.get_provider(provider_id)
518 if not provider:
519 raise HTTPException(status_code=404, detail=f"SSO provider '{provider_id}' not found")
521 result = {
522 "id": provider.id,
523 "name": provider.name,
524 "display_name": provider.display_name,
525 "provider_type": provider.provider_type,
526 "client_id": provider.client_id,
527 "authorization_url": provider.authorization_url,
528 "token_url": provider.token_url,
529 "userinfo_url": provider.userinfo_url,
530 "issuer": provider.issuer,
531 "jwks_uri": provider.jwks_uri,
532 "scope": provider.scope,
533 "trusted_domains": provider.trusted_domains,
534 "auto_create_users": provider.auto_create_users,
535 "team_mapping": provider.team_mapping,
536 "is_enabled": provider.is_enabled,
537 "created_at": provider.created_at,
538 "updated_at": provider.updated_at,
539 }
540 db.commit()
541 db.close()
542 return result
545@sso_router.put("/admin/providers/{provider_id}", response_model=Dict)
546@require_permission("admin.sso_providers:update")
547async def update_sso_provider(
548 provider_id: str,
549 provider_data: SSOProviderUpdateRequest,
550 db: Session = Depends(get_db),
551 user=Depends(get_current_user_with_permissions),
552) -> Dict:
553 """Update SSO provider configuration (Admin only).
555 Args:
556 provider_id: Provider identifier
557 provider_data: Updated provider configuration
558 db: Database session
559 user: Current authenticated user
561 Returns:
562 Updated provider information.
564 Raises:
565 HTTPException: If provider not found or update fails
566 """
567 sso_service = SSOService(db)
569 # Filter out None values
570 update_data = {k: v for k, v in provider_data.model_dump().items() if v is not None}
571 if not update_data:
572 raise HTTPException(status_code=400, detail="No update data provided")
574 try:
575 provider = await sso_service.update_provider(provider_id, update_data)
576 except ValueError as exc:
577 raise HTTPException(status_code=400, detail=str(exc)) from exc
579 if not provider:
580 raise HTTPException(status_code=404, detail=f"SSO provider '{provider_id}' not found")
582 result = {
583 "id": provider.id,
584 "name": provider.name,
585 "display_name": provider.display_name,
586 "provider_type": provider.provider_type,
587 "is_enabled": provider.is_enabled,
588 "updated_at": provider.updated_at,
589 }
590 db.commit()
591 db.close()
592 return result
595@sso_router.delete("/admin/providers/{provider_id}")
596@require_permission("admin.sso_providers:delete")
597async def delete_sso_provider(
598 provider_id: str,
599 db: Session = Depends(get_db),
600 user=Depends(get_current_user_with_permissions),
601) -> Dict:
602 """Delete SSO provider configuration (Admin only).
604 Args:
605 provider_id: Provider identifier
606 db: Database session
607 user: Current authenticated user
609 Returns:
610 Deletion confirmation.
612 Raises:
613 HTTPException: If provider not found
614 """
615 sso_service = SSOService(db)
617 if not sso_service.delete_provider(provider_id):
618 raise HTTPException(status_code=404, detail=f"SSO provider '{provider_id}' not found")
620 db.commit()
621 db.close()
622 return {"message": f"SSO provider '{provider_id}' deleted successfully"}
625# ---------------------------------------------------------------------------
626# SSO User Approval Management Endpoints
627# ---------------------------------------------------------------------------
630class PendingUserApprovalResponse(BaseModel):
631 """Response model for pending user approval."""
633 id: str
634 email: str
635 full_name: str
636 auth_provider: str
637 requested_at: str
638 expires_at: str
639 status: str
640 sso_metadata: Optional[Dict] = None
643class ApprovalActionRequest(BaseModel):
644 """Request model for approval actions."""
646 action: str # "approve" or "reject"
647 reason: Optional[str] = None # Required for rejection
648 notes: Optional[str] = None
651@sso_router.get("/pending-approvals", response_model=List[PendingUserApprovalResponse])
652@require_permission("admin.user_management")
653async def list_pending_approvals(
654 include_expired: bool = Query(False, description="Include expired approval requests"),
655 db: Session = Depends(get_db),
656 user=Depends(get_current_user_with_permissions),
657) -> List[PendingUserApprovalResponse]:
658 """List pending SSO user approval requests (Admin only).
660 Args:
661 include_expired: Whether to include expired requests
662 db: Database session
663 user: Current authenticated admin user
665 Returns:
666 List of pending approval requests
667 """
668 # Third-Party
669 from sqlalchemy import select
671 # First-Party
672 from mcpgateway.db import PendingUserApproval
674 query = select(PendingUserApproval)
676 if not include_expired:
677 # First-Party
678 from mcpgateway.db import utc_now
680 query = query.where(PendingUserApproval.expires_at > utc_now())
682 # Filter by status
683 query = query.where(PendingUserApproval.status == "pending")
684 query = query.order_by(PendingUserApproval.requested_at.desc())
686 result = db.execute(query)
687 pending_approvals = result.scalars().all()
689 return [
690 PendingUserApprovalResponse(
691 id=approval.id,
692 email=approval.email,
693 full_name=approval.full_name,
694 auth_provider=approval.auth_provider,
695 requested_at=approval.requested_at.isoformat(),
696 expires_at=approval.expires_at.isoformat(),
697 status=approval.status,
698 sso_metadata=approval.sso_metadata,
699 )
700 for approval in pending_approvals
701 ]
704@sso_router.post("/pending-approvals/{approval_id}/action")
705@require_permission("admin.user_management")
706async def handle_approval_request(
707 approval_id: str,
708 request: ApprovalActionRequest,
709 db: Session = Depends(get_db),
710 user=Depends(get_current_user_with_permissions),
711) -> Dict:
712 """Approve or reject a pending SSO user registration (Admin only).
714 Args:
715 approval_id: ID of the approval request
716 request: Approval action (approve/reject) with optional reason/notes
717 db: Database session
718 user: Current authenticated admin user
720 Returns:
721 Action confirmation message
723 Raises:
724 HTTPException: If approval not found or invalid action
725 """
726 # Third-Party
727 from sqlalchemy import select
729 # First-Party
730 from mcpgateway.db import PendingUserApproval
732 # Get pending approval
733 approval = db.execute(select(PendingUserApproval).where(PendingUserApproval.id == approval_id)).scalar_one_or_none()
735 if not approval:
736 raise HTTPException(status_code=404, detail="Approval request not found")
738 if approval.status != "pending":
739 raise HTTPException(status_code=400, detail=f"Approval request is already {approval.status}")
741 if approval.is_expired():
742 approval.status = "expired"
743 db.commit()
744 raise HTTPException(status_code=400, detail="Approval request has expired")
746 admin_email = user["email"]
748 if request.action == "approve":
749 approval.approve(admin_email, request.notes)
750 db.commit()
751 return {"message": f"User {approval.email} approved successfully"}
753 elif request.action == "reject":
754 if not request.reason:
755 raise HTTPException(status_code=400, detail="Rejection reason is required")
756 approval.reject(admin_email, request.reason, request.notes)
757 db.commit()
758 return {"message": f"User {approval.email} rejected"}
760 else:
761 raise HTTPException(status_code=400, detail="Invalid action. Must be 'approve' or 'reject'")