Coverage for mcpgateway / middleware / observability_middleware.py: 100%

121 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/middleware/observability_middleware.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Observability Middleware for automatic request/response tracing. 

8 

9This middleware automatically captures HTTP requests and responses as observability traces, 

10providing comprehensive visibility into all gateway operations. 

11 

12Examples: 

13 >>> from mcpgateway.middleware.observability_middleware import ObservabilityMiddleware # doctest: +SKIP 

14 >>> app.add_middleware(ObservabilityMiddleware) # doctest: +SKIP 

15""" 

16 

17# Standard 

18import logging 

19import time 

20import traceback 

21from typing import Callable, Optional 

22 

23# Third-Party 

24from starlette.middleware.base import BaseHTTPMiddleware 

25from starlette.requests import Request 

26from starlette.responses import Response 

27 

28# First-Party 

29from mcpgateway.config import settings 

30from mcpgateway.db import SessionLocal 

31from mcpgateway.instrumentation.sqlalchemy import attach_trace_to_session 

32from mcpgateway.middleware.path_filter import should_skip_observability 

33from mcpgateway.plugins.framework.observability import current_trace_id as plugins_trace_id 

34from mcpgateway.services.observability_service import current_trace_id, ObservabilityService, parse_traceparent 

35from mcpgateway.utils.log_sanitizer import sanitize_for_log 

36from mcpgateway.utils.trace_redaction import sanitize_trace_text 

37 

38logger = logging.getLogger(__name__) 

39 

40 

41class ObservabilityMiddleware(BaseHTTPMiddleware): 

42 """Middleware for automatic HTTP request/response tracing. 

43 

44 Captures every HTTP request as a trace with timing, status codes, 

45 and user context. Automatically creates spans for the request lifecycle. 

46 

47 This middleware is disabled by default and can be enabled via the 

48 MCPGATEWAY_OBSERVABILITY_ENABLED environment variable. 

49 """ 

50 

51 def __init__(self, app, enabled: bool = None, service: Optional[ObservabilityService] = None): 

52 """Initialize the observability middleware. 

53 

54 Args: 

55 app: ASGI application 

56 enabled: Whether observability is enabled (defaults to settings) 

57 service: Optional ObservabilityService instance 

58 """ 

59 super().__init__(app) 

60 self.enabled = enabled if enabled is not None else getattr(settings, "observability_enabled", False) 

61 self.service = service or ObservabilityService() 

62 logger.info(f"Observability middleware initialized (enabled={self.enabled})") 

63 

64 async def dispatch(self, request: Request, call_next: Callable) -> Response: 

65 """Process request and create observability trace. 

66 

67 Args: 

68 request: Incoming HTTP request 

69 call_next: Next middleware/handler in chain 

70 

71 Returns: 

72 HTTP response 

73 

74 Raises: 

75 Exception: Re-raises any exception from request processing after logging 

