Coverage for mcpgateway / middleware / http_auth_middleware.py: 100%

79 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""HTTP Authentication Middleware. 

3 

4This middleware allows plugins to: 

51. Transform request headers before authentication (HTTP_PRE_REQUEST) 

62. Inspect responses after request completion (HTTP_POST_REQUEST) 

7""" 

8 

9# Standard 

10import logging 

11from typing import Optional 

12 

13# Third-Party 

14from fastapi import Request 

15from starlette.middleware.base import BaseHTTPMiddleware 

16from starlette.types import ASGIApp 

17 

18# First-Party 

19from mcpgateway.config import settings 

20from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, HttpHookType, HttpPostRequestPayload, HttpPreRequestPayload, PluginManager 

21from mcpgateway.utils.correlation_id import generate_correlation_id, get_correlation_id 

22 

23logger = logging.getLogger(__name__) 

24 

25 

26async def run_pre_request_hooks( 

27 plugin_manager: PluginManager, 

28 headers: dict[str, str], 

29 path: str, 

30 method: str, 

31 client_host: Optional[str] = None, 

32 client_port: Optional[int] = None, 

33 global_context: Optional[GlobalContext] = None, 

34) -> tuple[dict[str, str], Optional[GlobalContext], Optional[dict]]: 

35 """Run HTTP_PRE_REQUEST plugin hooks and return (possibly modified) headers. 

36 

37 This is the shared hook runner used by both HttpAuthMiddleware (Python flow) 

38 and _run_internal_mcp_authentication (Rust flow) to ensure identical 

39 plugin behavior regardless of transport. 

40 

41 Args: 

42 plugin_manager: The plugin manager instance. 

43 headers: Original request headers (not mutated). 

44 path: Request path. 

45 method: HTTP method. 

46 client_host: Client IP address. 

47 client_port: Client port. 

48 global_context: Optional pre-created global context. Created if not provided. 

49 

50 Returns: 

51 Tuple of (merged_headers, global_context, context_table). 

52 merged_headers reflects any plugin modifications with the auth-header 

53 override guard applied. 

54 """ 

55 if not plugin_manager.has_hooks_for(HttpHookType.HTTP_PRE_REQUEST): 

56 return headers, global_context, None 

57 

58 if global_context is None: 

59 request_id = get_correlation_id() or generate_correlation_id() 

60 global_context = GlobalContext(request_id=request_id, server_id=None, tenant_id=None) 

61 

62 try: 

63 pre_result, context_table = await plugin_manager.invoke_hook( 

64 HttpHookType.HTTP_PRE_REQUEST, 

65 payload=HttpPreRequestPayload( 

66 path=path, 

67 method=method, 

68 headers=HttpHeaderPayload(root=dict(headers)), 

69 client_host=client_host, 

70 client_port=client_port, 

71 ), 

72 global_context=global_context, 

73 local_contexts=None, 

74 violations_as_exceptions=False, 

75 ) 

76 

77 if not pre_result.modified_payload: 

78 return headers, global_context, context_table 

79 

80 modified_headers_dict = pre_result.modified_payload.root 

81 

82 # Security: prevent plugin hooks from overriding auth-sensitive 

83 # headers that were already present on the inbound request. 

84 # Plugins MAY create new auth headers (e.g. x-api-key → authorization 

85 # transform) but MUST NOT replace values the client already sent. 

86 # 

87 # This guard can be disabled with PLUGINS_CAN_OVERRIDE_AUTH_HEADERS=true 

88 # for deployments that require plugin-driven token exchange (e.g. WXO auth). 

89 if not settings.plugins_can_override_auth_headers: 

90 _auth_protected_headers = {"authorization", "cookie", "x-api-key", "proxy-authorization"} 

91 original_lower = {h.lower() for h in headers} 

92 overridden = {k.lower() for k in modified_headers_dict if k.lower() in _auth_protected_headers and k.lower() in original_lower} 

93 if overridden: 

94 logger.warning("Pre-request hook attempted to override existing auth headers (stripped): %s", overridden) 

95 modified_headers_dict = {k: v for k, v in modified_headers_dict.items() if k.lower() not in overridden} 

96 

97 # Normalize to lowercase keys to avoid duplicate logical headers from 

98 # casing differences (e.g. "Authorization" vs "authorization"). 

99 merged_headers = {k.lower(): v for k, v in headers.items()} 

100 merged_headers.update({k.lower(): v for k, v in modified_headers_dict.items()}) 

101 logger.debug(f"Pre-request hook modified headers: {list(modified_headers_dict.keys())}") 

102 return merged_headers, global_context, context_table 

103 

104 except Exception as e: 

105 logger.warning(f"HTTP_PRE_REQUEST hook failed: {e}", exc_info=True) 

106 return headers, global_context, None 

107 

108 

109class HttpAuthMiddleware(BaseHTTPMiddleware): 

110 """Middleware for HTTP authentication hooks. 

