Coverage for mcpgateway / services / observability_service.py: 100%

338 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/observability_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Observability Service Implementation. 

8This module provides OpenTelemetry-style observability for MCP Gateway, 

9capturing traces, spans, events, and metrics for all operations. 

10 

11It includes: 

12- Trace creation and management 

13- Span tracking with hierarchical nesting 

14- Event logging within spans 

15- Metrics collection and storage 

16- Query and filtering capabilities 

17- Integration with FastAPI middleware 

18 

19Examples: 

20 >>> from mcpgateway.services.observability_service import ObservabilityService # doctest: +SKIP 

21 >>> service = ObservabilityService() # doctest: +SKIP 

22 >>> trace_id = service.start_trace(db, "GET /tools", http_method="GET", http_url="/tools") # doctest: +SKIP 

23 >>> span_id = service.start_span(db, trace_id, "database_query", resource_type="database") # doctest: +SKIP 

24 >>> service.end_span(db, span_id, status="ok") # doctest: +SKIP 

25 >>> service.end_trace(db, trace_id, status="ok", http_status_code=200) # doctest: +SKIP 

26""" 

27 

28# Standard 

29from contextlib import contextmanager 

30from contextvars import ContextVar 

31from datetime import datetime, timezone 

32import logging 

33import re 

34import traceback 

35from typing import Any, Dict, List, Optional, Pattern, Tuple 

36import uuid 

37 

38# Third-Party 

39from sqlalchemy import desc 

40from sqlalchemy.exc import SQLAlchemyError 

41from sqlalchemy.orm import joinedload, Session 

42 

43# First-Party 

44from mcpgateway.db import ObservabilityEvent, ObservabilityMetric, ObservabilitySpan, ObservabilityTrace 

45 

46logger = logging.getLogger(__name__) 

47 

48# Precompiled regex for W3C Trace Context traceparent header parsing 

49# Format: version-trace_id-parent_id-trace_flags 

50_TRACEPARENT_RE: Pattern[str] = re.compile(r"^([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})$") 

51 

52# Context variable for tracking the current trace_id across async calls 

53current_trace_id: ContextVar[Optional[str]] = ContextVar("current_trace_id", default=None) 

54 

55 

56def utc_now() -> datetime: 

57 """Return current UTC time with timezone. 

58 

59 Returns: 

60 datetime: Current time in UTC with timezone info 

61 """ 

62 return datetime.now(timezone.utc) 

63 

64 

65def ensure_timezone_aware(dt: datetime) -> datetime: 

66 """Ensure datetime is timezone-aware (UTC). 

67 

68 SQLite returns naive datetimes even when stored with timezone info. 

69 This helper ensures consistency for datetime arithmetic. 

70 

71 Args: 

72 dt: Datetime that may be naive or aware 

73 

74 Returns: 

75 Timezone-aware datetime in UTC 

76 """ 

77 if dt.tzinfo is None: 

78 return dt.replace(tzinfo=timezone.utc) 

79 return dt 

80 

81 

82def parse_traceparent(traceparent: str) -> Optional[Tuple[str, str, str]]: 

83 """Parse W3C Trace Context traceparent header. 

84 

85 Format: version-trace_id-parent_id-trace_flags 

86 Example: 00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01 

87 

88 Args: 

89 traceparent: W3C traceparent header value 

90 

91 Returns: 

92 Tuple of (trace_id, parent_span_id, trace_flags) or None if invalid 

93 

94 Examples: 

95 >>> parse_traceparent("00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01") # doctest: +SKIP 

96 ('0af7651916cd43dd8448eb211c80319c', 'b7ad6b7169203331', '01') 

97 """ 

98 # W3C Trace Context format: 00-trace_id(32hex)-parent_id(16hex)-flags(2hex) 

99 # Uses precompiled regex for performance 

100 match = _TRACEPARENT_RE.match(traceparent.lower()) 

101 

102 if not match: 

103 logger.warning(f"Invalid traceparent format: {traceparent}") 

104 return None 

105 

106 version, trace_id, parent_id, flags = match.groups() 

107 

108 # Only support version 00 for now 

109 if version != "00": 

110 logger.warning(f"Unsupported traceparent version: {version}") 

111 return None 

112 

113 # Validate trace_id and parent_id are not all zeros 

114 if trace_id == "0" * 32 or parent_id == "0" * 16: 

115 logger.warning("Invalid traceparent with zero trace_id or parent_id") 

116 return None 

117 

118 return (trace_id, parent_id, flags) 

119 

120 

121def generate_w3c_trace_id() -> str: 

122 """Generate a W3C compliant trace ID (32 hex characters). 

123 

124 Returns: 

125 32-character lowercase hex string 

126 

127 Examples: 

128 >>> trace_id = generate_w3c_trace_id() # doctest: +SKIP 

129 >>> len(trace_id) # doctest: +SKIP 

130 32 

131 """ 

132 return uuid.uuid4().hex + uuid.uuid4().hex[:16] 

133 

134 

135def generate_w3c_span_id() -> str: 

136 """Generate a W3C compliant span ID (16 hex characters). 

137 

138 Returns: 

139 16-character lowercase hex string 

140 

141 Examples: 

142 >>> span_id = generate_w3c_span_id() # doctest: +SKIP 

143 >>> len(span_id) # doctest: +SKIP 

144 16 

145 """ 

146 return uuid.uuid4().hex[:16] 

147 

148 

149def format_traceparent(trace_id: str, span_id: str, sampled: bool = True) -> str: 

150 """Format a W3C traceparent header value. 

151 

152 Args: 

153 trace_id: 32-character hex trace ID 

154 span_id: 16-character hex span ID 

155 sampled: Whether the trace is sampled (affects trace-flags) 

156 

157 Returns: 

158 W3C traceparent header value 

159 

160 Examples: 

161 >>> format_traceparent("0af7651916cd43dd8448eb211c80319c", "b7ad6b7169203331") # doctest: +SKIP 

162 '00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01' 

163 """ 

164 flags = "01" if sampled else "00" 

165 return f"00-{trace_id}-{span_id}-{flags}" 

166 

167 

168class ObservabilityService: 

169 """Service for managing observability traces, spans, events, and metrics. 

170 

171 This service provides comprehensive observability capabilities similar to 

172 OpenTelemetry, allowing tracking of request flows through the system. 

173 

174 Examples: 

175 >>> service = ObservabilityService() # doctest: +SKIP 

176 >>> trace_id = service.start_trace(db, "POST /tools/invoke") # doctest: +SKIP 

177 >>> span_id = service.start_span(db, trace_id, "tool_execution") # doctest: +SKIP 

178 >>> service.end_span(db, span_id, status="ok") # doctest: +SKIP 

179 >>> service.end_trace(db, trace_id, status="ok") # doctest: +SKIP 

