Coverage for mcpgateway / middleware / db_query_logging.py: 98%

183 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Database query logging middleware for N+1 detection. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6 

7This middleware logs all database queries per request to help identify 

8N+1 query patterns and other performance issues. 

9 

10Enable with: 

11 DB_QUERY_LOG_ENABLED=true 

12 

13Output files: 

14 - logs/db-queries.log (human-readable text) 

15 - logs/db-queries.jsonl (JSON Lines for tooling) 

16""" 

17 

18# Standard 

19from contextvars import ContextVar 

20from datetime import datetime, timezone 

21import logging 

22from pathlib import Path 

23import re 

24import threading 

25import time 

26from typing import Any, Dict, List, Optional, Pattern 

27 

28# Third-Party 

29import orjson 

30from sqlalchemy import event 

31from sqlalchemy.engine import Engine 

32from starlette.middleware.base import BaseHTTPMiddleware 

33from starlette.requests import Request 

34from starlette.responses import Response 

35 

36# First-Party 

37from mcpgateway.config import get_settings 

38from mcpgateway.middleware.path_filter import should_skip_db_query_logging 

39 

40logger = logging.getLogger(__name__) 

41 

42# ============================================================================ 

43# Precompiled regex patterns for query normalization (compiled once at module load) 

44# ============================================================================ 

45_QUOTED_STRING_RE: Pattern[str] = re.compile(r"'[^']*'") 

46_NUMBER_RE: Pattern[str] = re.compile(r"\b\d+\b") 

47_IN_CLAUSE_RE: Pattern[str] = re.compile(r"IN\s*\([^)]+\)", re.IGNORECASE) 

48_WHITESPACE_RE: Pattern[str] = re.compile(r"\s+") 

49_TABLE_NAME_RE: Pattern[str] = re.compile(r"(?:FROM|INTO|UPDATE)\s+[\"']?(\w+)[\"']?", re.IGNORECASE) 

50 

51# Context variable to track queries per request 

52_request_context: ContextVar[Optional[Dict[str, Any]]] = ContextVar("db_query_request_context", default=None) 

53 

54# Lock for thread-safe file writing 

55_file_lock = threading.Lock() 

56 

57# Track if we've already instrumented the engine 

58_instrumented_engines: set = set() 

59 

60 

61def _normalize_query(sql: str) -> str: 

62 """Normalize a SQL query for pattern detection. 

63 

64 Replaces specific values with placeholders to identify similar queries. 

65 Uses precompiled regex patterns for performance. 

66 

67 Args: 

68 sql: The SQL query string 

69 

70 Returns: 

71 Normalized query string 

72 """ 

73 # Replace quoted strings (uses precompiled regex) 

74 normalized = _QUOTED_STRING_RE.sub("'?'", sql) 

75 # Replace numbers (uses precompiled regex) 

76 normalized = _NUMBER_RE.sub("?", normalized) 

77 # Replace IN clauses with multiple values (uses precompiled regex) 

78 normalized = _IN_CLAUSE_RE.sub("IN (?)", normalized) 

79 # Normalize whitespace (uses precompiled regex) 

80 normalized = _WHITESPACE_RE.sub(" ", normalized).strip() 

81 return normalized 

82 

83 

84def _extract_table_name(sql: str) -> Optional[str]: 

85 """Extract the main table name from a SQL query. 

86 

87 Uses precompiled regex pattern for performance. 

88 

89 Args: 

90 sql: The SQL query string 

91 

92 Returns: 

93 Table name or None 

94 """ 

95 # Match FROM table or INTO table or UPDATE table (uses precompiled regex) 

96 match = _TABLE_NAME_RE.search(sql) 

97 if match: 

98 return match.group(1) 

99 return None 

100 

101 

102def _detect_n1_patterns(queries: List[Dict[str, Any]], threshold: int = 3) -> List[Dict[str, Any]]: 

103 """Detect potential N+1 query patterns. 

104 

105 Args: 

106 queries: List of query dictionaries with 'sql' key 

107 threshold: Minimum repetitions to flag as N+1 

108 

109 Returns: 

110 List of detected N+1 patterns with details 

