Coverage for mcpgateway / instrumentation / sqlalchemy.py: 100%
95 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/instrumentation/sqlalchemy.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Automatic instrumentation for SQLAlchemy database queries.
9This module instruments SQLAlchemy to automatically capture database
10queries as observability spans, providing visibility into database
11performance.
13Examples:
14 >>> from mcpgateway.instrumentation import instrument_sqlalchemy # doctest: +SKIP
15 >>> instrument_sqlalchemy(engine) # doctest: +SKIP
16"""
18# Standard
19import logging
20import queue
21import threading
22import time
23from typing import Any, Optional
25# Third-Party
26from sqlalchemy import event
27from sqlalchemy.engine import Connection, Engine
29logger = logging.getLogger(__name__)
31# Thread-local storage for tracking queries in progress
32_query_tracking = {}
34# Thread-local flag to prevent recursive instrumentation
35_instrumentation_context = threading.local()
37# Background queue for deferred span writes to avoid database locks
38_span_queue: queue.Queue = queue.Queue(maxsize=1000)
39_span_writer_thread: Optional[threading.Thread] = None
40_shutdown_event = threading.Event()
43def _write_span_to_db(span_data: dict) -> None:
44 """Write a single span to the database.
46 Args:
47 span_data: Dictionary containing span information
48 """
49 try:
50 # Import here to avoid circular imports
51 # First-Party
52 # pylint: disable=import-outside-toplevel
53 from mcpgateway.db import ObservabilitySpan, SessionLocal
54 from mcpgateway.services.observability_service import ObservabilityService
56 # pylint: enable=import-outside-toplevel
58 service = ObservabilityService()
59 db = SessionLocal()
60 try:
61 span_id = service.start_span(
62 db=db,
63 trace_id=span_data["trace_id"],
64 name=span_data["name"],
65 kind=span_data["kind"],
66 resource_type=span_data["resource_type"],
67 resource_name=span_data["resource_name"],
68 attributes=span_data["start_attributes"],
69 )
71 # End span with measured duration in attributes
72 service.end_span(
73 db=db,
74 span_id=span_id,
75 status=span_data["status"],
76 attributes=span_data["end_attributes"],
77 )
79 # Update the span duration to match what we actually measured
80 span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first()
81 if span:
82 span.duration_ms = span_data["duration_ms"]
83 db.commit()
85 logger.debug(f"Created span for {span_data['resource_name']} query: " f"{span_data['duration_ms']:.2f}ms, {span_data.get('row_count')} rows")
87 finally:
88 db.close() # Commit already done above
90 except Exception as e: # pylint: disable=broad-except
91 # Don't fail if span creation fails
92 logger.warning(f"Failed to write query span: {e}")
95def _span_writer_worker() -> None:
96 """Background worker thread that writes spans to the database.
98 This runs in a separate thread to avoid blocking the main request thread
99 and to prevent database lock contention.
100 """
101 logger.info("Span writer worker started")
103 while not _shutdown_event.is_set():
104 try:
105 # Wait for span data with timeout to allow checking shutdown
106 try:
107 span_data = _span_queue.get(timeout=1.0)
108 except queue.Empty:
109 continue
111 # Write the span to the database
112 _write_span_to_db(span_data)
113 _span_queue.task_done()
115 except Exception as e: # pylint: disable=broad-except
116 logger.error(f"Error in span writer worker: {e}")
117 # Continue processing even if one span fails
119 logger.info("Span writer worker stopped")
122def instrument_sqlalchemy(engine: Engine) -> None:
123 """Instrument a SQLAlchemy engine to capture query spans.
125 Args:
126 engine: SQLAlchemy engine to instrument
128 Examples:
129 >>> from sqlalchemy import create_engine # doctest: +SKIP
130 >>> engine = create_engine("sqlite:///./mcp.db") # doctest: +SKIP
131 >>> instrument_sqlalchemy(engine) # doctest: +SKIP
132 """
133 global _span_writer_thread # pylint: disable=global-statement
135 # Register event listeners
136 event.listen(engine, "before_cursor_execute", _before_cursor_execute)
137 event.listen(engine, "after_cursor_execute", _after_cursor_execute)
139 # Start background span writer thread if not already running
140 if _span_writer_thread is None or not _span_writer_thread.is_alive():
141 _span_writer_thread = threading.Thread(target=_span_writer_worker, name="SpanWriterThread", daemon=True)
142 _span_writer_thread.start()
143 logger.info("Started background span writer thread")
145 logger.info("SQLAlchemy instrumentation enabled")
148def _before_cursor_execute(
149 conn: Connection,
150 _cursor: Any,
151 statement: str,
152 parameters: Any,
153 _context: Any,
154 executemany: bool,
155) -> None:
156 """Event handler called before SQL query execution.
158 Args:
159 conn: Database connection
160 _cursor: Database cursor (required by SQLAlchemy event API)
161 statement: SQL statement
162 parameters: Query parameters
163 _context: Execution context (required by SQLAlchemy event API)
164 executemany: Whether this is a bulk execution
165 """
166 # Store start time for this query
167 conn_id = id(conn)
168 _query_tracking[conn_id] = {
169 "start_time": time.time(),
170 "statement": statement,
171 "parameters": parameters,
172 "executemany": executemany,
173 }
176def _after_cursor_execute(
177 conn: Connection,
178 cursor: Any,
179 statement: str,
180 _parameters: Any,
181 _context: Any,
182 executemany: bool,
183) -> None:
184 """Event handler called after SQL query execution.
186 Args:
187 conn: Database connection
188 cursor: Database cursor
189 statement: SQL statement
190 _parameters: Query parameters (required by SQLAlchemy event API)
191 _context: Execution context (required by SQLAlchemy event API)
192 executemany: Whether this is a bulk execution
193 """
194 conn_id = id(conn)
195 tracking = _query_tracking.pop(conn_id, None)
197 if not tracking:
198 return
200 # Skip instrumentation if we're already inside span creation (prevent recursion)
201 if getattr(_instrumentation_context, "inside_span_creation", False):
202 return
204 # Skip instrumentation for observability tables to prevent recursion and lock issues
205 statement_upper = statement.upper()
206 if any(table in statement_upper for table in ["OBSERVABILITY_TRACES", "OBSERVABILITY_SPANS", "OBSERVABILITY_EVENTS", "OBSERVABILITY_METRICS"]):
207 logger.debug(f"Skipping instrumentation for observability table query: {statement[:100]}...")
208 return
210 # Calculate query duration
211 duration_ms = (time.time() - tracking["start_time"]) * 1000
213 # Get row count if available
214 row_count = None
215 try:
216 if hasattr(cursor, "rowcount") and cursor.rowcount >= 0:
217 row_count = cursor.rowcount
218 except Exception: # pylint: disable=broad-except # nosec B110 - row_count is optional metadata
219 pass
221 # Try to get trace context from connection info
222 trace_id = None
223 if hasattr(conn, "info") and "trace_id" in conn.info:
224 trace_id = conn.info["trace_id"]
226 # If we have a trace_id, create a span
227 if trace_id:
228 _create_query_span(
229 trace_id=trace_id,
230 statement=statement,
231 duration_ms=duration_ms,
232 row_count=row_count,
233 executemany=executemany,
234 )
235 else:
236 # Log for debugging but don't fail
237 logger.debug(f"Query executed without trace context: {statement[:100]}... ({duration_ms:.2f}ms)")
240def _create_query_span(
241 trace_id: str,
242 statement: str,
243 duration_ms: float,
244 row_count: Optional[int],
245 executemany: bool,
246) -> None:
247 """Create an observability span for a database query.
249 This function enqueues span data to be written by a background thread,
250 avoiding database lock contention.
252 Args:
253 trace_id: Parent trace ID
254 statement: SQL statement
255 duration_ms: Query duration in milliseconds
256 row_count: Number of rows affected/returned
257 executemany: Whether this is a bulk execution
258 """
259 try:
260 # Extract query type (SELECT, INSERT, UPDATE, DELETE, etc.)
261 query_type = statement.strip().split()[0].upper() if statement else "UNKNOWN"
263 # Truncate long queries for span name
264 span_name = f"db.query.{query_type.lower()}"
266 # Prepare span data
267 span_data = {
268 "trace_id": trace_id,
269 "name": span_name,
270 "kind": "client",
271 "resource_type": "database",
272 "resource_name": query_type,
273 "duration_ms": duration_ms,
274 "status": "ok",
275 "start_attributes": {
276 "db.statement": statement[:500], # Truncate long queries
277 "db.operation": query_type,
278 "db.executemany": executemany,
279 "db.duration_measured_ms": duration_ms, # Store actual measured duration
280 },
281 "end_attributes": {
282 "db.row_count": row_count,
283 },
284 "row_count": row_count,
285 }
287 # Enqueue for background processing (non-blocking)
288 try:
289 _span_queue.put_nowait(span_data)
290 logger.debug(f"Enqueued span for {query_type} query: {duration_ms:.2f}ms")
291 except queue.Full:
292 logger.warning("Span queue is full, dropping span data")
294 except Exception as e: # pylint: disable=broad-except
295 # Don't fail the query if span creation fails
296 logger.debug(f"Failed to enqueue query span: {e}")
299def attach_trace_to_session(session: Any, trace_id: str) -> None:
300 """Attach a trace ID to a database session.
302 This allows the instrumentation to correlate queries with traces.
304 Args:
305 session: SQLAlchemy session
306 trace_id: Trace ID to attach
308 Examples:
309 >>> from mcpgateway.db import SessionLocal # doctest: +SKIP
310 >>> db = SessionLocal() # doctest: +SKIP
311 >>> attach_trace_to_session(db, trace_id) # doctest: +SKIP
312 """
313 if hasattr(session, "bind") and session.bind:
314 # Get a connection and attach trace_id to its info dict
315 connection = session.connection()
316 if hasattr(connection, "info"):
317 connection.info["trace_id"] = trace_id