Coverage for mcpgateway / middleware / observability_middleware.py: 100%
121 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/middleware/observability_middleware.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Observability Middleware for automatic request/response tracing.
9This middleware automatically captures HTTP requests and responses as observability traces,
10providing comprehensive visibility into all gateway operations.
12Examples:
13 >>> from mcpgateway.middleware.observability_middleware import ObservabilityMiddleware # doctest: +SKIP
14 >>> app.add_middleware(ObservabilityMiddleware) # doctest: +SKIP
15"""
17# Standard
18import logging
19import time
20import traceback
21from typing import Callable, Optional
23# Third-Party
24from starlette.middleware.base import BaseHTTPMiddleware
25from starlette.requests import Request
26from starlette.responses import Response
28# First-Party
29from mcpgateway.config import settings
30from mcpgateway.db import SessionLocal
31from mcpgateway.instrumentation.sqlalchemy import attach_trace_to_session
32from mcpgateway.middleware.path_filter import should_skip_observability
33from mcpgateway.plugins.framework.observability import current_trace_id as plugins_trace_id
34from mcpgateway.services.observability_service import current_trace_id, ObservabilityService, parse_traceparent
35from mcpgateway.utils.log_sanitizer import sanitize_for_log
36from mcpgateway.utils.trace_redaction import sanitize_trace_text
38logger = logging.getLogger(__name__)
41class ObservabilityMiddleware(BaseHTTPMiddleware):
42 """Middleware for automatic HTTP request/response tracing.
44 Captures every HTTP request as a trace with timing, status codes,
45 and user context. Automatically creates spans for the request lifecycle.
47 This middleware is disabled by default and can be enabled via the
48 MCPGATEWAY_OBSERVABILITY_ENABLED environment variable.
49 """
51 def __init__(self, app, enabled: bool = None, service: Optional[ObservabilityService] = None):
52 """Initialize the observability middleware.
54 Args:
55 app: ASGI application
56 enabled: Whether observability is enabled (defaults to settings)
57 service: Optional ObservabilityService instance
58 """
59 super().__init__(app)
60 self.enabled = enabled if enabled is not None else getattr(settings, "observability_enabled", False)
61 self.service = service or ObservabilityService()
62 logger.info(f"Observability middleware initialized (enabled={self.enabled})")
64 async def dispatch(self, request: Request, call_next: Callable) -> Response:
65 """Process request and create observability trace.
67 Args:
68 request: Incoming HTTP request
69 call_next: Next middleware/handler in chain
71 Returns:
72 HTTP response
74 Raises:
75 Exception: Re-raises any exception from request processing after logging
76 """
77 # Skip if observability is disabled
78 if not self.enabled:
79 return await call_next(request)
81 # Skip health checks and static files to reduce noise
82 if should_skip_observability(request.url.path):
83 return await call_next(request)
85 # Extract request context
86 http_method = request.method
87 http_url = str(request.url)
88 user_email = None
89 ip_address = request.client.host if request.client else None
90 user_agent = request.headers.get("user-agent")
92 # Try to extract user from request state (set by auth middleware)
93 if hasattr(request.state, "user") and hasattr(request.state.user, "email"):
94 user_email = request.state.user.email
96 # Extract W3C Trace Context from headers (for distributed tracing)
97 external_trace_id = None
98 external_parent_span_id = None
99 traceparent_header = request.headers.get("traceparent")
100 if traceparent_header:
101 parsed = parse_traceparent(traceparent_header)
102 if parsed:
103 external_trace_id, external_parent_span_id, _flags = parsed
104 logger.debug(f"Extracted W3C trace context: trace_id={external_trace_id}, parent_span_id={external_parent_span_id}")
106 db = None
107 trace_id = None
108 span_id = None
109 start_time = time.time()
110 session_owned_by_middleware = False
112 try:
113 # Create request-scoped database session and store in request.state
114 # This session will be reused by route handlers via get_db() dependency,
115 # eliminating duplicate session creation (Issue #3467)
116 db = SessionLocal()
117 logger.debug(f"[OBSERVABILITY] DB session created: {id(db)}")
118 request.state.db = db
119 session_owned_by_middleware = True
121 # Start trace (use external trace_id if provided for distributed tracing)
122 trace_id = self.service.start_trace(
123 db=db,
124 name=f"{http_method} {request.url.path}",
125 trace_id=external_trace_id, # Use external trace ID if provided
126 parent_span_id=external_parent_span_id, # Track parent span from upstream
127 http_method=http_method,
128 http_url=http_url,
129 user_email=user_email,
130 user_agent=user_agent,
131 ip_address=ip_address,
132 attributes={
133 "http.route": request.url.path,
134 "http.query": sanitize_trace_text(str(request.url.query)) if request.url.query else None,
135 },
136 resource_attributes={
137 "service.name": "mcp-gateway",
138 "service.version": getattr(settings, "version", "unknown"),
139 },
140 )
142 # Store trace_id in request state for use in route handlers
143 request.state.trace_id = trace_id
145 # Set trace_id in context variable for access throughout async call stack
146 current_trace_id.set(trace_id)
147 # Bridge: also set the framework's ContextVar so the plugin executor sees it
148 plugins_trace_id.set(trace_id)
150 # Attach trace_id to database session for SQL query instrumentation
151 attach_trace_to_session(db, trace_id)
153 # Start request span
154 span_id = self.service.start_span(db=db, trace_id=trace_id, name="http.request", kind="server", attributes={"http.method": http_method, "http.url": http_url})
156 except Exception as e:
157 # If trace setup failed, log and continue without tracing
158 logger.warning(f"Failed to setup observability trace: {e}")
159 # Close db if it was created
160 if db:
161 try:
162 db.rollback() # Error path - rollback any partial transaction
163 except Exception as rollback_error:
164 logger.debug(f"Failed to rollback during cleanup: {rollback_error}")
165 # Connection is broken - invalidate to remove from pool
166 try:
167 db.invalidate()
168 except Exception:
169 pass # nosec B110
170 try:
171 db.close()
172 except Exception as close_error:
173 logger.debug(f"Failed to close database session during cleanup: {close_error}")
174 # Clean up request.state.db to prevent get_db() from reusing a closed session
175 if hasattr(request.state, "db"):
176 delattr(request.state, "db")
177 # Continue without tracing
178 return await call_next(request)
180 # Process request (trace is set up at this point)
181 # Route handlers will reuse request.state.db via get_db() dependency
182 try:
183 response = await call_next(request)
184 status_code = response.status_code
186 # End span successfully
187 if span_id:
188 try:
189 self.service.end_span(
190 db,
191 span_id,
192 status="ok" if status_code < 400 else "error",
193 attributes={"http.status_code": status_code, "http.response_size": response.headers.get("content-length")},
194 )
195 except Exception as end_span_error:
196 logger.warning(f"Failed to end span {span_id}: {end_span_error}")
198 # End trace
199 if trace_id:
200 duration_ms = (time.time() - start_time) * 1000
201 try:
202 self.service.end_trace(
203 db,
204 trace_id,
205 status="ok" if status_code < 400 else "error",
206 http_status_code=status_code,
207 attributes={"response_time_ms": duration_ms},
208 )
209 except Exception as end_trace_error:
210 logger.warning(f"Failed to end trace {trace_id}: {end_trace_error}")
212 # NOTE: Transaction control delegated to get_db()
213 # Middleware only manages session lifecycle (create/close), not transactions.
214 # get_db() will commit on success or rollback on error to maintain
215 # predictable transaction semantics for route handlers (Issue #3731).
217 return response
219 except Exception as e:
220 # Log exception in span
221 if span_id:
222 try:
223 sanitized_error = sanitize_for_log(sanitize_trace_text(str(e)))
224 self.service.end_span(db, span_id, status="error", status_message=sanitized_error, attributes={"exception.type": type(e).__name__, "exception.message": sanitized_error})
226 # Add exception event
227 self.service.add_event(
228 db,
229 span_id,
230 name="exception",
231 severity="error",
232 message=sanitized_error,
233 exception_type=type(e).__name__,
234 exception_message=sanitized_error,
235 exception_stacktrace=traceback.format_exc(),
236 )
237 except Exception as log_error:
238 logger.warning(f"Failed to log exception in span: {log_error}")
240 # End trace with error
241 if trace_id:
242 try:
243 sanitized_error = sanitize_for_log(sanitize_trace_text(str(e)))
244 self.service.end_trace(db, trace_id, status="error", status_message=sanitized_error, http_status_code=500)
245 except Exception as trace_error:
246 logger.warning(f"Failed to end trace: {trace_error}")
248 # Rollback the shared session on error
249 try:
250 db.rollback()
251 except Exception as rollback_error:
252 logger.warning(f"Failed to rollback database session: {rollback_error}")
253 # Connection is broken - invalidate to remove from pool
254 # This handles cases like PgBouncer query_wait_timeout where
255 # the connection is dead and rollback itself fails
256 try:
257 db.invalidate()
258 except Exception:
259 pass # nosec B110
261 # Re-raise the original exception
262 raise
264 finally:
265 # Always close database session and clean up request state
266 if db and session_owned_by_middleware:
267 try:
268 db.close()
269 except Exception as close_error:
270 logger.warning(f"Failed to close database session: {close_error}")
271 # Clean up request.state.db to prevent stale references
272 if hasattr(request.state, "db"):
273 delattr(request.state, "db")