111 

112 This middleware invokes plugin hooks for HTTP request processing: 

113 - HTTP_PRE_REQUEST: Before any authentication, allows header transformation 

114 - HTTP_POST_REQUEST: After request completion, allows response inspection 

115 

116 The middleware allows plugins to: 

117 - Convert custom authentication tokens to standard formats 

118 - Add tracing/correlation headers 

119 - Implement custom authentication schemes 

120 - Audit authentication attempts 

121 - Log response status and headers 

122 """ 

123 

124 def __init__(self, app: ASGIApp, plugin_manager: PluginManager | None = None): 

125 """Initialize the HTTP auth middleware. 

126 

127 Args: 

128 app: The ASGI application 

129 plugin_manager: Optional plugin manager for hook invocation 

130 """ 

131 super().__init__(app) 

132 self.plugin_manager = plugin_manager 

133 

134 async def dispatch(self, request: Request, call_next): 

135 """Process request through plugin hooks. 

136 

137 Args: 

138 request: The incoming request 

139 call_next: The next middleware/handler in the chain 

140 

141 Returns: 

142 The response from the application 

143 """ 

144 # Skip hook invocation if no plugin manager 

145 if not self.plugin_manager: 

146 logger.debug("HttpAuthMiddleware: no plugin_manager, skipping hooks") 

147 return await call_next(request) 

148 

149 # Skip payload creation if no HTTP hooks registered 

150 has_pre = self.plugin_manager.has_hooks_for(HttpHookType.HTTP_PRE_REQUEST) 

151 has_post = self.plugin_manager.has_hooks_for(HttpHookType.HTTP_POST_REQUEST) 

152 

153 if not has_pre and not has_post: 

154 logger.debug("HttpAuthMiddleware: has_pre=%s has_post=%s, skipping hooks", has_pre, has_post) 

155 return await call_next(request) 

156 

157 # Use correlation ID from CorrelationIDMiddleware if available 

158 request_id = get_correlation_id() 

159 if not request_id: 

160 request_id = generate_correlation_id() 

161 logger.debug("Correlation ID not found, generated fallback: %s", request_id) 

162 

163 request.state.request_id = request_id 

164 

165 global_context = GlobalContext( 

166 request_id=request_id, 

167 server_id=None, 

168 tenant_id=None, 

169 ) 

170 

171 client_host = None 

172 client_port = None 

173 if request.client: 

174 client_host = request.client.host 

175 client_port = request.client.port 

176 

177 context_table = None 

178 

179 # PRE-REQUEST HOOK: Allow plugins to transform headers before authentication 

180 if has_pre: 

181 merged_headers, global_context, context_table = await run_pre_request_hooks( 

182 plugin_manager=self.plugin_manager, 

183 headers=dict(request.headers), 

184 path=str(request.url.path), 

185 method=request.method, 

186 client_host=client_host, 

187 client_port=client_port, 

188 global_context=global_context, 

189 ) 

190 

191 if context_table: 

192 request.state.plugin_context_table = context_table 

193 if global_context: 

194 request.state.plugin_global_context = global_context 

195 

196 # Apply modified headers to the request scope 

197 request.scope["headers"] = [(name.lower().encode(), value.encode()) for name, value in merged_headers.items()] 

198 

199 # Process the request through the rest of the application 

200 response = await call_next(request) 

201 

202 # POST-REQUEST HOOK: Allow plugins to inspect and modify response 

203 if has_post: 

204 try: 

205 response_headers = HttpHeaderPayload(root=dict(response.headers)) 

206 

207 post_result, _ = await self.plugin_manager.invoke_hook( 

208 HttpHookType.HTTP_POST_REQUEST, 

209 payload=HttpPostRequestPayload( 

210 path=str(request.url.path), 

211 method=request.method, 

212 headers=HttpHeaderPayload(root=dict(request.headers)), 

213 client_host=client_host, 

214 client_port=client_port, 

215 response_headers=response_headers, 

216 status_code=response.status_code, 

217 ), 

218 global_context=global_context, 

219 local_contexts=context_table, 

220 violations_as_exceptions=False, 

221 ) 

222 

223 if post_result.modified_payload: 

224 modified_response_headers = post_result.modified_payload.root 

225 for header_name, header_value in modified_response_headers.items(): 

226 response.headers[header_name] = header_value 

227 logger.debug("Post-request hook modified response headers: %s", list(modified_response_headers.keys())) 

228 

229 except Exception as e: 

230 logger.warning(f"HTTP_POST_REQUEST hook failed: {e}", exc_info=True) 

231 

232 return response