Coverage for mcpgateway / routers / sso.py: 97%

247 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/routers/sso.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Single Sign-On (SSO) authentication routes for OAuth2/OIDC providers. 

8Handles SSO login flows, provider configuration, and callback handling. 

9""" 

10 

11# Standard 

12from typing import Dict, List, Optional 

13from urllib.parse import urlparse 

14 

15# Third-Party 

16from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status 

17from pydantic import BaseModel 

18from sqlalchemy.orm import Session 

19 

20# First-Party 

21from mcpgateway.config import settings 

22from mcpgateway.db import get_db 

23from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission 

24from mcpgateway.services.logging_service import LoggingService 

25from mcpgateway.services.sso_service import SSOService 

26 

27# Initialize logging 

28logging_service = LoggingService() 

29logger = logging_service.get_logger("mcpgateway.routers.sso") 

30 

31 

32class SSOProviderCreateRequest(BaseModel): 

33 """Request to create SSO provider.""" 

34 

35 id: str 

36 name: str 

37 display_name: str 

38 provider_type: str # oauth2, oidc 

39 client_id: str 

40 client_secret: str 

41 authorization_url: str 

42 token_url: str 

43 userinfo_url: str 

44 issuer: Optional[str] = None 

45 scope: str = "openid profile email" 

46 trusted_domains: List[str] = [] 

47 auto_create_users: bool = True 

48 team_mapping: Dict = {} 

49 provider_metadata: Dict = {} # Role mappings, groups_claim config, etc. 

50 

51 

52class SSOProviderUpdateRequest(BaseModel): 

53 """Request to update SSO provider.""" 

54 

55 name: Optional[str] = None 

56 display_name: Optional[str] = None 

57 provider_type: Optional[str] = None 

58 client_id: Optional[str] = None 

59 client_secret: Optional[str] = None 

60 authorization_url: Optional[str] = None 

61 token_url: Optional[str] = None 

62 userinfo_url: Optional[str] = None 

63 issuer: Optional[str] = None 

64 scope: Optional[str] = None 

65 trusted_domains: Optional[List[str]] = None 

66 auto_create_users: Optional[bool] = None 

67 team_mapping: Optional[Dict] = None 

68 provider_metadata: Optional[Dict] = None # Role mappings, groups_claim config, etc. 

69 is_enabled: Optional[bool] = None 

70 

71 

72# Create router 

73sso_router = APIRouter(prefix="/auth/sso", tags=["SSO Authentication"]) 

74 

75 

76class SSOProviderResponse(BaseModel): 

77 """SSO provider information for client.""" 

78 

79 id: str 

80 name: str 

81 display_name: str 

82 authorization_url: Optional[str] = None # Only provided when initiating login 

83 

84 

85class SSOLoginResponse(BaseModel): 

86 """SSO login initiation response.""" 

87 

88 authorization_url: str 

89 state: str 

90 

91 

92class SSOCallbackResponse(BaseModel): 

93 """SSO authentication callback response.""" 

94 

95 access_token: str 

96 token_type: str = "bearer" 

97 expires_in: int 

98 user: Dict 

99 

100 

101@sso_router.get("/providers", response_model=List[SSOProviderResponse]) 

102async def list_sso_providers( 

103 db: Session = Depends(get_db), 

104) -> List[SSOProviderResponse]: 

105 """List available SSO providers for login. 

106 

107 Args: 

108 db: Database session 

109 

110 Returns: 

111 List of enabled SSO providers with basic information. 

112 

113 Raises: 

114 HTTPException: If SSO authentication is disabled 

115 

116 Examples: 

117 >>> import asyncio 

118 >>> asyncio.iscoroutinefunction(list_sso_providers) 

119 True 

120 """ 

121 if not settings.sso_enabled: 

122 raise HTTPException(status_code=404, detail="SSO authentication is disabled") 

123 

124 sso_service = SSOService(db) 

125 providers = sso_service.list_enabled_providers() 

126 

127 return [SSOProviderResponse(id=provider.id, name=provider.name, display_name=provider.display_name) for provider in providers] 

128 

129 

130def _normalize_origin(scheme: str, host: str, port: int | None) -> str: 

131 """Normalize an origin to scheme://host:port format. 