180 """ 

181 

182 def _safe_commit(self, db: Session, context: str) -> bool: 

183 """Commit and rollback on failure without raising. 

184 

185 Args: 

186 db: SQLAlchemy session for the current operation. 

187 context: Short label for the commit context (used in logs). 

188 

189 Returns: 

190 True when commit succeeds, False when a rollback was performed. 

191 """ 

192 try: 

193 db.commit() 

194 return True 

195 except SQLAlchemyError as exc: 

196 logger.warning(f"Observability commit failed ({context}): {exc}") 

197 try: 

198 db.rollback() 

199 except SQLAlchemyError as rollback_exc: 

200 logger.debug(f"Observability rollback failed ({context}): {rollback_exc}") 

201 return False 

202 

203 # ============================== 

204 # Trace Management 

205 # ============================== 

206 

207 def start_trace( 

208 self, 

209 db: Session, 

210 name: str, 

211 trace_id: Optional[str] = None, 

212 parent_span_id: Optional[str] = None, 

213 http_method: Optional[str] = None, 

214 http_url: Optional[str] = None, 

215 user_email: Optional[str] = None, 

216 user_agent: Optional[str] = None, 

217 ip_address: Optional[str] = None, 

218 attributes: Optional[Dict[str, Any]] = None, 

219 resource_attributes: Optional[Dict[str, Any]] = None, 

220 ) -> str: 

221 """Start a new trace. 

222 

223 Args: 

224 db: Database session 

225 name: Trace name (e.g., "POST /tools/invoke") 

226 trace_id: External trace ID (for distributed tracing, W3C format) 

227 parent_span_id: Parent span ID from upstream service 

228 http_method: HTTP method (GET, POST, etc.) 

229 http_url: Full request URL 

230 user_email: Authenticated user email 

231 user_agent: Client user agent string 

232 ip_address: Client IP address 

233 attributes: Additional trace attributes 

234 resource_attributes: Resource attributes (service name, version, etc.) 

235 

236 Returns: 

237 Trace ID (UUID string or W3C format) 

238 

239 Examples: 

240 >>> trace_id = service.start_trace( # doctest: +SKIP 

241 ... db, 

242 ... "POST /tools/invoke", 

243 ... http_method="POST", 

244 ... http_url="https://api.example.com/tools/invoke", 

245 ... user_email="user@example.com" 

246 ... ) 

247 """ 

248 # Use provided trace_id or generate new UUID 

249 if not trace_id: 

250 trace_id = str(uuid.uuid4()) 

251 

252 # Add parent context to attributes if provided 

253 attrs = attributes or {} 

254 if parent_span_id: 

255 attrs["parent_span_id"] = parent_span_id 

256 

257 trace = ObservabilityTrace( 

258 trace_id=trace_id, 

259 name=name, 

260 start_time=utc_now(), 

261 status="unset", 

262 http_method=http_method, 

263 http_url=http_url, 

264 user_email=user_email, 

265 user_agent=user_agent, 

266 ip_address=ip_address, 

267 attributes=attrs, 

268 resource_attributes=resource_attributes or {}, 

269 created_at=utc_now(), 

270 ) 

271 db.add(trace) 

272 self._safe_commit(db, "start_trace") 

273 logger.debug(f"Started trace {trace_id}: {name}") 

274 return trace_id 

275 

276 def end_trace( 

277 self, 

278 db: Session, 

279 trace_id: str, 

280 status: str = "ok", 

281 status_message: Optional[str] = None, 

282 http_status_code: Optional[int] = None, 

283 attributes: Optional[Dict[str, Any]] = None, 

284 ) -> None: 

285 """End a trace. 

286 

287 Args: 

288 db: Database session 

289 trace_id: Trace ID to end 

290 status: Trace status (ok, error) 

291 status_message: Optional status message 

292 http_status_code: HTTP response status code 

293 attributes: Additional attributes to merge 

294 

295 Examples: 

296 >>> service.end_trace( # doctest: +SKIP 

297 ... db, 

298 ... trace_id, 

299 ... status="ok", 

300 ... http_status_code=200 

301 ... ) 

302 """ 

303 trace = db.query(ObservabilityTrace).filter_by(trace_id=trace_id).first() 

304 if not trace: 

305 logger.warning(f"Trace {trace_id} not found") 

306 return 

307 

308 end_time = utc_now() 

309 duration_ms = (end_time - ensure_timezone_aware(trace.start_time)).total_seconds() * 1000 

310 

311 trace.end_time = end_time 

312 trace.duration_ms = duration_ms 

313 trace.status = status 

314 trace.status_message = status_message 

315 if http_status_code is not None: 

316 trace.http_status_code = http_status_code 

317 if attributes: 

318 trace.attributes = {**(trace.attributes or {}), **attributes} 

319 

320 self._safe_commit(db, "end_trace") 

321 logger.debug(f"Ended trace {trace_id}: {status} ({duration_ms:.2f}ms)") 

322 

323 def get_trace(self, db: Session, trace_id: str, include_spans: bool = False) -> Optional[ObservabilityTrace]: 

324 """Get a trace by ID. 

325 

326 Args: 

327 db: Database session 

328 trace_id: Trace ID 

329 include_spans: Whether to load spans eagerly 

330 

331 Returns: 

332 Trace object or None if not found 

333 

334 Examples: 

335 >>> trace = service.get_trace(db, trace_id, include_spans=True) # doctest: +SKIP 

336 >>> if trace: # doctest: +SKIP 

337 ... print(f"Trace: {trace.name}, Spans: {len(trace.spans)}") # doctest: +SKIP 

338 """ 

339 query = db.query(ObservabilityTrace).filter_by(trace_id=trace_id) 

340 if include_spans: 

341 query = query.options(joinedload(ObservabilityTrace.spans)) 

342 return query.first() 

343 

344 # ============================== 

345 # Span Management 

346 # ============================== 

347 

348 def start_span( 

349 self, 

350 db: Session, 

351 trace_id: str, 

352 name: str, 

353 parent_span_id: Optional[str] = None, 

354 kind: str = "internal", 

355 resource_name: Optional[str] = None, 

356 resource_type: Optional[str] = None, 

357 resource_id: Optional[str] = None, 

358 attributes: Optional[Dict[str, Any]] = None, 

359 commit: bool = True, 

360 ) -> str: 

361 """Start a new span within a trace. 

362 

363 Args: 

364 db: Database session 

365 trace_id: Parent trace ID 

366 name: Span name (e.g., "database_query", "tool_invocation") 

367 parent_span_id: Parent span ID (for nested spans) 

368 kind: Span kind (internal, server, client, producer, consumer) 

369 resource_name: Resource name being operated on 

370 resource_type: Resource type (tool, resource, prompt, etc.) 

