Coverage for mcpgateway / middleware / observability_middleware.py: 94%
94 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/middleware/observability_middleware.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Observability Middleware for automatic request/response tracing.
9This middleware automatically captures HTTP requests and responses as observability traces,
10providing comprehensive visibility into all gateway operations.
12Examples:
13 >>> from mcpgateway.middleware.observability_middleware import ObservabilityMiddleware # doctest: +SKIP
14 >>> app.add_middleware(ObservabilityMiddleware) # doctest: +SKIP
15"""
17# Standard
18import logging
19import time
20import traceback
21from typing import Callable
23# Third-Party
24from starlette.middleware.base import BaseHTTPMiddleware
25from starlette.requests import Request
26from starlette.responses import Response
28# First-Party
29from mcpgateway.config import settings
30from mcpgateway.db import SessionLocal
31from mcpgateway.instrumentation.sqlalchemy import attach_trace_to_session
32from mcpgateway.middleware.path_filter import should_skip_observability
33from mcpgateway.services.observability_service import current_trace_id, ObservabilityService, parse_traceparent
35logger = logging.getLogger(__name__)
38class ObservabilityMiddleware(BaseHTTPMiddleware):
39 """Middleware for automatic HTTP request/response tracing.
41 Captures every HTTP request as a trace with timing, status codes,
42 and user context. Automatically creates spans for the request lifecycle.
44 This middleware is disabled by default and can be enabled via the
45 MCPGATEWAY_OBSERVABILITY_ENABLED environment variable.
46 """
48 def __init__(self, app, enabled: bool = None):
49 """Initialize the observability middleware.
51 Args:
52 app: ASGI application
53 enabled: Whether observability is enabled (defaults to settings)
54 """
55 super().__init__(app)
56 self.enabled = enabled if enabled is not None else getattr(settings, "observability_enabled", False)
57 self.service = ObservabilityService()
58 logger.info(f"Observability middleware initialized (enabled={self.enabled})")
60 async def dispatch(self, request: Request, call_next: Callable) -> Response:
61 """Process request and create observability trace.
63 Args:
64 request: Incoming HTTP request
65 call_next: Next middleware/handler in chain
67 Returns:
68 HTTP response
70 Raises:
71 Exception: Re-raises any exception from request processing after logging
72 """
73 # Skip if observability is disabled
74 if not self.enabled:
75 return await call_next(request)
77 # Skip health checks and static files to reduce noise
78 if should_skip_observability(request.url.path):
79 return await call_next(request)
81 # Extract request context
82 http_method = request.method
83 http_url = str(request.url)
84 user_email = None
85 ip_address = request.client.host if request.client else None
86 user_agent = request.headers.get("user-agent")
88 # Try to extract user from request state (set by auth middleware)
89 if hasattr(request.state, "user") and hasattr(request.state.user, "email"): 89 ↛ 93line 89 didn't jump to line 93 because the condition on line 89 was always true
90 user_email = request.state.user.email
92 # Extract W3C Trace Context from headers (for distributed tracing)
93 external_trace_id = None
94 external_parent_span_id = None
95 traceparent_header = request.headers.get("traceparent")
96 if traceparent_header: 96 ↛ 102line 96 didn't jump to line 102 because the condition on line 96 was always true
97 parsed = parse_traceparent(traceparent_header)
98 if parsed:
99 external_trace_id, external_parent_span_id, _flags = parsed
100 logger.debug(f"Extracted W3C trace context: trace_id={external_trace_id}, parent_span_id={external_parent_span_id}")
102 db = None
103 trace_id = None
104 span_id = None
105 start_time = time.time()
107 try:
108 # Create database session
109 db = SessionLocal()
111 # Start trace (use external trace_id if provided for distributed tracing)
112 trace_id = self.service.start_trace(
113 db=db,
114 name=f"{http_method} {request.url.path}",
115 trace_id=external_trace_id, # Use external trace ID if provided
116 parent_span_id=external_parent_span_id, # Track parent span from upstream
117 http_method=http_method,
118 http_url=http_url,
119 user_email=user_email,
120 user_agent=user_agent,
121 ip_address=ip_address,
122 attributes={
123 "http.route": request.url.path,
124 "http.query": str(request.url.query) if request.url.query else None,
125 },
126 resource_attributes={
127 "service.name": "mcp-gateway",
128 "service.version": getattr(settings, "version", "unknown"),
129 },
130 )
132 # Store trace_id in request state for use in route handlers
133 request.state.trace_id = trace_id
135 # Set trace_id in context variable for access throughout async call stack
136 current_trace_id.set(trace_id)
138 # Attach trace_id to database session for SQL query instrumentation
139 attach_trace_to_session(db, trace_id)
141 # Start request span
142 span_id = self.service.start_span(db=db, trace_id=trace_id, name="http.request", kind="server", attributes={"http.method": http_method, "http.url": http_url})
144 except Exception as e:
145 # If trace setup failed, log and continue without tracing
146 logger.warning(f"Failed to setup observability trace: {e}")
147 # Close db if it was created
148 if db:
149 try:
150 db.rollback() # Error path - rollback any partial transaction
151 db.close()
152 except Exception as close_error:
153 logger.debug(f"Failed to close database session during cleanup: {close_error}")
154 # Continue without tracing
155 return await call_next(request)
157 # Process request (trace is set up at this point)
158 try:
159 response = await call_next(request)
160 status_code = response.status_code
162 # End span successfully
163 if span_id: 163 ↛ 175line 163 didn't jump to line 175 because the condition on line 163 was always true
164 try:
165 self.service.end_span(
166 db,
167 span_id,
168 status="ok" if status_code < 400 else "error",
169 attributes={"http.status_code": status_code, "http.response_size": response.headers.get("content-length")},
170 )
171 except Exception as end_span_error:
172 logger.warning(f"Failed to end span {span_id}: {end_span_error}")
174 # End trace
175 if trace_id: 175 ↛ 188line 175 didn't jump to line 188 because the condition on line 175 was always true
176 duration_ms = (time.time() - start_time) * 1000
177 try:
178 self.service.end_trace(
179 db,
180 trace_id,
181 status="ok" if status_code < 400 else "error",
182 http_status_code=status_code,
183 attributes={"response_time_ms": duration_ms},
184 )
185 except Exception as end_trace_error:
186 logger.warning(f"Failed to end trace {trace_id}: {end_trace_error}")
188 return response
190 except Exception as e:
191 # Log exception in span
192 if span_id: 192 ↛ 211line 192 didn't jump to line 211 because the condition on line 192 was always true
193 try:
194 self.service.end_span(db, span_id, status="error", status_message=str(e), attributes={"exception.type": type(e).__name__, "exception.message": str(e)})
196 # Add exception event
197 self.service.add_event(
198 db,
199 span_id,
200 name="exception",
201 severity="error",
202 message=str(e),
203 exception_type=type(e).__name__,
204 exception_message=str(e),
205 exception_stacktrace=traceback.format_exc(),
206 )
207 except Exception as log_error:
208 logger.warning(f"Failed to log exception in span: {log_error}")
210 # End trace with error
211 if trace_id: 211 ↛ 218line 211 didn't jump to line 218 because the condition on line 211 was always true
212 try:
213 self.service.end_trace(db, trace_id, status="error", status_message=str(e), http_status_code=500)
214 except Exception as trace_error:
215 logger.warning(f"Failed to end trace: {trace_error}")
217 # Re-raise the original exception
218 raise
220 finally:
221 # Always close database session - observability service handles its own commits
222 if db:
223 try:
224 if db.in_transaction(): 224 ↛ 226line 224 didn't jump to line 226 because the condition on line 224 was always true
225 db.rollback()
226 db.close()
227 except Exception as close_error:
228 logger.warning(f"Failed to close database session: {close_error}")