132 

133 Args: 

134 scheme: URL scheme (http/https) 

135 host: Hostname 

136 port: Port number (None uses default for scheme) 

137 

138 Returns: 

139 Normalized origin string 

140 """ 

141 # Use default ports for scheme if not specified 

142 default_ports = {"http": 80, "https": 443} 

143 if port is None or port == default_ports.get(scheme): 

144 return f"{scheme}://{host}" 

145 return f"{scheme}://{host}:{port}" 

146 

147 

148def _validate_redirect_uri(redirect_uri: str, request: Request | None = None) -> bool: 

149 """Validate redirect_uri to prevent open redirect attacks. 

150 

151 Validates against a server-side allowlist (settings.allowed_origins and settings.app_domain). 

152 Does NOT trust the Host header to prevent spoofing attacks. 

153 

154 Allows: 

155 - Relative URIs (no scheme/host) 

156 - URIs matching configured allowed_origins (full origin including scheme and port) 

157 - URIs matching app_domain (if configured) 

158 

159 Args: 

160 redirect_uri: The redirect URI to validate 

161 request: The FastAPI request object (unused, kept for API compatibility) 

162 

163 Returns: 

164 True if the redirect_uri is safe, False otherwise 

165 """ 

166 parsed = urlparse(redirect_uri) 

167 

168 # Allow relative URIs (no scheme and no netloc) 

169 if not parsed.scheme and not parsed.netloc: 

170 return True 

171 

172 # For absolute URIs, validate against server-side allowlist only 

173 # Extract full origin components from redirect_uri 

174 redirect_scheme = parsed.scheme.lower() 

175 redirect_host = parsed.hostname.lower() if parsed.hostname else "" 

176 redirect_port = parsed.port 

177 

178 # Normalize the redirect origin 

179 redirect_origin = _normalize_origin(redirect_scheme, redirect_host, redirect_port) 

180 

181 # Check against app_domain (if configured) 

182 if hasattr(settings, "app_domain") and settings.app_domain: 

183 # app_domain is typically just a hostname, allow both http and https 

184 app_domain = settings.app_domain.lower() 

185 if redirect_host == app_domain: 

186 # Only allow HTTPS in production, or HTTP for localhost 

187 if redirect_scheme == "https" or (redirect_scheme == "http" and app_domain in ("localhost", "127.0.0.1")): 187 ↛ 191line 187 didn't jump to line 191 because the condition on line 187 was always true

188 return True 

189 

190 # Check against allowed_origins (full origin match including scheme and port) 

191 if hasattr(settings, "allowed_origins") and settings.allowed_origins: 191 ↛ 208line 191 didn't jump to line 208 because the condition on line 191 was always true

192 for origin in settings.allowed_origins: 

193 origin = origin.strip() 

194 if not origin: 

195 continue 

196 

197 # Parse the allowed origin 

198 origin_parsed = urlparse(origin if "://" in origin else f"https://{origin}") 

199 origin_scheme = origin_parsed.scheme.lower() if origin_parsed.scheme else "https" 

200 origin_host = origin_parsed.hostname.lower() if origin_parsed.hostname else origin.lower() 

201 origin_port = origin_parsed.port 

202 

203 # Normalize and compare full origins 

204 allowed_origin = _normalize_origin(origin_scheme, origin_host, origin_port) 

205 if redirect_origin == allowed_origin: 

206 return True 

207 

208 return False 

209 

210 

211@sso_router.get("/login/{provider_id}", response_model=SSOLoginResponse) 

212async def initiate_sso_login( 

213 provider_id: str, 

214 request: Request, 

215 redirect_uri: str = Query(..., description="Callback URI after authentication"), 

216 scopes: Optional[str] = Query(None, description="Space-separated OAuth scopes"), 

217 db: Session = Depends(get_db), 

218) -> SSOLoginResponse: 

219 """Initiate SSO authentication flow. 

220 

221 Validates the redirect_uri against a server-side allowlist to prevent open redirect attacks. 

