Coverage for mcpgateway / utils / gateway_access.py: 100%

61 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/utils/gateway_access.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5 

6Gateway access control utilities. 

7 

8This module provides helper functions for checking gateway access permissions 

9in direct_proxy mode, ensuring consistent RBAC enforcement across the codebase. 

10""" 

11 

12# Standard 

13from typing import Dict, List, Optional 

14 

15# Third-Party 

16from sqlalchemy.orm import Session 

17 

18# First-Party 

19from mcpgateway.db import Gateway as DbGateway 

20from mcpgateway.utils.services_auth import decode_auth 

21 

22# Header name used by clients to target a specific gateway for direct_proxy mode. 

23# Defined once here to avoid string literal repetition across the codebase. 

24GATEWAY_ID_HEADER = "X-Context-Forge-Gateway-Id" 

25 

26 

27def extract_gateway_id_from_headers(headers: Optional[Dict[str, str]]) -> Optional[str]: 

28 """Extract gateway ID from request headers (case-insensitive). 

29 

30 Args: 

31 headers: Request headers dictionary (may be None). 

32 

33 Returns: 

34 Gateway ID string if found, None otherwise. 

35 """ 

36 if not headers: 

37 return None 

38 header_lower = GATEWAY_ID_HEADER.lower() 

39 for name, value in headers.items(): 

40 if name.lower() == header_lower: 

41 return value 

42 return None 

43 

44 

45async def check_gateway_access( 

46 db: Session, 

47 gateway: DbGateway, 

48 user_email: Optional[str], 

49 token_teams: Optional[List[str]], 

50) -> bool: 

51 """Check if user has access to a gateway based on visibility rules. 

52 

53 Used for direct_proxy mode to ensure users can only access gateways they have permission to use. 

54 

55 Access Rules: 

56 - Public gateways: Accessible by all authenticated users 

57 - Team gateways: Accessible by team members (team_id in user's teams) 

58 - Private gateways: Accessible only by owner (owner_email matches) 

59 

60 Args: 

61 db: Database session for team membership lookup if needed. 

62 gateway: Gateway ORM object. 

63 user_email: Email of the requesting user (None = unauthenticated). 

64 token_teams: List of team IDs from token. 

65 - None = unrestricted admin access 

66 - [] = public-only token 

67 - [...] = team-scoped token 

68 

69 Returns: 

70 True if access is allowed, False otherwise. 

71 """ 

72 visibility = gateway.visibility if hasattr(gateway, "visibility") else "public" 

73 gateway_team_id = gateway.team_id if hasattr(gateway, "team_id") else None 

74 gateway_owner_email = gateway.owner_email if hasattr(gateway, "owner_email") else None 

75 

76 # Public gateways are accessible by everyone 

77 if visibility == "public": 

78 return True 

79 

80 # Admin bypass: token_teams=None AND user_email=None means unrestricted admin 

81 # This happens when is_admin=True and no team scoping in token 

82 if token_teams is None and user_email is None: 

83 return True 

84 

85 # No user context (but not admin) = deny access to non-public gateways 

86 if not user_email: 

87 return False 

88 

89 # Public-only tokens (empty teams array) can ONLY access public gateways 

90 is_public_only_token = token_teams is not None and len(token_teams) == 0 

91 if is_public_only_token: 

92 return False # Already checked public above 

93 

94 # Owner can always access their own gateways 

95 if gateway_owner_email and gateway_owner_email == user_email: 

96 return True 

97 

98 # Team gateways: check team membership 

99 if gateway_team_id: 

100 # Use token_teams if provided, otherwise look up from DB 

101 if token_teams is not None: 

102 team_ids = token_teams 

103 else: 

104 # First-Party 

105 from mcpgateway.services.team_management_service import TeamManagementService # pylint: disable=import-outside-toplevel 

106 

107 team_service = TeamManagementService(db) 

108 user_teams = await team_service.get_user_teams(user_email) 

109 team_ids = [team.id for team in user_teams] 

110 

111 # Team/public visibility allows access if user is in the team 

112 if visibility in ["team", "public"] and gateway_team_id in team_ids: 

113 return True 

114 

115 # Default: deny access 

116 return False 

117 

118 

119def build_gateway_auth_headers(gateway: DbGateway) -> Dict[str, str]: 

120 """Build authentication headers for gateway requests. 

121 

122 Extracts and formats authentication headers from gateway configuration, 

123 handling both bearer and basic auth types with dict or encoded string values. 

124 

125 Args: 

126 gateway: Gateway ORM object with auth_type and auth_value attributes. 

127 

128 Returns: 

129 Dictionary of HTTP headers with Authorization header if auth is configured. 

130 Returns empty dict if no auth is configured or if token/credentials are empty. 

131 

132 Examples: 

133 >>> gateway = DbGateway(auth_type="bearer", auth_value={"Authorization": "Bearer token123"}) 

134 >>> headers = build_gateway_auth_headers(gateway) 

135 >>> headers["Authorization"] 

136 'Bearer token123' 

137 """ 

138 headers: Dict[str, str] = {} 

139 

140 if gateway.auth_type == "bearer" and gateway.auth_value: 

141 if isinstance(gateway.auth_value, dict): 

142 token = gateway.auth_value.get("Authorization", "").replace("Bearer ", "") 

143 if token: # Only add header if token is not empty 

144 headers["Authorization"] = f"Bearer {token}" 

145 elif isinstance(gateway.auth_value, str): 

146 decoded = decode_auth(gateway.auth_value) 

147 token = decoded.get("Authorization", "").replace("Bearer ", "") 

148 if token: # Only add header if token is not empty 

149 headers["Authorization"] = f"Bearer {token}" 

150 elif gateway.auth_type == "basic" and gateway.auth_value: 

151 if isinstance(gateway.auth_value, dict): 

152 auth_header = gateway.auth_value.get("Authorization", "") 

153 if auth_header: # Only add header if not empty 

154 headers["Authorization"] = auth_header 

155 elif isinstance(gateway.auth_value, str): 

156 decoded = decode_auth(gateway.auth_value) 

157 auth_header = decoded.get("Authorization", "") 

158 if auth_header: # Only add header if not empty 

159 headers["Authorization"] = auth_header 

160 

161 return headers