76 """ 

77 # Skip if observability is disabled 

78 if not self.enabled: 

79 return await call_next(request) 

80 

81 # Skip health checks and static files to reduce noise 

82 if should_skip_observability(request.url.path): 

83 return await call_next(request) 

84 

85 # Extract request context 

86 http_method = request.method 

87 http_url = str(request.url) 

88 user_email = None 

89 ip_address = request.client.host if request.client else None 

90 user_agent = request.headers.get("user-agent") 

91 

92 # Try to extract user from request state (set by auth middleware) 

93 if hasattr(request.state, "user") and hasattr(request.state.user, "email"): 

94 user_email = request.state.user.email 

95 

96 # Extract W3C Trace Context from headers (for distributed tracing) 

97 external_trace_id = None 

98 external_parent_span_id = None 

99 traceparent_header = request.headers.get("traceparent") 

100 if traceparent_header: 

101 parsed = parse_traceparent(traceparent_header) 

102 if parsed: 

103 external_trace_id, external_parent_span_id, _flags = parsed 

104 logger.debug(f"Extracted W3C trace context: trace_id={external_trace_id}, parent_span_id={external_parent_span_id}") 

105 

106 db = None 

107 trace_id = None 

108 span_id = None 

109 start_time = time.time() 

110 session_owned_by_middleware = False 

111 

112 try: 

113 # Create request-scoped database session and store in request.state 

114 # This session will be reused by route handlers via get_db() dependency, 

115 # eliminating duplicate session creation (Issue #3467) 

116 db = SessionLocal() 

117 logger.debug(f"[OBSERVABILITY] DB session created: {id(db)}") 

118 request.state.db = db 

119 session_owned_by_middleware = True 

120 

121 # Start trace (use external trace_id if provided for distributed tracing) 

122 trace_id = self.service.start_trace( 

123 db=db, 

124 name=f"{http_method} {request.url.path}", 

125 trace_id=external_trace_id, # Use external trace ID if provided 

126 parent_span_id=external_parent_span_id, # Track parent span from upstream 

127 http_method=http_method, 

128 http_url=http_url, 

129 user_email=user_email, 

130 user_agent=user_agent, 

131 ip_address=ip_address, 

132 attributes={ 

133 "http.route": request.url.path, 

134 "http.query": sanitize_trace_text(str(request.url.query)) if request.url.query else None, 

135 }, 

136 resource_attributes={ 

137 "service.name": "mcp-gateway", 

138 "service.version": getattr(settings, "version", "unknown"), 

139 }, 

140 ) 

141 

142 # Store trace_id in request state for use in route handlers 

143 request.state.trace_id = trace_id 

144 

145 # Set trace_id in context variable for access throughout async call stack 

146 current_trace_id.set(trace_id) 

147 # Bridge: also set the framework's ContextVar so the plugin executor sees it 

148 plugins_trace_id.set(trace_id) 

149 

150 # Attach trace_id to database session for SQL query instrumentation 

151 attach_trace_to_session(db, trace_id) 

152 

153 # Start request span 

154 span_id = self.service.start_span(db=db, trace_id=trace_id, name="http.request", kind="server", attributes={"http.method": http_method, "http.url": http_url}) 

155 

156 except Exception as e: 

157 # If trace setup failed, log and continue without tracing 

158 logger.warning(f"Failed to setup observability trace: {e}") 

159 # Close db if it was created 

160 if db: 

161 try: 

162 db.rollback() # Error path - rollback any partial transaction 

163 except Exception as rollback_error: 

164 logger.debug(f"Failed to rollback during cleanup: {rollback_error}") 

165 # Connection is broken - invalidate to remove from pool 

166 try: 

167 db.invalidate() 

168 except Exception: 

169 pass # nosec B110 

170 try: 

171 db.close() 

172 except Exception as close_error: 

173 logger.debug(f"Failed to close database session during cleanup: {close_error}") 

174 # Clean up request.state.db to prevent get_db() from reusing a closed session 

175 if hasattr(request.state, "db"): 

176 delattr(request.state, "db") 

177 # Continue without tracing 

178 return await call_next(request) 

179 

180 # Process request (trace is set up at this point) 

181 # Route handlers will reuse request.state.db via get_db() dependency 

182 try: 

183 response = await call_next(request) 

184 status_code = response.status_code 

185 

186 # End span successfully 

187 if span_id: 

188 try: 

189 self.service.end_span( 

190 db, 

191 span_id, 

192 status="ok" if status_code < 400 else "error", 

193 attributes={"http.status_code": status_code, "http.response_size": response.headers.get("content-length")}, 

194 ) 

195 except Exception as end_span_error: 

196 logger.warning(f"Failed to end span {span_id}: {end_span_error}") 

197 

198 # End trace 

199 if trace_id: 

200 duration_ms = (time.time() - start_time) * 1000 

201 try: 

202 self.service.end_trace( 

203 db, 

204 trace_id, 

205 status="ok" if status_code < 400 else "error", 

206 http_status_code=status_code, 

207 attributes={"response_time_ms": duration_ms}, 

208 ) 

209 except Exception as end_trace_error: 

210 logger.warning(f"Failed to end trace {trace_id}: {end_trace_error}") 

211 

212 # NOTE: Transaction control delegated to get_db() 

213 # Middleware only manages session lifecycle (create/close), not transactions. 

214 # get_db() will commit on success or rollback on error to maintain 

215 # predictable transaction semantics for route handlers (Issue #3731). 

216 

217 return response 

218 

219 except Exception as e: 

220 # Log exception in span 

221 if span_id: 

222 try: 

223 sanitized_error = sanitize_for_log(sanitize_trace_text(str(e))) 

224 self.service.end_span(db, span_id, status="error", status_message=sanitized_error, attributes={"exception.type": type(e).__name__, "exception.message": sanitized_error}) 

225 

226 # Add exception event 

227 self.service.add_event( 

228 db, 

229 span_id, 

230 name="exception", 

231 severity="error", 

232 message=sanitized_error, 

233 exception_type=type(e).__name__, 

234 exception_message=sanitized_error, 

235 exception_stacktrace=traceback.format_exc(), 

236 ) 

237 except Exception as log_error: 

238 logger.warning(f"Failed to log exception in span: {log_error}") 

239 

240 # End trace with error 

241 if trace_id: 

242 try: 

243 sanitized_error = sanitize_for_log(sanitize_trace_text(str(e))) 

244 self.service.end_trace(db, trace_id, status="error", status_message=sanitized_error, http_status_code=500) 

245 except Exception as trace_error: 

246 logger.warning(f"Failed to end trace: {trace_error}") 

247 

248 # Rollback the shared session on error 

249 try: 

250 db.rollback() 

251 except Exception as rollback_error: 

252 logger.warning(f"Failed to rollback database session: {rollback_error}") 

253 # Connection is broken - invalidate to remove from pool 

254 # This handles cases like PgBouncer query_wait_timeout where 

255 # the connection is dead and rollback itself fails 

256 try: 

257 db.invalidate() 

258 except Exception: 

259 pass # nosec B110 

260 

261 # Re-raise the original exception 

262 raise 

263 

264 finally: 

265 # Always close database session and clean up request state 

266 if db and session_owned_by_middleware: 

267 try: 

268 db.close() 

269 except Exception as close_error: 

270 logger.warning(f"Failed to close database session: {close_error}") 

271 # Clean up request.state.db to prevent stale references 

272 if hasattr(request.state, "db"): 

273 delattr(request.state, "db")