222 Only allows relative URIs, URIs matching app_domain, or URIs from configured allowed_origins. 

223 Does NOT trust the Host header for validation. 

224 

225 Args: 

226 provider_id: SSO provider identifier (e.g., 'github', 'google') 

227 request: FastAPI request object 

228 redirect_uri: Callback URI after successful authentication 

229 scopes: Optional custom OAuth scopes (space-separated) 

230 db: Database session 

231 

232 Returns: 

233 Authorization URL and state parameter for redirect. 

234 

235 Raises: 

236 HTTPException: If SSO is disabled, provider not found, or redirect_uri is invalid 

237 

238 Examples: 

239 >>> import asyncio 

240 >>> asyncio.iscoroutinefunction(initiate_sso_login) 

241 True 

242 """ 

243 if not settings.sso_enabled: 

244 raise HTTPException(status_code=404, detail="SSO authentication is disabled") 

245 

246 # Validate redirect_uri to prevent open redirect attacks 

247 # Uses server-side allowlist (allowed_origins, app_domain) - does NOT trust Host header 

248 if not _validate_redirect_uri(redirect_uri, request): 

249 logger.warning(f"SSO login rejected - invalid redirect_uri: {redirect_uri}") 

250 raise HTTPException( 

251 status_code=status.HTTP_400_BAD_REQUEST, 

252 detail="Invalid redirect_uri. Must be a relative path or URL matching allowed origins.", 

253 ) 

254 

255 sso_service = SSOService(db) 

256 scope_list = scopes.split() if scopes else None 

257 

258 auth_url = sso_service.get_authorization_url(provider_id, redirect_uri, scope_list) 

259 if not auth_url: 

260 raise HTTPException(status_code=404, detail=f"SSO provider '{provider_id}' not found or disabled") 

261 

262 # Extract state from URL for client reference 

263 # Standard 

264 import urllib.parse 

265 

266 parsed = urllib.parse.urlparse(auth_url) 

267 params = urllib.parse.parse_qs(parsed.query) 

268 state = params.get("state", [""])[0] 

269 

270 return SSOLoginResponse(authorization_url=auth_url, state=state) 

271 

272 

273@sso_router.get("/callback/{provider_id}") 

274async def handle_sso_callback( 

275 provider_id: str, 

276 code: str = Query(..., description="Authorization code from SSO provider"), 

277 state: str = Query(..., description="CSRF state parameter"), 

278 request: Request = None, 

279 response: Response = None, 

280 db: Session = Depends(get_db), 

281): 

282 """Handle SSO authentication callback. 

283 

284 Args: 

285 provider_id: SSO provider identifier 

286 code: Authorization code from provider 

287 state: CSRF state parameter for validation 

288 request: FastAPI request object 

289 response: FastAPI response object 

290 db: Database session 

291 

292 Returns: 

293 JWT access token and user information. 

294 

295 Raises: 

296 HTTPException: If SSO is disabled or authentication fails 

297 

298 Examples: 

299 >>> import asyncio 

300 >>> asyncio.iscoroutinefunction(handle_sso_callback) 

301 True 

