Coverage for mcpgateway / middleware / observability_middleware.py: 94%

94 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/middleware/observability_middleware.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Observability Middleware for automatic request/response tracing. 

8 

9This middleware automatically captures HTTP requests and responses as observability traces, 

10providing comprehensive visibility into all gateway operations. 

11 

12Examples: 

13 >>> from mcpgateway.middleware.observability_middleware import ObservabilityMiddleware # doctest: +SKIP 

14 >>> app.add_middleware(ObservabilityMiddleware) # doctest: +SKIP 

15""" 

16 

17# Standard 

18import logging 

19import time 

20import traceback 

21from typing import Callable 

22 

23# Third-Party 

24from starlette.middleware.base import BaseHTTPMiddleware 

25from starlette.requests import Request 

26from starlette.responses import Response 

27 

28# First-Party 

29from mcpgateway.config import settings 

30from mcpgateway.db import SessionLocal 

31from mcpgateway.instrumentation.sqlalchemy import attach_trace_to_session 

32from mcpgateway.middleware.path_filter import should_skip_observability 

33from mcpgateway.services.observability_service import current_trace_id, ObservabilityService, parse_traceparent 

34 

35logger = logging.getLogger(__name__) 

36 

37 

38class ObservabilityMiddleware(BaseHTTPMiddleware): 

39 """Middleware for automatic HTTP request/response tracing. 

40 

41 Captures every HTTP request as a trace with timing, status codes, 

42 and user context. Automatically creates spans for the request lifecycle. 

43 

44 This middleware is disabled by default and can be enabled via the 

45 MCPGATEWAY_OBSERVABILITY_ENABLED environment variable. 

46 """ 

47 

48 def __init__(self, app, enabled: bool = None): 

49 """Initialize the observability middleware. 

50 

51 Args: 

52 app: ASGI application 

53 enabled: Whether observability is enabled (defaults to settings) 

54 """ 

55 super().__init__(app) 

56 self.enabled = enabled if enabled is not None else getattr(settings, "observability_enabled", False) 

57 self.service = ObservabilityService() 

58 logger.info(f"Observability middleware initialized (enabled={self.enabled})") 

59 

60 async def dispatch(self, request: Request, call_next: Callable) -> Response: 

61 """Process request and create observability trace. 

62 

63 Args: 

64 request: Incoming HTTP request 

65 call_next: Next middleware/handler in chain 

66 

67 Returns: 

68 HTTP response 

69 

70 Raises: 

71 Exception: Re-raises any exception from request processing after logging 

