Coverage for mcpgateway / middleware / observability_middleware.py: 100%

96 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/middleware/observability_middleware.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Observability Middleware for automatic request/response tracing. 

8 

9This middleware automatically captures HTTP requests and responses as observability traces, 

10providing comprehensive visibility into all gateway operations. 

11 

12Examples: 

13 >>> from mcpgateway.middleware.observability_middleware import ObservabilityMiddleware # doctest: +SKIP 

14 >>> app.add_middleware(ObservabilityMiddleware) # doctest: +SKIP 

15""" 

16 

17# Standard 

18import logging 

19import time 

20import traceback 

21from typing import Callable, Optional 

22 

23# Third-Party 

24from starlette.middleware.base import BaseHTTPMiddleware 

25from starlette.requests import Request 

26from starlette.responses import Response 

27 

28# First-Party 

29from mcpgateway.config import settings 

30from mcpgateway.db import SessionLocal 

31from mcpgateway.instrumentation.sqlalchemy import attach_trace_to_session 

32from mcpgateway.middleware.path_filter import should_skip_observability 

33from mcpgateway.plugins.framework.observability import current_trace_id as plugins_trace_id 

34from mcpgateway.services.observability_service import current_trace_id, ObservabilityService, parse_traceparent 

35 

36logger = logging.getLogger(__name__) 

37 

38 

39class ObservabilityMiddleware(BaseHTTPMiddleware): 

40 """Middleware for automatic HTTP request/response tracing. 

41 

42 Captures every HTTP request as a trace with timing, status codes, 

43 and user context. Automatically creates spans for the request lifecycle. 

44 

45 This middleware is disabled by default and can be enabled via the 

46 MCPGATEWAY_OBSERVABILITY_ENABLED environment variable. 

47 """ 

48 

49 def __init__(self, app, enabled: bool = None, service: Optional[ObservabilityService] = None): 

50 """Initialize the observability middleware. 

51 

52 Args: 

53 app: ASGI application 

54 enabled: Whether observability is enabled (defaults to settings) 

55 service: Optional ObservabilityService instance 

56 """ 

57 super().__init__(app) 

58 self.enabled = enabled if enabled is not None else getattr(settings, "observability_enabled", False) 

59 self.service = service or ObservabilityService() 

60 logger.info(f"Observability middleware initialized (enabled={self.enabled})") 

61 

62 async def dispatch(self, request: Request, call_next: Callable) -> Response: 

63 """Process request and create observability trace. 

64 

65 Args: 

66 request: Incoming HTTP request 

67 call_next: Next middleware/handler in chain 

68 

69 Returns: 

70 HTTP response 

71 

72 Raises: 

73 Exception: Re-raises any exception from request processing after logging 