302 """ 

303 if not settings.sso_enabled: 

304 raise HTTPException(status_code=404, detail="SSO authentication is disabled") 

305 

306 # Get root path for URL construction 

307 root_path = request.scope.get("root_path", "") if request else "" 

308 

309 sso_service = SSOService(db) 

310 

311 # Handle OAuth callback 

312 user_info = await sso_service.handle_oauth_callback(provider_id, code, state) 

313 if not user_info: 

314 # Redirect back to login with error 

315 # Third-Party 

316 from fastapi.responses import RedirectResponse 

317 

318 return RedirectResponse(url=f"{root_path}/admin/login?error=sso_failed", status_code=302) 

319 

320 # Authenticate or create user 

321 access_token = await sso_service.authenticate_or_create_user(user_info) 

322 if not access_token: 

323 # Redirect back to login with error 

324 # Third-Party 

325 from fastapi.responses import RedirectResponse 

326 

327 return RedirectResponse(url=f"{root_path}/admin/login?error=user_creation_failed", status_code=302) 

328 

329 # Create redirect response 

330 # Third-Party 

331 from fastapi.responses import RedirectResponse 

332 

333 redirect_response = RedirectResponse(url=f"{root_path}/admin", status_code=302) 

334 

335 # Set secure HTTP-only cookie using the same method as email auth 

336 # First-Party 

337 from mcpgateway.utils.security_cookies import CookieTooLargeError, set_auth_cookie 

338 

339 try: 

340 set_auth_cookie(redirect_response, access_token, remember_me=False) 

341 except CookieTooLargeError: 

342 redirect_response = RedirectResponse( 

343 url=f"{root_path}/admin/login?error=token_too_large", 

344 status_code=302, 

345 ) 

346 return redirect_response 

347 

348 return redirect_response 

349 

350 

351# Admin endpoints for SSO provider management 

352@sso_router.post("/admin/providers", response_model=Dict) 

353@require_permission("admin.sso_providers:create") 

354async def create_sso_provider( 

355 provider_data: SSOProviderCreateRequest, 

356 db: Session = Depends(get_db), 

357 user=Depends(get_current_user_with_permissions), 

358) -> Dict: 

359 """Create new SSO provider configuration (Admin only). 

360 

361 Args: 

362 provider_data: SSO provider configuration 

363 db: Database session 

364 user: Current authenticated user 

365 

366 Returns: 

367 Created provider information. 

368 

369 Raises: 

370 HTTPException: If provider already exists or creation fails 

371 """ 

372 sso_service = SSOService(db) 

373 

374 # Check if provider already exists 

375 existing = sso_service.get_provider(provider_data.id) 

376 if existing: 

377 raise HTTPException(status_code=409, detail=f"SSO provider '{provider_data.id}' already exists") 

378 

379 provider = await sso_service.create_provider(provider_data.model_dump()) 

380 

381 result = { 

382 "id": provider.id, 

383 "name": provider.name, 

384 "display_name": provider.display_name, 

385 "provider_type": provider.provider_type, 

386 "is_enabled": provider.is_enabled, 

387 "created_at": provider.created_at, 

388 } 

389 db.commit() 

390 db.close() 

391 return result 

392 

393 

394@sso_router.get("/admin/providers", response_model=List[Dict]) 

395@require_permission("admin.sso_providers:read") 

396async def list_all_sso_providers( 

397 db: Session = Depends(get_db), 

398 user=Depends(get_current_user_with_permissions), 

399) -> List[Dict]: 

400 """List all SSO providers including disabled ones (Admin only). 

401 

402 Args: 

403 db: Database session 

404 user: Current authenticated user 

405 

406 Returns: 

407 List of all SSO providers with configuration details. 

408 """ 

409 # Third-Party 

410 from sqlalchemy import select 

411 

412 # First-Party 

413 from mcpgateway.db import SSOProvider 

414 

415 stmt = select(SSOProvider) 

416 result = db.execute(stmt) 

417 providers = result.scalars().all() 

418 

419 result = [ 

420 { 

421 "id": provider.id, 

422 "name": provider.name, 

423 "display_name": provider.display_name, 

424 "provider_type": provider.provider_type, 

425 "is_enabled": provider.is_enabled, 

426 "trusted_domains": provider.trusted_domains, 

427 "auto_create_users": provider.auto_create_users, 

428 "created_at": provider.created_at, 

429 "updated_at": provider.updated_at, 

430 } 

431 for provider in providers 

432 ] 

433 db.commit() 

434 db.close() 

435 return result 

436 

437 

438@sso_router.get("/admin/providers/{provider_id}", response_model=Dict) 

439@require_permission("admin.sso_providers:read") 

440async def get_sso_provider( 

441 provider_id: str, 

442 db: Session = Depends(get_db), 

443 user=Depends(get_current_user_with_permissions), 

444) -> Dict: 

445 """Get SSO provider details (Admin only). 

446 

447 Args: 

448 provider_id: Provider identifier 

449 db: Database session 

450 user: Current authenticated user 

451 

452 Returns: 

453 Provider configuration details. 

454 

455 Raises: 

456 HTTPException: If provider not found 

