Coverage for mcpgateway / middleware / http_auth_middleware.py: 98%
58 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""HTTP Authentication Middleware.
4This middleware allows plugins to:
51. Transform request headers before authentication (HTTP_PRE_REQUEST)
62. Inspect responses after request completion (HTTP_POST_REQUEST)
7"""
9# Standard
10import logging
12# Third-Party
13from fastapi import Request
14from starlette.middleware.base import BaseHTTPMiddleware
15from starlette.types import ASGIApp
17# First-Party
18from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, HttpHookType, HttpPostRequestPayload, HttpPreRequestPayload, PluginManager
19from mcpgateway.utils.correlation_id import generate_correlation_id, get_correlation_id
21logger = logging.getLogger(__name__)
24class HttpAuthMiddleware(BaseHTTPMiddleware):
25 """Middleware for HTTP authentication hooks.
27 This middleware invokes plugin hooks for HTTP request processing:
28 - HTTP_PRE_REQUEST: Before any authentication, allows header transformation
29 - HTTP_POST_REQUEST: After request completion, allows response inspection
31 The middleware allows plugins to:
32 - Convert custom authentication tokens to standard formats
33 - Add tracing/correlation headers
34 - Implement custom authentication schemes
35 - Audit authentication attempts
36 - Log response status and headers
37 """
39 def __init__(self, app: ASGIApp, plugin_manager: PluginManager | None = None):
40 """Initialize the HTTP auth middleware.
42 Args:
43 app: The ASGI application
44 plugin_manager: Optional plugin manager for hook invocation
45 """
46 super().__init__(app)
47 self.plugin_manager = plugin_manager
49 async def dispatch(self, request: Request, call_next):
50 """Process request through plugin hooks.
52 Args:
53 request: The incoming request
54 call_next: The next middleware/handler in the chain
56 Returns:
57 The response from the application
58 """
59 # Skip hook invocation if no plugin manager
60 if not self.plugin_manager:
61 return await call_next(request)
63 # Skip payload creation if no HTTP hooks registered
64 has_pre = self.plugin_manager.has_hooks_for(HttpHookType.HTTP_PRE_REQUEST)
65 has_post = self.plugin_manager.has_hooks_for(HttpHookType.HTTP_POST_REQUEST)
67 if not has_pre and not has_post:
68 return await call_next(request)
70 # Use correlation ID from CorrelationIDMiddleware if available
71 # This ensures all hooks and downstream code see the same unified request ID
72 request_id = get_correlation_id()
73 if not request_id:
74 # Fallback if correlation ID middleware is disabled
75 request_id = generate_correlation_id()
76 logger.debug(f"Correlation ID not found, generated fallback: {request_id}")
78 request.state.request_id = request_id
80 # Create global context for hooks
81 global_context = GlobalContext(
82 request_id=request_id,
83 server_id=None, # Not specific to any server
84 tenant_id=None, # Not specific to any tenant
85 )
87 # Extract client information
88 client_host = None
89 client_port = None
90 if request.client: 90 ↛ 95line 90 didn't jump to line 95 because the condition on line 90 was always true
91 client_host = request.client.host
92 client_port = request.client.port
94 # Initialize context_table for potential use by POST hook
95 context_table = None
97 # PRE-REQUEST HOOK: Allow plugins to transform headers before authentication
98 # Only create payload and invoke hook if plugins are registered for this hook type
99 if has_pre:
100 try:
101 pre_result, context_table = await self.plugin_manager.invoke_hook(
102 HttpHookType.HTTP_PRE_REQUEST,
103 payload=HttpPreRequestPayload(
104 path=str(request.url.path),
105 method=request.method,
106 headers=HttpHeaderPayload(root=dict(request.headers)),
107 client_host=client_host,
108 client_port=client_port,
109 ),
110 global_context=global_context,
111 local_contexts=None,
112 violations_as_exceptions=False, # Don't block on pre-request violations
113 )
115 if context_table:
116 request.state.plugin_context_table = context_table
118 if global_context: 118 ↛ 122line 118 didn't jump to line 122 because the condition on line 118 was always true
119 request.state.plugin_global_context = global_context
121 # Apply modified headers if plugin returned them
122 if pre_result.modified_payload:
123 # Modify request headers by updating request.scope["headers"]
124 # This is the proper way to modify headers in Starlette/FastAPI
125 # Reference: https://stackoverflow.com/questions/69934160/python-how-to-manipulate-fastapi-request-headers-to-be-mutable
126 modified_headers_dict = pre_result.modified_payload.root
128 # Merge modified headers with original headers (modified headers take precedence)
129 original_headers = dict(request.headers)
130 merged_headers = {**original_headers, **modified_headers_dict}
132 # Update request.scope["headers"] which is the raw header list Starlette uses
133 # Convert dict to list of (name, value) tuples with lowercase byte keys
134 request.scope["headers"] = [(name.lower().encode(), value.encode()) for name, value in merged_headers.items()]
136 logger.debug(f"Pre-request hook modified headers: {list(modified_headers_dict.keys())}")
138 except Exception as e:
139 # Log but don't fail the request if pre-hook has issues
140 logger.warning(f"HTTP_PRE_REQUEST hook failed: {e}", exc_info=True)
142 # Process the request through the rest of the application
143 response = await call_next(request)
145 # POST-REQUEST HOOK: Allow plugins to inspect and modify response
146 # Only create payload and invoke hook if plugins are registered for this hook type
147 if has_post:
148 try:
149 # Extract response headers
150 response_headers = HttpHeaderPayload(root=dict(response.headers))
152 post_result, _ = await self.plugin_manager.invoke_hook(
153 HttpHookType.HTTP_POST_REQUEST,
154 payload=HttpPostRequestPayload(
155 path=str(request.url.path),
156 method=request.method,
157 headers=HttpHeaderPayload(root=dict(request.headers)),
158 client_host=client_host,
159 client_port=client_port,
160 response_headers=response_headers,
161 status_code=response.status_code,
162 ),
163 global_context=global_context,
164 local_contexts=context_table, # Pass context from pre-hook
165 violations_as_exceptions=False, # Don't block on post-request violations
166 )
168 # Apply modified response headers if plugin returned them
169 if post_result.modified_payload:
170 modified_response_headers = post_result.modified_payload.root
171 # Update response headers (response.headers is mutable)
172 for header_name, header_value in modified_response_headers.items():
173 response.headers[header_name] = header_value
174 logger.debug(f"Post-request hook modified response headers: {list(modified_response_headers.keys())}")
176 except Exception as e:
177 # Log but don't fail the response if post-hook has issues
178 logger.warning(f"HTTP_POST_REQUEST hook failed: {e}", exc_info=True)
180 return response