Coverage for mcpgateway / middleware / rbac.py: 100%
316 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/middleware/rbac.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7RBAC Permission Checking Middleware.
9This module provides middleware for FastAPI to enforce role-based access control
10on API endpoints. It includes permission decorators and dependency injection
11functions for protecting routes.
12"""
14# Standard
15import functools
16from functools import wraps
17import logging
18from typing import Callable, Generator, List, Optional
19import uuid
21# Third-Party
22from fastapi import Cookie, Depends, HTTPException, Request, status
23from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
24from sqlalchemy.orm import Session
26# First-Party
27from mcpgateway.auth import get_current_user
28from mcpgateway.config import settings
29from mcpgateway.db import fresh_db_session, SessionLocal
30from mcpgateway.services.permission_service import PermissionService
31from mcpgateway.utils.verify_credentials import is_proxy_auth_trust_active
33logger = logging.getLogger(__name__)
35# Generic 403 message — intentionally vague to avoid leaking permission names to callers
36_ACCESS_DENIED_MSG = "Access denied"
38# HTTP Bearer security scheme for token extraction
39security = HTTPBearer(auto_error=False)
42def get_db() -> Generator[Session, None, None]:
43 """Get database session for dependency injection.
45 DEPRECATED: Use fresh_db_session() context manager instead to avoid session accumulation.
46 This function is kept for backwards compatibility with endpoints that still use Depends(get_db).
48 Commits the transaction on successful completion to avoid implicit rollbacks
49 for read-only operations. Rolls back explicitly on exception.
51 Yields:
52 Session: SQLAlchemy database session
54 Raises:
55 Exception: Re-raises any exception after rolling back the transaction.
57 Examples:
58 >>> gen = get_db()
59 >>> db = next(gen)
60 >>> hasattr(db, 'query')
61 True
62 """
63 db = SessionLocal()
64 try:
65 yield db
66 db.commit()
67 except Exception:
68 try:
69 db.rollback()
70 except Exception:
71 try:
72 db.invalidate()
73 except Exception:
74 pass # nosec B110 - Best effort cleanup on connection failure
75 raise
76 finally:
77 db.close()
80async def get_permission_service(db: Session = Depends(get_db)) -> PermissionService:
81 """Get permission service instance for dependency injection.
83 DEPRECATED: Use PermissionService(db) directly with fresh_db_session() context manager instead.
84 This function is kept for backwards compatibility with endpoints that still use dependency injection.
86 Args:
87 db: Database session
89 Returns:
90 PermissionService: Permission checking service instance
92 Examples:
93 >>> import asyncio
94 >>> asyncio.iscoroutinefunction(get_permission_service)
95 True
96 """
97 return PermissionService(db)
100async def get_current_user_with_permissions(request: Request, credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), jwt_token: Optional[str] = Cookie(default=None)):
101 """Extract current user from JWT token and prepare for permission checking.
103 Uses fresh_db_session() context manager to avoid session accumulation under high load.
104 Database sessions are created only when needed and closed immediately after use.
106 Args:
107 request: FastAPI request object for IP/user-agent extraction
108 credentials: HTTP Bearer credentials
109 jwt_token: JWT token from cookie
111 Returns:
112 dict: User information with permission checking context
114 Raises:
115 HTTPException: If authentication fails
117 Examples:
118 Use as FastAPI dependency::
120 @app.get("/protected-endpoint")
121 async def protected_route(user = Depends(get_current_user_with_permissions)):
122 return {"user": user["email"]}
123 """
124 # Check for proxy authentication first (if MCP client auth is disabled)
125 if not settings.mcp_client_auth_enabled:
126 # Read plugin context from request.state for cross-hook context sharing
127 # (set by HttpAuthMiddleware for passing contexts between different hook types)
128 plugin_context_table = getattr(request.state, "plugin_context_table", None)
129 plugin_global_context = getattr(request.state, "plugin_global_context", None)
131 if is_proxy_auth_trust_active(settings):
132 # Extract user from proxy header
133 proxy_user = request.headers.get(settings.proxy_user_header)
134 if proxy_user:
135 # Lookup user in DB to get is_admin status, or check platform_admin_email
136 is_admin = False
137 full_name = proxy_user
138 if proxy_user == settings.platform_admin_email:
139 is_admin = True
140 full_name = "Platform Admin"
141 else:
142 # Try to lookup user in EmailUser table for is_admin status
143 try:
144 # Third-Party
145 from sqlalchemy import select # pylint: disable=import-outside-toplevel
147 # First-Party
148 from mcpgateway.db import EmailUser # pylint: disable=import-outside-toplevel
150 # Use fresh_db_session for short-lived database access
151 with fresh_db_session() as db:
152 user = db.execute(select(EmailUser).where(EmailUser.email == proxy_user)).scalar_one_or_none()
153 if user:
154 is_admin = user.is_admin
155 full_name = user.full_name or proxy_user
156 except Exception as e:
157 logger.debug(f"Could not lookup proxy user in DB: {e}")
158 # Continue with is_admin=False if lookup fails
160 return {
161 "email": proxy_user,
162 "full_name": full_name,
163 "is_admin": is_admin,
164 "ip_address": request.client.host if request.client else None,
165 "user_agent": request.headers.get("user-agent"),
166 "db": None, # Session closed; use endpoint's db param instead
167 "auth_method": "proxy",
168 "request_id": getattr(request.state, "request_id", None),
169 "team_id": getattr(request.state, "team_id", None),
170 "plugin_context_table": plugin_context_table,
171 "plugin_global_context": plugin_global_context,
172 }
174 # No proxy header - check auth_required to align with WebSocket behavior
175 # For browser requests, redirect to login; for API requests, return 401
176 if settings.auth_required:
177 accept_header = request.headers.get("accept", "")
178 is_htmx = request.headers.get("hx-request") == "true"
179 if "text/html" in accept_header or is_htmx:
180 raise HTTPException(
181 status_code=status.HTTP_302_FOUND,
182 detail="Authentication required",
183 headers={"Location": f"{settings.app_root_path}/admin/login"},
184 )
185 raise HTTPException(
186 status_code=status.HTTP_401_UNAUTHORIZED,
187 detail="Proxy authentication header required",
188 )
190 # auth_required=false: allow anonymous access
192 return {
193 "email": "anonymous",
194 "full_name": "Anonymous User",
195 "is_admin": False,
196 "ip_address": request.client.host if request.client else None,
197 "user_agent": request.headers.get("user-agent"),
198 "db": None, # Session closed; use endpoint's db param instead
199 "auth_method": "anonymous",
200 "request_id": getattr(request.state, "request_id", None),
201 "team_id": getattr(request.state, "team_id", None),
202 "plugin_context_table": plugin_context_table,
203 "plugin_global_context": plugin_global_context,
204 }
206 # Warning: MCP auth disabled without proxy trust - security risk!
207 # This case is already warned about in config validation
208 # Still check auth_required for consistency
209 if settings.auth_required:
210 accept_header = request.headers.get("accept", "")
211 is_htmx = request.headers.get("hx-request") == "true"
212 if "text/html" in accept_header or is_htmx:
213 raise HTTPException(
214 status_code=status.HTTP_302_FOUND,
215 detail="Authentication required",
216 headers={"Location": f"{settings.app_root_path}/admin/login"},
217 )
218 raise HTTPException(
219 status_code=status.HTTP_401_UNAUTHORIZED,
220 detail="Authentication required but no auth method configured",
221 )
223 return {
224 "email": "anonymous",
225 "full_name": "Anonymous User",
226 "is_admin": False,
227 "ip_address": request.client.host if request.client else None,
228 "user_agent": request.headers.get("user-agent"),
229 "db": None, # Session closed; use endpoint's db param instead
230 "auth_method": "anonymous",
231 "request_id": getattr(request.state, "request_id", None),
232 "team_id": getattr(request.state, "team_id", None),
233 "plugin_context_table": plugin_context_table,
234 "plugin_global_context": plugin_global_context,
235 }
237 # Standard JWT authentication flow
238 # Try multiple sources for the token, prioritizing Authorization header for API requests
239 token = None
240 token_from_cookie = False
242 # 1. First try Authorization header (preferred for API requests)
243 if credentials and credentials.credentials:
244 token = credentials.credentials
246 # 2. Try manual cookie reading (for browser requests)
247 if not token and request.cookies:
248 # Try both jwt_token and access_token cookie names
249 manual_token = request.cookies.get("jwt_token") or request.cookies.get("access_token")
250 if manual_token:
251 token = manual_token
252 token_from_cookie = True
254 # 3. Finally try FastAPI Cookie dependency (fallback)
255 if not token and jwt_token:
256 token = jwt_token
257 token_from_cookie = True
259 # Check if this is a browser/admin-UI request (not an external API request)
260 accept_header = request.headers.get("accept", "")
261 is_htmx = request.headers.get("hx-request") == "true"
262 referer = request.headers.get("referer", "")
263 is_admin_ui_request = "/admin" in referer
264 is_browser_request = "text/html" in accept_header or is_htmx or is_admin_ui_request
266 # SECURITY: Reject cookie-only authentication for API requests
267 # Cookies should only be used for browser/HTML requests (including admin UI fetch calls)
268 if token_from_cookie and not is_browser_request:
269 raise HTTPException(
270 status_code=status.HTTP_401_UNAUTHORIZED,
271 detail="Cookie authentication not allowed for API requests. Use Authorization header.",
272 headers={"WWW-Authenticate": "Bearer"},
273 )
275 if not token:
276 # For browser requests (HTML Accept header or HTMX), redirect to login
277 if is_browser_request:
278 raise HTTPException(status_code=status.HTTP_302_FOUND, detail="Authentication required", headers={"Location": f"{settings.app_root_path}/admin/login"})
280 # AUTH_REQUIRED=false no longer implies admin access.
281 # Preserve explicit unsafe override for local-only compatibility.
282 if not settings.auth_required and getattr(settings, "allow_unauthenticated_admin", False) is True:
283 return {
284 "email": settings.platform_admin_email,
285 "full_name": "Platform Admin",
286 "is_admin": True,
287 "ip_address": request.client.host if request.client else None,
288 "user_agent": request.headers.get("user-agent"),
289 "db": None, # Session closed; use endpoint's db param instead
290 "auth_method": "disabled",
291 "request_id": getattr(request.state, "request_id", None),
292 "team_id": getattr(request.state, "team_id", None),
293 }
295 if not settings.auth_required:
296 return {
297 "email": "anonymous",
298 "full_name": "Anonymous User",
299 "is_admin": False,
300 "ip_address": request.client.host if request.client else None,
301 "user_agent": request.headers.get("user-agent"),
302 "db": None, # Session closed; use endpoint's db param instead
303 "auth_method": "anonymous",
304 "request_id": getattr(request.state, "request_id", None),
305 "team_id": getattr(request.state, "team_id", None),
306 "plugin_context_table": getattr(request.state, "plugin_context_table", None),
307 "plugin_global_context": getattr(request.state, "plugin_global_context", None),
308 }
310 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authorization token required")
312 try:
313 # Create credentials object if we got token from cookie
314 if not credentials:
315 credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
317 # Extract user from token using the email auth function
318 # Pass request to get_current_user so plugins can store auth_method in request.state
319 user = await get_current_user(credentials, request=request)
321 # Read auth_method and request_id from request.state
322 # (auth_method set by plugin in get_current_user, request_id set by HTTP middleware)
323 auth_method = getattr(request.state, "auth_method", None)
324 request_id = getattr(request.state, "request_id", None)
325 team_id = getattr(request.state, "team_id", None)
326 token_teams = getattr(request.state, "token_teams", None)
328 # Read plugin context data from request.state for cross-hook context sharing
329 # (set by HttpAuthMiddleware for passing contexts between different hook types)
330 plugin_context_table = getattr(request.state, "plugin_context_table", None)
331 plugin_global_context = getattr(request.state, "plugin_global_context", None)
333 # Get token_use from request.state (set by get_current_user)
334 token_use = getattr(request.state, "token_use", None)
336 # Add request context for permission auditing
337 return {
338 "email": user.email,
339 "full_name": user.full_name,
340 "is_admin": user.is_admin,
341 "ip_address": request.client.host if request.client else None,
342 "user_agent": request.headers.get("user-agent"),
343 "db": None, # Session closed; use endpoint's db param instead
344 "auth_method": auth_method, # Include auth_method from plugin
345 "request_id": request_id, # Include request_id from middleware
346 "team_id": team_id, # Include team_id from token
347 "token_teams": token_teams, # Include token teams for query-level scoping
348 "token_use": token_use, # Include token_use for RBAC team derivation
349 "plugin_context_table": plugin_context_table, # Plugin contexts for cross-hook sharing
350 "plugin_global_context": plugin_global_context, # Global context for consistency
351 }
352 except Exception as e:
353 logger.error(f"Authentication failed: {type(e).__name__}: {e}")
355 # For browser requests (HTML Accept header or HTMX), redirect to login
356 accept_header = request.headers.get("accept", "")
357 is_htmx = request.headers.get("hx-request") == "true"
358 if "text/html" in accept_header or is_htmx:
359 raise HTTPException(status_code=status.HTTP_302_FOUND, detail="Authentication required", headers={"Location": f"{settings.app_root_path}/admin/login"})
361 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials")
364# --- Team derivation helpers for multi-team session tokens ---
367@functools.lru_cache(maxsize=1)
368def _get_resource_param_to_model():
369 """Lazy-initialize the resource param to model mapping.
371 Returns:
372 dict: Mapping of URL parameter names to SQLAlchemy model classes.
373 """
374 # First-Party
375 from mcpgateway.db import A2AAgent, Gateway, Prompt, Resource, Server, Tool # pylint: disable=import-outside-toplevel
377 return {
378 "tool_id": Tool,
379 "server_id": Server,
380 "resource_id": Resource,
381 "prompt_id": Prompt,
382 "gateway_id": Gateway,
383 "agent_id": A2AAgent,
384 }
387def _derive_team_from_resource(kwargs, db_session) -> Optional[str]:
388 """Look up resource's team_id from DB for RBAC context (Tier 1).
390 For endpoints that target a specific resource (get, update, delete, execute),
391 derive the team context from the resource's owner team.
393 Args:
394 kwargs: Endpoint function kwargs containing resource ID params
395 db_session: Active SQLAlchemy session
397 Returns:
398 team_id string if found, None otherwise
399 """
400 mapping = _get_resource_param_to_model()
401 for param_name, model_cls in mapping.items():
402 resource_id = kwargs.get(param_name)
403 if resource_id:
404 try:
405 resource = db_session.get(model_cls, resource_id)
406 if resource:
407 return getattr(resource, "team_id", None)
408 except Exception: # nosec B110 - DB lookup failure falls through to None
409 pass
410 return None # Resource not found; let endpoint handle 404
411 return None # No resource ID param
414async def _derive_team_from_payload(kwargs) -> Optional[str]:
415 """Extract team_id from create payload objects or form data (Tier 3).
417 For create endpoints, derive team context from the Pydantic payload or form data.
419 Args:
420 kwargs: Endpoint function kwargs
422 Returns:
423 team_id string if found, None otherwise
424 """
425 # Try Pydantic payload objects (API endpoints)
426 for param_name in ("gateway", "tool", "server", "resource", "prompt", "agent"):
427 payload_obj = kwargs.get(param_name)
428 if payload_obj and hasattr(payload_obj, "team_id"):
429 tid = getattr(payload_obj, "team_id", None)
430 if tid:
431 return tid
433 # Try request form data (admin UI endpoints)
434 # Note: use 'is not None' rather than truthiness check because some
435 # objects (e.g. Pydantic models) may be truthy yet lack .headers.
436 request = kwargs.get("request")
437 if request is not None and isinstance(request, Request):
438 content_type = request.headers.get("content-type", "")
439 if "form" in content_type:
440 try:
441 form = await request.form()
442 tid = form.get("team_id")
443 if tid:
444 return tid
445 except Exception: # nosec B110 - Form parse failure is non-fatal
446 pass
448 return None
451# Permissions that indicate create/mutate operations (not safe for "any-team" aggregation)
452_MUTATE_PERMISSION_ACTIONS = frozenset(
453 {
454 "create",
455 "update",
456 "delete",
457 "execute",
458 "invoke",
459 "toggle",
460 "set_state",
461 "revoke",
462 "manage_members",
463 "join",
464 "manage",
465 "share",
466 "invite",
467 "use",
468 }
469)
472def _is_mutate_permission(permission: str) -> bool:
473 """Check if a permission string represents a mutate operation.
475 Handles both dot-separated (tools.create) and colon-separated
476 (admin.sso_providers:create) permission formats.
478 Args:
479 permission: Permission string like 'tools.create' or 'admin.sso_providers:create'.
481 Returns:
482 bool: True if the permission's action component is a mutating operation.
483 """
484 # Handle colon separator: admin.sso_providers:create → action is "create"
485 if ":" in permission:
486 action = permission.rsplit(":", 1)[-1]
487 return action in _MUTATE_PERMISSION_ACTIONS
488 parts = permission.split(".")
489 return parts[-1] in _MUTATE_PERMISSION_ACTIONS if len(parts) >= 2 else False
492def require_permission(permission: str, resource_type: Optional[str] = None, allow_admin_bypass: bool = True):
493 """Decorator to require specific permission for accessing an endpoint.
495 Args:
496 permission: Required permission (e.g., 'tools.create')
497 resource_type: Optional resource type for resource-specific permissions
498 allow_admin_bypass: If True (default), admin users bypass all permission checks.
499 If False, even admins must have explicit permissions.
500 Use False for admin UI routes to enforce granular RBAC.
502 Returns:
503 Callable: Decorated function that enforces the permission requirement
505 Examples:
506 >>> decorator = require_permission("tools.create", "tools")
507 >>> callable(decorator)
508 True
510 Execute wrapped function when permission granted:
511 >>> import asyncio
512 >>> class DummyPS:
513 ... def __init__(self, db):
514 ... pass
515 ... async def check_permission(self, **kwargs):
516 ... return True
517 >>> @require_permission("tools.read")
518 ... async def demo(user=None):
519 ... return "ok"
520 >>> from unittest.mock import patch
521 >>> with patch('mcpgateway.middleware.rbac.PermissionService', DummyPS):
522 ... asyncio.run(demo(user={"email": "u", "db": object()}))
523 'ok'
524 """
526 def decorator(func: Callable) -> Callable:
527 """Decorator function that wraps the original function with permission checking.
529 Args:
530 func: The function to be decorated
532 Returns:
533 Callable: The wrapped function with permission checking
534 """
536 @wraps(func)
537 async def wrapper(*args, **kwargs):
538 """Async wrapper function that performs permission check before calling original function.
540 Args:
541 *args: Positional arguments passed to the wrapped function
542 **kwargs: Keyword arguments passed to the wrapped function
544 Returns:
545 Any: Result from the wrapped function if permission check passes
547 Raises:
548 HTTPException: If user authentication or permission check fails
549 """
550 # Extract user context from named kwargs only (security: avoid picking up request body dicts)
551 user_context = kwargs.get("user") or kwargs.get("_user") or kwargs.get("current_user") or kwargs.get("current_user_ctx")
552 if not user_context or not isinstance(user_context, dict) or "email" not in user_context:
553 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User authentication required")
555 # Extract team_id from path parameters if available
556 team_id = kwargs.get("team_id")
558 # If team_id is None or blank in kwargs then check
559 if not team_id:
560 # check if user_context has team_id
561 team_id = user_context.get("team_id", None)
563 # For multi-team session tokens (team_id is None), derive team from context
564 check_any_team = False
565 if not team_id and user_context.get("token_use") == "session":
566 db_session = kwargs.get("db") or user_context.get("db")
567 if db_session:
568 # Tier 1: Try to derive team from existing resource
569 team_id = _derive_team_from_resource(kwargs, db_session)
570 # Tier 3: Try to derive team from create payload / form
571 if team_id is None:
572 team_id = await _derive_team_from_payload(kwargs)
573 # If still no team_id, check permission across all of the user's teams.
574 # This separates authorization ("does this user have the permission?")
575 # from resource scoping ("which team owns this resource?"). Team
576 # assignment is enforced downstream by endpoint logic (e.g.
577 # verify_team_for_user, token team membership checks).
578 if not team_id:
579 check_any_team = True
581 # First, check if any plugins want to handle permission checking
582 # First-Party
583 from mcpgateway.plugins.framework import get_plugin_manager, GlobalContext, HttpAuthCheckPermissionPayload, HttpHookType # pylint: disable=import-outside-toplevel
585 plugin_manager = get_plugin_manager()
586 if plugin_manager and plugin_manager.has_hooks_for(HttpHookType.HTTP_AUTH_CHECK_PERMISSION):
587 # Get plugin contexts from user_context (stored in request.state by HttpAuthMiddleware)
588 # These enable cross-hook context sharing between HTTP_PRE_REQUEST and HTTP_AUTH_CHECK_PERMISSION
589 plugin_context_table = user_context.get("plugin_context_table")
590 plugin_global_context = user_context.get("plugin_global_context")
592 # Reuse existing global context from middleware if available for consistency
593 # Otherwise create a new one (fallback for cases where middleware didn't run)
594 if plugin_global_context:
595 global_context = plugin_global_context
596 else:
597 request_id = user_context.get("request_id") or uuid.uuid4().hex
598 global_context = GlobalContext(
599 request_id=request_id,
600 server_id=None,
601 tenant_id=None,
602 )
604 # Invoke permission check hook, passing plugin contexts from HTTP_PRE_REQUEST hook
605 result, _ = await plugin_manager.invoke_hook(
606 HttpHookType.HTTP_AUTH_CHECK_PERMISSION,
607 payload=HttpAuthCheckPermissionPayload(
608 user_email=user_context["email"],
609 permission=permission,
610 resource_type=resource_type,
611 team_id=team_id,
612 is_admin=user_context.get("is_admin", False),
613 auth_method=user_context.get("auth_method"),
614 client_host=user_context.get("ip_address"),
615 user_agent=user_context.get("user_agent"),
616 ),
617 global_context=global_context,
618 local_contexts=plugin_context_table, # Pass context table for cross-hook state
619 )
621 # If a plugin made a decision, respect it
622 if result and result.modified_payload and hasattr(result.modified_payload, "granted"):
623 decision_plugin = "unknown"
624 decision_reason = getattr(result.modified_payload, "reason", None)
625 result_metadata = result.metadata if isinstance(result.metadata, dict) else {}
626 if result_metadata.get("_decision_plugin"):
627 decision_plugin = str(result_metadata["_decision_plugin"])
628 for key in ("plugin_name", "plugin", "source_plugin", "handler"):
629 if decision_plugin != "unknown":
630 break
631 plugin_name = result_metadata.get(key)
632 if plugin_name:
633 decision_plugin = str(plugin_name)
635 logger.info(
636 "Plugin permission decision: plugin=%s user=%s permission=%s granted=%s reason=%s",
637 decision_plugin,
638 user_context["email"],
639 permission,
640 result.modified_payload.granted,
641 decision_reason,
642 )
644 if result.modified_payload.granted:
645 if settings.plugins_can_override_rbac:
646 logger.warning(
647 "Plugin RBAC grant override applied: plugin=%s user=%s permission=%s reason=%s",
648 decision_plugin,
649 user_context["email"],
650 permission,
651 decision_reason,
652 )
653 return await func(*args, **kwargs)
655 logger.info(
656 "Plugin RBAC grant decision ignored by default policy: plugin=%s user=%s permission=%s",
657 decision_plugin,
658 user_context["email"],
659 permission,
660 )
661 else:
662 logger.warning(
663 "Permission denied by plugin: plugin=%s user=%s permission=%s reason=%s",
664 decision_plugin,
665 user_context["email"],
666 permission,
667 decision_reason,
668 )
669 raise HTTPException(
670 status_code=status.HTTP_403_FORBIDDEN,
671 detail=_ACCESS_DENIED_MSG,
672 )
674 # No plugin handled it, fall through to standard RBAC check
675 # Get db session: prefer endpoint's db param, then user_context["db"], then create fresh
676 db_session = kwargs.get("db") or user_context.get("db")
677 if db_session:
678 # Use existing session from endpoint or user_context
679 permission_service = PermissionService(db_session)
680 granted = await permission_service.check_permission(
681 user_email=user_context["email"],
682 permission=permission,
683 resource_type=resource_type,
684 team_id=team_id,
685 token_teams=user_context.get("token_teams"),
686 ip_address=user_context.get("ip_address"),
687 user_agent=user_context.get("user_agent"),
688 allow_admin_bypass=allow_admin_bypass,
689 check_any_team=check_any_team,
690 )
691 else:
692 # Create fresh db session for permission check
693 with fresh_db_session() as db:
694 permission_service = PermissionService(db)
695 granted = await permission_service.check_permission(
696 user_email=user_context["email"],
697 permission=permission,
698 resource_type=resource_type,
699 team_id=team_id,
700 token_teams=user_context.get("token_teams"),
701 ip_address=user_context.get("ip_address"),
702 user_agent=user_context.get("user_agent"),
703 allow_admin_bypass=allow_admin_bypass,
704 check_any_team=check_any_team,
705 )
707 if not granted:
708 logger.warning(f"Permission denied: user={user_context['email']}, permission={permission}, resource_type={resource_type}")
709 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=_ACCESS_DENIED_MSG)
711 # Permission granted, execute the original function
712 return await func(*args, **kwargs)
714 return wrapper
716 return decorator
719def require_admin_permission():
720 """Decorator to require admin permissions for accessing an endpoint.
722 Returns:
723 Callable: Decorated function that enforces admin permission requirement
725 Examples:
726 >>> decorator = require_admin_permission()
727 >>> callable(decorator)
728 True
730 Execute when admin permission granted:
731 >>> import asyncio
732 >>> class DummyPS:
733 ... def __init__(self, db):
734 ... pass
735 ... async def check_admin_permission(self, email):
736 ... return True
737 >>> @require_admin_permission()
738 ... async def demo(user=None):
739 ... return "admin-ok"
740 >>> from unittest.mock import patch
741 >>> with patch('mcpgateway.middleware.rbac.PermissionService', DummyPS):
742 ... asyncio.run(demo(user={"email": "u", "db": object()}))
743 'admin-ok'
744 """
746 def decorator(func: Callable) -> Callable:
747 """Decorator function that wraps the original function with admin permission checking.
749 Args:
750 func: The function to be decorated
752 Returns:
753 Callable: The wrapped function with admin permission checking
754 """
756 @wraps(func)
757 async def wrapper(*args, **kwargs):
758 """Async wrapper function that performs admin permission check before calling original function.
760 Args:
761 *args: Positional arguments passed to the wrapped function
762 **kwargs: Keyword arguments passed to the wrapped function
764 Returns:
765 Any: Result from the wrapped function if admin permission check passes
767 Raises:
768 HTTPException: If user authentication or admin permission check fails
769 """
770 # Extract user context from named kwargs only (security: avoid picking up request body dicts)
771 user_context = kwargs.get("user") or kwargs.get("_user") or kwargs.get("current_user") or kwargs.get("current_user_ctx")
772 if not user_context or not isinstance(user_context, dict) or "email" not in user_context:
773 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User authentication required")
775 # Get db session: prefer endpoint's db param, then user_context["db"], then create fresh
776 db_session = kwargs.get("db") or user_context.get("db")
777 if db_session:
778 # Use existing session from endpoint or user_context
779 permission_service = PermissionService(db_session)
780 has_admin_permission = await permission_service.check_admin_permission(user_context["email"])
781 else:
782 # Create fresh db session for permission check
783 with fresh_db_session() as db:
784 permission_service = PermissionService(db)
785 has_admin_permission = await permission_service.check_admin_permission(user_context["email"])
787 if not has_admin_permission:
788 logger.warning(f"Admin permission denied: user={user_context['email']}")
789 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=_ACCESS_DENIED_MSG)
791 # Admin permission granted, execute the original function
792 return await func(*args, **kwargs)
794 return wrapper
796 return decorator
799def require_any_permission(permissions: List[str], resource_type: Optional[str] = None, allow_admin_bypass: bool = True):
800 """Decorator to require any of the specified permissions for accessing an endpoint.
802 Args:
803 permissions: List of permissions, user needs at least one
804 resource_type: Optional resource type for resource-specific permissions
805 allow_admin_bypass: If True (default), admin users bypass all permission checks.
806 If False, even admins must have explicit permissions.
808 Returns:
809 Callable: Decorated function that enforces the permission requirements
811 Examples:
812 >>> decorator = require_any_permission(["tools.read", "tools.execute"], "tools")
813 >>> callable(decorator)
814 True
816 Execute when any permission granted:
817 >>> import asyncio
818 >>> class DummyPS:
819 ... def __init__(self, db):
820 ... pass
821 ... async def check_permission(self, **kwargs):
822 ... return True
823 >>> @require_any_permission(["tools.read", "tools.execute"], "tools")
824 ... async def demo(user=None):
825 ... return "any-ok"
826 >>> from unittest.mock import patch
827 >>> with patch('mcpgateway.middleware.rbac.PermissionService', DummyPS):
828 ... asyncio.run(demo(user={"email": "u", "db": object()}))
829 'any-ok'
830 """
832 def decorator(func: Callable) -> Callable:
833 """Decorator function that wraps the original function with any-permission checking.
835 Args:
836 func: The function to be decorated
838 Returns:
839 Callable: The wrapped function with any-permission checking
840 """
842 @wraps(func)
843 async def wrapper(*args, **kwargs):
844 """Async wrapper function that performs any-permission check before calling original function.
846 Args:
847 *args: Positional arguments passed to the wrapped function
848 **kwargs: Keyword arguments passed to the wrapped function
850 Returns:
851 Any: Result from the wrapped function if any-permission check passes
853 Raises:
854 HTTPException: If user authentication or any-permission check fails
855 """
856 # Extract user context from named kwargs only (security: avoid picking up request body dicts)
857 user_context = kwargs.get("user") or kwargs.get("_user") or kwargs.get("current_user") or kwargs.get("current_user_ctx")
858 if not user_context or not isinstance(user_context, dict) or "email" not in user_context:
859 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User authentication required")
861 # Extract team_id from path parameters if available
862 team_id = kwargs.get("team_id")
864 # If team_id is None or blank in kwargs then check
865 if not team_id:
866 # check if user_context has team_id
867 team_id = user_context.get("team_id", None)
869 # For multi-team session tokens (team_id is None), derive team from context
870 check_any_team = False
871 if not team_id and user_context.get("token_use") == "session":
872 db_session = kwargs.get("db") or user_context.get("db")
873 if db_session:
874 # Tier 1: Try to derive team from existing resource
875 team_id = _derive_team_from_resource(kwargs, db_session)
876 # Tier 3: Try to derive team from create payload / form
877 if team_id is None:
878 team_id = await _derive_team_from_payload(kwargs)
879 # If still no team_id, check permission across all of the user's teams.
880 # Authorization ("does this user have the permission?") is separate
881 # from resource scoping ("which team owns this resource?").
882 if not team_id:
883 check_any_team = True
885 # Get db session: prefer endpoint's db param, then user_context["db"], then create fresh
886 db_session = kwargs.get("db") or user_context.get("db")
887 if db_session:
888 # Use existing session from endpoint or user_context
889 permission_service = PermissionService(db_session)
890 # Check if user has any of the required permissions
891 granted = False
892 for permission in permissions:
893 if await permission_service.check_permission(
894 user_email=user_context["email"],
895 permission=permission,
896 resource_type=resource_type,
897 team_id=team_id,
898 token_teams=user_context.get("token_teams"),
899 ip_address=user_context.get("ip_address"),
900 user_agent=user_context.get("user_agent"),
901 allow_admin_bypass=allow_admin_bypass,
902 check_any_team=check_any_team,
903 ):
904 granted = True
905 break
906 else:
907 # Create fresh db session for permission check
908 with fresh_db_session() as db:
909 permission_service = PermissionService(db)
910 # Check if user has any of the required permissions
911 granted = False
912 for permission in permissions:
913 if await permission_service.check_permission(
914 user_email=user_context["email"],
915 permission=permission,
916 resource_type=resource_type,
917 team_id=team_id,
918 token_teams=user_context.get("token_teams"),
919 ip_address=user_context.get("ip_address"),
920 user_agent=user_context.get("user_agent"),
921 allow_admin_bypass=allow_admin_bypass,
922 check_any_team=check_any_team,
923 ):
924 granted = True
925 break
927 if not granted:
928 logger.warning(f"Permission denied: user={user_context['email']}, permissions={permissions}, resource_type={resource_type}")
929 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=_ACCESS_DENIED_MSG)
931 # Permission granted, execute the original function
932 return await func(*args, **kwargs)
934 return wrapper
936 return decorator
939class PermissionChecker:
940 """Context manager for manual permission checking.
942 Useful for complex permission logic that can't be handled by decorators.
944 Examples:
945 >>> from unittest.mock import Mock
946 >>> checker = PermissionChecker({"email": "user@example.com", "db": Mock()})
947 >>> hasattr(checker, 'has_permission') and hasattr(checker, 'has_admin_permission')
948 True
949 """
951 def __init__(self, user_context: dict):
952 """Initialize permission checker with user context.
954 Args:
955 user_context: User context from get_current_user_with_permissions
956 """
957 self.user_context = user_context
958 self.db_session = user_context.get("db")
960 async def has_permission(self, permission: str, resource_type: Optional[str] = None, resource_id: Optional[str] = None, team_id: Optional[str] = None, check_any_team: bool = False) -> bool:
961 """Check if user has specific permission.
963 Args:
964 permission: Permission to check
965 resource_type: Optional resource type
966 resource_id: Optional resource ID
967 team_id: Optional team context
968 check_any_team: If True, check across all teams the user belongs to
970 Returns:
971 bool: True if user has permission
972 """
973 if self.db_session:
974 # Use existing session
975 permission_service = PermissionService(self.db_session)
976 return await permission_service.check_permission(
977 user_email=self.user_context["email"],
978 permission=permission,
979 resource_type=resource_type,
980 resource_id=resource_id,
981 team_id=team_id,
982 token_teams=self.user_context.get("token_teams"),
983 ip_address=self.user_context.get("ip_address"),
984 user_agent=self.user_context.get("user_agent"),
985 check_any_team=check_any_team,
986 )
987 # Create fresh db session
988 with fresh_db_session() as db:
989 permission_service = PermissionService(db)
990 return await permission_service.check_permission(
991 user_email=self.user_context["email"],
992 permission=permission,
993 resource_type=resource_type,
994 resource_id=resource_id,
995 team_id=team_id,
996 token_teams=self.user_context.get("token_teams"),
997 ip_address=self.user_context.get("ip_address"),
998 user_agent=self.user_context.get("user_agent"),
999 check_any_team=check_any_team,
1000 )
1002 async def has_admin_permission(self) -> bool:
1003 """Check if user has admin permissions.
1005 Returns:
1006 bool: True if user has admin permissions
1007 """
1008 if self.db_session:
1009 # Use existing session
1010 permission_service = PermissionService(self.db_session)
1011 return await permission_service.check_admin_permission(self.user_context["email"])
1012 # Create fresh db session
1013 with fresh_db_session() as db:
1014 permission_service = PermissionService(db)
1015 return await permission_service.check_admin_permission(self.user_context["email"])
1017 async def has_any_permission(self, permissions: List[str], resource_type: Optional[str] = None, team_id: Optional[str] = None) -> bool:
1018 """Check if user has any of the specified permissions.
1020 Args:
1021 permissions: List of permissions to check
1022 resource_type: Optional resource type
1023 team_id: Optional team context
1025 Returns:
1026 bool: True if user has at least one permission
1027 """
1028 if self.db_session:
1029 # Use existing session for all checks
1030 permission_service = PermissionService(self.db_session)
1031 for permission in permissions:
1032 if await permission_service.check_permission(
1033 user_email=self.user_context["email"],
1034 permission=permission,
1035 resource_type=resource_type,
1036 team_id=team_id,
1037 token_teams=self.user_context.get("token_teams"),
1038 ip_address=self.user_context.get("ip_address"),
1039 user_agent=self.user_context.get("user_agent"),
1040 ):
1041 return True
1042 return False
1043 # Create single fresh session for all checks (avoid N sessions for N permissions)
1044 with fresh_db_session() as db:
1045 permission_service = PermissionService(db)
1046 for permission in permissions:
1047 if await permission_service.check_permission(
1048 user_email=self.user_context["email"],
1049 permission=permission,
1050 resource_type=resource_type,
1051 team_id=team_id,
1052 token_teams=self.user_context.get("token_teams"),
1053 ip_address=self.user_context.get("ip_address"),
1054 user_agent=self.user_context.get("user_agent"),
1055 ):
1056 return True
1057 return False
1059 async def require_permission(self, permission: str, resource_type: Optional[str] = None, resource_id: Optional[str] = None, team_id: Optional[str] = None) -> None:
1060 """Require specific permission, raise HTTPException if not granted.
1062 Args:
1063 permission: Required permission
1064 resource_type: Optional resource type
1065 resource_id: Optional resource ID
1066 team_id: Optional team context
1068 Raises:
1069 HTTPException: If permission is not granted
1070 """
1071 if not await self.has_permission(permission, resource_type, resource_id, team_id):
1072 logger.warning(f"{_ACCESS_DENIED_MSG}: user '{self.user_context.get('email')}' missing permission '{permission}'")
1073 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=_ACCESS_DENIED_MSG)