457 """ 

458 sso_service = SSOService(db) 

459 provider = sso_service.get_provider(provider_id) 

460 

461 if not provider: 

462 raise HTTPException(status_code=404, detail=f"SSO provider '{provider_id}' not found") 

463 

464 result = { 

465 "id": provider.id, 

466 "name": provider.name, 

467 "display_name": provider.display_name, 

468 "provider_type": provider.provider_type, 

469 "client_id": provider.client_id, 

470 "authorization_url": provider.authorization_url, 

471 "token_url": provider.token_url, 

472 "userinfo_url": provider.userinfo_url, 

473 "issuer": provider.issuer, 

474 "scope": provider.scope, 

475 "trusted_domains": provider.trusted_domains, 

476 "auto_create_users": provider.auto_create_users, 

477 "team_mapping": provider.team_mapping, 

478 "is_enabled": provider.is_enabled, 

479 "created_at": provider.created_at, 

480 "updated_at": provider.updated_at, 

481 } 

482 db.commit() 

483 db.close() 

484 return result 

485 

486 

487@sso_router.put("/admin/providers/{provider_id}", response_model=Dict) 

488@require_permission("admin.sso_providers:update") 

489async def update_sso_provider( 

490 provider_id: str, 

491 provider_data: SSOProviderUpdateRequest, 

492 db: Session = Depends(get_db), 

493 user=Depends(get_current_user_with_permissions), 

494) -> Dict: 

495 """Update SSO provider configuration (Admin only). 

496 

497 Args: 

498 provider_id: Provider identifier 

499 provider_data: Updated provider configuration 

500 db: Database session 

501 user: Current authenticated user 

502 

503 Returns: 

504 Updated provider information. 

505 

506 Raises: 

507 HTTPException: If provider not found or update fails 

508 """ 

509 sso_service = SSOService(db) 

510 

511 # Filter out None values 

512 update_data = {k: v for k, v in provider_data.model_dump().items() if v is not None} 

513 if not update_data: 

514 raise HTTPException(status_code=400, detail="No update data provided") 

515 

516 provider = await sso_service.update_provider(provider_id, update_data) 

517 if not provider: 

518 raise HTTPException(status_code=404, detail=f"SSO provider '{provider_id}' not found") 

519 

520 result = { 

521 "id": provider.id, 

522 "name": provider.name, 

523 "display_name": provider.display_name, 

524 "provider_type": provider.provider_type, 

525 "is_enabled": provider.is_enabled, 

526 "updated_at": provider.updated_at, 

527 } 

528 db.commit() 

529 db.close() 

530 return result 

531 

532 

533@sso_router.delete("/admin/providers/{provider_id}") 

534@require_permission("admin.sso_providers:delete") 

535async def delete_sso_provider( 

536 provider_id: str, 

537 db: Session = Depends(get_db), 

538 user=Depends(get_current_user_with_permissions), 

539) -> Dict: 

540 """Delete SSO provider configuration (Admin only). 

541 

542 Args: 

543 provider_id: Provider identifier 

544 db: Database session 

545 user: Current authenticated user 

546 

547 Returns: 

548 Deletion confirmation. 

549 

550 Raises: 

551 HTTPException: If provider not found 

552 """ 

553 sso_service = SSOService(db) 

554 

555 if not sso_service.delete_provider(provider_id): 

556 raise HTTPException(status_code=404, detail=f"SSO provider '{provider_id}' not found") 

557 

558 db.commit() 

559 db.close() 

560 return {"message": f"SSO provider '{provider_id}' deleted successfully"} 

561 

562 

563# --------------------------------------------------------------------------- 

564# SSO User Approval Management Endpoints 

565# --------------------------------------------------------------------------- 

566 

567 

568class PendingUserApprovalResponse(BaseModel): 

569 """Response model for pending user approval.""" 

570 

571 id: str 

572 email: str 

573 full_name: str 

574 auth_provider: str 

575 requested_at: str 

576 expires_at: str 

577 status: str 

578 sso_metadata: Optional[Dict] = None 

579 

580 

581class ApprovalActionRequest(BaseModel): 

582 """Request model for approval actions.""" 

583 

584 action: str # "approve" or "reject" 

585 reason: Optional[str] = None # Required for rejection 

586 notes: Optional[str] = None 

587 

588 

589@sso_router.get("/pending-approvals", response_model=List[PendingUserApprovalResponse]) 

590@require_permission("admin.user_management") 

591async def list_pending_approvals( 

592 include_expired: bool = Query(False, description="Include expired approval requests"), 

593 db: Session = Depends(get_db), 

594 user=Depends(get_current_user_with_permissions), 

595) -> List[PendingUserApprovalResponse]: 

596 """List pending SSO user approval requests (Admin only). 

