Coverage for mcpgateway / services / observability_service.py: 100%
338 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/observability_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Observability Service Implementation.
8This module provides OpenTelemetry-style observability for MCP Gateway,
9capturing traces, spans, events, and metrics for all operations.
11It includes:
12- Trace creation and management
13- Span tracking with hierarchical nesting
14- Event logging within spans
15- Metrics collection and storage
16- Query and filtering capabilities
17- Integration with FastAPI middleware
19Examples:
20 >>> from mcpgateway.services.observability_service import ObservabilityService # doctest: +SKIP
21 >>> service = ObservabilityService() # doctest: +SKIP
22 >>> trace_id = service.start_trace(db, "GET /tools", http_method="GET", http_url="/tools") # doctest: +SKIP
23 >>> span_id = service.start_span(db, trace_id, "database_query", resource_type="database") # doctest: +SKIP
24 >>> service.end_span(db, span_id, status="ok") # doctest: +SKIP
25 >>> service.end_trace(db, trace_id, status="ok", http_status_code=200) # doctest: +SKIP
26"""
28# Standard
29from contextlib import contextmanager
30from contextvars import ContextVar
31from datetime import datetime, timezone
32import logging
33import re
34import traceback
35from typing import Any, Dict, List, Optional, Pattern, Tuple
36import uuid
38# Third-Party
39from sqlalchemy import desc
40from sqlalchemy.exc import SQLAlchemyError
41from sqlalchemy.orm import joinedload, Session
43# First-Party
44from mcpgateway.db import ObservabilityEvent, ObservabilityMetric, ObservabilitySpan, ObservabilityTrace
46logger = logging.getLogger(__name__)
48# Precompiled regex for W3C Trace Context traceparent header parsing
49# Format: version-trace_id-parent_id-trace_flags
50_TRACEPARENT_RE: Pattern[str] = re.compile(r"^([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})$")
52# Context variable for tracking the current trace_id across async calls
53current_trace_id: ContextVar[Optional[str]] = ContextVar("current_trace_id", default=None)
56def utc_now() -> datetime:
57 """Return current UTC time with timezone.
59 Returns:
60 datetime: Current time in UTC with timezone info
61 """
62 return datetime.now(timezone.utc)
65def ensure_timezone_aware(dt: datetime) -> datetime:
66 """Ensure datetime is timezone-aware (UTC).
68 SQLite returns naive datetimes even when stored with timezone info.
69 This helper ensures consistency for datetime arithmetic.
71 Args:
72 dt: Datetime that may be naive or aware
74 Returns:
75 Timezone-aware datetime in UTC
76 """
77 if dt.tzinfo is None:
78 return dt.replace(tzinfo=timezone.utc)
79 return dt
82def parse_traceparent(traceparent: str) -> Optional[Tuple[str, str, str]]:
83 """Parse W3C Trace Context traceparent header.
85 Format: version-trace_id-parent_id-trace_flags
86 Example: 00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01
88 Args:
89 traceparent: W3C traceparent header value
91 Returns:
92 Tuple of (trace_id, parent_span_id, trace_flags) or None if invalid
94 Examples:
95 >>> parse_traceparent("00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01") # doctest: +SKIP
96 ('0af7651916cd43dd8448eb211c80319c', 'b7ad6b7169203331', '01')
97 """
98 # W3C Trace Context format: 00-trace_id(32hex)-parent_id(16hex)-flags(2hex)
99 # Uses precompiled regex for performance
100 match = _TRACEPARENT_RE.match(traceparent.lower())
102 if not match:
103 logger.warning(f"Invalid traceparent format: {traceparent}")
104 return None
106 version, trace_id, parent_id, flags = match.groups()
108 # Only support version 00 for now
109 if version != "00":
110 logger.warning(f"Unsupported traceparent version: {version}")
111 return None
113 # Validate trace_id and parent_id are not all zeros
114 if trace_id == "0" * 32 or parent_id == "0" * 16:
115 logger.warning("Invalid traceparent with zero trace_id or parent_id")
116 return None
118 return (trace_id, parent_id, flags)
121def generate_w3c_trace_id() -> str:
122 """Generate a W3C compliant trace ID (32 hex characters).
124 Returns:
125 32-character lowercase hex string
127 Examples:
128 >>> trace_id = generate_w3c_trace_id() # doctest: +SKIP
129 >>> len(trace_id) # doctest: +SKIP
130 32
131 """
132 return uuid.uuid4().hex + uuid.uuid4().hex[:16]
135def generate_w3c_span_id() -> str:
136 """Generate a W3C compliant span ID (16 hex characters).
138 Returns:
139 16-character lowercase hex string
141 Examples:
142 >>> span_id = generate_w3c_span_id() # doctest: +SKIP
143 >>> len(span_id) # doctest: +SKIP
144 16
145 """
146 return uuid.uuid4().hex[:16]
149def format_traceparent(trace_id: str, span_id: str, sampled: bool = True) -> str:
150 """Format a W3C traceparent header value.
152 Args:
153 trace_id: 32-character hex trace ID
154 span_id: 16-character hex span ID
155 sampled: Whether the trace is sampled (affects trace-flags)
157 Returns:
158 W3C traceparent header value
160 Examples:
161 >>> format_traceparent("0af7651916cd43dd8448eb211c80319c", "b7ad6b7169203331") # doctest: +SKIP
162 '00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01'
163 """
164 flags = "01" if sampled else "00"
165 return f"00-{trace_id}-{span_id}-{flags}"
168class ObservabilityService:
169 """Service for managing observability traces, spans, events, and metrics.
171 This service provides comprehensive observability capabilities similar to
172 OpenTelemetry, allowing tracking of request flows through the system.
174 Examples:
175 >>> service = ObservabilityService() # doctest: +SKIP
176 >>> trace_id = service.start_trace(db, "POST /tools/invoke") # doctest: +SKIP
177 >>> span_id = service.start_span(db, trace_id, "tool_execution") # doctest: +SKIP
178 >>> service.end_span(db, span_id, status="ok") # doctest: +SKIP
179 >>> service.end_trace(db, trace_id, status="ok") # doctest: +SKIP
180 """
182 def _safe_commit(self, db: Session, context: str) -> bool:
183 """Commit and rollback on failure without raising.
185 Args:
186 db: SQLAlchemy session for the current operation.
187 context: Short label for the commit context (used in logs).
189 Returns:
190 True when commit succeeds, False when a rollback was performed.
191 """
192 try:
193 db.commit()
194 return True
195 except SQLAlchemyError as exc:
196 logger.warning(f"Observability commit failed ({context}): {exc}")
197 try:
198 db.rollback()
199 except SQLAlchemyError as rollback_exc:
200 logger.debug(f"Observability rollback failed ({context}): {rollback_exc}")
201 return False
203 # ==============================
204 # Trace Management
205 # ==============================
207 def start_trace(
208 self,
209 db: Session,
210 name: str,
211 trace_id: Optional[str] = None,
212 parent_span_id: Optional[str] = None,
213 http_method: Optional[str] = None,
214 http_url: Optional[str] = None,
215 user_email: Optional[str] = None,
216 user_agent: Optional[str] = None,
217 ip_address: Optional[str] = None,
218 attributes: Optional[Dict[str, Any]] = None,
219 resource_attributes: Optional[Dict[str, Any]] = None,
220 ) -> str:
221 """Start a new trace.
223 Args:
224 db: Database session
225 name: Trace name (e.g., "POST /tools/invoke")
226 trace_id: External trace ID (for distributed tracing, W3C format)
227 parent_span_id: Parent span ID from upstream service
228 http_method: HTTP method (GET, POST, etc.)
229 http_url: Full request URL
230 user_email: Authenticated user email
231 user_agent: Client user agent string
232 ip_address: Client IP address
233 attributes: Additional trace attributes
234 resource_attributes: Resource attributes (service name, version, etc.)
236 Returns:
237 Trace ID (UUID string or W3C format)
239 Examples:
240 >>> trace_id = service.start_trace( # doctest: +SKIP
241 ... db,
242 ... "POST /tools/invoke",
243 ... http_method="POST",
244 ... http_url="https://api.example.com/tools/invoke",
245 ... user_email="user@example.com"
246 ... )
247 """
248 # Use provided trace_id or generate new UUID
249 if not trace_id:
250 trace_id = str(uuid.uuid4())
252 # Add parent context to attributes if provided
253 attrs = attributes or {}
254 if parent_span_id:
255 attrs["parent_span_id"] = parent_span_id
257 trace = ObservabilityTrace(
258 trace_id=trace_id,
259 name=name,
260 start_time=utc_now(),
261 status="unset",
262 http_method=http_method,
263 http_url=http_url,
264 user_email=user_email,
265 user_agent=user_agent,
266 ip_address=ip_address,
267 attributes=attrs,
268 resource_attributes=resource_attributes or {},
269 created_at=utc_now(),
270 )
271 db.add(trace)
272 self._safe_commit(db, "start_trace")
273 logger.debug(f"Started trace {trace_id}: {name}")
274 return trace_id
276 def end_trace(
277 self,
278 db: Session,
279 trace_id: str,
280 status: str = "ok",
281 status_message: Optional[str] = None,
282 http_status_code: Optional[int] = None,
283 attributes: Optional[Dict[str, Any]] = None,
284 ) -> None:
285 """End a trace.
287 Args:
288 db: Database session
289 trace_id: Trace ID to end
290 status: Trace status (ok, error)
291 status_message: Optional status message
292 http_status_code: HTTP response status code
293 attributes: Additional attributes to merge
295 Examples:
296 >>> service.end_trace( # doctest: +SKIP
297 ... db,
298 ... trace_id,
299 ... status="ok",
300 ... http_status_code=200
301 ... )
302 """
303 trace = db.query(ObservabilityTrace).filter_by(trace_id=trace_id).first()
304 if not trace:
305 logger.warning(f"Trace {trace_id} not found")
306 return
308 end_time = utc_now()
309 duration_ms = (end_time - ensure_timezone_aware(trace.start_time)).total_seconds() * 1000
311 trace.end_time = end_time
312 trace.duration_ms = duration_ms
313 trace.status = status
314 trace.status_message = status_message
315 if http_status_code is not None:
316 trace.http_status_code = http_status_code
317 if attributes:
318 trace.attributes = {**(trace.attributes or {}), **attributes}
320 self._safe_commit(db, "end_trace")
321 logger.debug(f"Ended trace {trace_id}: {status} ({duration_ms:.2f}ms)")
323 def get_trace(self, db: Session, trace_id: str, include_spans: bool = False) -> Optional[ObservabilityTrace]:
324 """Get a trace by ID.
326 Args:
327 db: Database session
328 trace_id: Trace ID
329 include_spans: Whether to load spans eagerly
331 Returns:
332 Trace object or None if not found
334 Examples:
335 >>> trace = service.get_trace(db, trace_id, include_spans=True) # doctest: +SKIP
336 >>> if trace: # doctest: +SKIP
337 ... print(f"Trace: {trace.name}, Spans: {len(trace.spans)}") # doctest: +SKIP
338 """
339 query = db.query(ObservabilityTrace).filter_by(trace_id=trace_id)
340 if include_spans:
341 query = query.options(joinedload(ObservabilityTrace.spans))
342 return query.first()
344 # ==============================
345 # Span Management
346 # ==============================
348 def start_span(
349 self,
350 db: Session,
351 trace_id: str,
352 name: str,
353 parent_span_id: Optional[str] = None,
354 kind: str = "internal",
355 resource_name: Optional[str] = None,
356 resource_type: Optional[str] = None,
357 resource_id: Optional[str] = None,
358 attributes: Optional[Dict[str, Any]] = None,
359 commit: bool = True,
360 ) -> str:
361 """Start a new span within a trace.
363 Args:
364 db: Database session
365 trace_id: Parent trace ID
366 name: Span name (e.g., "database_query", "tool_invocation")
367 parent_span_id: Parent span ID (for nested spans)
368 kind: Span kind (internal, server, client, producer, consumer)
369 resource_name: Resource name being operated on
370 resource_type: Resource type (tool, resource, prompt, etc.)
371 resource_id: Resource ID
372 attributes: Additional span attributes
373 commit: Whether to commit the transaction (default True).
374 Set to False when using fresh_db_session() which handles commits.
376 Returns:
377 Span ID (UUID string)
379 Examples:
380 >>> span_id = service.start_span( # doctest: +SKIP
381 ... db,
382 ... trace_id,
383 ... "tool_invocation",
384 ... resource_type="tool",
385 ... resource_name="get_weather"
386 ... )
387 """
388 span_id = str(uuid.uuid4())
389 span = ObservabilitySpan(
390 span_id=span_id,
391 trace_id=trace_id,
392 parent_span_id=parent_span_id,
393 name=name,
394 kind=kind,
395 start_time=utc_now(),
396 status="unset",
397 resource_name=resource_name,
398 resource_type=resource_type,
399 resource_id=resource_id,
400 attributes=attributes or {},
401 created_at=utc_now(),
402 )
403 db.add(span)
404 if commit:
405 self._safe_commit(db, "start_span")
406 logger.debug(f"Started span {span_id}: {name} (trace={trace_id})")
407 return span_id
409 def end_span(
410 self,
411 db: Session,
412 span_id: str,
413 status: str = "ok",
414 status_message: Optional[str] = None,
415 attributes: Optional[Dict[str, Any]] = None,
416 commit: bool = True,
417 ) -> None:
418 """End a span.
420 Args:
421 db: Database session
422 span_id: Span ID to end
423 status: Span status (ok, error)
424 status_message: Optional status message
425 attributes: Additional attributes to merge
426 commit: Whether to commit the transaction (default True).
427 Set to False when using fresh_db_session() which handles commits.
429 Examples:
430 >>> service.end_span(db, span_id, status="ok") # doctest: +SKIP
431 """
432 span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first()
433 if not span:
434 logger.warning(f"Span {span_id} not found")
435 return
437 end_time = utc_now()
438 duration_ms = (end_time - ensure_timezone_aware(span.start_time)).total_seconds() * 1000
440 span.end_time = end_time
441 span.duration_ms = duration_ms
442 span.status = status
443 span.status_message = status_message
444 if attributes:
445 span.attributes = {**(span.attributes or {}), **attributes}
447 if commit:
448 self._safe_commit(db, "end_span")
449 logger.debug(f"Ended span {span_id}: {status} ({duration_ms:.2f}ms)")
451 @contextmanager
452 def trace_span(
453 self,
454 db: Session,
455 trace_id: str,
456 name: str,
457 parent_span_id: Optional[str] = None,
458 resource_type: Optional[str] = None,
459 resource_name: Optional[str] = None,
460 attributes: Optional[Dict[str, Any]] = None,
461 ):
462 """Context manager for automatic span lifecycle management.
464 Args:
465 db: Database session
466 trace_id: Parent trace ID
467 name: Span name
468 parent_span_id: Parent span ID (optional)
469 resource_type: Resource type
470 resource_name: Resource name
471 attributes: Additional attributes
473 Yields:
474 Span ID
476 Raises:
477 Exception: Re-raises any exception after logging it in the span
479 Examples:
480 >>> with service.trace_span(db, trace_id, "database_query") as span_id: # doctest: +SKIP
481 ... results = db.query(Tool).all() # doctest: +SKIP
482 """
483 span_id = self.start_span(db, trace_id, name, parent_span_id, resource_type=resource_type, resource_name=resource_name, attributes=attributes)
484 try:
485 yield span_id
486 self.end_span(db, span_id, status="ok")
487 except Exception as e:
488 self.end_span(db, span_id, status="error", status_message=str(e))
489 self.add_event(db, span_id, "exception", severity="error", message=str(e), exception_type=type(e).__name__, exception_message=str(e), exception_stacktrace=traceback.format_exc())
490 raise
492 @contextmanager
493 def trace_tool_invocation(
494 self,
495 db: Session,
496 tool_name: str,
497 arguments: Dict[str, Any],
498 integration_type: Optional[str] = None,
499 ):
500 """Context manager for tracing MCP tool invocations.
502 This automatically creates a span for tool execution, capturing timing,
503 arguments, results, and errors.
505 Args:
506 db: Database session
507 tool_name: Name of the tool being invoked
508 arguments: Tool arguments (will be sanitized)
509 integration_type: Integration type (MCP, REST, A2A, etc.)
511 Yields:
512 Tuple of (span_id, result_dict) - update result_dict with tool results
514 Raises:
515 Exception: Re-raises any exception from tool invocation after logging
517 Examples:
518 >>> with service.trace_tool_invocation(db, "weather", {"city": "NYC"}) as (span_id, result): # doctest: +SKIP
519 ... response = await http_client.post(...) # doctest: +SKIP
520 ... result["status_code"] = response.status_code # doctest: +SKIP
521 ... result["response_size"] = len(response.content) # doctest: +SKIP
522 """
523 trace_id = current_trace_id.get()
524 if not trace_id:
525 # No active trace, yield a no-op
526 result_dict: Dict[str, Any] = {}
527 yield (None, result_dict)
528 return
530 # Sanitize arguments (remove sensitive data)
531 safe_args = {k: ("***REDACTED***" if any(sensitive in k.lower() for sensitive in ["password", "token", "key", "secret"]) else v) for k, v in arguments.items()}
533 # Start tool invocation span
534 span_id = self.start_span(
535 db=db,
536 trace_id=trace_id,
537 name=f"tool.invoke.{tool_name}",
538 kind="client",
539 resource_type="tool",
540 resource_name=tool_name,
541 attributes={
542 "tool.name": tool_name,
543 "tool.integration_type": integration_type,
544 "tool.argument_count": len(arguments),
545 "tool.arguments": safe_args,
546 },
547 )
549 result_dict = {}
550 try:
551 yield (span_id, result_dict)
553 # End span with results
554 self.end_span(
555 db=db,
556 span_id=span_id,
557 status="ok",
558 attributes={
559 "tool.result": result_dict,
560 },
561 )
562 except Exception as e:
563 # Log error in span
564 self.end_span(db=db, span_id=span_id, status="error", status_message=str(e))
566 self.add_event(
567 db=db,
568 span_id=span_id,
569 name="tool.error",
570 severity="error",
571 message=str(e),
572 exception_type=type(e).__name__,
573 exception_message=str(e),
574 exception_stacktrace=traceback.format_exc(),
575 )
576 raise
578 # ==============================
579 # Event Management
580 # ==============================
582 def add_event(
583 self,
584 db: Session,
585 span_id: str,
586 name: str,
587 severity: Optional[str] = None,
588 message: Optional[str] = None,
589 exception_type: Optional[str] = None,
590 exception_message: Optional[str] = None,
591 exception_stacktrace: Optional[str] = None,
592 attributes: Optional[Dict[str, Any]] = None,
593 ) -> int:
594 """Add an event to a span.
596 Args:
597 db: Database session
598 span_id: Parent span ID
599 name: Event name
600 severity: Log severity (debug, info, warning, error, critical)
601 message: Event message
602 exception_type: Exception class name
603 exception_message: Exception message
604 exception_stacktrace: Exception stacktrace
605 attributes: Additional event attributes
607 Returns:
608 Event ID
610 Examples:
611 >>> event_id = service.add_event( # doctest: +SKIP
612 ... db, # doctest: +SKIP
613 ... span_id, # doctest: +SKIP
614 ... "database_connection_error", # doctest: +SKIP
615 ... severity="error", # doctest: +SKIP
616 ... message="Failed to connect to database" # doctest: +SKIP
617 ... ) # doctest: +SKIP
618 """
619 event = ObservabilityEvent(
620 span_id=span_id,
621 name=name,
622 timestamp=utc_now(),
623 severity=severity,
624 message=message,
625 exception_type=exception_type,
626 exception_message=exception_message,
627 exception_stacktrace=exception_stacktrace,
628 attributes=attributes or {},
629 created_at=utc_now(),
630 )
631 db.add(event)
632 if not self._safe_commit(db, "add_event"):
633 return 0
634 db.refresh(event)
635 logger.debug(f"Added event to span {span_id}: {name}")
636 return event.id
638 # ==============================
639 # Token Usage Tracking
640 # ==============================
642 def record_token_usage(
643 self,
644 db: Session,
645 span_id: Optional[str] = None,
646 trace_id: Optional[str] = None,
647 model: Optional[str] = None,
648 input_tokens: int = 0,
649 output_tokens: int = 0,
650 total_tokens: Optional[int] = None,
651 estimated_cost_usd: Optional[float] = None,
652 provider: Optional[str] = None,
653 ) -> None:
654 """Record token usage for LLM calls.
656 Args:
657 db: Database session
658 span_id: Span ID to attach token usage to
659 trace_id: Trace ID (will use current context if not provided)
660 model: Model name (e.g., "gpt-4", "claude-3-opus")
661 input_tokens: Number of input/prompt tokens
662 output_tokens: Number of output/completion tokens
663 total_tokens: Total tokens (calculated if not provided)
664 estimated_cost_usd: Estimated cost in USD
665 provider: LLM provider (openai, anthropic, etc.)
667 Examples:
668 >>> service.record_token_usage( # doctest: +SKIP
669 ... db, span_id="abc123",
670 ... model="gpt-4",
671 ... input_tokens=100,
672 ... output_tokens=50,
673 ... estimated_cost_usd=0.015
674 ... )
675 """
676 if not trace_id:
677 trace_id = current_trace_id.get()
679 if not trace_id:
680 logger.warning("Cannot record token usage: no active trace")
681 return
683 # Calculate total if not provided
684 if total_tokens is None:
685 total_tokens = input_tokens + output_tokens
687 # Estimate cost if not provided and we have model info
688 if estimated_cost_usd is None and model:
689 estimated_cost_usd = self._estimate_token_cost(model, input_tokens, output_tokens)
691 # Store in span attributes if span_id provided
692 if span_id:
693 span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first()
694 if span:
695 attrs = span.attributes or {}
696 attrs.update(
697 {
698 "llm.model": model,
699 "llm.provider": provider,
700 "llm.input_tokens": input_tokens,
701 "llm.output_tokens": output_tokens,
702 "llm.total_tokens": total_tokens,
703 "llm.estimated_cost_usd": estimated_cost_usd,
704 }
705 )
706 span.attributes = attrs
707 self._safe_commit(db, "record_token_usage")
709 # Also record as metrics for aggregation
710 if input_tokens > 0:
711 self.record_metric(
712 db=db,
713 name="llm.tokens.input",
714 value=float(input_tokens),
715 metric_type="counter",
716 unit="tokens",
717 trace_id=trace_id,
718 attributes={"model": model, "provider": provider},
719 )
721 if output_tokens > 0:
722 self.record_metric(
723 db=db,
724 name="llm.tokens.output",
725 value=float(output_tokens),
726 metric_type="counter",
727 unit="tokens",
728 trace_id=trace_id,
729 attributes={"model": model, "provider": provider},
730 )
732 if estimated_cost_usd:
733 self.record_metric(
734 db=db,
735 name="llm.cost",
736 value=estimated_cost_usd,
737 metric_type="counter",
738 unit="usd",
739 trace_id=trace_id,
740 attributes={"model": model, "provider": provider},
741 )
743 logger.debug(f"Recorded token usage: {input_tokens} in, {output_tokens} out, ${estimated_cost_usd:.6f}")
745 def _estimate_token_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
746 """Estimate cost based on model and token counts.
748 Pricing as of January 2025 (prices may change).
750 Args:
751 model: Model name
752 input_tokens: Input token count
753 output_tokens: Output token count
755 Returns:
756 Estimated cost in USD
757 """
758 # Pricing per 1M tokens (input, output)
759 pricing = {
760 # OpenAI
761 "gpt-4": (30.0, 60.0),
762 "gpt-4-turbo": (10.0, 30.0),
763 "gpt-4o": (2.5, 10.0),
764 "gpt-4o-mini": (0.15, 0.60),
765 "gpt-3.5-turbo": (0.50, 1.50),
766 # Anthropic
767 "claude-3-opus": (15.0, 75.0),
768 "claude-3-sonnet": (3.0, 15.0),
769 "claude-3-haiku": (0.25, 1.25),
770 "claude-3.5-sonnet": (3.0, 15.0),
771 "claude-3.5-haiku": (0.80, 4.0),
772 # Fallback for unknown models
773 "default": (1.0, 3.0),
774 }
776 # Find matching pricing (case-insensitive, partial match)
777 model_lower = model.lower()
778 input_price, output_price = pricing.get("default")
780 for model_key, prices in pricing.items():
781 if model_key in model_lower:
782 input_price, output_price = prices
783 break
785 # Calculate cost (pricing is per 1M tokens)
786 input_cost = (input_tokens / 1_000_000) * input_price
787 output_cost = (output_tokens / 1_000_000) * output_price
789 return input_cost + output_cost
791 # ==============================
792 # Agent-to-Agent (A2A) Tracing
793 # ==============================
795 @contextmanager
796 def trace_a2a_request(
797 self,
798 db: Session,
799 agent_id: str,
800 agent_name: Optional[str] = None,
801 operation: Optional[str] = None,
802 request_data: Optional[Dict[str, Any]] = None,
803 ):
804 """Context manager for tracing Agent-to-Agent requests.
806 This automatically creates a span for A2A communication, capturing timing,
807 request/response data, and errors.
809 Args:
810 db: Database session
811 agent_id: Target agent ID
812 agent_name: Human-readable agent name
813 operation: Operation being performed (e.g., "query", "execute", "status")
814 request_data: Request payload (will be sanitized)
816 Yields:
817 Tuple of (span_id, result_dict) - update result_dict with A2A results
819 Raises:
820 Exception: Re-raises any exception from A2A call after logging
822 Examples:
823 >>> with service.trace_a2a_request(db, "agent-123", "WeatherAgent", "query") as (span_id, result): # doctest: +SKIP
824 ... response = await http_client.post(...) # doctest: +SKIP
825 ... result["status_code"] = response.status_code # doctest: +SKIP
826 ... result["response_time_ms"] = 45.2 # doctest: +SKIP
827 """
828 trace_id = current_trace_id.get()
829 if not trace_id:
830 # No active trace, yield a no-op
831 result_dict: Dict[str, Any] = {}
832 yield (None, result_dict)
833 return
835 # Sanitize request data
836 safe_data = {}
837 if request_data:
838 safe_data = {k: ("***REDACTED***" if any(sensitive in k.lower() for sensitive in ["password", "token", "key", "secret", "auth"]) else v) for k, v in request_data.items()}
840 # Start A2A span
841 span_id = self.start_span(
842 db=db,
843 trace_id=trace_id,
844 name=f"a2a.call.{agent_name or agent_id}",
845 kind="client",
846 resource_type="agent",
847 resource_name=agent_name or agent_id,
848 attributes={
849 "a2a.agent_id": agent_id,
850 "a2a.agent_name": agent_name,
851 "a2a.operation": operation,
852 "a2a.request_data": safe_data,
853 },
854 )
856 result_dict = {}
857 try:
858 yield (span_id, result_dict)
860 # End span with results
861 self.end_span(
862 db=db,
863 span_id=span_id,
864 status="ok",
865 attributes={
866 "a2a.result": result_dict,
867 },
868 )
869 except Exception as e:
870 # Log error in span
871 self.end_span(db=db, span_id=span_id, status="error", status_message=str(e))
873 self.add_event(
874 db=db,
875 span_id=span_id,
876 name="a2a.error",
877 severity="error",
878 message=str(e),
879 exception_type=type(e).__name__,
880 exception_message=str(e),
881 exception_stacktrace=traceback.format_exc(),
882 )
883 raise
885 # ==============================
886 # Transport Metrics
887 # ==============================
889 def record_transport_activity(
890 self,
891 db: Session,
892 transport_type: str,
893 operation: str,
894 message_count: int = 1,
895 bytes_sent: Optional[int] = None,
896 bytes_received: Optional[int] = None,
897 connection_id: Optional[str] = None,
898 error: Optional[str] = None,
899 ) -> None:
900 """Record transport-specific activity metrics.
902 Args:
903 db: Database session
904 transport_type: Transport type (sse, websocket, stdio, http)
905 operation: Operation type (connect, disconnect, send, receive, error)
906 message_count: Number of messages processed
907 bytes_sent: Bytes sent (if applicable)
908 bytes_received: Bytes received (if applicable)
909 connection_id: Connection/session identifier
910 error: Error message if operation failed
912 Examples:
913 >>> service.record_transport_activity( # doctest: +SKIP
914 ... db, transport_type="sse",
915 ... operation="send",
916 ... message_count=1,
917 ... bytes_sent=1024
918 ... )
919 """
920 trace_id = current_trace_id.get()
922 # Record message count
923 if message_count > 0:
924 self.record_metric(
925 db=db,
926 name=f"transport.{transport_type}.messages",
927 value=float(message_count),
928 metric_type="counter",
929 unit="messages",
930 trace_id=trace_id,
931 attributes={
932 "transport": transport_type,
933 "operation": operation,
934 "connection_id": connection_id,
935 },
936 )
938 # Record bytes sent
939 if bytes_sent:
940 self.record_metric(
941 db=db,
942 name=f"transport.{transport_type}.bytes_sent",
943 value=float(bytes_sent),
944 metric_type="counter",
945 unit="bytes",
946 trace_id=trace_id,
947 attributes={
948 "transport": transport_type,
949 "operation": operation,
950 "connection_id": connection_id,
951 },
952 )
954 # Record bytes received
955 if bytes_received:
956 self.record_metric(
957 db=db,
958 name=f"transport.{transport_type}.bytes_received",
959 value=float(bytes_received),
960 metric_type="counter",
961 unit="bytes",
962 trace_id=trace_id,
963 attributes={
964 "transport": transport_type,
965 "operation": operation,
966 "connection_id": connection_id,
967 },
968 )
970 # Record errors
971 if error:
972 self.record_metric(
973 db=db,
974 name=f"transport.{transport_type}.errors",
975 value=1.0,
976 metric_type="counter",
977 unit="errors",
978 trace_id=trace_id,
979 attributes={
980 "transport": transport_type,
981 "operation": operation,
982 "connection_id": connection_id,
983 "error": error,
984 },
985 )
987 logger.debug(f"Recorded {transport_type} transport activity: {operation} ({message_count} messages)")
989 # ==============================
990 # Metric Management
991 # ==============================
993 def record_metric(
994 self,
995 db: Session,
996 name: str,
997 value: float,
998 metric_type: str = "gauge",
999 unit: Optional[str] = None,
1000 resource_type: Optional[str] = None,
1001 resource_id: Optional[str] = None,
1002 trace_id: Optional[str] = None,
1003 attributes: Optional[Dict[str, Any]] = None,
1004 ) -> int:
1005 """Record a metric.
1007 Args:
1008 db: Database session
1009 name: Metric name (e.g., "http.request.duration")
1010 value: Metric value
1011 metric_type: Metric type (counter, gauge, histogram)
1012 unit: Metric unit (ms, count, bytes, etc.)
1013 resource_type: Resource type
1014 resource_id: Resource ID
1015 trace_id: Associated trace ID
1016 attributes: Additional metric attributes/labels
1018 Returns:
1019 Metric ID
1021 Examples:
1022 >>> metric_id = service.record_metric( # doctest: +SKIP
1023 ... db, # doctest: +SKIP
1024 ... "http.request.duration", # doctest: +SKIP
1025 ... 123.45, # doctest: +SKIP
1026 ... metric_type="histogram", # doctest: +SKIP
1027 ... unit="ms", # doctest: +SKIP
1028 ... trace_id=trace_id # doctest: +SKIP
1029 ... ) # doctest: +SKIP
1030 """
1031 metric = ObservabilityMetric(
1032 name=name,
1033 value=value,
1034 metric_type=metric_type,
1035 timestamp=utc_now(),
1036 unit=unit,
1037 resource_type=resource_type,
1038 resource_id=resource_id,
1039 trace_id=trace_id,
1040 attributes=attributes or {},
1041 created_at=utc_now(),
1042 )
1043 db.add(metric)
1044 if not self._safe_commit(db, "record_metric"):
1045 return 0
1046 db.refresh(metric)
1047 logger.debug(f"Recorded metric: {name} = {value} {unit or ''}")
1048 return metric.id
1050 # ==============================
1051 # Query Methods
1052 # ==============================
1054 # pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-locals
1055 def query_traces(
1056 self,
1057 db: Session,
1058 start_time: Optional[datetime] = None,
1059 end_time: Optional[datetime] = None,
1060 min_duration_ms: Optional[float] = None,
1061 max_duration_ms: Optional[float] = None,
1062 status: Optional[str] = None,
1063 status_in: Optional[List[str]] = None,
1064 status_not_in: Optional[List[str]] = None,
1065 http_status_code: Optional[int] = None,
1066 http_status_code_in: Optional[List[int]] = None,
1067 http_method: Optional[str] = None,
1068 http_method_in: Optional[List[str]] = None,
1069 user_email: Optional[str] = None,
1070 user_email_in: Optional[List[str]] = None,
1071 attribute_filters: Optional[Dict[str, Any]] = None,
1072 attribute_filters_or: Optional[Dict[str, Any]] = None,
1073 attribute_search: Optional[str] = None,
1074 name_contains: Optional[str] = None,
1075 order_by: str = "start_time_desc",
1076 limit: int = 100,
1077 offset: int = 0,
1078 ) -> List[ObservabilityTrace]:
1079 """Query traces with advanced filters.
1081 Supports both simple filters (single value) and list filters (multiple values with OR logic).
1082 All top-level filters are combined with AND logic unless using _or suffix.
1084 Args:
1085 db: Database session
1086 start_time: Filter traces after this time
1087 end_time: Filter traces before this time
1088 min_duration_ms: Filter traces with duration >= this value (milliseconds)
1089 max_duration_ms: Filter traces with duration <= this value (milliseconds)
1090 status: Filter by single status (ok, error)
1091 status_in: Filter by multiple statuses (OR logic)
1092 status_not_in: Exclude these statuses (NOT logic)
1093 http_status_code: Filter by single HTTP status code
1094 http_status_code_in: Filter by multiple HTTP status codes (OR logic)
1095 http_method: Filter by single HTTP method (GET, POST, etc.)
1096 http_method_in: Filter by multiple HTTP methods (OR logic)
1097 user_email: Filter by single user email
1098 user_email_in: Filter by multiple user emails (OR logic)
1099 attribute_filters: JSON attribute filters (AND logic - all must match)
1100 attribute_filters_or: JSON attribute filters (OR logic - any must match)
1101 attribute_search: Free-text search within JSON attributes (partial match)
1102 name_contains: Filter traces where name contains this substring
1103 order_by: Sort order (start_time_desc, start_time_asc, duration_desc, duration_asc)
1104 limit: Maximum results (1-1000)
1105 offset: Result offset
1107 Returns:
1108 List of traces
1110 Raises:
1111 ValueError: If invalid parameters are provided
1113 Examples:
1114 >>> # Find slow errors from multiple endpoints
1115 >>> traces = service.query_traces( # doctest: +SKIP
1116 ... db,
1117 ... status="error",
1118 ... min_duration_ms=100.0,
1119 ... http_method_in=["POST", "PUT"],
1120 ... attribute_filters={"http.route": "/api/tools"},
1121 ... limit=50
1122 ... )
1123 >>> # Exclude health checks and find slow requests
1124 >>> traces = service.query_traces( # doctest: +SKIP
1125 ... db,
1126 ... min_duration_ms=1000.0,
1127 ... name_contains="api",
1128 ... status_not_in=["ok"],
1129 ... order_by="duration_desc"
1130 ... )
1131 """
1132 # Third-Party
1133 # pylint: disable=import-outside-toplevel
1134 from sqlalchemy import cast, or_, String
1136 # pylint: enable=import-outside-toplevel
1137 # Validate limit
1138 if limit < 1 or limit > 1000:
1139 raise ValueError("limit must be between 1 and 1000")
1141 # Validate order_by
1142 valid_orders = ["start_time_desc", "start_time_asc", "duration_desc", "duration_asc"]
1143 if order_by not in valid_orders:
1144 raise ValueError(f"order_by must be one of: {', '.join(valid_orders)}")
1146 query = db.query(ObservabilityTrace)
1148 # Time range filters
1149 if start_time:
1150 query = query.filter(ObservabilityTrace.start_time >= start_time)
1151 if end_time:
1152 query = query.filter(ObservabilityTrace.start_time <= end_time)
1154 # Duration filters
1155 if min_duration_ms is not None:
1156 query = query.filter(ObservabilityTrace.duration_ms >= min_duration_ms)
1157 if max_duration_ms is not None:
1158 query = query.filter(ObservabilityTrace.duration_ms <= max_duration_ms)
1160 # Status filters (with OR and NOT support)
1161 if status:
1162 query = query.filter(ObservabilityTrace.status == status)
1163 if status_in:
1164 query = query.filter(ObservabilityTrace.status.in_(status_in))
1165 if status_not_in:
1166 query = query.filter(~ObservabilityTrace.status.in_(status_not_in))
1168 # HTTP status code filters (with OR support)
1169 if http_status_code:
1170 query = query.filter(ObservabilityTrace.http_status_code == http_status_code)
1171 if http_status_code_in:
1172 query = query.filter(ObservabilityTrace.http_status_code.in_(http_status_code_in))
1174 # HTTP method filters (with OR support)
1175 if http_method:
1176 query = query.filter(ObservabilityTrace.http_method == http_method)
1177 if http_method_in:
1178 query = query.filter(ObservabilityTrace.http_method.in_(http_method_in))
1180 # User email filters (with OR support)
1181 if user_email:
1182 query = query.filter(ObservabilityTrace.user_email == user_email)
1183 if user_email_in:
1184 query = query.filter(ObservabilityTrace.user_email.in_(user_email_in))
1186 # Name substring filter
1187 if name_contains:
1188 query = query.filter(ObservabilityTrace.name.ilike(f"%{name_contains}%"))
1190 # Attribute-based filtering with AND logic (all filters must match)
1191 if attribute_filters:
1192 for key, value in attribute_filters.items():
1193 # Use JSON path access for filtering
1194 # Supports both SQLite (via json_extract) and PostgreSQL (via ->>)
1195 query = query.filter(ObservabilityTrace.attributes[key].astext == str(value))
1197 # Attribute-based filtering with OR logic (any filter must match)
1198 if attribute_filters_or:
1199 or_conditions = []
1200 for key, value in attribute_filters_or.items():
1201 or_conditions.append(ObservabilityTrace.attributes[key].astext == str(value))
1202 if or_conditions:
1203 query = query.filter(or_(*or_conditions))
1205 # Free-text search across all attribute values
1206 if attribute_search:
1207 # Cast JSON attributes to text and search for substring
1208 # Works with both SQLite and PostgreSQL
1209 # Escape special characters to prevent SQL injection
1210 safe_search = attribute_search.replace("%", "\\%").replace("_", "\\_")
1211 query = query.filter(cast(ObservabilityTrace.attributes, String).ilike(f"%{safe_search}%"))
1213 # Apply ordering
1214 if order_by == "start_time_desc":
1215 query = query.order_by(desc(ObservabilityTrace.start_time))
1216 elif order_by == "start_time_asc":
1217 query = query.order_by(ObservabilityTrace.start_time)
1218 elif order_by == "duration_desc":
1219 query = query.order_by(desc(ObservabilityTrace.duration_ms))
1220 else: # duration_asc (validated above)
1221 query = query.order_by(ObservabilityTrace.duration_ms)
1223 # Apply pagination
1224 query = query.limit(limit).offset(offset)
1226 return query.all()
1228 # pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-locals
1229 def query_spans(
1230 self,
1231 db: Session,
1232 trace_id: Optional[str] = None,
1233 trace_id_in: Optional[List[str]] = None,
1234 resource_type: Optional[str] = None,
1235 resource_type_in: Optional[List[str]] = None,
1236 resource_name: Optional[str] = None,
1237 resource_name_in: Optional[List[str]] = None,
1238 name_contains: Optional[str] = None,
1239 kind: Optional[str] = None,
1240 kind_in: Optional[List[str]] = None,
1241 status: Optional[str] = None,
1242 status_in: Optional[List[str]] = None,
1243 status_not_in: Optional[List[str]] = None,
1244 start_time: Optional[datetime] = None,
1245 end_time: Optional[datetime] = None,
1246 min_duration_ms: Optional[float] = None,
1247 max_duration_ms: Optional[float] = None,
1248 attribute_filters: Optional[Dict[str, Any]] = None,
1249 attribute_search: Optional[str] = None,
1250 order_by: str = "start_time_desc",
1251 limit: int = 100,
1252 offset: int = 0,
1253 ) -> List[ObservabilitySpan]:
1254 """Query spans with advanced filters.
1256 Supports filtering by trace, resource, kind, status, duration, and attributes.
1257 All top-level filters are combined with AND logic. List filters use OR logic.
1259 Args:
1260 db: Database session
1261 trace_id: Filter by single trace ID
1262 trace_id_in: Filter by multiple trace IDs (OR logic)
1263 resource_type: Filter by single resource type (tool, database, plugin, etc.)
1264 resource_type_in: Filter by multiple resource types (OR logic)
1265 resource_name: Filter by single resource name
1266 resource_name_in: Filter by multiple resource names (OR logic)
1267 name_contains: Filter spans where name contains this substring
1268 kind: Filter by span kind (client, server, internal)
1269 kind_in: Filter by multiple kinds (OR logic)
1270 status: Filter by single status (ok, error)
1271 status_in: Filter by multiple statuses (OR logic)
1272 status_not_in: Exclude these statuses (NOT logic)
1273 start_time: Filter spans after this time
1274 end_time: Filter spans before this time
1275 min_duration_ms: Filter spans with duration >= this value (milliseconds)
1276 max_duration_ms: Filter spans with duration <= this value (milliseconds)
1277 attribute_filters: JSON attribute filters (AND logic)
1278 attribute_search: Free-text search within JSON attributes
1279 order_by: Sort order (start_time_desc, start_time_asc, duration_desc, duration_asc)
1280 limit: Maximum results (1-1000)
1281 offset: Result offset
1283 Returns:
1284 List of spans
1286 Raises:
1287 ValueError: If invalid parameters are provided
1289 Examples:
1290 >>> # Find slow database queries
1291 >>> spans = service.query_spans( # doctest: +SKIP
1292 ... db,
1293 ... resource_type="database",
1294 ... min_duration_ms=100.0,
1295 ... order_by="duration_desc",
1296 ... limit=50
1297 ... )
1298 >>> # Find tool invocation errors
1299 >>> spans = service.query_spans( # doctest: +SKIP
1300 ... db,
1301 ... resource_type="tool",
1302 ... status="error",
1303 ... name_contains="invoke"
1304 ... )
1305 """
1306 # Third-Party
1307 # pylint: disable=import-outside-toplevel
1308 from sqlalchemy import cast, String
1310 # pylint: enable=import-outside-toplevel
1311 # Validate limit
1312 if limit < 1 or limit > 1000:
1313 raise ValueError("limit must be between 1 and 1000")
1315 # Validate order_by
1316 valid_orders = ["start_time_desc", "start_time_asc", "duration_desc", "duration_asc"]
1317 if order_by not in valid_orders:
1318 raise ValueError(f"order_by must be one of: {', '.join(valid_orders)}")
1320 query = db.query(ObservabilitySpan)
1322 # Trace ID filters (with OR support)
1323 if trace_id:
1324 query = query.filter(ObservabilitySpan.trace_id == trace_id)
1325 if trace_id_in:
1326 query = query.filter(ObservabilitySpan.trace_id.in_(trace_id_in))
1328 # Resource type filters (with OR support)
1329 if resource_type:
1330 query = query.filter(ObservabilitySpan.resource_type == resource_type)
1331 if resource_type_in:
1332 query = query.filter(ObservabilitySpan.resource_type.in_(resource_type_in))
1334 # Resource name filters (with OR support)
1335 if resource_name:
1336 query = query.filter(ObservabilitySpan.resource_name == resource_name)
1337 if resource_name_in:
1338 query = query.filter(ObservabilitySpan.resource_name.in_(resource_name_in))
1340 # Name substring filter
1341 if name_contains:
1342 query = query.filter(ObservabilitySpan.name.ilike(f"%{name_contains}%"))
1344 # Kind filters (with OR support)
1345 if kind:
1346 query = query.filter(ObservabilitySpan.kind == kind)
1347 if kind_in:
1348 query = query.filter(ObservabilitySpan.kind.in_(kind_in))
1350 # Status filters (with OR and NOT support)
1351 if status:
1352 query = query.filter(ObservabilitySpan.status == status)
1353 if status_in:
1354 query = query.filter(ObservabilitySpan.status.in_(status_in))
1355 if status_not_in:
1356 query = query.filter(~ObservabilitySpan.status.in_(status_not_in))
1358 # Time range filters
1359 if start_time:
1360 query = query.filter(ObservabilitySpan.start_time >= start_time)
1361 if end_time:
1362 query = query.filter(ObservabilitySpan.start_time <= end_time)
1364 # Duration filters
1365 if min_duration_ms is not None:
1366 query = query.filter(ObservabilitySpan.duration_ms >= min_duration_ms)
1367 if max_duration_ms is not None:
1368 query = query.filter(ObservabilitySpan.duration_ms <= max_duration_ms)
1370 # Attribute-based filtering with AND logic
1371 if attribute_filters:
1372 for key, value in attribute_filters.items():
1373 query = query.filter(ObservabilitySpan.attributes[key].astext == str(value))
1375 # Free-text search across all attribute values
1376 if attribute_search:
1377 safe_search = attribute_search.replace("%", "\\%").replace("_", "\\_")
1378 query = query.filter(cast(ObservabilitySpan.attributes, String).ilike(f"%{safe_search}%"))
1380 # Apply ordering
1381 if order_by == "start_time_desc":
1382 query = query.order_by(desc(ObservabilitySpan.start_time))
1383 elif order_by == "start_time_asc":
1384 query = query.order_by(ObservabilitySpan.start_time)
1385 elif order_by == "duration_desc":
1386 query = query.order_by(desc(ObservabilitySpan.duration_ms))
1387 else: # duration_asc (validated above)
1388 query = query.order_by(ObservabilitySpan.duration_ms)
1390 # Apply pagination
1391 query = query.limit(limit).offset(offset)
1393 return query.all()
1395 def get_trace_with_spans(self, db: Session, trace_id: str) -> Optional[ObservabilityTrace]:
1396 """Get a complete trace with all spans and events.
1398 Args:
1399 db: Database session
1400 trace_id: Trace ID
1402 Returns:
1403 Trace with spans and events loaded
1405 Examples:
1406 >>> trace = service.get_trace_with_spans(db, trace_id) # doctest: +SKIP
1407 >>> if trace: # doctest: +SKIP
1408 ... for span in trace.spans: # doctest: +SKIP
1409 ... print(f"Span: {span.name}, Events: {len(span.events)}") # doctest: +SKIP
1410 """
1411 return db.query(ObservabilityTrace).filter_by(trace_id=trace_id).options(joinedload(ObservabilityTrace.spans).joinedload(ObservabilitySpan.events)).first()
1413 def delete_old_traces(self, db: Session, before_time: datetime) -> int:
1414 """Delete traces older than a given time.
1416 Args:
1417 db: Database session
1418 before_time: Delete traces before this time
1420 Returns:
1421 Number of traces deleted
1423 Examples:
1424 >>> from datetime import timedelta # doctest: +SKIP
1425 >>> cutoff = utc_now() - timedelta(days=30) # doctest: +SKIP
1426 >>> deleted = service.delete_old_traces(db, cutoff) # doctest: +SKIP
1427 >>> print(f"Deleted {deleted} old traces") # doctest: +SKIP
1428 """
1429 deleted = db.query(ObservabilityTrace).filter(ObservabilityTrace.start_time < before_time).delete()
1430 if not self._safe_commit(db, "delete_old_traces"):
1431 return 0
1432 logger.info(f"Deleted {deleted} traces older than {before_time}")
1433 return deleted