72 """ 

73 # Skip if observability is disabled 

74 if not self.enabled: 

75 return await call_next(request) 

76 

77 # Skip health checks and static files to reduce noise 

78 if should_skip_observability(request.url.path): 

79 return await call_next(request) 

80 

81 # Extract request context 

82 http_method = request.method 

83 http_url = str(request.url) 

84 user_email = None 

85 ip_address = request.client.host if request.client else None 

86 user_agent = request.headers.get("user-agent") 

87 

88 # Try to extract user from request state (set by auth middleware) 

89 if hasattr(request.state, "user") and hasattr(request.state.user, "email"): 89 ↛ 93line 89 didn't jump to line 93 because the condition on line 89 was always true

90 user_email = request.state.user.email 

91 

92 # Extract W3C Trace Context from headers (for distributed tracing) 

93 external_trace_id = None 

94 external_parent_span_id = None 

95 traceparent_header = request.headers.get("traceparent") 

96 if traceparent_header: 96 ↛ 102line 96 didn't jump to line 102 because the condition on line 96 was always true

97 parsed = parse_traceparent(traceparent_header) 

98 if parsed: 

99 external_trace_id, external_parent_span_id, _flags = parsed 

100 logger.debug(f"Extracted W3C trace context: trace_id={external_trace_id}, parent_span_id={external_parent_span_id}") 

101 

102 db = None 

103 trace_id = None 

104 span_id = None 

105 start_time = time.time() 

106 

107 try: 

108 # Create database session 

109 db = SessionLocal() 

110 

111 # Start trace (use external trace_id if provided for distributed tracing) 

112 trace_id = self.service.start_trace( 

113 db=db, 

114 name=f"{http_method} {request.url.path}", 

115 trace_id=external_trace_id, # Use external trace ID if provided 

116 parent_span_id=external_parent_span_id, # Track parent span from upstream 

117 http_method=http_method, 

118 http_url=http_url, 

119 user_email=user_email, 

120 user_agent=user_agent, 

121 ip_address=ip_address, 

122 attributes={ 

123 "http.route": request.url.path, 

124 "http.query": str(request.url.query) if request.url.query else None, 

125 }, 

126 resource_attributes={ 

127 "service.name": "mcp-gateway", 

128 "service.version": getattr(settings, "version", "unknown"), 

129 }, 

130 ) 

131 

132 # Store trace_id in request state for use in route handlers 

133 request.state.trace_id = trace_id 

134 

135 # Set trace_id in context variable for access throughout async call stack 

136 current_trace_id.set(trace_id) 

137 

138 # Attach trace_id to database session for SQL query instrumentation 

139 attach_trace_to_session(db, trace_id) 

140 

141 # Start request span 

142 span_id = self.service.start_span(db=db, trace_id=trace_id, name="http.request", kind="server", attributes={"http.method": http_method, "http.url": http_url}) 

143 

144 except Exception as e: 

145 # If trace setup failed, log and continue without tracing 

146 logger.warning(f"Failed to setup observability trace: {e}") 

147 # Close db if it was created 

148 if db: 

149 try: 

150 db.rollback() # Error path - rollback any partial transaction 

151 db.close() 

152 except Exception as close_error: 

153 logger.debug(f"Failed to close database session during cleanup: {close_error}") 

154 # Continue without tracing 

155 return await call_next(request) 

156 

157 # Process request (trace is set up at this point) 

158 try: 

159 response = await call_next(request) 

160 status_code = response.status_code 

161 

162 # End span successfully 

163 if span_id: 163 ↛ 175line 163 didn't jump to line 175 because the condition on line 163 was always true

164 try: 

165 self.service.end_span( 

166 db, 

167 span_id, 

168 status="ok" if status_code < 400 else "error", 

169 attributes={"http.status_code": status_code, "http.response_size": response.headers.get("content-length")}, 

170 ) 

171 except Exception as end_span_error: 

172 logger.warning(f"Failed to end span {span_id}: {end_span_error}") 

173 

174 # End trace 

175 if trace_id: 175 ↛ 188line 175 didn't jump to line 188 because the condition on line 175 was always true

176 duration_ms = (time.time() - start_time) * 1000 

177 try: 

178 self.service.end_trace( 

179 db, 

180 trace_id, 

181 status="ok" if status_code < 400 else "error", 

182 http_status_code=status_code, 

183 attributes={"response_time_ms": duration_ms}, 

184 ) 

185 except Exception as end_trace_error: 

186 logger.warning(f"Failed to end trace {trace_id}: {end_trace_error}") 

187 

188 return response 

189 

190 except Exception as e: 

191 # Log exception in span 

192 if span_id: 192 ↛ 211line 192 didn't jump to line 211 because the condition on line 192 was always true

193 try: 

194 self.service.end_span(db, span_id, status="error", status_message=str(e), attributes={"exception.type": type(e).__name__, "exception.message": str(e)}) 

195 

196 # Add exception event 

197 self.service.add_event( 

198 db, 

199 span_id, 

200 name="exception", 

201 severity="error", 

202 message=str(e), 

203 exception_type=type(e).__name__, 

204 exception_message=str(e), 

205 exception_stacktrace=traceback.format_exc(), 

206 ) 

207 except Exception as log_error: 

208 logger.warning(f"Failed to log exception in span: {log_error}") 

209 

210 # End trace with error 

211 if trace_id: 211 ↛ 218line 211 didn't jump to line 218 because the condition on line 211 was always true

212 try: 

213 self.service.end_trace(db, trace_id, status="error", status_message=str(e), http_status_code=500) 

214 except Exception as trace_error: 

215 logger.warning(f"Failed to end trace: {trace_error}") 

216 

217 # Re-raise the original exception 

218 raise 

219 

220 finally: 

221 # Always close database session - observability service handles its own commits 

222 if db: 

223 try: 

224 if db.in_transaction(): 224 ↛ 226line 224 didn't jump to line 226 because the condition on line 224 was always true

225 db.rollback() 

226 db.close() 

227 except Exception as close_error: 

228 logger.warning(f"Failed to close database session: {close_error}")