Coverage for mcpgateway / routers / sso.py: 97%
247 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/routers/sso.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Single Sign-On (SSO) authentication routes for OAuth2/OIDC providers.
8Handles SSO login flows, provider configuration, and callback handling.
9"""
11# Standard
12from typing import Dict, List, Optional
13from urllib.parse import urlparse
15# Third-Party
16from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status
17from pydantic import BaseModel
18from sqlalchemy.orm import Session
20# First-Party
21from mcpgateway.config import settings
22from mcpgateway.db import get_db
23from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission
24from mcpgateway.services.logging_service import LoggingService
25from mcpgateway.services.sso_service import SSOService
27# Initialize logging
28logging_service = LoggingService()
29logger = logging_service.get_logger("mcpgateway.routers.sso")
32class SSOProviderCreateRequest(BaseModel):
33 """Request to create SSO provider."""
35 id: str
36 name: str
37 display_name: str
38 provider_type: str # oauth2, oidc
39 client_id: str
40 client_secret: str
41 authorization_url: str
42 token_url: str
43 userinfo_url: str
44 issuer: Optional[str] = None
45 scope: str = "openid profile email"
46 trusted_domains: List[str] = []
47 auto_create_users: bool = True
48 team_mapping: Dict = {}
49 provider_metadata: Dict = {} # Role mappings, groups_claim config, etc.
52class SSOProviderUpdateRequest(BaseModel):
53 """Request to update SSO provider."""
55 name: Optional[str] = None
56 display_name: Optional[str] = None
57 provider_type: Optional[str] = None
58 client_id: Optional[str] = None
59 client_secret: Optional[str] = None
60 authorization_url: Optional[str] = None
61 token_url: Optional[str] = None
62 userinfo_url: Optional[str] = None
63 issuer: Optional[str] = None
64 scope: Optional[str] = None
65 trusted_domains: Optional[List[str]] = None
66 auto_create_users: Optional[bool] = None
67 team_mapping: Optional[Dict] = None
68 provider_metadata: Optional[Dict] = None # Role mappings, groups_claim config, etc.
69 is_enabled: Optional[bool] = None
72# Create router
73sso_router = APIRouter(prefix="/auth/sso", tags=["SSO Authentication"])
76class SSOProviderResponse(BaseModel):
77 """SSO provider information for client."""
79 id: str
80 name: str
81 display_name: str
82 authorization_url: Optional[str] = None # Only provided when initiating login
85class SSOLoginResponse(BaseModel):
86 """SSO login initiation response."""
88 authorization_url: str
89 state: str
92class SSOCallbackResponse(BaseModel):
93 """SSO authentication callback response."""
95 access_token: str
96 token_type: str = "bearer"
97 expires_in: int
98 user: Dict
101@sso_router.get("/providers", response_model=List[SSOProviderResponse])
102async def list_sso_providers(
103 db: Session = Depends(get_db),
104) -> List[SSOProviderResponse]:
105 """List available SSO providers for login.
107 Args:
108 db: Database session
110 Returns:
111 List of enabled SSO providers with basic information.
113 Raises:
114 HTTPException: If SSO authentication is disabled
116 Examples:
117 >>> import asyncio
118 >>> asyncio.iscoroutinefunction(list_sso_providers)
119 True
120 """
121 if not settings.sso_enabled:
122 raise HTTPException(status_code=404, detail="SSO authentication is disabled")
124 sso_service = SSOService(db)
125 providers = sso_service.list_enabled_providers()
127 return [SSOProviderResponse(id=provider.id, name=provider.name, display_name=provider.display_name) for provider in providers]
130def _normalize_origin(scheme: str, host: str, port: int | None) -> str:
131 """Normalize an origin to scheme://host:port format.
133 Args:
134 scheme: URL scheme (http/https)
135 host: Hostname
136 port: Port number (None uses default for scheme)
138 Returns:
139 Normalized origin string
140 """
141 # Use default ports for scheme if not specified
142 default_ports = {"http": 80, "https": 443}
143 if port is None or port == default_ports.get(scheme):
144 return f"{scheme}://{host}"
145 return f"{scheme}://{host}:{port}"
148def _validate_redirect_uri(redirect_uri: str, request: Request | None = None) -> bool:
149 """Validate redirect_uri to prevent open redirect attacks.
151 Validates against a server-side allowlist (settings.allowed_origins and settings.app_domain).
152 Does NOT trust the Host header to prevent spoofing attacks.
154 Allows:
155 - Relative URIs (no scheme/host)
156 - URIs matching configured allowed_origins (full origin including scheme and port)
157 - URIs matching app_domain (if configured)
159 Args:
160 redirect_uri: The redirect URI to validate
161 request: The FastAPI request object (unused, kept for API compatibility)
163 Returns:
164 True if the redirect_uri is safe, False otherwise
165 """
166 parsed = urlparse(redirect_uri)
168 # Allow relative URIs (no scheme and no netloc)
169 if not parsed.scheme and not parsed.netloc:
170 return True
172 # For absolute URIs, validate against server-side allowlist only
173 # Extract full origin components from redirect_uri
174 redirect_scheme = parsed.scheme.lower()
175 redirect_host = parsed.hostname.lower() if parsed.hostname else ""
176 redirect_port = parsed.port
178 # Normalize the redirect origin
179 redirect_origin = _normalize_origin(redirect_scheme, redirect_host, redirect_port)
181 # Check against app_domain (if configured)
182 if hasattr(settings, "app_domain") and settings.app_domain:
183 # app_domain is typically just a hostname, allow both http and https
184 app_domain = settings.app_domain.lower()
185 if redirect_host == app_domain:
186 # Only allow HTTPS in production, or HTTP for localhost
187 if redirect_scheme == "https" or (redirect_scheme == "http" and app_domain in ("localhost", "127.0.0.1")): 187 ↛ 191line 187 didn't jump to line 191 because the condition on line 187 was always true
188 return True
190 # Check against allowed_origins (full origin match including scheme and port)
191 if hasattr(settings, "allowed_origins") and settings.allowed_origins: 191 ↛ 208line 191 didn't jump to line 208 because the condition on line 191 was always true
192 for origin in settings.allowed_origins:
193 origin = origin.strip()
194 if not origin:
195 continue
197 # Parse the allowed origin
198 origin_parsed = urlparse(origin if "://" in origin else f"https://{origin}")
199 origin_scheme = origin_parsed.scheme.lower() if origin_parsed.scheme else "https"
200 origin_host = origin_parsed.hostname.lower() if origin_parsed.hostname else origin.lower()
201 origin_port = origin_parsed.port
203 # Normalize and compare full origins
204 allowed_origin = _normalize_origin(origin_scheme, origin_host, origin_port)
205 if redirect_origin == allowed_origin:
206 return True
208 return False
211@sso_router.get("/login/{provider_id}", response_model=SSOLoginResponse)
212async def initiate_sso_login(
213 provider_id: str,
214 request: Request,
215 redirect_uri: str = Query(..., description="Callback URI after authentication"),
216 scopes: Optional[str] = Query(None, description="Space-separated OAuth scopes"),
217 db: Session = Depends(get_db),
218) -> SSOLoginResponse:
219 """Initiate SSO authentication flow.
221 Validates the redirect_uri against a server-side allowlist to prevent open redirect attacks.
222 Only allows relative URIs, URIs matching app_domain, or URIs from configured allowed_origins.
223 Does NOT trust the Host header for validation.
225 Args:
226 provider_id: SSO provider identifier (e.g., 'github', 'google')
227 request: FastAPI request object
228 redirect_uri: Callback URI after successful authentication
229 scopes: Optional custom OAuth scopes (space-separated)
230 db: Database session
232 Returns:
233 Authorization URL and state parameter for redirect.
235 Raises:
236 HTTPException: If SSO is disabled, provider not found, or redirect_uri is invalid
238 Examples:
239 >>> import asyncio
240 >>> asyncio.iscoroutinefunction(initiate_sso_login)
241 True
242 """
243 if not settings.sso_enabled:
244 raise HTTPException(status_code=404, detail="SSO authentication is disabled")
246 # Validate redirect_uri to prevent open redirect attacks
247 # Uses server-side allowlist (allowed_origins, app_domain) - does NOT trust Host header
248 if not _validate_redirect_uri(redirect_uri, request):
249 logger.warning(f"SSO login rejected - invalid redirect_uri: {redirect_uri}")
250 raise HTTPException(
251 status_code=status.HTTP_400_BAD_REQUEST,
252 detail="Invalid redirect_uri. Must be a relative path or URL matching allowed origins.",
253 )
255 sso_service = SSOService(db)
256 scope_list = scopes.split() if scopes else None
258 auth_url = sso_service.get_authorization_url(provider_id, redirect_uri, scope_list)
259 if not auth_url:
260 raise HTTPException(status_code=404, detail=f"SSO provider '{provider_id}' not found or disabled")
262 # Extract state from URL for client reference
263 # Standard
264 import urllib.parse
266 parsed = urllib.parse.urlparse(auth_url)
267 params = urllib.parse.parse_qs(parsed.query)
268 state = params.get("state", [""])[0]
270 return SSOLoginResponse(authorization_url=auth_url, state=state)
273@sso_router.get("/callback/{provider_id}")
274async def handle_sso_callback(
275 provider_id: str,
276 code: str = Query(..., description="Authorization code from SSO provider"),
277 state: str = Query(..., description="CSRF state parameter"),
278 request: Request = None,
279 response: Response = None,
280 db: Session = Depends(get_db),
281):
282 """Handle SSO authentication callback.
284 Args:
285 provider_id: SSO provider identifier
286 code: Authorization code from provider
287 state: CSRF state parameter for validation
288 request: FastAPI request object
289 response: FastAPI response object
290 db: Database session
292 Returns:
293 JWT access token and user information.
295 Raises:
296 HTTPException: If SSO is disabled or authentication fails
298 Examples:
299 >>> import asyncio
300 >>> asyncio.iscoroutinefunction(handle_sso_callback)
301 True
302 """
303 if not settings.sso_enabled:
304 raise HTTPException(status_code=404, detail="SSO authentication is disabled")
306 # Get root path for URL construction
307 root_path = request.scope.get("root_path", "") if request else ""
309 sso_service = SSOService(db)
311 # Handle OAuth callback
312 user_info = await sso_service.handle_oauth_callback(provider_id, code, state)
313 if not user_info:
314 # Redirect back to login with error
315 # Third-Party
316 from fastapi.responses import RedirectResponse
318 return RedirectResponse(url=f"{root_path}/admin/login?error=sso_failed", status_code=302)
320 # Authenticate or create user
321 access_token = await sso_service.authenticate_or_create_user(user_info)
322 if not access_token:
323 # Redirect back to login with error
324 # Third-Party
325 from fastapi.responses import RedirectResponse
327 return RedirectResponse(url=f"{root_path}/admin/login?error=user_creation_failed", status_code=302)
329 # Create redirect response
330 # Third-Party
331 from fastapi.responses import RedirectResponse
333 redirect_response = RedirectResponse(url=f"{root_path}/admin", status_code=302)
335 # Set secure HTTP-only cookie using the same method as email auth
336 # First-Party
337 from mcpgateway.utils.security_cookies import CookieTooLargeError, set_auth_cookie
339 try:
340 set_auth_cookie(redirect_response, access_token, remember_me=False)
341 except CookieTooLargeError:
342 redirect_response = RedirectResponse(
343 url=f"{root_path}/admin/login?error=token_too_large",
344 status_code=302,
345 )
346 return redirect_response
348 return redirect_response
351# Admin endpoints for SSO provider management
352@sso_router.post("/admin/providers", response_model=Dict)
353@require_permission("admin.sso_providers:create")
354async def create_sso_provider(
355 provider_data: SSOProviderCreateRequest,
356 db: Session = Depends(get_db),
357 user=Depends(get_current_user_with_permissions),
358) -> Dict:
359 """Create new SSO provider configuration (Admin only).
361 Args:
362 provider_data: SSO provider configuration
363 db: Database session
364 user: Current authenticated user
366 Returns:
367 Created provider information.
369 Raises:
370 HTTPException: If provider already exists or creation fails
371 """
372 sso_service = SSOService(db)
374 # Check if provider already exists
375 existing = sso_service.get_provider(provider_data.id)
376 if existing:
377 raise HTTPException(status_code=409, detail=f"SSO provider '{provider_data.id}' already exists")
379 provider = await sso_service.create_provider(provider_data.model_dump())
381 result = {
382 "id": provider.id,
383 "name": provider.name,
384 "display_name": provider.display_name,
385 "provider_type": provider.provider_type,
386 "is_enabled": provider.is_enabled,
387 "created_at": provider.created_at,
388 }
389 db.commit()
390 db.close()
391 return result
394@sso_router.get("/admin/providers", response_model=List[Dict])
395@require_permission("admin.sso_providers:read")
396async def list_all_sso_providers(
397 db: Session = Depends(get_db),
398 user=Depends(get_current_user_with_permissions),
399) -> List[Dict]:
400 """List all SSO providers including disabled ones (Admin only).
402 Args:
403 db: Database session
404 user: Current authenticated user
406 Returns:
407 List of all SSO providers with configuration details.
408 """
409 # Third-Party
410 from sqlalchemy import select
412 # First-Party
413 from mcpgateway.db import SSOProvider
415 stmt = select(SSOProvider)
416 result = db.execute(stmt)
417 providers = result.scalars().all()
419 result = [
420 {
421 "id": provider.id,
422 "name": provider.name,
423 "display_name": provider.display_name,
424 "provider_type": provider.provider_type,
425 "is_enabled": provider.is_enabled,
426 "trusted_domains": provider.trusted_domains,
427 "auto_create_users": provider.auto_create_users,
428 "created_at": provider.created_at,
429 "updated_at": provider.updated_at,
430 }
431 for provider in providers
432 ]
433 db.commit()
434 db.close()
435 return result
438@sso_router.get("/admin/providers/{provider_id}", response_model=Dict)
439@require_permission("admin.sso_providers:read")
440async def get_sso_provider(
441 provider_id: str,
442 db: Session = Depends(get_db),
443 user=Depends(get_current_user_with_permissions),
444) -> Dict:
445 """Get SSO provider details (Admin only).
447 Args:
448 provider_id: Provider identifier
449 db: Database session
450 user: Current authenticated user
452 Returns:
453 Provider configuration details.
455 Raises:
456 HTTPException: If provider not found
457 """
458 sso_service = SSOService(db)
459 provider = sso_service.get_provider(provider_id)
461 if not provider:
462 raise HTTPException(status_code=404, detail=f"SSO provider '{provider_id}' not found")
464 result = {
465 "id": provider.id,
466 "name": provider.name,
467 "display_name": provider.display_name,
468 "provider_type": provider.provider_type,
469 "client_id": provider.client_id,
470 "authorization_url": provider.authorization_url,
471 "token_url": provider.token_url,
472 "userinfo_url": provider.userinfo_url,
473 "issuer": provider.issuer,
474 "scope": provider.scope,
475 "trusted_domains": provider.trusted_domains,
476 "auto_create_users": provider.auto_create_users,
477 "team_mapping": provider.team_mapping,
478 "is_enabled": provider.is_enabled,
479 "created_at": provider.created_at,
480 "updated_at": provider.updated_at,
481 }
482 db.commit()
483 db.close()
484 return result
487@sso_router.put("/admin/providers/{provider_id}", response_model=Dict)
488@require_permission("admin.sso_providers:update")
489async def update_sso_provider(
490 provider_id: str,
491 provider_data: SSOProviderUpdateRequest,
492 db: Session = Depends(get_db),
493 user=Depends(get_current_user_with_permissions),
494) -> Dict:
495 """Update SSO provider configuration (Admin only).
497 Args:
498 provider_id: Provider identifier
499 provider_data: Updated provider configuration
500 db: Database session
501 user: Current authenticated user
503 Returns:
504 Updated provider information.
506 Raises:
507 HTTPException: If provider not found or update fails
508 """
509 sso_service = SSOService(db)
511 # Filter out None values
512 update_data = {k: v for k, v in provider_data.model_dump().items() if v is not None}
513 if not update_data:
514 raise HTTPException(status_code=400, detail="No update data provided")
516 provider = await sso_service.update_provider(provider_id, update_data)
517 if not provider:
518 raise HTTPException(status_code=404, detail=f"SSO provider '{provider_id}' not found")
520 result = {
521 "id": provider.id,
522 "name": provider.name,
523 "display_name": provider.display_name,
524 "provider_type": provider.provider_type,
525 "is_enabled": provider.is_enabled,
526 "updated_at": provider.updated_at,
527 }
528 db.commit()
529 db.close()
530 return result
533@sso_router.delete("/admin/providers/{provider_id}")
534@require_permission("admin.sso_providers:delete")
535async def delete_sso_provider(
536 provider_id: str,
537 db: Session = Depends(get_db),
538 user=Depends(get_current_user_with_permissions),
539) -> Dict:
540 """Delete SSO provider configuration (Admin only).
542 Args:
543 provider_id: Provider identifier
544 db: Database session
545 user: Current authenticated user
547 Returns:
548 Deletion confirmation.
550 Raises:
551 HTTPException: If provider not found
552 """
553 sso_service = SSOService(db)
555 if not sso_service.delete_provider(provider_id):
556 raise HTTPException(status_code=404, detail=f"SSO provider '{provider_id}' not found")
558 db.commit()
559 db.close()
560 return {"message": f"SSO provider '{provider_id}' deleted successfully"}
563# ---------------------------------------------------------------------------
564# SSO User Approval Management Endpoints
565# ---------------------------------------------------------------------------
568class PendingUserApprovalResponse(BaseModel):
569 """Response model for pending user approval."""
571 id: str
572 email: str
573 full_name: str
574 auth_provider: str
575 requested_at: str
576 expires_at: str
577 status: str
578 sso_metadata: Optional[Dict] = None
581class ApprovalActionRequest(BaseModel):
582 """Request model for approval actions."""
584 action: str # "approve" or "reject"
585 reason: Optional[str] = None # Required for rejection
586 notes: Optional[str] = None
589@sso_router.get("/pending-approvals", response_model=List[PendingUserApprovalResponse])
590@require_permission("admin.user_management")
591async def list_pending_approvals(
592 include_expired: bool = Query(False, description="Include expired approval requests"),
593 db: Session = Depends(get_db),
594 user=Depends(get_current_user_with_permissions),
595) -> List[PendingUserApprovalResponse]:
596 """List pending SSO user approval requests (Admin only).
598 Args:
599 include_expired: Whether to include expired requests
600 db: Database session
601 user: Current authenticated admin user
603 Returns:
604 List of pending approval requests
605 """
606 # Third-Party
607 from sqlalchemy import select
609 # First-Party
610 from mcpgateway.db import PendingUserApproval
612 query = select(PendingUserApproval)
614 if not include_expired: 614 ↛ 616line 614 didn't jump to line 616 because the condition on line 614 was never true
615 # First-Party
616 from mcpgateway.db import utc_now
618 query = query.where(PendingUserApproval.expires_at > utc_now())
620 # Filter by status
621 query = query.where(PendingUserApproval.status == "pending")
622 query = query.order_by(PendingUserApproval.requested_at.desc())
624 result = db.execute(query)
625 pending_approvals = result.scalars().all()
627 return [
628 PendingUserApprovalResponse(
629 id=approval.id,
630 email=approval.email,
631 full_name=approval.full_name,
632 auth_provider=approval.auth_provider,
633 requested_at=approval.requested_at.isoformat(),
634 expires_at=approval.expires_at.isoformat(),
635 status=approval.status,
636 sso_metadata=approval.sso_metadata,
637 )
638 for approval in pending_approvals
639 ]
642@sso_router.post("/pending-approvals/{approval_id}/action")
643@require_permission("admin.user_management")
644async def handle_approval_request(
645 approval_id: str,
646 request: ApprovalActionRequest,
647 db: Session = Depends(get_db),
648 user=Depends(get_current_user_with_permissions),
649) -> Dict:
650 """Approve or reject a pending SSO user registration (Admin only).
652 Args:
653 approval_id: ID of the approval request
654 request: Approval action (approve/reject) with optional reason/notes
655 db: Database session
656 user: Current authenticated admin user
658 Returns:
659 Action confirmation message
661 Raises:
662 HTTPException: If approval not found or invalid action
663 """
664 # Third-Party
665 from sqlalchemy import select
667 # First-Party
668 from mcpgateway.db import PendingUserApproval
670 # Get pending approval
671 approval = db.execute(select(PendingUserApproval).where(PendingUserApproval.id == approval_id)).scalar_one_or_none()
673 if not approval:
674 raise HTTPException(status_code=404, detail="Approval request not found")
676 if approval.status != "pending":
677 raise HTTPException(status_code=400, detail=f"Approval request is already {approval.status}")
679 if approval.is_expired():
680 approval.status = "expired"
681 db.commit()
682 raise HTTPException(status_code=400, detail="Approval request has expired")
684 admin_email = user["email"]
686 if request.action == "approve":
687 approval.approve(admin_email, request.notes)
688 db.commit()
689 return {"message": f"User {approval.email} approved successfully"}
691 elif request.action == "reject":
692 if not request.reason:
693 raise HTTPException(status_code=400, detail="Rejection reason is required")
694 approval.reject(admin_email, request.reason, request.notes)
695 db.commit()
696 return {"message": f"User {approval.email} rejected"}
698 else:
699 raise HTTPException(status_code=400, detail="Invalid action. Must be 'approve' or 'reject'")