111 """ 

112 patterns: Dict[str, List[int]] = {} 

113 

114 for idx, q in enumerate(queries): 

115 normalized = _normalize_query(q.get("sql", "")) 

116 if normalized not in patterns: 

117 patterns[normalized] = [] 

118 patterns[normalized].append(idx) 

119 

120 n1_issues = [] 

121 for pattern, indices in patterns.items(): 

122 if len(indices) >= threshold: 

123 table = _extract_table_name(pattern) 

124 n1_issues.append( 

125 { 

126 "pattern": pattern[:200], # Truncate long patterns 

127 "count": len(indices), 

128 "table": table, 

129 "query_indices": indices, 

130 } 

131 ) 

132 

133 return sorted(n1_issues, key=lambda x: x["count"], reverse=True) 

134 

135 

136def _format_text_log(request_data: Dict[str, Any], queries: List[Dict[str, Any]], n1_issues: List[Dict[str, Any]]) -> str: 

137 """Format request and queries as human-readable text. 

138 

139 Args: 

140 request_data: Request metadata 

141 queries: List of executed queries 

142 n1_issues: Detected N+1 patterns 

143 

144 Returns: 

145 Formatted text string 

146 """ 

147 lines = [] 

148 separator = "=" * 80 

149 

150 # Header 

151 lines.append(separator) 

152 timestamp = request_data.get("timestamp", datetime.now(timezone.utc).isoformat()) 

153 method = request_data.get("method", "?") 

154 path = request_data.get("path", "?") 

155 lines.append(f"[{timestamp}] {method} {path}") 

156 

157 # Metadata line 

158 meta_parts = [] 

159 if request_data.get("user"): 

160 meta_parts.append(f"User: {request_data['user']}") 

161 if request_data.get("correlation_id"): 

162 meta_parts.append(f"Correlation-ID: {request_data['correlation_id']}") 

163 meta_parts.append(f"Queries: {len(queries)}") 

164 total_ms = sum(q.get("duration_ms", 0) for q in queries) 

165 meta_parts.append(f"Total: {total_ms:.1f}ms") 

166 lines.append(" | ".join(meta_parts)) 

167 lines.append(separator) 

168 

169 # N+1 warnings at top if detected 

170 if n1_issues: 

171 lines.append("") 

172 lines.append("⚠️ POTENTIAL N+1 QUERIES DETECTED:") 

173 for issue in n1_issues: 

174 table_info = f" on '{issue['table']}'" if issue.get("table") else "" 

175 lines.append(f"{issue['count']}x similar queries{table_info}") 

176 lines.append(f" Pattern: {issue['pattern'][:100]}...") 

177 lines.append("") 

178 

179 # Query list 

180 for idx, q in enumerate(queries, 1): 

181 duration = q.get("duration_ms", 0) 

182 sql = q.get("sql", "") 

183 

184 # Check if this query is part of an N+1 pattern 

185 n1_marker = "" 

186 for issue in n1_issues: 

187 if idx - 1 in issue.get("query_indices", []): 187 ↛ 186line 187 didn't jump to line 186 because the condition on line 187 was always true

188 n1_marker = " ← N+1" 

189 break 

190 

191 # Truncate long queries 

192 if len(sql) > 200: 

193 sql = sql[:200] + "..." 

194 

195 lines.append(f" {idx:3}. [{duration:6.1f}ms] {sql}{n1_marker}") 

196 

197 # Footer 

198 lines.append("-" * 80) 

199 if n1_issues: 

200 lines.append(f"⚠️ {len(n1_issues)} potential N+1 pattern(s) detected - see docs/docs/development/db-performance.md") 

201 lines.append(f"Total: {len(queries)} queries, {total_ms:.1f}ms") 

202 lines.append(separator) 

203 lines.append("") 

204 

205 return "\n".join(lines) 

206 

207 

208def _format_json_log(request_data: Dict[str, Any], queries: List[Dict[str, Any]], n1_issues: List[Dict[str, Any]]) -> str: 

209 """Format request and queries as JSON. 

210 

211 Args: 

212 request_data: Request metadata 

213 queries: List of executed queries 

214 n1_issues: Detected N+1 patterns 

215 

216 Returns: 

217 JSON string (single line) 

218 """ 

219 total_ms = sum(q.get("duration_ms", 0) for q in queries) 

220 

221 log_entry = { 

222 "timestamp": request_data.get("timestamp", datetime.now(timezone.utc).isoformat()), 

223 "method": request_data.get("method"), 

224 "path": request_data.get("path"), 

225 "user": request_data.get("user"), 

226 "correlation_id": request_data.get("correlation_id"), 

227 "status_code": request_data.get("status_code"), 

228 "query_count": len(queries), 

229 "total_query_ms": round(total_ms, 2), 

230 "request_duration_ms": request_data.get("request_duration_ms"), 

231 "n1_issues": n1_issues if n1_issues else None, 

232 "queries": [ 

233 { 

234 "sql": q.get("sql", "")[:500], # Truncate long queries 

235 "duration_ms": round(q.get("duration_ms", 0), 2), 

236 "table": _extract_table_name(q.get("sql", "")), 

237 } 

238 for q in queries 

239 ], 

240 } 

241 

242 return orjson.dumps(log_entry, default=str).decode() 

243 

244 

245def _write_logs(request_data: Dict[str, Any], queries: List[Dict[str, Any]]) -> None: 

246 """Write query logs to file(s). 