74 """ 

75 # Skip if observability is disabled 

76 if not self.enabled: 

77 return await call_next(request) 

78 

79 # Skip health checks and static files to reduce noise 

80 if should_skip_observability(request.url.path): 

81 return await call_next(request) 

82 

83 # Extract request context 

84 http_method = request.method 

85 http_url = str(request.url) 

86 user_email = None 

87 ip_address = request.client.host if request.client else None 

88 user_agent = request.headers.get("user-agent") 

89 

90 # Try to extract user from request state (set by auth middleware) 

91 if hasattr(request.state, "user") and hasattr(request.state.user, "email"): 

92 user_email = request.state.user.email 

93 

94 # Extract W3C Trace Context from headers (for distributed tracing) 

95 external_trace_id = None 

96 external_parent_span_id = None 

97 traceparent_header = request.headers.get("traceparent") 

98 if traceparent_header: 

99 parsed = parse_traceparent(traceparent_header) 

100 if parsed: 

101 external_trace_id, external_parent_span_id, _flags = parsed 

102 logger.debug(f"Extracted W3C trace context: trace_id={external_trace_id}, parent_span_id={external_parent_span_id}") 

103 

104 db = None 

105 trace_id = None 

106 span_id = None 

107 start_time = time.time() 

108 

109 try: 

110 # Create database session 

111 db = SessionLocal() 

112 

113 # Start trace (use external trace_id if provided for distributed tracing) 

114 trace_id = self.service.start_trace( 

115 db=db, 

116 name=f"{http_method} {request.url.path}", 

117 trace_id=external_trace_id, # Use external trace ID if provided 

118 parent_span_id=external_parent_span_id, # Track parent span from upstream 

119 http_method=http_method, 

120 http_url=http_url, 

121 user_email=user_email, 

122 user_agent=user_agent, 

123 ip_address=ip_address, 

124 attributes={ 

125 "http.route": request.url.path, 

126 "http.query": str(request.url.query) if request.url.query else None, 

127 }, 

128 resource_attributes={ 

129 "service.name": "mcp-gateway", 

130 "service.version": getattr(settings, "version", "unknown"), 

131 }, 

132 ) 

133 

134 # Store trace_id in request state for use in route handlers 

135 request.state.trace_id = trace_id 

136 

137 # Set trace_id in context variable for access throughout async call stack 

138 current_trace_id.set(trace_id) 

139 # Bridge: also set the framework's ContextVar so the plugin executor sees it 

140 plugins_trace_id.set(trace_id) 

141 

142 # Attach trace_id to database session for SQL query instrumentation 

143 attach_trace_to_session(db, trace_id) 

144 

145 # Start request span 

146 span_id = self.service.start_span(db=db, trace_id=trace_id, name="http.request", kind="server", attributes={"http.method": http_method, "http.url": http_url}) 

147 

148 except Exception as e: 

149 # If trace setup failed, log and continue without tracing 

150 logger.warning(f"Failed to setup observability trace: {e}") 

151 # Close db if it was created 

152 if db: 

153 try: 

154 db.rollback() # Error path - rollback any partial transaction 

155 db.close() 

156 except Exception as close_error: 

157 logger.debug(f"Failed to close database session during cleanup: {close_error}") 

158 # Continue without tracing 

159 return await call_next(request) 

160 

161 # Process request (trace is set up at this point) 

162 try: 

163 response = await call_next(request) 

164 status_code = response.status_code 

165 

166 # End span successfully 

167 if span_id: 

168 try: 

169 self.service.end_span( 

170 db, 

171 span_id, 

172 status="ok" if status_code < 400 else "error", 

173 attributes={"http.status_code": status_code, "http.response_size": response.headers.get("content-length")}, 

174 ) 

175 except Exception as end_span_error: 

176 logger.warning(f"Failed to end span {span_id}: {end_span_error}") 

177 

178 # End trace 

179 if trace_id: 

180 duration_ms = (time.time() - start_time) * 1000 

181 try: 

182 self.service.end_trace( 

183 db, 

184 trace_id, 

185 status="ok" if status_code < 400 else "error", 

186 http_status_code=status_code, 

187 attributes={"response_time_ms": duration_ms}, 

188 ) 

189 except Exception as end_trace_error: 

190 logger.warning(f"Failed to end trace {trace_id}: {end_trace_error}") 

191 

192 return response 

193 

194 except Exception as e: 

195 # Log exception in span 

196 if span_id: 

197 try: 

198 self.service.end_span(db, span_id, status="error", status_message=str(e), attributes={"exception.type": type(e).__name__, "exception.message": str(e)}) 

199 

200 # Add exception event 

201 self.service.add_event( 

202 db, 

203 span_id, 

204 name="exception", 

205 severity="error", 

206 message=str(e), 

207 exception_type=type(e).__name__, 

208 exception_message=str(e), 

209 exception_stacktrace=traceback.format_exc(), 

210 ) 

211 except Exception as log_error: 

212 logger.warning(f"Failed to log exception in span: {log_error}") 

213 

214 # End trace with error 

215 if trace_id: 

216 try: 

217 self.service.end_trace(db, trace_id, status="error", status_message=str(e), http_status_code=500) 

218 except Exception as trace_error: 

219 logger.warning(f"Failed to end trace: {trace_error}") 

220 

221 # Re-raise the original exception 

222 raise 

223 

224 finally: 

225 # Always close database session - observability service handles its own commits 

226 if db: 

227 try: 

228 if db.in_transaction(): 

229 db.rollback() 

230 db.close() 

231 except Exception as close_error: 

232 logger.warning(f"Failed to close database session: {close_error}")