Coverage for mcpgateway / middleware / http_auth_middleware.py: 100%
63 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""HTTP Authentication Middleware.
4This middleware allows plugins to:
51. Transform request headers before authentication (HTTP_PRE_REQUEST)
62. Inspect responses after request completion (HTTP_POST_REQUEST)
7"""
9# Standard
10import logging
12# Third-Party
13from fastapi import Request
14from starlette.middleware.base import BaseHTTPMiddleware
15from starlette.types import ASGIApp
17# First-Party
18from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, HttpHookType, HttpPostRequestPayload, HttpPreRequestPayload, PluginManager
19from mcpgateway.utils.correlation_id import generate_correlation_id, get_correlation_id
21logger = logging.getLogger(__name__)
24class HttpAuthMiddleware(BaseHTTPMiddleware):
25 """Middleware for HTTP authentication hooks.
27 This middleware invokes plugin hooks for HTTP request processing:
28 - HTTP_PRE_REQUEST: Before any authentication, allows header transformation
29 - HTTP_POST_REQUEST: After request completion, allows response inspection
31 The middleware allows plugins to:
32 - Convert custom authentication tokens to standard formats
33 - Add tracing/correlation headers
34 - Implement custom authentication schemes
35 - Audit authentication attempts
36 - Log response status and headers
37 """
39 def __init__(self, app: ASGIApp, plugin_manager: PluginManager | None = None):
40 """Initialize the HTTP auth middleware.
42 Args:
43 app: The ASGI application
44 plugin_manager: Optional plugin manager for hook invocation
45 """
46 super().__init__(app)
47 self.plugin_manager = plugin_manager
49 async def dispatch(self, request: Request, call_next):
50 """Process request through plugin hooks.
52 Args:
53 request: The incoming request
54 call_next: The next middleware/handler in the chain
56 Returns:
57 The response from the application
58 """
59 # Skip hook invocation if no plugin manager
60 if not self.plugin_manager:
61 return await call_next(request)
63 # Skip payload creation if no HTTP hooks registered
64 has_pre = self.plugin_manager.has_hooks_for(HttpHookType.HTTP_PRE_REQUEST)
65 has_post = self.plugin_manager.has_hooks_for(HttpHookType.HTTP_POST_REQUEST)
67 if not has_pre and not has_post:
68 return await call_next(request)
70 # Use correlation ID from CorrelationIDMiddleware if available
71 # This ensures all hooks and downstream code see the same unified request ID
72 request_id = get_correlation_id()
73 if not request_id:
74 # Fallback if correlation ID middleware is disabled
75 request_id = generate_correlation_id()
76 logger.debug(f"Correlation ID not found, generated fallback: {request_id}")
78 request.state.request_id = request_id
80 # Create global context for hooks
81 global_context = GlobalContext(
82 request_id=request_id,
83 server_id=None, # Not specific to any server
84 tenant_id=None, # Not specific to any tenant
85 )
87 # Extract client information
88 client_host = None
89 client_port = None
90 if request.client:
91 client_host = request.client.host
92 client_port = request.client.port
94 # Initialize context_table for potential use by POST hook
95 context_table = None
97 # PRE-REQUEST HOOK: Allow plugins to transform headers before authentication
98 # Only create payload and invoke hook if plugins are registered for this hook type
99 if has_pre:
100 try:
101 pre_result, context_table = await self.plugin_manager.invoke_hook(
102 HttpHookType.HTTP_PRE_REQUEST,
103 payload=HttpPreRequestPayload(
104 path=str(request.url.path),
105 method=request.method,
106 headers=HttpHeaderPayload(root=dict(request.headers)),
107 client_host=client_host,
108 client_port=client_port,
109 ),
110 global_context=global_context,
111 local_contexts=None,
112 violations_as_exceptions=False, # Don't block on pre-request violations
113 )
115 if context_table:
116 request.state.plugin_context_table = context_table
118 if global_context:
119 request.state.plugin_global_context = global_context
121 # Apply modified headers if plugin returned them
122 if pre_result.modified_payload:
123 # Modify request headers by updating request.scope["headers"]
124 # This is the proper way to modify headers in Starlette/FastAPI
125 # Reference: https://stackoverflow.com/questions/69934160/python-how-to-manipulate-fastapi-request-headers-to-be-mutable
126 modified_headers_dict = pre_result.modified_payload.root
128 # Security: prevent plugin hooks from overriding auth-sensitive
129 # headers that were already present on the inbound request.
130 # Plugins MAY create new auth headers (e.g. x-api-key → authorization
131 # transform) but MUST NOT replace values the client already sent.
132 original_headers = dict(request.headers)
133 _auth_protected_headers = {"authorization", "cookie", "x-api-key", "proxy-authorization"}
134 overridden = {k for k in modified_headers_dict if k.lower() in _auth_protected_headers and k.lower() in original_headers}
135 if overridden:
136 logger.warning("Pre-request hook attempted to override existing auth headers (stripped): %s", overridden)
137 modified_headers_dict = {k: v for k, v in modified_headers_dict.items() if k.lower() not in overridden}
139 # Merge modified headers with original headers (modified headers take precedence)
140 merged_headers = {**original_headers, **modified_headers_dict}
142 # Update request.scope["headers"] which is the raw header list Starlette uses
143 # Convert dict to list of (name, value) tuples with lowercase byte keys
144 request.scope["headers"] = [(name.lower().encode(), value.encode()) for name, value in merged_headers.items()]
146 logger.debug(f"Pre-request hook modified headers: {list(modified_headers_dict.keys())}")
148 except Exception as e:
149 # Log but don't fail the request if pre-hook has issues
150 logger.warning(f"HTTP_PRE_REQUEST hook failed: {e}", exc_info=True)
152 # Process the request through the rest of the application
153 response = await call_next(request)
155 # POST-REQUEST HOOK: Allow plugins to inspect and modify response
156 # Only create payload and invoke hook if plugins are registered for this hook type
157 if has_post:
158 try:
159 # Extract response headers
160 response_headers = HttpHeaderPayload(root=dict(response.headers))
162 post_result, _ = await self.plugin_manager.invoke_hook(
163 HttpHookType.HTTP_POST_REQUEST,
164 payload=HttpPostRequestPayload(
165 path=str(request.url.path),
166 method=request.method,
167 headers=HttpHeaderPayload(root=dict(request.headers)),
168 client_host=client_host,
169 client_port=client_port,
170 response_headers=response_headers,
171 status_code=response.status_code,
172 ),
173 global_context=global_context,
174 local_contexts=context_table, # Pass context from pre-hook
175 violations_as_exceptions=False, # Don't block on post-request violations
176 )
178 # Apply modified response headers if plugin returned them
179 if post_result.modified_payload:
180 modified_response_headers = post_result.modified_payload.root
181 # Update response headers (response.headers is mutable)
182 for header_name, header_value in modified_response_headers.items():
183 response.headers[header_name] = header_value
184 logger.debug(f"Post-request hook modified response headers: {list(modified_response_headers.keys())}")
186 except Exception as e:
187 # Log but don't fail the response if post-hook has issues
188 logger.warning(f"HTTP_POST_REQUEST hook failed: {e}", exc_info=True)
190 return response