371 resource_id: Resource ID 

372 attributes: Additional span attributes 

373 commit: Whether to commit the transaction (default True). 

374 Set to False when using fresh_db_session() which handles commits. 

375 

376 Returns: 

377 Span ID (UUID string) 

378 

379 Examples: 

380 >>> span_id = service.start_span( # doctest: +SKIP 

381 ... db, 

382 ... trace_id, 

383 ... "tool_invocation", 

384 ... resource_type="tool", 

385 ... resource_name="get_weather" 

386 ... ) 

387 """ 

388 span_id = str(uuid.uuid4()) 

389 span = ObservabilitySpan( 

390 span_id=span_id, 

391 trace_id=trace_id, 

392 parent_span_id=parent_span_id, 

393 name=name, 

394 kind=kind, 

395 start_time=utc_now(), 

396 status="unset", 

397 resource_name=resource_name, 

398 resource_type=resource_type, 

399 resource_id=resource_id, 

400 attributes=attributes or {}, 

401 created_at=utc_now(), 

402 ) 

403 db.add(span) 

404 if commit: 

405 self._safe_commit(db, "start_span") 

406 logger.debug(f"Started span {span_id}: {name} (trace={trace_id})") 

407 return span_id 

408 

409 def end_span( 

410 self, 

411 db: Session, 

412 span_id: str, 

413 status: str = "ok", 

414 status_message: Optional[str] = None, 

415 attributes: Optional[Dict[str, Any]] = None, 

416 commit: bool = True, 

417 ) -> None: 

418 """End a span. 

419 

420 Args: 

421 db: Database session 

422 span_id: Span ID to end 

423 status: Span status (ok, error) 

424 status_message: Optional status message 

425 attributes: Additional attributes to merge 

426 commit: Whether to commit the transaction (default True). 

427 Set to False when using fresh_db_session() which handles commits. 

428 

429 Examples: 

430 >>> service.end_span(db, span_id, status="ok") # doctest: +SKIP 

431 """ 

432 span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first() 

433 if not span: 

434 logger.warning(f"Span {span_id} not found") 

435 return 

436 

437 end_time = utc_now() 

438 duration_ms = (end_time - ensure_timezone_aware(span.start_time)).total_seconds() * 1000 

439 

440 span.end_time = end_time 

441 span.duration_ms = duration_ms 

442 span.status = status 

443 span.status_message = status_message 

444 if attributes: 

445 span.attributes = {**(span.attributes or {}), **attributes} 

446 

447 if commit: 

448 self._safe_commit(db, "end_span") 

449 logger.debug(f"Ended span {span_id}: {status} ({duration_ms:.2f}ms)") 

450 

451 @contextmanager 

452 def trace_span( 

453 self, 

454 db: Session, 

455 trace_id: str, 

456 name: str, 

457 parent_span_id: Optional[str] = None, 

458 resource_type: Optional[str] = None, 

459 resource_name: Optional[str] = None, 

460 attributes: Optional[Dict[str, Any]] = None, 

461 ): 

462 """Context manager for automatic span lifecycle management. 

463 

464 Args: 

465 db: Database session 

466 trace_id: Parent trace ID 

467 name: Span name 

468 parent_span_id: Parent span ID (optional) 

469 resource_type: Resource type 

470 resource_name: Resource name 

471 attributes: Additional attributes 

472 

473 Yields: 

474 Span ID 

475 

476 Raises: 

477 Exception: Re-raises any exception after logging it in the span 

478 

479 Examples: 

480 >>> with service.trace_span(db, trace_id, "database_query") as span_id: # doctest: +SKIP 

481 ... results = db.query(Tool).all() # doctest: +SKIP 

482 """ 

483 span_id = self.start_span(db, trace_id, name, parent_span_id, resource_type=resource_type, resource_name=resource_name, attributes=attributes) 

484 try: 

485 yield span_id 

486 self.end_span(db, span_id, status="ok") 

487 except Exception as e: 

488 self.end_span(db, span_id, status="error", status_message=str(e)) 

489 self.add_event(db, span_id, "exception", severity="error", message=str(e), exception_type=type(e).__name__, exception_message=str(e), exception_stacktrace=traceback.format_exc()) 

490 raise 

491 

492 @contextmanager 

493 def trace_tool_invocation( 

494 self, 

495 db: Session, 

496 tool_name: str, 

497 arguments: Dict[str, Any], 

498 integration_type: Optional[str] = None, 

499 ): 

500 """Context manager for tracing MCP tool invocations. 

501 

502 This automatically creates a span for tool execution, capturing timing, 

503 arguments, results, and errors. 

504 

505 Args: 

506 db: Database session 

507 tool_name: Name of the tool being invoked 

508 arguments: Tool arguments (will be sanitized) 

509 integration_type: Integration type (MCP, REST, A2A, etc.) 

510 

511 Yields: 

512 Tuple of (span_id, result_dict) - update result_dict with tool results 

513 

514 Raises: 

515 Exception: Re-raises any exception from tool invocation after logging 

516 

517 Examples: 

518 >>> with service.trace_tool_invocation(db, "weather", {"city": "NYC"}) as (span_id, result): # doctest: +SKIP 

519 ... response = await http_client.post(...) # doctest: +SKIP 

520 ... result["status_code"] = response.status_code # doctest: +SKIP 

521 ... result["response_size"] = len(response.content) # doctest: +SKIP 

522 """ 

523 trace_id = current_trace_id.get() 

524 if not trace_id: 

525 # No active trace, yield a no-op 

526 result_dict: Dict[str, Any] = {} 

527 yield (None, result_dict) 

528 return 

529 

530 # Sanitize arguments (remove sensitive data) 

531 safe_args = {k: ("***REDACTED***" if any(sensitive in k.lower() for sensitive in ["password", "token", "key", "secret"]) else v) for k, v in arguments.items()} 

532 

533 # Start tool invocation span 

534 span_id = self.start_span( 

535 db=db, 

536 trace_id=trace_id, 

537 name=f"tool.invoke.{tool_name}", 

538 kind="client", 

539 resource_type="tool", 

540 resource_name=tool_name, 

541 attributes={ 

542 "tool.name": tool_name, 

543 "tool.integration_type": integration_type, 

544 "tool.argument_count": len(arguments), 

545 "tool.arguments": safe_args, 

546 }, 

547 ) 

548 

549 result_dict = {} 

550 try: 

551 yield (span_id, result_dict) 

552 

553 # End span with results 

554 self.end_span( 

555 db=db, 

556 span_id=span_id, 

557 status="ok", 

558 attributes={ 

559 "tool.result": result_dict, 

560 }, 

561 ) 

562 except Exception as e: 

563 # Log error in span 

564 self.end_span(db=db, span_id=span_id, status="error", status_message=str(e)) 

565 

566 self.add_event( 

567 db=db, 

568 span_id=span_id, 

569 name="tool.error", 

570 severity="error", 

571 message=str(e), 

572 exception_type=type(e).__name__, 

573 exception_message=str(e), 

574 exception_stacktrace=traceback.format_exc(), 

575 ) 

576 raise 

577 

578 # ============================== 

579 # Event Management 

580 # ============================== 

581 

582 def add_event( 

583 self, 

584 db: Session, 

585 span_id: str, 

586 name: str, 

587 severity: Optional[str] = None, 

588 message: Optional[str] = None, 

589 exception_type: Optional[str] = None, 

590 exception_message: Optional[str] = None, 

591 exception_stacktrace: Optional[str] = None, 

592 attributes: Optional[Dict[str, Any]] = None, 

593 ) -> int: 

594 """Add an event to a span. 

