Coverage for mcpgateway / middleware / db_query_logging.py: 98%
183 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Database query logging middleware for N+1 detection.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
7This middleware logs all database queries per request to help identify
8N+1 query patterns and other performance issues.
10Enable with:
11 DB_QUERY_LOG_ENABLED=true
13Output files:
14 - logs/db-queries.log (human-readable text)
15 - logs/db-queries.jsonl (JSON Lines for tooling)
16"""
18# Standard
19from contextvars import ContextVar
20from datetime import datetime, timezone
21import logging
22from pathlib import Path
23import re
24import threading
25import time
26from typing import Any, Dict, List, Optional, Pattern
28# Third-Party
29import orjson
30from sqlalchemy import event
31from sqlalchemy.engine import Engine
32from starlette.middleware.base import BaseHTTPMiddleware
33from starlette.requests import Request
34from starlette.responses import Response
36# First-Party
37from mcpgateway.config import get_settings
38from mcpgateway.middleware.path_filter import should_skip_db_query_logging
40logger = logging.getLogger(__name__)
42# ============================================================================
43# Precompiled regex patterns for query normalization (compiled once at module load)
44# ============================================================================
45_QUOTED_STRING_RE: Pattern[str] = re.compile(r"'[^']*'")
46_NUMBER_RE: Pattern[str] = re.compile(r"\b\d+\b")
47_IN_CLAUSE_RE: Pattern[str] = re.compile(r"IN\s*\([^)]+\)", re.IGNORECASE)
48_WHITESPACE_RE: Pattern[str] = re.compile(r"\s+")
49_TABLE_NAME_RE: Pattern[str] = re.compile(r"(?:FROM|INTO|UPDATE)\s+[\"']?(\w+)[\"']?", re.IGNORECASE)
51# Context variable to track queries per request
52_request_context: ContextVar[Optional[Dict[str, Any]]] = ContextVar("db_query_request_context", default=None)
54# Lock for thread-safe file writing
55_file_lock = threading.Lock()
57# Track if we've already instrumented the engine
58_instrumented_engines: set = set()
61def _normalize_query(sql: str) -> str:
62 """Normalize a SQL query for pattern detection.
64 Replaces specific values with placeholders to identify similar queries.
65 Uses precompiled regex patterns for performance.
67 Args:
68 sql: The SQL query string
70 Returns:
71 Normalized query string
72 """
73 # Replace quoted strings (uses precompiled regex)
74 normalized = _QUOTED_STRING_RE.sub("'?'", sql)
75 # Replace numbers (uses precompiled regex)
76 normalized = _NUMBER_RE.sub("?", normalized)
77 # Replace IN clauses with multiple values (uses precompiled regex)
78 normalized = _IN_CLAUSE_RE.sub("IN (?)", normalized)
79 # Normalize whitespace (uses precompiled regex)
80 normalized = _WHITESPACE_RE.sub(" ", normalized).strip()
81 return normalized
84def _extract_table_name(sql: str) -> Optional[str]:
85 """Extract the main table name from a SQL query.
87 Uses precompiled regex pattern for performance.
89 Args:
90 sql: The SQL query string
92 Returns:
93 Table name or None
94 """
95 # Match FROM table or INTO table or UPDATE table (uses precompiled regex)
96 match = _TABLE_NAME_RE.search(sql)
97 if match:
98 return match.group(1)
99 return None
102def _detect_n1_patterns(queries: List[Dict[str, Any]], threshold: int = 3) -> List[Dict[str, Any]]:
103 """Detect potential N+1 query patterns.
105 Args:
106 queries: List of query dictionaries with 'sql' key
107 threshold: Minimum repetitions to flag as N+1
109 Returns:
110 List of detected N+1 patterns with details
111 """
112 patterns: Dict[str, List[int]] = {}
114 for idx, q in enumerate(queries):
115 normalized = _normalize_query(q.get("sql", ""))
116 if normalized not in patterns:
117 patterns[normalized] = []
118 patterns[normalized].append(idx)
120 n1_issues = []
121 for pattern, indices in patterns.items():
122 if len(indices) >= threshold:
123 table = _extract_table_name(pattern)
124 n1_issues.append(
125 {
126 "pattern": pattern[:200], # Truncate long patterns
127 "count": len(indices),
128 "table": table,
129 "query_indices": indices,
130 }
131 )
133 return sorted(n1_issues, key=lambda x: x["count"], reverse=True)
136def _format_text_log(request_data: Dict[str, Any], queries: List[Dict[str, Any]], n1_issues: List[Dict[str, Any]]) -> str:
137 """Format request and queries as human-readable text.
139 Args:
140 request_data: Request metadata
141 queries: List of executed queries
142 n1_issues: Detected N+1 patterns
144 Returns:
145 Formatted text string
146 """
147 lines = []
148 separator = "=" * 80
150 # Header
151 lines.append(separator)
152 timestamp = request_data.get("timestamp", datetime.now(timezone.utc).isoformat())
153 method = request_data.get("method", "?")
154 path = request_data.get("path", "?")
155 lines.append(f"[{timestamp}] {method} {path}")
157 # Metadata line
158 meta_parts = []
159 if request_data.get("user"):
160 meta_parts.append(f"User: {request_data['user']}")
161 if request_data.get("correlation_id"):
162 meta_parts.append(f"Correlation-ID: {request_data['correlation_id']}")
163 meta_parts.append(f"Queries: {len(queries)}")
164 total_ms = sum(q.get("duration_ms", 0) for q in queries)
165 meta_parts.append(f"Total: {total_ms:.1f}ms")
166 lines.append(" | ".join(meta_parts))
167 lines.append(separator)
169 # N+1 warnings at top if detected
170 if n1_issues:
171 lines.append("")
172 lines.append("⚠️ POTENTIAL N+1 QUERIES DETECTED:")
173 for issue in n1_issues:
174 table_info = f" on '{issue['table']}'" if issue.get("table") else ""
175 lines.append(f" • {issue['count']}x similar queries{table_info}")
176 lines.append(f" Pattern: {issue['pattern'][:100]}...")
177 lines.append("")
179 # Query list
180 for idx, q in enumerate(queries, 1):
181 duration = q.get("duration_ms", 0)
182 sql = q.get("sql", "")
184 # Check if this query is part of an N+1 pattern
185 n1_marker = ""
186 for issue in n1_issues:
187 if idx - 1 in issue.get("query_indices", []): 187 ↛ 186line 187 didn't jump to line 186 because the condition on line 187 was always true
188 n1_marker = " ← N+1"
189 break
191 # Truncate long queries
192 if len(sql) > 200:
193 sql = sql[:200] + "..."
195 lines.append(f" {idx:3}. [{duration:6.1f}ms] {sql}{n1_marker}")
197 # Footer
198 lines.append("-" * 80)
199 if n1_issues:
200 lines.append(f"⚠️ {len(n1_issues)} potential N+1 pattern(s) detected - see docs/docs/development/db-performance.md")
201 lines.append(f"Total: {len(queries)} queries, {total_ms:.1f}ms")
202 lines.append(separator)
203 lines.append("")
205 return "\n".join(lines)
208def _format_json_log(request_data: Dict[str, Any], queries: List[Dict[str, Any]], n1_issues: List[Dict[str, Any]]) -> str:
209 """Format request and queries as JSON.
211 Args:
212 request_data: Request metadata
213 queries: List of executed queries
214 n1_issues: Detected N+1 patterns
216 Returns:
217 JSON string (single line)
218 """
219 total_ms = sum(q.get("duration_ms", 0) for q in queries)
221 log_entry = {
222 "timestamp": request_data.get("timestamp", datetime.now(timezone.utc).isoformat()),
223 "method": request_data.get("method"),
224 "path": request_data.get("path"),
225 "user": request_data.get("user"),
226 "correlation_id": request_data.get("correlation_id"),
227 "status_code": request_data.get("status_code"),
228 "query_count": len(queries),
229 "total_query_ms": round(total_ms, 2),
230 "request_duration_ms": request_data.get("request_duration_ms"),
231 "n1_issues": n1_issues if n1_issues else None,
232 "queries": [
233 {
234 "sql": q.get("sql", "")[:500], # Truncate long queries
235 "duration_ms": round(q.get("duration_ms", 0), 2),
236 "table": _extract_table_name(q.get("sql", "")),
237 }
238 for q in queries
239 ],
240 }
242 return orjson.dumps(log_entry, default=str).decode()
245def _write_logs(request_data: Dict[str, Any], queries: List[Dict[str, Any]]) -> None:
246 """Write query logs to file(s).
248 Args:
249 request_data: Request metadata
250 queries: List of executed queries
251 """
252 settings = get_settings()
254 # Skip if no queries or below threshold
255 if not queries or len(queries) < settings.db_query_log_min_queries:
256 return
258 # Detect N+1 patterns
259 n1_issues = []
260 if settings.db_query_log_detect_n1: 260 ↛ 263line 260 didn't jump to line 263 because the condition on line 260 was always true
261 n1_issues = _detect_n1_patterns(queries, settings.db_query_log_n1_threshold)
263 log_format = settings.db_query_log_format.lower()
265 with _file_lock:
266 # Write text log
267 if log_format in ("text", "both"): 267 ↛ 274line 267 didn't jump to line 274 because the condition on line 267 was always true
268 text_path = Path(settings.db_query_log_file)
269 text_path.parent.mkdir(parents=True, exist_ok=True)
270 with open(text_path, "a", encoding="utf-8") as f:
271 f.write(_format_text_log(request_data, queries, n1_issues))
273 # Write JSON log
274 if log_format in ("json", "both"): 274 ↛ exitline 274 didn't jump to the function exit
275 json_path = Path(settings.db_query_log_json_file)
276 json_path.parent.mkdir(parents=True, exist_ok=True)
277 with open(json_path, "a", encoding="utf-8") as f:
278 f.write(_format_json_log(request_data, queries, n1_issues) + "\n")
281def _before_cursor_execute(conn: Any, _cursor: Any, _statement: str, _parameters: Any, _context: Any, _executemany: bool) -> None:
282 """SQLAlchemy event handler called before query execution.
284 Args:
285 conn: Database connection
286 _cursor: Database cursor (unused, required by SQLAlchemy event signature)
287 _statement: SQL statement to execute (unused, required by SQLAlchemy event signature)
288 _parameters: Query parameters (unused, required by SQLAlchemy event signature)
289 _context: Execution context (unused, required by SQLAlchemy event signature)
290 _executemany: Whether this is an executemany call (unused, required by SQLAlchemy event signature)
291 """
292 ctx = _request_context.get()
293 if ctx is None:
294 return
296 # Store start time on the connection
297 conn.info["_query_start_time"] = time.perf_counter()
300# Tables to exclude from query logging (internal/observability tables)
301_EXCLUDED_TABLES = {
302 "observability_traces",
303 "observability_spans",
304 "observability_events",
305 "observability_metrics",
306 "structured_log_entries",
307 "audit_logs",
308 "security_events",
309}
312def _should_exclude_query(statement: str) -> bool:
313 """Check if a query should be excluded from logging.
315 Args:
316 statement: SQL statement
318 Returns:
319 True if the query should be excluded
320 """
321 statement_upper = statement.upper()
322 for table in _EXCLUDED_TABLES:
323 if table.upper() in statement_upper:
324 return True
325 return False
328def _after_cursor_execute(conn: Any, _cursor: Any, statement: str, parameters: Any, _context: Any, executemany: bool) -> None:
329 """SQLAlchemy event handler called after query execution.
331 Args:
332 conn: Database connection
333 _cursor: Database cursor (unused, required by SQLAlchemy event signature)
334 statement: SQL statement that was executed
335 parameters: Query parameters
336 _context: Execution context (unused, required by SQLAlchemy event signature)
337 executemany: Whether this was an executemany call
338 """
339 ctx = _request_context.get()
340 if ctx is None:
341 return
343 # Skip internal observability queries
344 if _should_exclude_query(statement):
345 conn.info.pop("_query_start_time", None) # Clean up
346 return
348 # Calculate duration
349 start_time = conn.info.pop("_query_start_time", None)
350 duration_ms = (time.perf_counter() - start_time) * 1000 if start_time else 0
352 # Get settings for parameter inclusion
353 settings = get_settings()
355 query_info = {
356 "sql": statement,
357 "duration_ms": duration_ms,
358 "executemany": executemany,
359 }
361 if settings.db_query_log_include_params and parameters: 361 ↛ 365line 361 didn't jump to line 365 because the condition on line 361 was always true
362 # Sanitize parameters - don't include actual values by default
363 query_info["param_count"] = len(parameters) if isinstance(parameters, (list, tuple, dict)) else 1
365 ctx["queries"].append(query_info)
368def instrument_engine_for_logging(engine: Engine) -> None:
369 """Instrument a SQLAlchemy engine for query logging.
371 Args:
372 engine: SQLAlchemy engine to instrument
373 """
374 engine_id = id(engine)
375 if engine_id in _instrumented_engines:
376 return
378 event.listen(engine, "before_cursor_execute", _before_cursor_execute)
379 event.listen(engine, "after_cursor_execute", _after_cursor_execute)
380 _instrumented_engines.add(engine_id)
381 logger.info("Database query logging instrumentation enabled")
384class DBQueryLoggingMiddleware(BaseHTTPMiddleware):
385 """Middleware to log database queries per request.
387 This middleware:
388 1. Creates a request context to collect queries
389 2. Captures request metadata (method, path, user, correlation ID)
390 3. After the request, writes all queries to log file(s)
391 4. Detects and flags potential N+1 query patterns
392 """
394 async def dispatch(self, request: Request, call_next) -> Response:
395 """Process request and log database queries.
397 Args:
398 request: The incoming request
399 call_next: Next middleware/handler
401 Returns:
402 Response from the handler
403 """
404 settings = get_settings()
406 if not settings.db_query_log_enabled:
407 return await call_next(request)
409 # Skip static files and health checks
410 path = request.url.path
411 if should_skip_db_query_logging(path):
412 return await call_next(request)
414 # Create request context
415 ctx: Dict[str, Any] = {
416 "timestamp": datetime.now(timezone.utc).isoformat(),
417 "method": request.method,
418 "path": path,
419 "user": None,
420 "correlation_id": request.headers.get(settings.correlation_id_header),
421 "queries": [],
422 }
424 # Try to get user from request state (set by auth middleware)
425 if hasattr(request.state, "user"):
426 ctx["user"] = getattr(request.state.user, "username", str(request.state.user))
427 elif hasattr(request.state, "username"): 427 ↛ 431line 427 didn't jump to line 431 because the condition on line 427 was always true
428 ctx["user"] = request.state.username
430 # Set context for SQLAlchemy event handlers
431 token = _request_context.set(ctx)
433 try:
434 start_time = time.perf_counter()
435 response = await call_next(request)
436 request_duration = (time.perf_counter() - start_time) * 1000
438 ctx["status_code"] = response.status_code
439 ctx["request_duration_ms"] = round(request_duration, 2)
441 return response
442 finally:
443 # Write logs
444 try:
445 _write_logs(ctx, ctx["queries"])
446 except Exception as e:
447 logger.warning(f"Failed to write query log: {e}")
449 # Reset context
450 _request_context.reset(token)
453def setup_query_logging(app: Any, engine: Engine) -> None:
454 """Set up database query logging for an application.
456 Args:
457 app: FastAPI application
458 engine: SQLAlchemy engine
459 """
460 settings = get_settings()
462 if not settings.db_query_log_enabled:
463 return
465 # Instrument the engine
466 instrument_engine_for_logging(engine)
468 # Add middleware
469 app.add_middleware(DBQueryLoggingMiddleware)
471 logger.info(f"Database query logging enabled: " f"format={settings.db_query_log_format}, " f"text_file={settings.db_query_log_file}, " f"json_file={settings.db_query_log_json_file}")