Coverage for mcpgateway / middleware / http_auth_middleware.py: 98%

58 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""HTTP Authentication Middleware. 

3 

4This middleware allows plugins to: 

51. Transform request headers before authentication (HTTP_PRE_REQUEST) 

62. Inspect responses after request completion (HTTP_POST_REQUEST) 

7""" 

8 

9# Standard 

10import logging 

11 

12# Third-Party 

13from fastapi import Request 

14from starlette.middleware.base import BaseHTTPMiddleware 

15from starlette.types import ASGIApp 

16 

17# First-Party 

18from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, HttpHookType, HttpPostRequestPayload, HttpPreRequestPayload, PluginManager 

19from mcpgateway.utils.correlation_id import generate_correlation_id, get_correlation_id 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24class HttpAuthMiddleware(BaseHTTPMiddleware): 

25 """Middleware for HTTP authentication hooks. 

26 

27 This middleware invokes plugin hooks for HTTP request processing: 

28 - HTTP_PRE_REQUEST: Before any authentication, allows header transformation 

29 - HTTP_POST_REQUEST: After request completion, allows response inspection 

30 

31 The middleware allows plugins to: 

32 - Convert custom authentication tokens to standard formats 

33 - Add tracing/correlation headers 

34 - Implement custom authentication schemes 

35 - Audit authentication attempts 

36 - Log response status and headers 

37 """ 

38 

39 def __init__(self, app: ASGIApp, plugin_manager: PluginManager | None = None): 

40 """Initialize the HTTP auth middleware. 

41 

42 Args: 

43 app: The ASGI application 

44 plugin_manager: Optional plugin manager for hook invocation 

45 """ 

46 super().__init__(app) 

47 self.plugin_manager = plugin_manager 

48 

49 async def dispatch(self, request: Request, call_next): 

50 """Process request through plugin hooks. 

51 

52 Args: 

53 request: The incoming request 

54 call_next: The next middleware/handler in the chain 

55 

56 Returns: 

57 The response from the application 

58 """ 

59 # Skip hook invocation if no plugin manager 

60 if not self.plugin_manager: 

61 return await call_next(request) 

62 

63 # Skip payload creation if no HTTP hooks registered 

64 has_pre = self.plugin_manager.has_hooks_for(HttpHookType.HTTP_PRE_REQUEST) 

65 has_post = self.plugin_manager.has_hooks_for(HttpHookType.HTTP_POST_REQUEST) 

66 

67 if not has_pre and not has_post: 

68 return await call_next(request) 

69 

70 # Use correlation ID from CorrelationIDMiddleware if available 

71 # This ensures all hooks and downstream code see the same unified request ID 

72 request_id = get_correlation_id() 

73 if not request_id: 

74 # Fallback if correlation ID middleware is disabled 

75 request_id = generate_correlation_id() 

76 logger.debug(f"Correlation ID not found, generated fallback: {request_id}") 

77 

78 request.state.request_id = request_id 

79 

80 # Create global context for hooks 

81 global_context = GlobalContext( 

82 request_id=request_id, 

83 server_id=None, # Not specific to any server 

84 tenant_id=None, # Not specific to any tenant 

85 ) 

86 

87 # Extract client information 

88 client_host = None 

89 client_port = None 

90 if request.client: 90 ↛ 95line 90 didn't jump to line 95 because the condition on line 90 was always true

91 client_host = request.client.host 

92 client_port = request.client.port 

93 

94 # Initialize context_table for potential use by POST hook 

95 context_table = None 

96 

97 # PRE-REQUEST HOOK: Allow plugins to transform headers before authentication 

98 # Only create payload and invoke hook if plugins are registered for this hook type 

99 if has_pre: 

100 try: 

101 pre_result, context_table = await self.plugin_manager.invoke_hook( 

102 HttpHookType.HTTP_PRE_REQUEST, 

103 payload=HttpPreRequestPayload( 

104 path=str(request.url.path), 

105 method=request.method, 

106 headers=HttpHeaderPayload(root=dict(request.headers)), 

107 client_host=client_host, 

108 client_port=client_port, 

109 ), 

110 global_context=global_context, 

111 local_contexts=None, 

112 violations_as_exceptions=False, # Don't block on pre-request violations 

113 ) 

114 

115 if context_table: 

116 request.state.plugin_context_table = context_table 

117 

118 if global_context: 118 ↛ 122line 118 didn't jump to line 122 because the condition on line 118 was always true

119 request.state.plugin_global_context = global_context 

120 

121 # Apply modified headers if plugin returned them 

122 if pre_result.modified_payload: 

123 # Modify request headers by updating request.scope["headers"] 

124 # This is the proper way to modify headers in Starlette/FastAPI 

125 # Reference: https://stackoverflow.com/questions/69934160/python-how-to-manipulate-fastapi-request-headers-to-be-mutable 

126 modified_headers_dict = pre_result.modified_payload.root 

127 

128 # Merge modified headers with original headers (modified headers take precedence) 

129 original_headers = dict(request.headers) 

130 merged_headers = {**original_headers, **modified_headers_dict} 

131 

132 # Update request.scope["headers"] which is the raw header list Starlette uses 

133 # Convert dict to list of (name, value) tuples with lowercase byte keys 

134 request.scope["headers"] = [(name.lower().encode(), value.encode()) for name, value in merged_headers.items()] 

135 

136 logger.debug(f"Pre-request hook modified headers: {list(modified_headers_dict.keys())}") 

137 

138 except Exception as e: 

139 # Log but don't fail the request if pre-hook has issues 

140 logger.warning(f"HTTP_PRE_REQUEST hook failed: {e}", exc_info=True) 

141 

142 # Process the request through the rest of the application 

143 response = await call_next(request) 

144 

145 # POST-REQUEST HOOK: Allow plugins to inspect and modify response 

146 # Only create payload and invoke hook if plugins are registered for this hook type 

147 if has_post: 

148 try: 

149 # Extract response headers 

150 response_headers = HttpHeaderPayload(root=dict(response.headers)) 

151 

152 post_result, _ = await self.plugin_manager.invoke_hook( 

153 HttpHookType.HTTP_POST_REQUEST, 

154 payload=HttpPostRequestPayload( 

155 path=str(request.url.path), 

156 method=request.method, 

157 headers=HttpHeaderPayload(root=dict(request.headers)), 

158 client_host=client_host, 

159 client_port=client_port, 

160 response_headers=response_headers, 

161 status_code=response.status_code, 

162 ), 

163 global_context=global_context, 

164 local_contexts=context_table, # Pass context from pre-hook 

165 violations_as_exceptions=False, # Don't block on post-request violations 

166 ) 

167 

168 # Apply modified response headers if plugin returned them 

169 if post_result.modified_payload: 

170 modified_response_headers = post_result.modified_payload.root 

171 # Update response headers (response.headers is mutable) 

172 for header_name, header_value in modified_response_headers.items(): 

173 response.headers[header_name] = header_value 

174 logger.debug(f"Post-request hook modified response headers: {list(modified_response_headers.keys())}") 

175 

176 except Exception as e: 

177 # Log but don't fail the response if post-hook has issues 

178 logger.warning(f"HTTP_POST_REQUEST hook failed: {e}", exc_info=True) 

179 

180 return response