595 

596 Args: 

597 db: Database session 

598 span_id: Parent span ID 

599 name: Event name 

600 severity: Log severity (debug, info, warning, error, critical) 

601 message: Event message 

602 exception_type: Exception class name 

603 exception_message: Exception message 

604 exception_stacktrace: Exception stacktrace 

605 attributes: Additional event attributes 

606 

607 Returns: 

608 Event ID 

609 

610 Examples: 

611 >>> event_id = service.add_event( # doctest: +SKIP 

612 ... db, # doctest: +SKIP 

613 ... span_id, # doctest: +SKIP 

614 ... "database_connection_error", # doctest: +SKIP 

615 ... severity="error", # doctest: +SKIP 

616 ... message="Failed to connect to database" # doctest: +SKIP 

617 ... ) # doctest: +SKIP 

618 """ 

619 event = ObservabilityEvent( 

620 span_id=span_id, 

621 name=name, 

622 timestamp=utc_now(), 

623 severity=severity, 

624 message=message, 

625 exception_type=exception_type, 

626 exception_message=exception_message, 

627 exception_stacktrace=exception_stacktrace, 

628 attributes=attributes or {}, 

629 created_at=utc_now(), 

630 ) 

631 db.add(event) 

632 if not self._safe_commit(db, "add_event"): 

633 return 0 

634 db.refresh(event) 

635 logger.debug(f"Added event to span {span_id}: {name}") 

636 return event.id 

637 

638 # ============================== 

639 # Token Usage Tracking 

640 # ============================== 

641 

642 def record_token_usage( 

643 self, 

644 db: Session, 

645 span_id: Optional[str] = None, 

646 trace_id: Optional[str] = None, 

647 model: Optional[str] = None, 

648 input_tokens: int = 0, 

649 output_tokens: int = 0, 

650 total_tokens: Optional[int] = None, 

651 estimated_cost_usd: Optional[float] = None, 

652 provider: Optional[str] = None, 

653 ) -> None: 

654 """Record token usage for LLM calls. 

655 

656 Args: 

657 db: Database session 

658 span_id: Span ID to attach token usage to 

659 trace_id: Trace ID (will use current context if not provided) 

660 model: Model name (e.g., "gpt-4", "claude-3-opus") 

661 input_tokens: Number of input/prompt tokens 

662 output_tokens: Number of output/completion tokens 

663 total_tokens: Total tokens (calculated if not provided) 

664 estimated_cost_usd: Estimated cost in USD 

665 provider: LLM provider (openai, anthropic, etc.) 

666 

667 Examples: 

668 >>> service.record_token_usage( # doctest: +SKIP 

669 ... db, span_id="abc123", 

670 ... model="gpt-4", 

671 ... input_tokens=100, 

672 ... output_tokens=50, 

673 ... estimated_cost_usd=0.015 

674 ... ) 

675 """ 

676 if not trace_id: 

677 trace_id = current_trace_id.get() 

678 

679 if not trace_id: 

680 logger.warning("Cannot record token usage: no active trace") 

681 return 

682 

683 # Calculate total if not provided 

684 if total_tokens is None: 

685 total_tokens = input_tokens + output_tokens 

686 

687 # Estimate cost if not provided and we have model info 

688 if estimated_cost_usd is None and model: 

689 estimated_cost_usd = self._estimate_token_cost(model, input_tokens, output_tokens) 

690 

691 # Store in span attributes if span_id provided 

692 if span_id: 

693 span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first() 

694 if span: 

695 attrs = span.attributes or {} 

696 attrs.update( 

697 { 

698 "llm.model": model, 

699 "llm.provider": provider, 

700 "llm.input_tokens": input_tokens, 

701 "llm.output_tokens": output_tokens, 

702 "llm.total_tokens": total_tokens, 

703 "llm.estimated_cost_usd": estimated_cost_usd, 

704 } 

705 ) 

706 span.attributes = attrs 

707 self._safe_commit(db, "record_token_usage") 

708 

709 # Also record as metrics for aggregation 

710 if input_tokens > 0: 

711 self.record_metric( 

712 db=db, 

713 name="llm.tokens.input", 

714 value=float(input_tokens), 

715 metric_type="counter", 

716 unit="tokens", 

717 trace_id=trace_id, 

718 attributes={"model": model, "provider": provider}, 

719 ) 

720 

721 if output_tokens > 0: 

722 self.record_metric( 

723 db=db, 

724 name="llm.tokens.output", 

725 value=float(output_tokens), 

726 metric_type="counter", 

727 unit="tokens", 

728 trace_id=trace_id, 

729 attributes={"model": model, "provider": provider}, 

730 ) 

731 

732 if estimated_cost_usd: 

733 self.record_metric( 

734 db=db, 

735 name="llm.cost", 

736 value=estimated_cost_usd, 

737 metric_type="counter", 

738 unit="usd", 

739 trace_id=trace_id, 

740 attributes={"model": model, "provider": provider}, 

741 ) 

742 

743 logger.debug(f"Recorded token usage: {input_tokens} in, {output_tokens} out, ${estimated_cost_usd:.6f}") 

744 

745 def _estimate_token_cost(self, model: str, input_tokens: int, output_tokens: int) -> float: 

746 """Estimate cost based on model and token counts. 

747 

748 Pricing as of January 2025 (prices may change). 

749 

750 Args: 

751 model: Model name 

752 input_tokens: Input token count 

753 output_tokens: Output token count 

754 

755 Returns: 

756 Estimated cost in USD 

