Coverage for mcpgateway / instrumentation / sqlalchemy.py: 100%

95 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/instrumentation/sqlalchemy.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Automatic instrumentation for SQLAlchemy database queries. 

8 

9This module instruments SQLAlchemy to automatically capture database 

10queries as observability spans, providing visibility into database 

11performance. 

12 

13Examples: 

14 >>> from mcpgateway.instrumentation import instrument_sqlalchemy # doctest: +SKIP 

15 >>> instrument_sqlalchemy(engine) # doctest: +SKIP 

16""" 

17 

18# Standard 

19import logging 

20import queue 

21import threading 

22import time 

23from typing import Any, Optional 

24 

25# Third-Party 

26from sqlalchemy import event 

27from sqlalchemy.engine import Connection, Engine 

28 

29logger = logging.getLogger(__name__) 

30 

31# Thread-local storage for tracking queries in progress 

32_query_tracking = {} 

33 

34# Thread-local flag to prevent recursive instrumentation 

35_instrumentation_context = threading.local() 

36 

37# Background queue for deferred span writes to avoid database locks 

38_span_queue: queue.Queue = queue.Queue(maxsize=1000) 

39_span_writer_thread: Optional[threading.Thread] = None 

40_shutdown_event = threading.Event() 

41 

42 

43def _write_span_to_db(span_data: dict) -> None: 

44 """Write a single span to the database. 

45 

46 Args: 

47 span_data: Dictionary containing span information 

48 """ 

49 try: 

50 # Import here to avoid circular imports 

51 # First-Party 

52 # pylint: disable=import-outside-toplevel 

53 from mcpgateway.db import ObservabilitySpan, SessionLocal 

54 from mcpgateway.services.observability_service import ObservabilityService 

55 

56 # pylint: enable=import-outside-toplevel 

57 

58 service = ObservabilityService() 

59 db = SessionLocal() 

60 try: 

61 span_id = service.start_span( 

62 db=db, 

63 trace_id=span_data["trace_id"], 

64 name=span_data["name"], 

65 kind=span_data["kind"], 

66 resource_type=span_data["resource_type"], 

67 resource_name=span_data["resource_name"], 

68 attributes=span_data["start_attributes"], 

69 ) 

70 

71 # End span with measured duration in attributes 

72 service.end_span( 

73 db=db, 

74 span_id=span_id, 

75 status=span_data["status"], 

76 attributes=span_data["end_attributes"], 

77 ) 

78 

79 # Update the span duration to match what we actually measured 

80 span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first() 

81 if span: 

82 span.duration_ms = span_data["duration_ms"] 

83 db.commit() 

84 

85 logger.debug(f"Created span for {span_data['resource_name']} query: " f"{span_data['duration_ms']:.2f}ms, {span_data.get('row_count')} rows") 

86 

87 finally: 

88 db.close() # Commit already done above 

89 

90 except Exception as e: # pylint: disable=broad-except 

91 # Don't fail if span creation fails 

92 logger.warning(f"Failed to write query span: {e}") 

93 

94 

95def _span_writer_worker() -> None: 

96 """Background worker thread that writes spans to the database. 

97 

98 This runs in a separate thread to avoid blocking the main request thread 

99 and to prevent database lock contention. 

100 """ 

101 logger.info("Span writer worker started") 

102 

103 while not _shutdown_event.is_set(): 

104 try: 

105 # Wait for span data with timeout to allow checking shutdown 

106 try: 

107 span_data = _span_queue.get(timeout=1.0) 

108 except queue.Empty: 

109 continue 

110 

111 # Write the span to the database 

112 _write_span_to_db(span_data) 

113 _span_queue.task_done() 

114 

115 except Exception as e: # pylint: disable=broad-except 

116 logger.error(f"Error in span writer worker: {e}") 

117 # Continue processing even if one span fails 

118 

119 logger.info("Span writer worker stopped") 

120 

121 

122def instrument_sqlalchemy(engine: Engine) -> None: 

123 """Instrument a SQLAlchemy engine to capture query spans. 

124 

125 Args: 

126 engine: SQLAlchemy engine to instrument 

127 

128 Examples: 

129 >>> from sqlalchemy import create_engine # doctest: +SKIP 

130 >>> engine = create_engine("sqlite:///./mcp.db") # doctest: +SKIP 

131 >>> instrument_sqlalchemy(engine) # doctest: +SKIP 

132 """ 

133 global _span_writer_thread # pylint: disable=global-statement 

134 

135 # Register event listeners 

136 event.listen(engine, "before_cursor_execute", _before_cursor_execute) 

137 event.listen(engine, "after_cursor_execute", _after_cursor_execute) 

138 

139 # Start background span writer thread if not already running 

140 if _span_writer_thread is None or not _span_writer_thread.is_alive(): 

141 _span_writer_thread = threading.Thread(target=_span_writer_worker, name="SpanWriterThread", daemon=True) 

142 _span_writer_thread.start() 

143 logger.info("Started background span writer thread") 

144 

145 logger.info("SQLAlchemy instrumentation enabled") 

146 

147 

148def _before_cursor_execute( 

149 conn: Connection, 

150 _cursor: Any, 

151 statement: str, 

152 parameters: Any, 

153 _context: Any, 

154 executemany: bool, 

155) -> None: 

156 """Event handler called before SQL query execution. 

157 

158 Args: 

159 conn: Database connection 

160 _cursor: Database cursor (required by SQLAlchemy event API) 

161 statement: SQL statement 

162 parameters: Query parameters 

163 _context: Execution context (required by SQLAlchemy event API) 