597 

598 Args: 

599 include_expired: Whether to include expired requests 

600 db: Database session 

601 user: Current authenticated admin user 

602 

603 Returns: 

604 List of pending approval requests 

605 """ 

606 # Third-Party 

607 from sqlalchemy import select 

608 

609 # First-Party 

610 from mcpgateway.db import PendingUserApproval 

611 

612 query = select(PendingUserApproval) 

613 

614 if not include_expired: 614 ↛ 616line 614 didn't jump to line 616 because the condition on line 614 was never true

615 # First-Party 

616 from mcpgateway.db import utc_now 

617 

618 query = query.where(PendingUserApproval.expires_at > utc_now()) 

619 

620 # Filter by status 

621 query = query.where(PendingUserApproval.status == "pending") 

622 query = query.order_by(PendingUserApproval.requested_at.desc()) 

623 

624 result = db.execute(query) 

625 pending_approvals = result.scalars().all() 

626 

627 return [ 

628 PendingUserApprovalResponse( 

629 id=approval.id, 

630 email=approval.email, 

631 full_name=approval.full_name, 

632 auth_provider=approval.auth_provider, 

633 requested_at=approval.requested_at.isoformat(), 

634 expires_at=approval.expires_at.isoformat(), 

635 status=approval.status, 

636 sso_metadata=approval.sso_metadata, 

637 ) 

638 for approval in pending_approvals 

639 ] 

640 

641 

642@sso_router.post("/pending-approvals/{approval_id}/action") 

643@require_permission("admin.user_management") 

644async def handle_approval_request( 

645 approval_id: str, 

646 request: ApprovalActionRequest, 

647 db: Session = Depends(get_db), 

648 user=Depends(get_current_user_with_permissions), 

649) -> Dict: 

650 """Approve or reject a pending SSO user registration (Admin only). 

651 

652 Args: 

653 approval_id: ID of the approval request 

654 request: Approval action (approve/reject) with optional reason/notes 

655 db: Database session 

656 user: Current authenticated admin user 

657 

658 Returns: 

659 Action confirmation message 

660 

661 Raises: 

662 HTTPException: If approval not found or invalid action 

663 """ 

664 # Third-Party 

665 from sqlalchemy import select 

666 

667 # First-Party 

668 from mcpgateway.db import PendingUserApproval 

669 

670 # Get pending approval 

671 approval = db.execute(select(PendingUserApproval).where(PendingUserApproval.id == approval_id)).scalar_one_or_none() 

672 

673 if not approval: 

674 raise HTTPException(status_code=404, detail="Approval request not found") 

675 

676 if approval.status != "pending": 

677 raise HTTPException(status_code=400, detail=f"Approval request is already {approval.status}") 

678 

679 if approval.is_expired(): 

680 approval.status = "expired" 

681 db.commit() 

682 raise HTTPException(status_code=400, detail="Approval request has expired") 

683 

684 admin_email = user["email"] 

685 

686 if request.action == "approve": 

687 approval.approve(admin_email, request.notes) 

688 db.commit() 

689 return {"message": f"User {approval.email} approved successfully"} 

690 

691 elif request.action == "reject": 

692 if not request.reason: 

693 raise HTTPException(status_code=400, detail="Rejection reason is required") 

694 approval.reject(admin_email, request.reason, request.notes) 

695 db.commit() 

696 return {"message": f"User {approval.email} rejected"} 

697 

698 else: 

699 raise HTTPException(status_code=400, detail="Invalid action. Must be 'approve' or 'reject'")