757 """ 

758 # Pricing per 1M tokens (input, output) 

759 pricing = { 

760 # OpenAI 

761 "gpt-4": (30.0, 60.0), 

762 "gpt-4-turbo": (10.0, 30.0), 

763 "gpt-4o": (2.5, 10.0), 

764 "gpt-4o-mini": (0.15, 0.60), 

765 "gpt-3.5-turbo": (0.50, 1.50), 

766 # Anthropic 

767 "claude-3-opus": (15.0, 75.0), 

768 "claude-3-sonnet": (3.0, 15.0), 

769 "claude-3-haiku": (0.25, 1.25), 

770 "claude-3.5-sonnet": (3.0, 15.0), 

771 "claude-3.5-haiku": (0.80, 4.0), 

772 # Fallback for unknown models 

773 "default": (1.0, 3.0), 

774 } 

775 

776 # Find matching pricing (case-insensitive, partial match) 

777 model_lower = model.lower() 

778 input_price, output_price = pricing.get("default") 

779 

780 for model_key, prices in pricing.items(): 

781 if model_key in model_lower: 

782 input_price, output_price = prices 

783 break 

784 

785 # Calculate cost (pricing is per 1M tokens) 

786 input_cost = (input_tokens / 1_000_000) * input_price 

787 output_cost = (output_tokens / 1_000_000) * output_price 

788 

789 return input_cost + output_cost 

790 

791 # ============================== 

792 # Agent-to-Agent (A2A) Tracing 

793 # ============================== 

794 

795 @contextmanager 

796 def trace_a2a_request( 

797 self, 

798 db: Session, 

799 agent_id: str, 

800 agent_name: Optional[str] = None, 

801 operation: Optional[str] = None, 

802 request_data: Optional[Dict[str, Any]] = None, 

803 ): 

804 """Context manager for tracing Agent-to-Agent requests. 

805 

806 This automatically creates a span for A2A communication, capturing timing, 

807 request/response data, and errors. 

808 

809 Args: 

810 db: Database session 

811 agent_id: Target agent ID 

812 agent_name: Human-readable agent name 

813 operation: Operation being performed (e.g., "query", "execute", "status") 

814 request_data: Request payload (will be sanitized) 

815 

816 Yields: 

817 Tuple of (span_id, result_dict) - update result_dict with A2A results 

818 

819 Raises: 

820 Exception: Re-raises any exception from A2A call after logging 

821 

822 Examples: 

823 >>> with service.trace_a2a_request(db, "agent-123", "WeatherAgent", "query") as (span_id, result): # doctest: +SKIP 

824 ... response = await http_client.post(...) # doctest: +SKIP 

825 ... result["status_code"] = response.status_code # doctest: +SKIP 

826 ... result["response_time_ms"] = 45.2 # doctest: +SKIP 

827 """ 

828 trace_id = current_trace_id.get() 

829 if not trace_id: 

830 # No active trace, yield a no-op 

831 result_dict: Dict[str, Any] = {} 

832 yield (None, result_dict) 

833 return 

834 

835 # Sanitize request data 

836 safe_data = {} 

837 if request_data: 

838 safe_data = {k: ("***REDACTED***" if any(sensitive in k.lower() for sensitive in ["password", "token", "key", "secret", "auth"]) else v) for k, v in request_data.items()} 

839 

840 # Start A2A span 

841 span_id = self.start_span( 

842 db=db, 

843 trace_id=trace_id, 

844 name=f"a2a.call.{agent_name or agent_id}", 

845 kind="client", 

846 resource_type="agent", 

847 resource_name=agent_name or agent_id, 

848 attributes={ 

849 "a2a.agent_id": agent_id, 

850 "a2a.agent_name": agent_name, 

851 "a2a.operation": operation, 

852 "a2a.request_data": safe_data, 

853 }, 

854 ) 

855 

856 result_dict = {} 

857 try: 

858 yield (span_id, result_dict) 

859 

860 # End span with results 

861 self.end_span( 

862 db=db, 

863 span_id=span_id, 

864 status="ok", 

865 attributes={ 

866 "a2a.result": result_dict, 

867 }, 

868 ) 

869 except Exception as e: 

870 # Log error in span 

871 self.end_span(db=db, span_id=span_id, status="error", status_message=str(e)) 

872 

873 self.add_event( 

874 db=db, 

875 span_id=span_id, 

876 name="a2a.error", 

877 severity="error", 

878 message=str(e), 

879 exception_type=type(e).__name__, 

880 exception_message=str(e), 

881 exception_stacktrace=traceback.format_exc(), 

882 ) 

883 raise 

884 

885 # ============================== 

886 # Transport Metrics 

887 # ============================== 

888 

889 def record_transport_activity( 

890 self, 

891 db: Session, 

892 transport_type: str, 

893 operation: str, 

894 message_count: int = 1, 

895 bytes_sent: Optional[int] = None, 

896 bytes_received: Optional[int] = None, 

897 connection_id: Optional[str] = None, 

898 error: Optional[str] = None, 

899 ) -> None: 

900 """Record transport-specific activity metrics. 

901 

902 Args: 

903 db: Database session 

904 transport_type: Transport type (sse, websocket, stdio, http) 

905 operation: Operation type (connect, disconnect, send, receive, error) 

906 message_count: Number of messages processed 

907 bytes_sent: Bytes sent (if applicable) 

908 bytes_received: Bytes received (if applicable) 

909 connection_id: Connection/session identifier 

910 error: Error message if operation failed 

911 

912 Examples: 

913 >>> service.record_transport_activity( # doctest: +SKIP 

914 ... db, transport_type="sse", 

915 ... operation="send", 

916 ... message_count=1, 

917 ... bytes_sent=1024 

918 ... ) 

919 """ 

920 trace_id = current_trace_id.get() 

921 

922 # Record message count 

923 if message_count > 0: 

924 self.record_metric( 

925 db=db, 

926 name=f"transport.{transport_type}.messages", 

927 value=float(message_count), 

928 metric_type="counter", 

929 unit="messages", 

930 trace_id=trace_id, 

931 attributes={ 

932 "transport": transport_type, 

933 "operation": operation, 

934 "connection_id": connection_id, 

935 }, 

936 ) 

937 

938 # Record bytes sent 

939 if bytes_sent: 

940 self.record_metric( 

941 db=db, 

942 name=f"transport.{transport_type}.bytes_sent", 

943 value=float(bytes_sent), 

944 metric_type="counter", 

945 unit="bytes", 

946 trace_id=trace_id, 

947 attributes={ 

948 "transport": transport_type, 

949 "operation": operation, 

950 "connection_id": connection_id, 

951 }, 

952 ) 

953 

954 # Record bytes received 

955 if bytes_received: 

956 self.record_metric( 

957 db=db, 

958 name=f"transport.{transport_type}.bytes_received", 

959 value=float(bytes_received), 

960 metric_type="counter", 

961 unit="bytes", 

962 trace_id=trace_id, 

963 attributes={ 

964 "transport": transport_type, 

965 "operation": operation, 

966 "connection_id": connection_id, 

967 }, 

968 ) 

969 

970 # Record errors 

971 if error: 

972 self.record_metric( 

973 db=db, 

974 name=f"transport.{transport_type}.errors", 

975 value=1.0, 

976 metric_type="counter", 

977 unit="errors", 

978 trace_id=trace_id, 

979 attributes={ 

980 "transport": transport_type, 

981 "operation": operation, 

982 "connection_id": connection_id, 

983 "error": error, 

984 }, 

985 ) 

986 

987 logger.debug(f"Recorded {transport_type} transport activity: {operation} ({message_count} messages)") 

988 

989 # ============================== 

990 # Metric Management 

991 # ============================== 

992 

993 def record_metric( 

994 self, 

995 db: Session, 

996 name: str, 

997 value: float, 

998 metric_type: str = "gauge", 

999 unit: Optional[str] = None, 

1000 resource_type: Optional[str] = None, 

1001 resource_id: Optional[str] = None, 

1002 trace_id: Optional[str] = None, 

1003 attributes: Optional[Dict[str, Any]] = None, 

1004 ) -> int: 

1005 """Record a metric. 