247 

248 Args: 

249 request_data: Request metadata 

250 queries: List of executed queries 

251 """ 

252 settings = get_settings() 

253 

254 # Skip if no queries or below threshold 

255 if not queries or len(queries) < settings.db_query_log_min_queries: 

256 return 

257 

258 # Detect N+1 patterns 

259 n1_issues = [] 

260 if settings.db_query_log_detect_n1: 260 ↛ 263line 260 didn't jump to line 263 because the condition on line 260 was always true

261 n1_issues = _detect_n1_patterns(queries, settings.db_query_log_n1_threshold) 

262 

263 log_format = settings.db_query_log_format.lower() 

264 

265 with _file_lock: 

266 # Write text log 

267 if log_format in ("text", "both"): 267 ↛ 274line 267 didn't jump to line 274 because the condition on line 267 was always true

268 text_path = Path(settings.db_query_log_file) 

269 text_path.parent.mkdir(parents=True, exist_ok=True) 

270 with open(text_path, "a", encoding="utf-8") as f: 

271 f.write(_format_text_log(request_data, queries, n1_issues)) 

272 

273 # Write JSON log 

274 if log_format in ("json", "both"): 274 ↛ exitline 274 didn't jump to the function exit

275 json_path = Path(settings.db_query_log_json_file) 

276 json_path.parent.mkdir(parents=True, exist_ok=True) 

277 with open(json_path, "a", encoding="utf-8") as f: 

278 f.write(_format_json_log(request_data, queries, n1_issues) + "\n") 

279 

280 

281def _before_cursor_execute(conn: Any, _cursor: Any, _statement: str, _parameters: Any, _context: Any, _executemany: bool) -> None: 

282 """SQLAlchemy event handler called before query execution. 

283 

284 Args: 

285 conn: Database connection 

286 _cursor: Database cursor (unused, required by SQLAlchemy event signature) 

287 _statement: SQL statement to execute (unused, required by SQLAlchemy event signature) 

288 _parameters: Query parameters (unused, required by SQLAlchemy event signature) 

289 _context: Execution context (unused, required by SQLAlchemy event signature) 

290 _executemany: Whether this is an executemany call (unused, required by SQLAlchemy event signature) 

291 """ 

292 ctx = _request_context.get() 

293 if ctx is None: 

294 return 

295 

296 # Store start time on the connection 

297 conn.info["_query_start_time"] = time.perf_counter() 

298 

299 

300# Tables to exclude from query logging (internal/observability tables) 

301_EXCLUDED_TABLES = { 

302 "observability_traces", 

303 "observability_spans", 

304 "observability_events", 

305 "observability_metrics", 

306 "structured_log_entries", 

307 "audit_logs", 

308 "security_events", 

309} 

310 

311 

312def _should_exclude_query(statement: str) -> bool: 

313 """Check if a query should be excluded from logging. 

314 

315 Args: 

316 statement: SQL statement 

317 

318 Returns: 

319 True if the query should be excluded 

320 """ 

321 statement_upper = statement.upper() 

322 for table in _EXCLUDED_TABLES: 

323 if table.upper() in statement_upper: 

324 return True 

325 return False 

326 

327 

328def _after_cursor_execute(conn: Any, _cursor: Any, statement: str, parameters: Any, _context: Any, executemany: bool) -> None: 

329 """SQLAlchemy event handler called after query execution. 

330 

331 Args: 

332 conn: Database connection 

333 _cursor: Database cursor (unused, required by SQLAlchemy event signature) 

334 statement: SQL statement that was executed 

335 parameters: Query parameters 

336 _context: Execution context (unused, required by SQLAlchemy event signature) 

337 executemany: Whether this was an executemany call 

