Coverage for mcpgateway / middleware / http_auth_middleware.py: 100%
79 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""HTTP Authentication Middleware.
4This middleware allows plugins to:
51. Transform request headers before authentication (HTTP_PRE_REQUEST)
62. Inspect responses after request completion (HTTP_POST_REQUEST)
7"""
9# Standard
10import logging
11from typing import Optional
13# Third-Party
14from fastapi import Request
15from starlette.middleware.base import BaseHTTPMiddleware
16from starlette.types import ASGIApp
18# First-Party
19from mcpgateway.config import settings
20from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, HttpHookType, HttpPostRequestPayload, HttpPreRequestPayload, PluginManager
21from mcpgateway.utils.correlation_id import generate_correlation_id, get_correlation_id
23logger = logging.getLogger(__name__)
26async def run_pre_request_hooks(
27 plugin_manager: PluginManager,
28 headers: dict[str, str],
29 path: str,
30 method: str,
31 client_host: Optional[str] = None,
32 client_port: Optional[int] = None,
33 global_context: Optional[GlobalContext] = None,
34) -> tuple[dict[str, str], Optional[GlobalContext], Optional[dict]]:
35 """Run HTTP_PRE_REQUEST plugin hooks and return (possibly modified) headers.
37 This is the shared hook runner used by both HttpAuthMiddleware (Python flow)
38 and _run_internal_mcp_authentication (Rust flow) to ensure identical
39 plugin behavior regardless of transport.
41 Args:
42 plugin_manager: The plugin manager instance.
43 headers: Original request headers (not mutated).
44 path: Request path.
45 method: HTTP method.
46 client_host: Client IP address.
47 client_port: Client port.
48 global_context: Optional pre-created global context. Created if not provided.
50 Returns:
51 Tuple of (merged_headers, global_context, context_table).
52 merged_headers reflects any plugin modifications with the auth-header
53 override guard applied.
54 """
55 if not plugin_manager.has_hooks_for(HttpHookType.HTTP_PRE_REQUEST):
56 return headers, global_context, None
58 if global_context is None:
59 request_id = get_correlation_id() or generate_correlation_id()
60 global_context = GlobalContext(request_id=request_id, server_id=None, tenant_id=None)
62 try:
63 pre_result, context_table = await plugin_manager.invoke_hook(
64 HttpHookType.HTTP_PRE_REQUEST,
65 payload=HttpPreRequestPayload(
66 path=path,
67 method=method,
68 headers=HttpHeaderPayload(root=dict(headers)),
69 client_host=client_host,
70 client_port=client_port,
71 ),
72 global_context=global_context,
73 local_contexts=None,
74 violations_as_exceptions=False,
75 )
77 if not pre_result.modified_payload:
78 return headers, global_context, context_table
80 modified_headers_dict = pre_result.modified_payload.root
82 # Security: prevent plugin hooks from overriding auth-sensitive
83 # headers that were already present on the inbound request.
84 # Plugins MAY create new auth headers (e.g. x-api-key → authorization
85 # transform) but MUST NOT replace values the client already sent.
86 #
87 # This guard can be disabled with PLUGINS_CAN_OVERRIDE_AUTH_HEADERS=true
88 # for deployments that require plugin-driven token exchange (e.g. WXO auth).
89 if not settings.plugins_can_override_auth_headers:
90 _auth_protected_headers = {"authorization", "cookie", "x-api-key", "proxy-authorization"}
91 original_lower = {h.lower() for h in headers}
92 overridden = {k.lower() for k in modified_headers_dict if k.lower() in _auth_protected_headers and k.lower() in original_lower}
93 if overridden:
94 logger.warning("Pre-request hook attempted to override existing auth headers (stripped): %s", overridden)
95 modified_headers_dict = {k: v for k, v in modified_headers_dict.items() if k.lower() not in overridden}
97 # Normalize to lowercase keys to avoid duplicate logical headers from
98 # casing differences (e.g. "Authorization" vs "authorization").
99 merged_headers = {k.lower(): v for k, v in headers.items()}
100 merged_headers.update({k.lower(): v for k, v in modified_headers_dict.items()})
101 logger.debug(f"Pre-request hook modified headers: {list(modified_headers_dict.keys())}")
102 return merged_headers, global_context, context_table
104 except Exception as e:
105 logger.warning(f"HTTP_PRE_REQUEST hook failed: {e}", exc_info=True)
106 return headers, global_context, None
109class HttpAuthMiddleware(BaseHTTPMiddleware):
110 """Middleware for HTTP authentication hooks.
112 This middleware invokes plugin hooks for HTTP request processing:
113 - HTTP_PRE_REQUEST: Before any authentication, allows header transformation
114 - HTTP_POST_REQUEST: After request completion, allows response inspection
116 The middleware allows plugins to:
117 - Convert custom authentication tokens to standard formats
118 - Add tracing/correlation headers
119 - Implement custom authentication schemes
120 - Audit authentication attempts
121 - Log response status and headers
122 """
124 def __init__(self, app: ASGIApp, plugin_manager: PluginManager | None = None):
125 """Initialize the HTTP auth middleware.
127 Args:
128 app: The ASGI application
129 plugin_manager: Optional plugin manager for hook invocation
130 """
131 super().__init__(app)
132 self.plugin_manager = plugin_manager
134 async def dispatch(self, request: Request, call_next):
135 """Process request through plugin hooks.
137 Args:
138 request: The incoming request
139 call_next: The next middleware/handler in the chain
141 Returns:
142 The response from the application
143 """
144 # Skip hook invocation if no plugin manager
145 if not self.plugin_manager:
146 logger.debug("HttpAuthMiddleware: no plugin_manager, skipping hooks")
147 return await call_next(request)
149 # Skip payload creation if no HTTP hooks registered
150 has_pre = self.plugin_manager.has_hooks_for(HttpHookType.HTTP_PRE_REQUEST)
151 has_post = self.plugin_manager.has_hooks_for(HttpHookType.HTTP_POST_REQUEST)
153 if not has_pre and not has_post:
154 logger.debug("HttpAuthMiddleware: has_pre=%s has_post=%s, skipping hooks", has_pre, has_post)
155 return await call_next(request)
157 # Use correlation ID from CorrelationIDMiddleware if available
158 request_id = get_correlation_id()
159 if not request_id:
160 request_id = generate_correlation_id()
161 logger.debug("Correlation ID not found, generated fallback: %s", request_id)
163 request.state.request_id = request_id
165 global_context = GlobalContext(
166 request_id=request_id,
167 server_id=None,
168 tenant_id=None,
169 )
171 client_host = None
172 client_port = None
173 if request.client:
174 client_host = request.client.host
175 client_port = request.client.port
177 context_table = None
179 # PRE-REQUEST HOOK: Allow plugins to transform headers before authentication
180 if has_pre:
181 merged_headers, global_context, context_table = await run_pre_request_hooks(
182 plugin_manager=self.plugin_manager,
183 headers=dict(request.headers),
184 path=str(request.url.path),
185 method=request.method,
186 client_host=client_host,
187 client_port=client_port,
188 global_context=global_context,
189 )
191 if context_table:
192 request.state.plugin_context_table = context_table
193 if global_context:
194 request.state.plugin_global_context = global_context
196 # Apply modified headers to the request scope
197 request.scope["headers"] = [(name.lower().encode(), value.encode()) for name, value in merged_headers.items()]
199 # Process the request through the rest of the application
200 response = await call_next(request)
202 # POST-REQUEST HOOK: Allow plugins to inspect and modify response
203 if has_post:
204 try:
205 response_headers = HttpHeaderPayload(root=dict(response.headers))
207 post_result, _ = await self.plugin_manager.invoke_hook(
208 HttpHookType.HTTP_POST_REQUEST,
209 payload=HttpPostRequestPayload(
210 path=str(request.url.path),
211 method=request.method,
212 headers=HttpHeaderPayload(root=dict(request.headers)),
213 client_host=client_host,
214 client_port=client_port,
215 response_headers=response_headers,
216 status_code=response.status_code,
217 ),
218 global_context=global_context,
219 local_contexts=context_table,
220 violations_as_exceptions=False,
221 )
223 if post_result.modified_payload:
224 modified_response_headers = post_result.modified_payload.root
225 for header_name, header_value in modified_response_headers.items():
226 response.headers[header_name] = header_value
227 logger.debug("Post-request hook modified response headers: %s", list(modified_response_headers.keys()))
229 except Exception as e:
230 logger.warning(f"HTTP_POST_REQUEST hook failed: {e}", exc_info=True)
232 return response