1006 

1007 Args: 

1008 db: Database session 

1009 name: Metric name (e.g., "http.request.duration") 

1010 value: Metric value 

1011 metric_type: Metric type (counter, gauge, histogram) 

1012 unit: Metric unit (ms, count, bytes, etc.) 

1013 resource_type: Resource type 

1014 resource_id: Resource ID 

1015 trace_id: Associated trace ID 

1016 attributes: Additional metric attributes/labels 

1017 

1018 Returns: 

1019 Metric ID 

1020 

1021 Examples: 

1022 >>> metric_id = service.record_metric( # doctest: +SKIP 

1023 ... db, # doctest: +SKIP 

1024 ... "http.request.duration", # doctest: +SKIP 

1025 ... 123.45, # doctest: +SKIP 

1026 ... metric_type="histogram", # doctest: +SKIP 

1027 ... unit="ms", # doctest: +SKIP 

1028 ... trace_id=trace_id # doctest: +SKIP 

1029 ... ) # doctest: +SKIP 

1030 """ 

1031 metric = ObservabilityMetric( 

1032 name=name, 

1033 value=value, 

1034 metric_type=metric_type, 

1035 timestamp=utc_now(), 

1036 unit=unit, 

1037 resource_type=resource_type, 

1038 resource_id=resource_id, 

1039 trace_id=trace_id, 

1040 attributes=attributes or {}, 

1041 created_at=utc_now(), 

1042 ) 

1043 db.add(metric) 

1044 if not self._safe_commit(db, "record_metric"): 

1045 return 0 

1046 db.refresh(metric) 

1047 logger.debug(f"Recorded metric: {name} = {value} {unit or ''}") 

1048 return metric.id 

1049 

1050 # ============================== 

1051 # Query Methods 

1052 # ============================== 

1053 

1054 # pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-locals 

1055 def query_traces( 

1056 self, 

1057 db: Session, 

1058 start_time: Optional[datetime] = None, 

1059 end_time: Optional[datetime] = None, 

1060 min_duration_ms: Optional[float] = None, 

1061 max_duration_ms: Optional[float] = None, 

1062 status: Optional[str] = None, 

1063 status_in: Optional[List[str]] = None, 

1064 status_not_in: Optional[List[str]] = None, 

1065 http_status_code: Optional[int] = None, 

1066 http_status_code_in: Optional[List[int]] = None, 

1067 http_method: Optional[str] = None, 

1068 http_method_in: Optional[List[str]] = None, 

1069 user_email: Optional[str] = None, 

1070 user_email_in: Optional[List[str]] = None, 

1071 attribute_filters: Optional[Dict[str, Any]] = None, 

1072 attribute_filters_or: Optional[Dict[str, Any]] = None, 

1073 attribute_search: Optional[str] = None, 

1074 name_contains: Optional[str] = None, 

1075 order_by: str = "start_time_desc", 

1076 limit: int = 100, 

1077 offset: int = 0, 

1078 ) -> List[ObservabilityTrace]: 

1079 """Query traces with advanced filters. 

1080 

1081 Supports both simple filters (single value) and list filters (multiple values with OR logic). 

1082 All top-level filters are combined with AND logic unless using _or suffix. 

1083 

1084 Args: 

1085 db: Database session 

1086 start_time: Filter traces after this time 

1087 end_time: Filter traces before this time 

1088 min_duration_ms: Filter traces with duration >= this value (milliseconds) 

1089 max_duration_ms: Filter traces with duration <= this value (milliseconds) 

1090 status: Filter by single status (ok, error) 

1091 status_in: Filter by multiple statuses (OR logic) 

1092 status_not_in: Exclude these statuses (NOT logic) 

1093 http_status_code: Filter by single HTTP status code 

1094 http_status_code_in: Filter by multiple HTTP status codes (OR logic) 

1095 http_method: Filter by single HTTP method (GET, POST, etc.) 

1096 http_method_in: Filter by multiple HTTP methods (OR logic) 

1097 user_email: Filter by single user email 

1098 user_email_in: Filter by multiple user emails (OR logic) 

1099 attribute_filters: JSON attribute filters (AND logic - all must match) 

1100 attribute_filters_or: JSON attribute filters (OR logic - any must match) 

1101 attribute_search: Free-text search within JSON attributes (partial match) 

1102 name_contains: Filter traces where name contains this substring 

1103 order_by: Sort order (start_time_desc, start_time_asc, duration_desc, duration_asc) 

1104 limit: Maximum results (1-1000) 

1105 offset: Result offset 

1106 

1107 Returns: 

1108 List of traces 

1109 

1110 Raises: 

1111 ValueError: If invalid parameters are provided 

1112 

1113 Examples: 

1114 >>> # Find slow errors from multiple endpoints 

1115 >>> traces = service.query_traces( # doctest: +SKIP 

1116 ... db, 

1117 ... status="error", 

1118 ... min_duration_ms=100.0, 

1119 ... http_method_in=["POST", "PUT"], 

1120 ... attribute_filters={"http.route": "/api/tools"}, 

1121 ... limit=50 

1122 ... ) 

1123 >>> # Exclude health checks and find slow requests 

1124 >>> traces = service.query_traces( # doctest: +SKIP 

1125 ... db, 

1126 ... min_duration_ms=1000.0, 

1127 ... name_contains="api", 

1128 ... status_not_in=["ok"], 

1129 ... order_by="duration_desc" 

1130 ... ) 