338 """ 

339 ctx = _request_context.get() 

340 if ctx is None: 

341 return 

342 

343 # Skip internal observability queries 

344 if _should_exclude_query(statement): 

345 conn.info.pop("_query_start_time", None) # Clean up 

346 return 

347 

348 # Calculate duration 

349 start_time = conn.info.pop("_query_start_time", None) 

350 duration_ms = (time.perf_counter() - start_time) * 1000 if start_time else 0 

351 

352 # Get settings for parameter inclusion 

353 settings = get_settings() 

354 

355 query_info = { 

356 "sql": statement, 

357 "duration_ms": duration_ms, 

358 "executemany": executemany, 

359 } 

360 

361 if settings.db_query_log_include_params and parameters: 361 ↛ 365line 361 didn't jump to line 365 because the condition on line 361 was always true

362 # Sanitize parameters - don't include actual values by default 

363 query_info["param_count"] = len(parameters) if isinstance(parameters, (list, tuple, dict)) else 1 

364 

365 ctx["queries"].append(query_info) 

366 

367 

368def instrument_engine_for_logging(engine: Engine) -> None: 

369 """Instrument a SQLAlchemy engine for query logging. 

370 

371 Args: 

372 engine: SQLAlchemy engine to instrument 

373 """ 

374 engine_id = id(engine) 

375 if engine_id in _instrumented_engines: 

376 return 

377 

378 event.listen(engine, "before_cursor_execute", _before_cursor_execute) 

379 event.listen(engine, "after_cursor_execute", _after_cursor_execute) 

380 _instrumented_engines.add(engine_id) 

381 logger.info("Database query logging instrumentation enabled") 

382 

383 

384class DBQueryLoggingMiddleware(BaseHTTPMiddleware): 

385 """Middleware to log database queries per request. 

386 

387 This middleware: 

388 1. Creates a request context to collect queries 

389 2. Captures request metadata (method, path, user, correlation ID) 

390 3. After the request, writes all queries to log file(s) 

391 4. Detects and flags potential N+1 query patterns 

392 """ 

393 

394 async def dispatch(self, request: Request, call_next) -> Response: 

395 """Process request and log database queries. 

396 

397 Args: 

398 request: The incoming request 

399 call_next: Next middleware/handler 

400 

401 Returns: 

402 Response from the handler 

403 """ 

404 settings = get_settings() 

405 

406 if not settings.db_query_log_enabled: 

407 return await call_next(request) 

408 

409 # Skip static files and health checks 

410 path = request.url.path 

411 if should_skip_db_query_logging(path): 

412 return await call_next(request) 

413 

414 # Create request context 

415 ctx: Dict[str, Any] = { 

416 "timestamp": datetime.now(timezone.utc).isoformat(), 

417 "method": request.method, 

418 "path": path, 

419 "user": None, 

420 "correlation_id": request.headers.get(settings.correlation_id_header), 

421 "queries": [], 

422 } 

423 

424 # Try to get user from request state (set by auth middleware) 

425 if hasattr(request.state, "user"): 

426 ctx["user"] = getattr(request.state.user, "username", str(request.state.user)) 

427 elif hasattr(request.state, "username"): 427 ↛ 431line 427 didn't jump to line 431 because the condition on line 427 was always true

428 ctx["user"] = request.state.username 

429 

430 # Set context for SQLAlchemy event handlers 

431 token = _request_context.set(ctx) 

432 

433 try: 

434 start_time = time.perf_counter() 

435 response = await call_next(request) 

436 request_duration = (time.perf_counter() - start_time) * 1000 

437 

438 ctx["status_code"] = response.status_code 

439 ctx["request_duration_ms"] = round(request_duration, 2) 

440 

441 return response 

442 finally: 

443 # Write logs 

444 try: 

445 _write_logs(ctx, ctx["queries"]) 

446 except Exception as e: 

447 logger.warning(f"Failed to write query log: {e}") 

448 

449 # Reset context 

450 _request_context.reset(token) 

451 

452 

453def setup_query_logging(app: Any, engine: Engine) -> None: 

454 """Set up database query logging for an application. 

455 

456 Args: 

457 app: FastAPI application 

458 engine: SQLAlchemy engine 

459 """ 

460 settings = get_settings() 

461 

462 if not settings.db_query_log_enabled: 

463 return 

464 

465 # Instrument the engine 

466 instrument_engine_for_logging(engine) 

467 

468 # Add middleware 

469 app.add_middleware(DBQueryLoggingMiddleware) 

470 

471 logger.info(f"Database query logging enabled: " f"format={settings.db_query_log_format}, " f"text_file={settings.db_query_log_file}, " f"json_file={settings.db_query_log_json_file}")