Coverage for mcpgateway / middleware / observability_middleware.py: 100%
96 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/middleware/observability_middleware.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Observability Middleware for automatic request/response tracing.
9This middleware automatically captures HTTP requests and responses as observability traces,
10providing comprehensive visibility into all gateway operations.
12Examples:
13 >>> from mcpgateway.middleware.observability_middleware import ObservabilityMiddleware # doctest: +SKIP
14 >>> app.add_middleware(ObservabilityMiddleware) # doctest: +SKIP
15"""
17# Standard
18import logging
19import time
20import traceback
21from typing import Callable, Optional
23# Third-Party
24from starlette.middleware.base import BaseHTTPMiddleware
25from starlette.requests import Request
26from starlette.responses import Response
28# First-Party
29from mcpgateway.config import settings
30from mcpgateway.db import SessionLocal
31from mcpgateway.instrumentation.sqlalchemy import attach_trace_to_session
32from mcpgateway.middleware.path_filter import should_skip_observability
33from mcpgateway.plugins.framework.observability import current_trace_id as plugins_trace_id
34from mcpgateway.services.observability_service import current_trace_id, ObservabilityService, parse_traceparent
36logger = logging.getLogger(__name__)
39class ObservabilityMiddleware(BaseHTTPMiddleware):
40 """Middleware for automatic HTTP request/response tracing.
42 Captures every HTTP request as a trace with timing, status codes,
43 and user context. Automatically creates spans for the request lifecycle.
45 This middleware is disabled by default and can be enabled via the
46 MCPGATEWAY_OBSERVABILITY_ENABLED environment variable.
47 """
49 def __init__(self, app, enabled: bool = None, service: Optional[ObservabilityService] = None):
50 """Initialize the observability middleware.
52 Args:
53 app: ASGI application
54 enabled: Whether observability is enabled (defaults to settings)
55 service: Optional ObservabilityService instance
56 """
57 super().__init__(app)
58 self.enabled = enabled if enabled is not None else getattr(settings, "observability_enabled", False)
59 self.service = service or ObservabilityService()
60 logger.info(f"Observability middleware initialized (enabled={self.enabled})")
62 async def dispatch(self, request: Request, call_next: Callable) -> Response:
63 """Process request and create observability trace.
65 Args:
66 request: Incoming HTTP request
67 call_next: Next middleware/handler in chain
69 Returns:
70 HTTP response
72 Raises:
73 Exception: Re-raises any exception from request processing after logging
74 """
75 # Skip if observability is disabled
76 if not self.enabled:
77 return await call_next(request)
79 # Skip health checks and static files to reduce noise
80 if should_skip_observability(request.url.path):
81 return await call_next(request)
83 # Extract request context
84 http_method = request.method
85 http_url = str(request.url)
86 user_email = None
87 ip_address = request.client.host if request.client else None
88 user_agent = request.headers.get("user-agent")
90 # Try to extract user from request state (set by auth middleware)
91 if hasattr(request.state, "user") and hasattr(request.state.user, "email"):
92 user_email = request.state.user.email
94 # Extract W3C Trace Context from headers (for distributed tracing)
95 external_trace_id = None
96 external_parent_span_id = None
97 traceparent_header = request.headers.get("traceparent")
98 if traceparent_header:
99 parsed = parse_traceparent(traceparent_header)
100 if parsed:
101 external_trace_id, external_parent_span_id, _flags = parsed
102 logger.debug(f"Extracted W3C trace context: trace_id={external_trace_id}, parent_span_id={external_parent_span_id}")
104 db = None
105 trace_id = None
106 span_id = None
107 start_time = time.time()
109 try:
110 # Create database session
111 db = SessionLocal()
113 # Start trace (use external trace_id if provided for distributed tracing)
114 trace_id = self.service.start_trace(
115 db=db,
116 name=f"{http_method} {request.url.path}",
117 trace_id=external_trace_id, # Use external trace ID if provided
118 parent_span_id=external_parent_span_id, # Track parent span from upstream
119 http_method=http_method,
120 http_url=http_url,
121 user_email=user_email,
122 user_agent=user_agent,
123 ip_address=ip_address,
124 attributes={
125 "http.route": request.url.path,
126 "http.query": str(request.url.query) if request.url.query else None,
127 },
128 resource_attributes={
129 "service.name": "mcp-gateway",
130 "service.version": getattr(settings, "version", "unknown"),
131 },
132 )
134 # Store trace_id in request state for use in route handlers
135 request.state.trace_id = trace_id
137 # Set trace_id in context variable for access throughout async call stack
138 current_trace_id.set(trace_id)
139 # Bridge: also set the framework's ContextVar so the plugin executor sees it
140 plugins_trace_id.set(trace_id)
142 # Attach trace_id to database session for SQL query instrumentation
143 attach_trace_to_session(db, trace_id)
145 # Start request span
146 span_id = self.service.start_span(db=db, trace_id=trace_id, name="http.request", kind="server", attributes={"http.method": http_method, "http.url": http_url})
148 except Exception as e:
149 # If trace setup failed, log and continue without tracing
150 logger.warning(f"Failed to setup observability trace: {e}")
151 # Close db if it was created
152 if db:
153 try:
154 db.rollback() # Error path - rollback any partial transaction
155 db.close()
156 except Exception as close_error:
157 logger.debug(f"Failed to close database session during cleanup: {close_error}")
158 # Continue without tracing
159 return await call_next(request)
161 # Process request (trace is set up at this point)
162 try:
163 response = await call_next(request)
164 status_code = response.status_code
166 # End span successfully
167 if span_id:
168 try:
169 self.service.end_span(
170 db,
171 span_id,
172 status="ok" if status_code < 400 else "error",
173 attributes={"http.status_code": status_code, "http.response_size": response.headers.get("content-length")},
174 )
175 except Exception as end_span_error:
176 logger.warning(f"Failed to end span {span_id}: {end_span_error}")
178 # End trace
179 if trace_id:
180 duration_ms = (time.time() - start_time) * 1000
181 try:
182 self.service.end_trace(
183 db,
184 trace_id,
185 status="ok" if status_code < 400 else "error",
186 http_status_code=status_code,
187 attributes={"response_time_ms": duration_ms},
188 )
189 except Exception as end_trace_error:
190 logger.warning(f"Failed to end trace {trace_id}: {end_trace_error}")
192 return response
194 except Exception as e:
195 # Log exception in span
196 if span_id:
197 try:
198 self.service.end_span(db, span_id, status="error", status_message=str(e), attributes={"exception.type": type(e).__name__, "exception.message": str(e)})
200 # Add exception event
201 self.service.add_event(
202 db,
203 span_id,
204 name="exception",
205 severity="error",
206 message=str(e),
207 exception_type=type(e).__name__,
208 exception_message=str(e),
209 exception_stacktrace=traceback.format_exc(),
210 )
211 except Exception as log_error:
212 logger.warning(f"Failed to log exception in span: {log_error}")
214 # End trace with error
215 if trace_id:
216 try:
217 self.service.end_trace(db, trace_id, status="error", status_message=str(e), http_status_code=500)
218 except Exception as trace_error:
219 logger.warning(f"Failed to end trace: {trace_error}")
221 # Re-raise the original exception
222 raise
224 finally:
225 # Always close database session - observability service handles its own commits
226 if db:
227 try:
228 if db.in_transaction():
229 db.rollback()
230 db.close()
231 except Exception as close_error:
232 logger.warning(f"Failed to close database session: {close_error}")