1131 """ 

1132 # Third-Party 

1133 # pylint: disable=import-outside-toplevel 

1134 from sqlalchemy import cast, or_, String 

1135 

1136 # pylint: enable=import-outside-toplevel 

1137 # Validate limit 

1138 if limit < 1 or limit > 1000: 

1139 raise ValueError("limit must be between 1 and 1000") 

1140 

1141 # Validate order_by 

1142 valid_orders = ["start_time_desc", "start_time_asc", "duration_desc", "duration_asc"] 

1143 if order_by not in valid_orders: 

1144 raise ValueError(f"order_by must be one of: {', '.join(valid_orders)}") 

1145 

1146 query = db.query(ObservabilityTrace) 

1147 

1148 # Time range filters 

1149 if start_time: 

1150 query = query.filter(ObservabilityTrace.start_time >= start_time) 

1151 if end_time: 

1152 query = query.filter(ObservabilityTrace.start_time <= end_time) 

1153 

1154 # Duration filters 

1155 if min_duration_ms is not None: 

1156 query = query.filter(ObservabilityTrace.duration_ms >= min_duration_ms) 

1157 if max_duration_ms is not None: 

1158 query = query.filter(ObservabilityTrace.duration_ms <= max_duration_ms) 

1159 

1160 # Status filters (with OR and NOT support) 

1161 if status: 

1162 query = query.filter(ObservabilityTrace.status == status) 

1163 if status_in: 

1164 query = query.filter(ObservabilityTrace.status.in_(status_in)) 

1165 if status_not_in: 

1166 query = query.filter(~ObservabilityTrace.status.in_(status_not_in)) 

1167 

1168 # HTTP status code filters (with OR support) 

1169 if http_status_code: 

1170 query = query.filter(ObservabilityTrace.http_status_code == http_status_code) 

1171 if http_status_code_in: 

1172 query = query.filter(ObservabilityTrace.http_status_code.in_(http_status_code_in)) 

1173 

1174 # HTTP method filters (with OR support) 

1175 if http_method: 

1176 query = query.filter(ObservabilityTrace.http_method == http_method) 

1177 if http_method_in: 

1178 query = query.filter(ObservabilityTrace.http_method.in_(http_method_in)) 

1179 

1180 # User email filters (with OR support) 

1181 if user_email: 

1182 query = query.filter(ObservabilityTrace.user_email == user_email) 

1183 if user_email_in: 

1184 query = query.filter(ObservabilityTrace.user_email.in_(user_email_in)) 

1185 

1186 # Name substring filter 

1187 if name_contains: 

1188 query = query.filter(ObservabilityTrace.name.ilike(f"%{name_contains}%")) 

1189 

1190 # Attribute-based filtering with AND logic (all filters must match) 

1191 if attribute_filters: 

1192 for key, value in attribute_filters.items(): 

1193 # Use JSON path access for filtering 

1194 # Supports both SQLite (via json_extract) and PostgreSQL (via ->>) 

1195 query = query.filter(ObservabilityTrace.attributes[key].astext == str(value)) 

1196 

1197 # Attribute-based filtering with OR logic (any filter must match) 

1198 if attribute_filters_or: 

1199 or_conditions = [] 

1200 for key, value in attribute_filters_or.items(): 

1201 or_conditions.append(ObservabilityTrace.attributes[key].astext == str(value)) 

1202 if or_conditions: 

1203 query = query.filter(or_(*or_conditions)) 

1204 

1205 # Free-text search across all attribute values 

1206 if attribute_search: 

1207 # Cast JSON attributes to text and search for substring 

1208 # Works with both SQLite and PostgreSQL 

1209 # Escape special characters to prevent SQL injection 

1210 safe_search = attribute_search.replace("%", "\\%").replace("_", "\\_") 

1211 query = query.filter(cast(ObservabilityTrace.attributes, String).ilike(f"%{safe_search}%")) 

1212 

1213 # Apply ordering 

1214 if order_by == "start_time_desc": 

1215 query = query.order_by(desc(ObservabilityTrace.start_time)) 

1216 elif order_by == "start_time_asc": 

1217 query = query.order_by(ObservabilityTrace.start_time) 

1218 elif order_by == "duration_desc": 

1219 query = query.order_by(desc(ObservabilityTrace.duration_ms)) 

1220 else: # duration_asc (validated above) 

1221 query = query.order_by(ObservabilityTrace.duration_ms) 

1222 

1223 # Apply pagination 

1224 query = query.limit(limit).offset(offset) 

1225 

1226 return query.all() 

1227 

1228 # pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-locals 

1229 def query_spans( 

1230 self, 

1231 db: Session, 

1232 trace_id: Optional[str] = None, 

1233 trace_id_in: Optional[List[str]] = None, 

1234 resource_type: Optional[str] = None, 

1235 resource_type_in: Optional[List[str]] = None, 

1236 resource_name: Optional[str] = None, 

1237 resource_name_in: Optional[List[str]] = None, 

1238 name_contains: Optional[str] = None, 

1239 kind: Optional[str] = None, 

1240 kind_in: Optional[List[str]] = None, 

1241 status: Optional[str] = None, 

1242 status_in: Optional[List[str]] = None, 

1243 status_not_in: Optional[List[str]] = None, 

1244 start_time: Optional[datetime] = None, 

1245 end_time: Optional[datetime] = None, 

1246 min_duration_ms: Optional[float] = None, 

1247 max_duration_ms: Optional[float] = None, 

1248 attribute_filters: Optional[Dict[str, Any]] = None, 

1249 attribute_search: Optional[str] = None, 

1250 order_by: str = "start_time_desc", 

1251 limit: int = 100, 

1252 offset: int = 0, 

1253 ) -> List[ObservabilitySpan]: 

1254 """Query spans with advanced filters. 

1255 

1256 Supports filtering by trace, resource, kind, status, duration, and attributes. 

1257 All top-level filters are combined with AND logic. List filters use OR logic. 

1258 

1259 Args: 

1260 db: Database session 

1261 trace_id: Filter by single trace ID 

1262 trace_id_in: Filter by multiple trace IDs (OR logic) 

1263 resource_type: Filter by single resource type (tool, database, plugin, etc.) 

1264 resource_type_in: Filter by multiple resource types (OR logic) 

1265 resource_name: Filter by single resource name 

1266 resource_name_in: Filter by multiple resource names (OR logic) 

1267 name_contains: Filter spans where name contains this substring 

1268 kind: Filter by span kind (client, server, internal) 

1269 kind_in: Filter by multiple kinds (OR logic) 

1270 status: Filter by single status (ok, error) 

1271 status_in: Filter by multiple statuses (OR logic) 

1272 status_not_in: Exclude these statuses (NOT logic) 

1273 start_time: Filter spans after this time 

1274 end_time: Filter spans before this time 

1275 min_duration_ms: Filter spans with duration >= this value (milliseconds) 

1276 max_duration_ms: Filter spans with duration <= this value (milliseconds) 

1277 attribute_filters: JSON attribute filters (AND logic) 

1278 attribute_search: Free-text search within JSON attributes 

1279 order_by: Sort order (start_time_desc, start_time_asc, duration_desc, duration_asc) 