164 executemany: Whether this is a bulk execution 

165 """ 

166 # Store start time for this query 

167 conn_id = id(conn) 

168 _query_tracking[conn_id] = { 

169 "start_time": time.time(), 

170 "statement": statement, 

171 "parameters": parameters, 

172 "executemany": executemany, 

173 } 

174 

175 

176def _after_cursor_execute( 

177 conn: Connection, 

178 cursor: Any, 

179 statement: str, 

180 _parameters: Any, 

181 _context: Any, 

182 executemany: bool, 

183) -> None: 

184 """Event handler called after SQL query execution. 

185 

186 Args: 

187 conn: Database connection 

188 cursor: Database cursor 

189 statement: SQL statement 

190 _parameters: Query parameters (required by SQLAlchemy event API) 

191 _context: Execution context (required by SQLAlchemy event API) 

192 executemany: Whether this is a bulk execution 

193 """ 

194 conn_id = id(conn) 

195 tracking = _query_tracking.pop(conn_id, None) 

196 

197 if not tracking: 

198 return 

199 

200 # Skip instrumentation if we're already inside span creation (prevent recursion) 

201 if getattr(_instrumentation_context, "inside_span_creation", False): 

202 return 

203 

204 # Skip instrumentation for observability tables to prevent recursion and lock issues 

205 statement_upper = statement.upper() 

206 if any(table in statement_upper for table in ["OBSERVABILITY_TRACES", "OBSERVABILITY_SPANS", "OBSERVABILITY_EVENTS", "OBSERVABILITY_METRICS"]): 

207 logger.debug(f"Skipping instrumentation for observability table query: {statement[:100]}...") 

208 return 

209 

210 # Calculate query duration 

211 duration_ms = (time.time() - tracking["start_time"]) * 1000 

212 

213 # Get row count if available 

214 row_count = None 

215 try: 

216 if hasattr(cursor, "rowcount") and cursor.rowcount >= 0: 

217 row_count = cursor.rowcount 

218 except Exception: # pylint: disable=broad-except # nosec B110 - row_count is optional metadata 

219 pass 

220 

221 # Try to get trace context from connection info 

222 trace_id = None 

223 if hasattr(conn, "info") and "trace_id" in conn.info: 

224 trace_id = conn.info["trace_id"] 

225 

226 # If we have a trace_id, create a span 

227 if trace_id: 

228 _create_query_span( 

229 trace_id=trace_id, 

230 statement=statement, 

231 duration_ms=duration_ms, 

232 row_count=row_count, 

233 executemany=executemany, 

234 ) 

235 else: 

236 # Log for debugging but don't fail 

237 logger.debug(f"Query executed without trace context: {statement[:100]}... ({duration_ms:.2f}ms)") 

238 

239 

240def _create_query_span( 

241 trace_id: str, 

242 statement: str, 

243 duration_ms: float, 

244 row_count: Optional[int], 

245 executemany: bool, 

246) -> None: 

247 """Create an observability span for a database query. 

248 

249 This function enqueues span data to be written by a background thread, 

250 avoiding database lock contention. 

251 

252 Args: 

253 trace_id: Parent trace ID 

254 statement: SQL statement 

255 duration_ms: Query duration in milliseconds 

256 row_count: Number of rows affected/returned 

257 executemany: Whether this is a bulk execution 

258 """ 

259 try: 

260 # Extract query type (SELECT, INSERT, UPDATE, DELETE, etc.) 

261 query_type = statement.strip().split()[0].upper() if statement else "UNKNOWN" 

262 

263 # Truncate long queries for span name 

264 span_name = f"db.query.{query_type.lower()}" 

265 

266 # Prepare span data 

267 span_data = { 

268 "trace_id": trace_id, 

269 "name": span_name, 

270 "kind": "client", 

271 "resource_type": "database", 

272 "resource_name": query_type, 

273 "duration_ms": duration_ms, 

274 "status": "ok", 

275 "start_attributes": { 

276 "db.statement": statement[:500], # Truncate long queries 

277 "db.operation": query_type, 

278 "db.executemany": executemany, 

279 "db.duration_measured_ms": duration_ms, # Store actual measured duration 

280 }, 

281 "end_attributes": { 

282 "db.row_count": row_count, 

283 }, 

284 "row_count": row_count, 

285 } 

286 

287 # Enqueue for background processing (non-blocking) 

288 try: 

289 _span_queue.put_nowait(span_data) 

290 logger.debug(f"Enqueued span for {query_type} query: {duration_ms:.2f}ms") 

291 except queue.Full: 

292 logger.warning("Span queue is full, dropping span data") 

293 

294 except Exception as e: # pylint: disable=broad-except 

295 # Don't fail the query if span creation fails 

296 logger.debug(f"Failed to enqueue query span: {e}") 

297 

298 

299def attach_trace_to_session(session: Any, trace_id: str) -> None: 

300 """Attach a trace ID to a database session. 

301 

302 This allows the instrumentation to correlate queries with traces. 

303 

304 Args: 

305 session: SQLAlchemy session 

306 trace_id: Trace ID to attach 

307 

308 Examples: 

309 >>> from mcpgateway.db import SessionLocal # doctest: +SKIP 

310 >>> db = SessionLocal() # doctest: +SKIP 

311 >>> attach_trace_to_session(db, trace_id) # doctest: +SKIP 

312 """ 

313 if hasattr(session, "bind") and session.bind: 

314 # Get a connection and attach trace_id to its info dict 

315 connection = session.connection() 

316 if hasattr(connection, "info"): 

317 connection.info["trace_id"] = trace_id