1280 limit: Maximum results (1-1000) 

1281 offset: Result offset 

1282 

1283 Returns: 

1284 List of spans 

1285 

1286 Raises: 

1287 ValueError: If invalid parameters are provided 

1288 

1289 Examples: 

1290 >>> # Find slow database queries 

1291 >>> spans = service.query_spans( # doctest: +SKIP 

1292 ... db, 

1293 ... resource_type="database", 

1294 ... min_duration_ms=100.0, 

1295 ... order_by="duration_desc", 

1296 ... limit=50 

1297 ... ) 

1298 >>> # Find tool invocation errors 

1299 >>> spans = service.query_spans( # doctest: +SKIP 

1300 ... db, 

1301 ... resource_type="tool", 

1302 ... status="error", 

1303 ... name_contains="invoke" 

1304 ... ) 

1305 """ 

1306 # Third-Party 

1307 # pylint: disable=import-outside-toplevel 

1308 from sqlalchemy import cast, String 

1309 

1310 # pylint: enable=import-outside-toplevel 

1311 # Validate limit 

1312 if limit < 1 or limit > 1000: 

1313 raise ValueError("limit must be between 1 and 1000") 

1314 

1315 # Validate order_by 

1316 valid_orders = ["start_time_desc", "start_time_asc", "duration_desc", "duration_asc"] 

1317 if order_by not in valid_orders: 

1318 raise ValueError(f"order_by must be one of: {', '.join(valid_orders)}") 

1319 

1320 query = db.query(ObservabilitySpan) 

1321 

1322 # Trace ID filters (with OR support) 

1323 if trace_id: 

1324 query = query.filter(ObservabilitySpan.trace_id == trace_id) 

1325 if trace_id_in: 

1326 query = query.filter(ObservabilitySpan.trace_id.in_(trace_id_in)) 

1327 

1328 # Resource type filters (with OR support) 

1329 if resource_type: 

1330 query = query.filter(ObservabilitySpan.resource_type == resource_type) 

1331 if resource_type_in: 

1332 query = query.filter(ObservabilitySpan.resource_type.in_(resource_type_in)) 

1333 

1334 # Resource name filters (with OR support) 

1335 if resource_name: 

1336 query = query.filter(ObservabilitySpan.resource_name == resource_name) 

1337 if resource_name_in: 

1338 query = query.filter(ObservabilitySpan.resource_name.in_(resource_name_in)) 

1339 

1340 # Name substring filter 

1341 if name_contains: 

1342 query = query.filter(ObservabilitySpan.name.ilike(f"%{name_contains}%")) 

1343 

1344 # Kind filters (with OR support) 

1345 if kind: 

1346 query = query.filter(ObservabilitySpan.kind == kind) 

1347 if kind_in: 

1348 query = query.filter(ObservabilitySpan.kind.in_(kind_in)) 

1349 

1350 # Status filters (with OR and NOT support) 

1351 if status: 

1352 query = query.filter(ObservabilitySpan.status == status) 

1353 if status_in: 

1354 query = query.filter(ObservabilitySpan.status.in_(status_in)) 

1355 if status_not_in: 

1356 query = query.filter(~ObservabilitySpan.status.in_(status_not_in)) 

1357 

1358 # Time range filters 

1359 if start_time: 

1360 query = query.filter(ObservabilitySpan.start_time >= start_time) 

1361 if end_time: 

1362 query = query.filter(ObservabilitySpan.start_time <= end_time) 

1363 

1364 # Duration filters 

1365 if min_duration_ms is not None: 

1366 query = query.filter(ObservabilitySpan.duration_ms >= min_duration_ms) 

1367 if max_duration_ms is not None: 

1368 query = query.filter(ObservabilitySpan.duration_ms <= max_duration_ms) 

1369 

1370 # Attribute-based filtering with AND logic 

1371 if attribute_filters: 

1372 for key, value in attribute_filters.items(): 

1373 query = query.filter(ObservabilitySpan.attributes[key].astext == str(value)) 

1374 

1375 # Free-text search across all attribute values 

1376 if attribute_search: 

1377 safe_search = attribute_search.replace("%", "\\%").replace("_", "\\_") 

1378 query = query.filter(cast(ObservabilitySpan.attributes, String).ilike(f"%{safe_search}%")) 

1379 

1380 # Apply ordering 

1381 if order_by == "start_time_desc": 

1382 query = query.order_by(desc(ObservabilitySpan.start_time)) 

1383 elif order_by == "start_time_asc": 

1384 query = query.order_by(ObservabilitySpan.start_time) 

1385 elif order_by == "duration_desc": 

1386 query = query.order_by(desc(ObservabilitySpan.duration_ms)) 

1387 else: # duration_asc (validated above) 

1388 query = query.order_by(ObservabilitySpan.duration_ms) 

1389 

1390 # Apply pagination 

1391 query = query.limit(limit).offset(offset) 

1392 

1393 return query.all() 

1394 

1395 def get_trace_with_spans(self, db: Session, trace_id: str) -> Optional[ObservabilityTrace]: 

1396 """Get a complete trace with all spans and events. 

1397 

1398 Args: 

1399 db: Database session 

1400 trace_id: Trace ID 

1401 

1402 Returns: 

1403 Trace with spans and events loaded 

1404 

1405 Examples: 

1406 >>> trace = service.get_trace_with_spans(db, trace_id) # doctest: +SKIP 

1407 >>> if trace: # doctest: +SKIP 

1408 ... for span in trace.spans: # doctest: +SKIP 

1409 ... print(f"Span: {span.name}, Events: {len(span.events)}") # doctest: +SKIP 

1410 """ 

1411 return db.query(ObservabilityTrace).filter_by(trace_id=trace_id).options(joinedload(ObservabilityTrace.spans).joinedload(ObservabilitySpan.events)).first() 

1412 

1413 def delete_old_traces(self, db: Session, before_time: datetime) -> int: 

1414 """Delete traces older than a given time. 

1415 

1416 Args: 

1417 db: Database session 

1418 before_time: Delete traces before this time 

1419 

1420 Returns: 

1421 Number of traces deleted 

1422 

1423 Examples: 

1424 >>> from datetime import timedelta # doctest: +SKIP 

1425 >>> cutoff = utc_now() - timedelta(days=30) # doctest: +SKIP 

1426 >>> deleted = service.delete_old_traces(db, cutoff) # doctest: +SKIP 

1427 >>> print(f"Deleted {deleted} old traces") # doctest: +SKIP 

1428 """ 

1429 deleted = db.query(ObservabilityTrace).filter(ObservabilityTrace.start_time < before_time).delete() 

1430 if not self._safe_commit(db, "delete_old_traces"): 

1431 return 0 

1432 logger.info(f"Deleted {deleted} traces older than {before_time}") 

1433 return deleted