Coverage for mcpgateway / services / observability_service.py: 100%
338 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/observability_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Observability Service Implementation.
8This module provides OpenTelemetry-style observability for ContextForge,
9capturing traces, spans, events, and metrics for all operations.
11It includes:
12- Trace creation and management
13- Span tracking with hierarchical nesting
14- Event logging within spans
15- Metrics collection and storage
16- Query and filtering capabilities
17- Integration with FastAPI middleware
19Examples:
20 >>> from mcpgateway.services.observability_service import ObservabilityService # doctest: +SKIP
21 >>> service = ObservabilityService() # doctest: +SKIP
22 >>> trace_id = service.start_trace(db, "GET /tools", http_method="GET", http_url="/tools") # doctest: +SKIP
23 >>> span_id = service.start_span(db, trace_id, "database_query", resource_type="database") # doctest: +SKIP
24 >>> service.end_span(db, span_id, status="ok") # doctest: +SKIP
25 >>> service.end_trace(db, trace_id, status="ok", http_status_code=200) # doctest: +SKIP
26"""
28# Standard
29from contextlib import contextmanager
30from contextvars import ContextVar
31from datetime import datetime, timezone
32import logging
33import re
34import traceback
35from typing import Any, Dict, List, Optional, Pattern, Tuple
36import uuid
38# Third-Party
39from sqlalchemy import desc
40from sqlalchemy.exc import SQLAlchemyError
41from sqlalchemy.orm import joinedload, Session
43# First-Party
44from mcpgateway.db import ObservabilityEvent, ObservabilityMetric, ObservabilitySpan, ObservabilityTrace
46logger = logging.getLogger(__name__)
48# Precompiled regex for W3C Trace Context traceparent header parsing
49# Format: version-trace_id-parent_id-trace_flags
50_TRACEPARENT_RE: Pattern[str] = re.compile(r"^([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})$")
52# Context variable for tracking the current trace_id across async calls.
53# NOTE: The plugin framework maintains a separate ContextVar in
54# mcpgateway.plugins.framework.observability.current_trace_id.
55# ObservabilityMiddleware bridges both — any new code path that sets this
56# variable must also set the framework copy to keep plugin tracing in sync.
57current_trace_id: ContextVar[Optional[str]] = ContextVar("current_trace_id", default=None)
60def utc_now() -> datetime:
61 """Return current UTC time with timezone.
63 Returns:
64 datetime: Current time in UTC with timezone info
65 """
66 return datetime.now(timezone.utc)
69def ensure_timezone_aware(dt: datetime) -> datetime:
70 """Ensure datetime is timezone-aware (UTC).
72 SQLite returns naive datetimes even when stored with timezone info.
73 This helper ensures consistency for datetime arithmetic.
75 Args:
76 dt: Datetime that may be naive or aware
78 Returns:
79 Timezone-aware datetime in UTC
80 """
81 if dt.tzinfo is None:
82 return dt.replace(tzinfo=timezone.utc)
83 return dt
86def parse_traceparent(traceparent: str) -> Optional[Tuple[str, str, str]]:
87 """Parse W3C Trace Context traceparent header.
89 Format: version-trace_id-parent_id-trace_flags
90 Example: 00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01
92 Args:
93 traceparent: W3C traceparent header value
95 Returns:
96 Tuple of (trace_id, parent_span_id, trace_flags) or None if invalid
98 Examples:
99 >>> parse_traceparent("00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01") # doctest: +SKIP
100 ('0af7651916cd43dd8448eb211c80319c', 'b7ad6b7169203331', '01')
101 """
102 # W3C Trace Context format: 00-trace_id(32hex)-parent_id(16hex)-flags(2hex)
103 # Uses precompiled regex for performance
104 match = _TRACEPARENT_RE.match(traceparent.lower())
106 if not match:
107 logger.warning(f"Invalid traceparent format: {traceparent}")
108 return None
110 version, trace_id, parent_id, flags = match.groups()
112 # Only support version 00 for now
113 if version != "00":
114 logger.warning(f"Unsupported traceparent version: {version}")
115 return None
117 # Validate trace_id and parent_id are not all zeros
118 if trace_id == "0" * 32 or parent_id == "0" * 16:
119 logger.warning("Invalid traceparent with zero trace_id or parent_id")
120 return None
122 return (trace_id, parent_id, flags)
125def generate_w3c_trace_id() -> str:
126 """Generate a W3C compliant trace ID (32 hex characters).
128 Returns:
129 32-character lowercase hex string
131 Examples:
132 >>> trace_id = generate_w3c_trace_id() # doctest: +SKIP
133 >>> len(trace_id) # doctest: +SKIP
134 32
135 """
136 return uuid.uuid4().hex + uuid.uuid4().hex[:16]
139def generate_w3c_span_id() -> str:
140 """Generate a W3C compliant span ID (16 hex characters).
142 Returns:
143 16-character lowercase hex string
145 Examples:
146 >>> span_id = generate_w3c_span_id() # doctest: +SKIP
147 >>> len(span_id) # doctest: +SKIP
148 16
149 """
150 return uuid.uuid4().hex[:16]
153def format_traceparent(trace_id: str, span_id: str, sampled: bool = True) -> str:
154 """Format a W3C traceparent header value.
156 Args:
157 trace_id: 32-character hex trace ID
158 span_id: 16-character hex span ID
159 sampled: Whether the trace is sampled (affects trace-flags)
161 Returns:
162 W3C traceparent header value
164 Examples:
165 >>> format_traceparent("0af7651916cd43dd8448eb211c80319c", "b7ad6b7169203331") # doctest: +SKIP
166 '00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01'
167 """
168 flags = "01" if sampled else "00"
169 return f"00-{trace_id}-{span_id}-{flags}"
172class ObservabilityService:
173 """Service for managing observability traces, spans, events, and metrics.
175 This service provides comprehensive observability capabilities similar to
176 OpenTelemetry, allowing tracking of request flows through the system.
178 Examples:
179 >>> service = ObservabilityService() # doctest: +SKIP
180 >>> trace_id = service.start_trace(db, "POST /tools/invoke") # doctest: +SKIP
181 >>> span_id = service.start_span(db, trace_id, "tool_execution") # doctest: +SKIP
182 >>> service.end_span(db, span_id, status="ok") # doctest: +SKIP
183 >>> service.end_trace(db, trace_id, status="ok") # doctest: +SKIP
184 """
186 def _safe_commit(self, db: Session, context: str) -> bool:
187 """Commit and rollback on failure without raising.
189 Args:
190 db: SQLAlchemy session for the current operation.
191 context: Short label for the commit context (used in logs).
193 Returns:
194 True when commit succeeds, False when a rollback was performed.
195 """
196 try:
197 db.commit()
198 return True
199 except SQLAlchemyError as exc:
200 logger.warning(f"Observability commit failed ({context}): {exc}")
201 try:
202 db.rollback()
203 except SQLAlchemyError as rollback_exc:
204 logger.debug(f"Observability rollback failed ({context}): {rollback_exc}")
205 return False
207 # ==============================
208 # Trace Management
209 # ==============================
211 def start_trace(
212 self,
213 db: Session,
214 name: str,
215 trace_id: Optional[str] = None,
216 parent_span_id: Optional[str] = None,
217 http_method: Optional[str] = None,
218 http_url: Optional[str] = None,
219 user_email: Optional[str] = None,
220 user_agent: Optional[str] = None,
221 ip_address: Optional[str] = None,
222 attributes: Optional[Dict[str, Any]] = None,
223 resource_attributes: Optional[Dict[str, Any]] = None,
224 ) -> str:
225 """Start a new trace.
227 Args:
228 db: Database session
229 name: Trace name (e.g., "POST /tools/invoke")
230 trace_id: External trace ID (for distributed tracing, W3C format)
231 parent_span_id: Parent span ID from upstream service
232 http_method: HTTP method (GET, POST, etc.)
233 http_url: Full request URL
234 user_email: Authenticated user email
235 user_agent: Client user agent string
236 ip_address: Client IP address
237 attributes: Additional trace attributes
238 resource_attributes: Resource attributes (service name, version, etc.)
240 Returns:
241 Trace ID (UUID string or W3C format)
243 Examples:
244 >>> trace_id = service.start_trace( # doctest: +SKIP
245 ... db,
246 ... "POST /tools/invoke",
247 ... http_method="POST",
248 ... http_url="https://api.example.com/tools/invoke",
249 ... user_email="user@example.com"
250 ... )
251 """
252 # Use provided trace_id or generate new UUID
253 if not trace_id:
254 trace_id = str(uuid.uuid4())
256 # Add parent context to attributes if provided
257 attrs = attributes or {}
258 if parent_span_id:
259 attrs["parent_span_id"] = parent_span_id
261 trace = ObservabilityTrace(
262 trace_id=trace_id,
263 name=name,
264 start_time=utc_now(),
265 status="unset",
266 http_method=http_method,
267 http_url=http_url,
268 user_email=user_email,
269 user_agent=user_agent,
270 ip_address=ip_address,
271 attributes=attrs,
272 resource_attributes=resource_attributes or {},
273 created_at=utc_now(),
274 )
275 db.add(trace)
276 self._safe_commit(db, "start_trace")
277 logger.debug(f"Started trace {trace_id}: {name}")
278 return trace_id
280 def end_trace(
281 self,
282 db: Session,
283 trace_id: str,
284 status: str = "ok",
285 status_message: Optional[str] = None,
286 http_status_code: Optional[int] = None,
287 attributes: Optional[Dict[str, Any]] = None,
288 ) -> None:
289 """End a trace.
291 Args:
292 db: Database session
293 trace_id: Trace ID to end
294 status: Trace status (ok, error)
295 status_message: Optional status message
296 http_status_code: HTTP response status code
297 attributes: Additional attributes to merge
299 Examples:
300 >>> service.end_trace( # doctest: +SKIP
301 ... db,
302 ... trace_id,
303 ... status="ok",
304 ... http_status_code=200
305 ... )
306 """
307 trace = db.query(ObservabilityTrace).filter_by(trace_id=trace_id).first()
308 if not trace:
309 logger.warning(f"Trace {trace_id} not found")
310 return
312 end_time = utc_now()
313 duration_ms = (end_time - ensure_timezone_aware(trace.start_time)).total_seconds() * 1000
315 trace.end_time = end_time
316 trace.duration_ms = duration_ms
317 trace.status = status
318 trace.status_message = status_message
319 if http_status_code is not None:
320 trace.http_status_code = http_status_code
321 if attributes:
322 trace.attributes = {**(trace.attributes or {}), **attributes}
324 self._safe_commit(db, "end_trace")
325 logger.debug(f"Ended trace {trace_id}: {status} ({duration_ms:.2f}ms)")
327 def get_trace(self, db: Session, trace_id: str, include_spans: bool = False) -> Optional[ObservabilityTrace]:
328 """Get a trace by ID.
330 Args:
331 db: Database session
332 trace_id: Trace ID
333 include_spans: Whether to load spans eagerly
335 Returns:
336 Trace object or None if not found
338 Examples:
339 >>> trace = service.get_trace(db, trace_id, include_spans=True) # doctest: +SKIP
340 >>> if trace: # doctest: +SKIP
341 ... print(f"Trace: {trace.name}, Spans: {len(trace.spans)}") # doctest: +SKIP
342 """
343 query = db.query(ObservabilityTrace).filter_by(trace_id=trace_id)
344 if include_spans:
345 query = query.options(joinedload(ObservabilityTrace.spans))
346 return query.first()
348 # ==============================
349 # Span Management
350 # ==============================
352 def start_span(
353 self,
354 db: Session,
355 trace_id: str,
356 name: str,
357 parent_span_id: Optional[str] = None,
358 kind: str = "internal",
359 resource_name: Optional[str] = None,
360 resource_type: Optional[str] = None,
361 resource_id: Optional[str] = None,
362 attributes: Optional[Dict[str, Any]] = None,
363 commit: bool = True,
364 ) -> str:
365 """Start a new span within a trace.
367 Args:
368 db: Database session
369 trace_id: Parent trace ID
370 name: Span name (e.g., "database_query", "tool_invocation")
371 parent_span_id: Parent span ID (for nested spans)
372 kind: Span kind (internal, server, client, producer, consumer)
373 resource_name: Resource name being operated on
374 resource_type: Resource type (tool, resource, prompt, etc.)
375 resource_id: Resource ID
376 attributes: Additional span attributes
377 commit: Whether to commit the transaction (default True).
378 Set to False when using fresh_db_session() which handles commits.
380 Returns:
381 Span ID (UUID string)
383 Examples:
384 >>> span_id = service.start_span( # doctest: +SKIP
385 ... db,
386 ... trace_id,
387 ... "tool_invocation",
388 ... resource_type="tool",
389 ... resource_name="get_weather"
390 ... )
391 """
392 span_id = str(uuid.uuid4())
393 span = ObservabilitySpan(
394 span_id=span_id,
395 trace_id=trace_id,
396 parent_span_id=parent_span_id,
397 name=name,
398 kind=kind,
399 start_time=utc_now(),
400 status="unset",
401 resource_name=resource_name,
402 resource_type=resource_type,
403 resource_id=resource_id,
404 attributes=attributes or {},
405 created_at=utc_now(),
406 )
407 db.add(span)
408 if commit:
409 self._safe_commit(db, "start_span")
410 logger.debug(f"Started span {span_id}: {name} (trace={trace_id})")
411 return span_id
413 def end_span(
414 self,
415 db: Session,
416 span_id: str,
417 status: str = "ok",
418 status_message: Optional[str] = None,
419 attributes: Optional[Dict[str, Any]] = None,
420 commit: bool = True,
421 ) -> None:
422 """End a span.
424 Args:
425 db: Database session
426 span_id: Span ID to end
427 status: Span status (ok, error)
428 status_message: Optional status message
429 attributes: Additional attributes to merge
430 commit: Whether to commit the transaction (default True).
431 Set to False when using fresh_db_session() which handles commits.
433 Examples:
434 >>> service.end_span(db, span_id, status="ok") # doctest: +SKIP
435 """
436 span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first()
437 if not span:
438 logger.warning(f"Span {span_id} not found")
439 return
441 end_time = utc_now()
442 duration_ms = (end_time - ensure_timezone_aware(span.start_time)).total_seconds() * 1000
444 span.end_time = end_time
445 span.duration_ms = duration_ms
446 span.status = status
447 span.status_message = status_message
448 if attributes:
449 span.attributes = {**(span.attributes or {}), **attributes}
451 if commit:
452 self._safe_commit(db, "end_span")
453 logger.debug(f"Ended span {span_id}: {status} ({duration_ms:.2f}ms)")
455 @contextmanager
456 def trace_span(
457 self,
458 db: Session,
459 trace_id: str,
460 name: str,
461 parent_span_id: Optional[str] = None,
462 resource_type: Optional[str] = None,
463 resource_name: Optional[str] = None,
464 attributes: Optional[Dict[str, Any]] = None,
465 ):
466 """Context manager for automatic span lifecycle management.
468 Args:
469 db: Database session
470 trace_id: Parent trace ID
471 name: Span name
472 parent_span_id: Parent span ID (optional)
473 resource_type: Resource type
474 resource_name: Resource name
475 attributes: Additional attributes
477 Yields:
478 Span ID
480 Raises:
481 Exception: Re-raises any exception after logging it in the span
483 Examples:
484 >>> with service.trace_span(db, trace_id, "database_query") as span_id: # doctest: +SKIP
485 ... results = db.query(Tool).all() # doctest: +SKIP
486 """
487 span_id = self.start_span(db, trace_id, name, parent_span_id, resource_type=resource_type, resource_name=resource_name, attributes=attributes)
488 try:
489 yield span_id
490 self.end_span(db, span_id, status="ok")
491 except Exception as e:
492 self.end_span(db, span_id, status="error", status_message=str(e))
493 self.add_event(db, span_id, "exception", severity="error", message=str(e), exception_type=type(e).__name__, exception_message=str(e), exception_stacktrace=traceback.format_exc())
494 raise
496 @contextmanager
497 def trace_tool_invocation(
498 self,
499 db: Session,
500 tool_name: str,
501 arguments: Dict[str, Any],
502 integration_type: Optional[str] = None,
503 ):
504 """Context manager for tracing MCP tool invocations.
506 This automatically creates a span for tool execution, capturing timing,
507 arguments, results, and errors.
509 Args:
510 db: Database session
511 tool_name: Name of the tool being invoked
512 arguments: Tool arguments (will be sanitized)
513 integration_type: Integration type (MCP, REST, A2A, etc.)
515 Yields:
516 Tuple of (span_id, result_dict) - update result_dict with tool results
518 Raises:
519 Exception: Re-raises any exception from tool invocation after logging
521 Examples:
522 >>> with service.trace_tool_invocation(db, "weather", {"city": "NYC"}) as (span_id, result): # doctest: +SKIP
523 ... response = await http_client.post(...) # doctest: +SKIP
524 ... result["status_code"] = response.status_code # doctest: +SKIP
525 ... result["response_size"] = len(response.content) # doctest: +SKIP
526 """
527 trace_id = current_trace_id.get()
528 if not trace_id:
529 # No active trace, yield a no-op
530 result_dict: Dict[str, Any] = {}
531 yield (None, result_dict)
532 return
534 # Sanitize arguments (remove sensitive data)
535 safe_args = {k: ("***REDACTED***" if any(sensitive in k.lower() for sensitive in ["password", "token", "key", "secret"]) else v) for k, v in arguments.items()}
537 # Start tool invocation span
538 span_id = self.start_span(
539 db=db,
540 trace_id=trace_id,
541 name=f"tool.invoke.{tool_name}",
542 kind="client",
543 resource_type="tool",
544 resource_name=tool_name,
545 attributes={
546 "tool.name": tool_name,
547 "tool.integration_type": integration_type,
548 "tool.argument_count": len(arguments),
549 "tool.arguments": safe_args,
550 },
551 )
553 result_dict = {}
554 try:
555 yield (span_id, result_dict)
557 # End span with results
558 self.end_span(
559 db=db,
560 span_id=span_id,
561 status="ok",
562 attributes={
563 "tool.result": result_dict,
564 },
565 )
566 except Exception as e:
567 # Log error in span
568 self.end_span(db=db, span_id=span_id, status="error", status_message=str(e))
570 self.add_event(
571 db=db,
572 span_id=span_id,
573 name="tool.error",
574 severity="error",
575 message=str(e),
576 exception_type=type(e).__name__,
577 exception_message=str(e),
578 exception_stacktrace=traceback.format_exc(),
579 )
580 raise
582 # ==============================
583 # Event Management
584 # ==============================
586 def add_event(
587 self,
588 db: Session,
589 span_id: str,
590 name: str,
591 severity: Optional[str] = None,
592 message: Optional[str] = None,
593 exception_type: Optional[str] = None,
594 exception_message: Optional[str] = None,
595 exception_stacktrace: Optional[str] = None,
596 attributes: Optional[Dict[str, Any]] = None,
597 ) -> int:
598 """Add an event to a span.
600 Args:
601 db: Database session
602 span_id: Parent span ID
603 name: Event name
604 severity: Log severity (debug, info, warning, error, critical)
605 message: Event message
606 exception_type: Exception class name
607 exception_message: Exception message
608 exception_stacktrace: Exception stacktrace
609 attributes: Additional event attributes
611 Returns:
612 Event ID
614 Examples:
615 >>> event_id = service.add_event( # doctest: +SKIP
616 ... db, # doctest: +SKIP
617 ... span_id, # doctest: +SKIP
618 ... "database_connection_error", # doctest: +SKIP
619 ... severity="error", # doctest: +SKIP
620 ... message="Failed to connect to database" # doctest: +SKIP
621 ... ) # doctest: +SKIP
622 """
623 event = ObservabilityEvent(
624 span_id=span_id,
625 name=name,
626 timestamp=utc_now(),
627 severity=severity,
628 message=message,
629 exception_type=exception_type,
630 exception_message=exception_message,
631 exception_stacktrace=exception_stacktrace,
632 attributes=attributes or {},
633 created_at=utc_now(),
634 )
635 db.add(event)
636 if not self._safe_commit(db, "add_event"):
637 return 0
638 db.refresh(event)
639 logger.debug(f"Added event to span {span_id}: {name}")
640 return event.id
642 # ==============================
643 # Token Usage Tracking
644 # ==============================
646 def record_token_usage(
647 self,
648 db: Session,
649 span_id: Optional[str] = None,
650 trace_id: Optional[str] = None,
651 model: Optional[str] = None,
652 input_tokens: int = 0,
653 output_tokens: int = 0,
654 total_tokens: Optional[int] = None,
655 estimated_cost_usd: Optional[float] = None,
656 provider: Optional[str] = None,
657 ) -> None:
658 """Record token usage for LLM calls.
660 Args:
661 db: Database session
662 span_id: Span ID to attach token usage to
663 trace_id: Trace ID (will use current context if not provided)
664 model: Model name (e.g., "gpt-4", "claude-3-opus")
665 input_tokens: Number of input/prompt tokens
666 output_tokens: Number of output/completion tokens
667 total_tokens: Total tokens (calculated if not provided)
668 estimated_cost_usd: Estimated cost in USD
669 provider: LLM provider (openai, anthropic, etc.)
671 Examples:
672 >>> service.record_token_usage( # doctest: +SKIP
673 ... db, span_id="abc123",
674 ... model="gpt-4",
675 ... input_tokens=100,
676 ... output_tokens=50,
677 ... estimated_cost_usd=0.015
678 ... )
679 """
680 if not trace_id:
681 trace_id = current_trace_id.get()
683 if not trace_id:
684 logger.warning("Cannot record token usage: no active trace")
685 return
687 # Calculate total if not provided
688 if total_tokens is None:
689 total_tokens = input_tokens + output_tokens
691 # Estimate cost if not provided and we have model info
692 if estimated_cost_usd is None and model:
693 estimated_cost_usd = self._estimate_token_cost(model, input_tokens, output_tokens)
695 # Store in span attributes if span_id provided
696 if span_id:
697 span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first()
698 if span:
699 attrs = span.attributes or {}
700 attrs.update(
701 {
702 "llm.model": model,
703 "llm.provider": provider,
704 "llm.input_tokens": input_tokens,
705 "llm.output_tokens": output_tokens,
706 "llm.total_tokens": total_tokens,
707 "llm.estimated_cost_usd": estimated_cost_usd,
708 }
709 )
710 span.attributes = attrs
711 self._safe_commit(db, "record_token_usage")
713 # Also record as metrics for aggregation
714 if input_tokens > 0:
715 self.record_metric(
716 db=db,
717 name="llm.tokens.input",
718 value=float(input_tokens),
719 metric_type="counter",
720 unit="tokens",
721 trace_id=trace_id,
722 attributes={"model": model, "provider": provider},
723 )
725 if output_tokens > 0:
726 self.record_metric(
727 db=db,
728 name="llm.tokens.output",
729 value=float(output_tokens),
730 metric_type="counter",
731 unit="tokens",
732 trace_id=trace_id,
733 attributes={"model": model, "provider": provider},
734 )
736 if estimated_cost_usd:
737 self.record_metric(
738 db=db,
739 name="llm.cost",
740 value=estimated_cost_usd,
741 metric_type="counter",
742 unit="usd",
743 trace_id=trace_id,
744 attributes={"model": model, "provider": provider},
745 )
747 logger.debug(f"Recorded token usage: {input_tokens} in, {output_tokens} out, ${estimated_cost_usd:.6f}")
749 def _estimate_token_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
750 """Estimate cost based on model and token counts.
752 Pricing as of January 2025 (prices may change).
754 Args:
755 model: Model name
756 input_tokens: Input token count
757 output_tokens: Output token count
759 Returns:
760 Estimated cost in USD
761 """
762 # Pricing per 1M tokens (input, output)
763 pricing = {
764 # OpenAI
765 "gpt-4": (30.0, 60.0),
766 "gpt-4-turbo": (10.0, 30.0),
767 "gpt-4o": (2.5, 10.0),
768 "gpt-4o-mini": (0.15, 0.60),
769 "gpt-3.5-turbo": (0.50, 1.50),
770 # Anthropic
771 "claude-3-opus": (15.0, 75.0),
772 "claude-3-sonnet": (3.0, 15.0),
773 "claude-3-haiku": (0.25, 1.25),
774 "claude-3.5-sonnet": (3.0, 15.0),
775 "claude-3.5-haiku": (0.80, 4.0),
776 # Fallback for unknown models
777 "default": (1.0, 3.0),
778 }
780 # Find matching pricing (case-insensitive, partial match)
781 model_lower = model.lower()
782 input_price, output_price = pricing.get("default")
784 for model_key, prices in pricing.items():
785 if model_key in model_lower:
786 input_price, output_price = prices
787 break
789 # Calculate cost (pricing is per 1M tokens)
790 input_cost = (input_tokens / 1_000_000) * input_price
791 output_cost = (output_tokens / 1_000_000) * output_price
793 return input_cost + output_cost
795 # ==============================
796 # Agent-to-Agent (A2A) Tracing
797 # ==============================
799 @contextmanager
800 def trace_a2a_request(
801 self,
802 db: Session,
803 agent_id: str,
804 agent_name: Optional[str] = None,
805 operation: Optional[str] = None,
806 request_data: Optional[Dict[str, Any]] = None,
807 ):
808 """Context manager for tracing Agent-to-Agent requests.
810 This automatically creates a span for A2A communication, capturing timing,
811 request/response data, and errors.
813 Args:
814 db: Database session
815 agent_id: Target agent ID
816 agent_name: Human-readable agent name
817 operation: Operation being performed (e.g., "query", "execute", "status")
818 request_data: Request payload (will be sanitized)
820 Yields:
821 Tuple of (span_id, result_dict) - update result_dict with A2A results
823 Raises:
824 Exception: Re-raises any exception from A2A call after logging
826 Examples:
827 >>> with service.trace_a2a_request(db, "agent-123", "WeatherAgent", "query") as (span_id, result): # doctest: +SKIP
828 ... response = await http_client.post(...) # doctest: +SKIP
829 ... result["status_code"] = response.status_code # doctest: +SKIP
830 ... result["response_time_ms"] = 45.2 # doctest: +SKIP
831 """
832 trace_id = current_trace_id.get()
833 if not trace_id:
834 # No active trace, yield a no-op
835 result_dict: Dict[str, Any] = {}
836 yield (None, result_dict)
837 return
839 # Sanitize request data
840 safe_data = {}
841 if request_data:
842 safe_data = {k: ("***REDACTED***" if any(sensitive in k.lower() for sensitive in ["password", "token", "key", "secret", "auth"]) else v) for k, v in request_data.items()}
844 # Start A2A span
845 span_id = self.start_span(
846 db=db,
847 trace_id=trace_id,
848 name=f"a2a.call.{agent_name or agent_id}",
849 kind="client",
850 resource_type="agent",
851 resource_name=agent_name or agent_id,
852 attributes={
853 "a2a.agent_id": agent_id,
854 "a2a.agent_name": agent_name,
855 "a2a.operation": operation,
856 "a2a.request_data": safe_data,
857 },
858 )
860 result_dict = {}
861 try:
862 yield (span_id, result_dict)
864 # End span with results
865 self.end_span(
866 db=db,
867 span_id=span_id,
868 status="ok",
869 attributes={
870 "a2a.result": result_dict,
871 },
872 )
873 except Exception as e:
874 # Log error in span
875 self.end_span(db=db, span_id=span_id, status="error", status_message=str(e))
877 self.add_event(
878 db=db,
879 span_id=span_id,
880 name="a2a.error",
881 severity="error",
882 message=str(e),
883 exception_type=type(e).__name__,
884 exception_message=str(e),
885 exception_stacktrace=traceback.format_exc(),
886 )
887 raise
889 # ==============================
890 # Transport Metrics
891 # ==============================
893 def record_transport_activity(
894 self,
895 db: Session,
896 transport_type: str,
897 operation: str,
898 message_count: int = 1,
899 bytes_sent: Optional[int] = None,
900 bytes_received: Optional[int] = None,
901 connection_id: Optional[str] = None,
902 error: Optional[str] = None,
903 ) -> None:
904 """Record transport-specific activity metrics.
906 Args:
907 db: Database session
908 transport_type: Transport type (sse, websocket, stdio, http)
909 operation: Operation type (connect, disconnect, send, receive, error)
910 message_count: Number of messages processed
911 bytes_sent: Bytes sent (if applicable)
912 bytes_received: Bytes received (if applicable)
913 connection_id: Connection/session identifier
914 error: Error message if operation failed
916 Examples:
917 >>> service.record_transport_activity( # doctest: +SKIP
918 ... db, transport_type="sse",
919 ... operation="send",
920 ... message_count=1,
921 ... bytes_sent=1024
922 ... )
923 """
924 trace_id = current_trace_id.get()
926 # Record message count
927 if message_count > 0:
928 self.record_metric(
929 db=db,
930 name=f"transport.{transport_type}.messages",
931 value=float(message_count),
932 metric_type="counter",
933 unit="messages",
934 trace_id=trace_id,
935 attributes={
936 "transport": transport_type,
937 "operation": operation,
938 "connection_id": connection_id,
939 },
940 )
942 # Record bytes sent
943 if bytes_sent:
944 self.record_metric(
945 db=db,
946 name=f"transport.{transport_type}.bytes_sent",
947 value=float(bytes_sent),
948 metric_type="counter",
949 unit="bytes",
950 trace_id=trace_id,
951 attributes={
952 "transport": transport_type,
953 "operation": operation,
954 "connection_id": connection_id,
955 },
956 )
958 # Record bytes received
959 if bytes_received:
960 self.record_metric(
961 db=db,
962 name=f"transport.{transport_type}.bytes_received",
963 value=float(bytes_received),
964 metric_type="counter",
965 unit="bytes",
966 trace_id=trace_id,
967 attributes={
968 "transport": transport_type,
969 "operation": operation,
970 "connection_id": connection_id,
971 },
972 )
974 # Record errors
975 if error:
976 self.record_metric(
977 db=db,
978 name=f"transport.{transport_type}.errors",
979 value=1.0,
980 metric_type="counter",
981 unit="errors",
982 trace_id=trace_id,
983 attributes={
984 "transport": transport_type,
985 "operation": operation,
986 "connection_id": connection_id,
987 "error": error,
988 },
989 )
991 logger.debug(f"Recorded {transport_type} transport activity: {operation} ({message_count} messages)")
993 # ==============================
994 # Metric Management
995 # ==============================
997 def record_metric(
998 self,
999 db: Session,
1000 name: str,
1001 value: float,
1002 metric_type: str = "gauge",
1003 unit: Optional[str] = None,
1004 resource_type: Optional[str] = None,
1005 resource_id: Optional[str] = None,
1006 trace_id: Optional[str] = None,
1007 attributes: Optional[Dict[str, Any]] = None,
1008 ) -> int:
1009 """Record a metric.
1011 Args:
1012 db: Database session
1013 name: Metric name (e.g., "http.request.duration")
1014 value: Metric value
1015 metric_type: Metric type (counter, gauge, histogram)
1016 unit: Metric unit (ms, count, bytes, etc.)
1017 resource_type: Resource type
1018 resource_id: Resource ID
1019 trace_id: Associated trace ID
1020 attributes: Additional metric attributes/labels
1022 Returns:
1023 Metric ID
1025 Examples:
1026 >>> metric_id = service.record_metric( # doctest: +SKIP
1027 ... db, # doctest: +SKIP
1028 ... "http.request.duration", # doctest: +SKIP
1029 ... 123.45, # doctest: +SKIP
1030 ... metric_type="histogram", # doctest: +SKIP
1031 ... unit="ms", # doctest: +SKIP
1032 ... trace_id=trace_id # doctest: +SKIP
1033 ... ) # doctest: +SKIP
1034 """
1035 metric = ObservabilityMetric(
1036 name=name,
1037 value=value,
1038 metric_type=metric_type,
1039 timestamp=utc_now(),
1040 unit=unit,
1041 resource_type=resource_type,
1042 resource_id=resource_id,
1043 trace_id=trace_id,
1044 attributes=attributes or {},
1045 created_at=utc_now(),
1046 )
1047 db.add(metric)
1048 if not self._safe_commit(db, "record_metric"):
1049 return 0
1050 db.refresh(metric)
1051 logger.debug(f"Recorded metric: {name} = {value} {unit or ''}")
1052 return metric.id
1054 # ==============================
1055 # Query Methods
1056 # ==============================
1058 # pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-locals
1059 def query_traces( # noqa: PLR0917
1060 self,
1061 db: Session,
1062 start_time: Optional[datetime] = None,
1063 end_time: Optional[datetime] = None,
1064 min_duration_ms: Optional[float] = None,
1065 max_duration_ms: Optional[float] = None,
1066 status: Optional[str] = None,
1067 status_in: Optional[List[str]] = None,
1068 status_not_in: Optional[List[str]] = None,
1069 http_status_code: Optional[int] = None,
1070 http_status_code_in: Optional[List[int]] = None,
1071 http_method: Optional[str] = None,
1072 http_method_in: Optional[List[str]] = None,
1073 user_email: Optional[str] = None,
1074 user_email_in: Optional[List[str]] = None,
1075 attribute_filters: Optional[Dict[str, Any]] = None,
1076 attribute_filters_or: Optional[Dict[str, Any]] = None,
1077 attribute_search: Optional[str] = None,
1078 name_contains: Optional[str] = None,
1079 order_by: str = "start_time_desc",
1080 limit: int = 100,
1081 offset: int = 0,
1082 ) -> List[ObservabilityTrace]:
1083 """Query traces with advanced filters.
1085 Supports both simple filters (single value) and list filters (multiple values with OR logic).
1086 All top-level filters are combined with AND logic unless using _or suffix.
1088 Args:
1089 db: Database session
1090 start_time: Filter traces after this time
1091 end_time: Filter traces before this time
1092 min_duration_ms: Filter traces with duration >= this value (milliseconds)
1093 max_duration_ms: Filter traces with duration <= this value (milliseconds)
1094 status: Filter by single status (ok, error)
1095 status_in: Filter by multiple statuses (OR logic)
1096 status_not_in: Exclude these statuses (NOT logic)
1097 http_status_code: Filter by single HTTP status code
1098 http_status_code_in: Filter by multiple HTTP status codes (OR logic)
1099 http_method: Filter by single HTTP method (GET, POST, etc.)
1100 http_method_in: Filter by multiple HTTP methods (OR logic)
1101 user_email: Filter by single user email
1102 user_email_in: Filter by multiple user emails (OR logic)
1103 attribute_filters: JSON attribute filters (AND logic - all must match)
1104 attribute_filters_or: JSON attribute filters (OR logic - any must match)
1105 attribute_search: Free-text search within JSON attributes (partial match)
1106 name_contains: Filter traces where name contains this substring
1107 order_by: Sort order (start_time_desc, start_time_asc, duration_desc, duration_asc)
1108 limit: Maximum results (1-1000)
1109 offset: Result offset
1111 Returns:
1112 List of traces
1114 Raises:
1115 ValueError: If invalid parameters are provided
1117 Examples:
1118 >>> # Find slow errors from multiple endpoints
1119 >>> traces = service.query_traces( # doctest: +SKIP
1120 ... db,
1121 ... status="error",
1122 ... min_duration_ms=100.0,
1123 ... http_method_in=["POST", "PUT"],
1124 ... attribute_filters={"http.route": "/api/tools"},
1125 ... limit=50
1126 ... )
1127 >>> # Exclude health checks and find slow requests
1128 >>> traces = service.query_traces( # doctest: +SKIP
1129 ... db,
1130 ... min_duration_ms=1000.0,
1131 ... name_contains="api",
1132 ... status_not_in=["ok"],
1133 ... order_by="duration_desc"
1134 ... )
1135 """
1136 # Third-Party
1137 # pylint: disable=import-outside-toplevel
1138 from sqlalchemy import cast, or_, String
1140 # pylint: enable=import-outside-toplevel
1141 # Validate limit
1142 if limit < 1 or limit > 1000:
1143 raise ValueError("limit must be between 1 and 1000")
1145 # Validate order_by
1146 valid_orders = ["start_time_desc", "start_time_asc", "duration_desc", "duration_asc"]
1147 if order_by not in valid_orders:
1148 raise ValueError(f"order_by must be one of: {', '.join(valid_orders)}")
1150 query = db.query(ObservabilityTrace)
1152 # Time range filters
1153 if start_time:
1154 query = query.filter(ObservabilityTrace.start_time >= start_time)
1155 if end_time:
1156 query = query.filter(ObservabilityTrace.start_time <= end_time)
1158 # Duration filters
1159 if min_duration_ms is not None:
1160 query = query.filter(ObservabilityTrace.duration_ms >= min_duration_ms)
1161 if max_duration_ms is not None:
1162 query = query.filter(ObservabilityTrace.duration_ms <= max_duration_ms)
1164 # Status filters (with OR and NOT support)
1165 if status:
1166 query = query.filter(ObservabilityTrace.status == status)
1167 if status_in:
1168 query = query.filter(ObservabilityTrace.status.in_(status_in))
1169 if status_not_in:
1170 query = query.filter(~ObservabilityTrace.status.in_(status_not_in))
1172 # HTTP status code filters (with OR support)
1173 if http_status_code:
1174 query = query.filter(ObservabilityTrace.http_status_code == http_status_code)
1175 if http_status_code_in:
1176 query = query.filter(ObservabilityTrace.http_status_code.in_(http_status_code_in))
1178 # HTTP method filters (with OR support)
1179 if http_method:
1180 query = query.filter(ObservabilityTrace.http_method == http_method)
1181 if http_method_in:
1182 query = query.filter(ObservabilityTrace.http_method.in_(http_method_in))
1184 # User email filters (with OR support)
1185 if user_email:
1186 query = query.filter(ObservabilityTrace.user_email == user_email)
1187 if user_email_in:
1188 query = query.filter(ObservabilityTrace.user_email.in_(user_email_in))
1190 # Name substring filter
1191 if name_contains:
1192 query = query.filter(ObservabilityTrace.name.ilike(f"%{name_contains}%"))
1194 # Attribute-based filtering with AND logic (all filters must match)
1195 if attribute_filters:
1196 for key, value in attribute_filters.items():
1197 # Use JSON path access for filtering
1198 # Supports both SQLite (via json_extract) and PostgreSQL (via ->>)
1199 query = query.filter(ObservabilityTrace.attributes[key].astext == str(value))
1201 # Attribute-based filtering with OR logic (any filter must match)
1202 if attribute_filters_or:
1203 or_conditions = []
1204 for key, value in attribute_filters_or.items():
1205 or_conditions.append(ObservabilityTrace.attributes[key].astext == str(value))
1206 if or_conditions:
1207 query = query.filter(or_(*or_conditions))
1209 # Free-text search across all attribute values
1210 if attribute_search:
1211 # Cast JSON attributes to text and search for substring
1212 # Works with both SQLite and PostgreSQL
1213 # Escape special characters to prevent SQL injection
1214 safe_search = attribute_search.replace("%", "\\%").replace("_", "\\_")
1215 query = query.filter(cast(ObservabilityTrace.attributes, String).ilike(f"%{safe_search}%"))
1217 # Apply ordering
1218 if order_by == "start_time_desc":
1219 query = query.order_by(desc(ObservabilityTrace.start_time))
1220 elif order_by == "start_time_asc":
1221 query = query.order_by(ObservabilityTrace.start_time)
1222 elif order_by == "duration_desc":
1223 query = query.order_by(desc(ObservabilityTrace.duration_ms))
1224 else: # duration_asc (validated above)
1225 query = query.order_by(ObservabilityTrace.duration_ms)
1227 # Apply pagination
1228 query = query.limit(limit).offset(offset)
1230 return query.all()
1232 # pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-locals
1233 def query_spans( # noqa: PLR0917
1234 self,
1235 db: Session,
1236 trace_id: Optional[str] = None,
1237 trace_id_in: Optional[List[str]] = None,
1238 resource_type: Optional[str] = None,
1239 resource_type_in: Optional[List[str]] = None,
1240 resource_name: Optional[str] = None,
1241 resource_name_in: Optional[List[str]] = None,
1242 name_contains: Optional[str] = None,
1243 kind: Optional[str] = None,
1244 kind_in: Optional[List[str]] = None,
1245 status: Optional[str] = None,
1246 status_in: Optional[List[str]] = None,
1247 status_not_in: Optional[List[str]] = None,
1248 start_time: Optional[datetime] = None,
1249 end_time: Optional[datetime] = None,
1250 min_duration_ms: Optional[float] = None,
1251 max_duration_ms: Optional[float] = None,
1252 attribute_filters: Optional[Dict[str, Any]] = None,
1253 attribute_search: Optional[str] = None,
1254 order_by: str = "start_time_desc",
1255 limit: int = 100,
1256 offset: int = 0,
1257 ) -> List[ObservabilitySpan]:
1258 """Query spans with advanced filters.
1260 Supports filtering by trace, resource, kind, status, duration, and attributes.
1261 All top-level filters are combined with AND logic. List filters use OR logic.
1263 Args:
1264 db: Database session
1265 trace_id: Filter by single trace ID
1266 trace_id_in: Filter by multiple trace IDs (OR logic)
1267 resource_type: Filter by single resource type (tool, database, plugin, etc.)
1268 resource_type_in: Filter by multiple resource types (OR logic)
1269 resource_name: Filter by single resource name
1270 resource_name_in: Filter by multiple resource names (OR logic)
1271 name_contains: Filter spans where name contains this substring
1272 kind: Filter by span kind (client, server, internal)
1273 kind_in: Filter by multiple kinds (OR logic)
1274 status: Filter by single status (ok, error)
1275 status_in: Filter by multiple statuses (OR logic)
1276 status_not_in: Exclude these statuses (NOT logic)
1277 start_time: Filter spans after this time
1278 end_time: Filter spans before this time
1279 min_duration_ms: Filter spans with duration >= this value (milliseconds)
1280 max_duration_ms: Filter spans with duration <= this value (milliseconds)
1281 attribute_filters: JSON attribute filters (AND logic)
1282 attribute_search: Free-text search within JSON attributes
1283 order_by: Sort order (start_time_desc, start_time_asc, duration_desc, duration_asc)
1284 limit: Maximum results (1-1000)
1285 offset: Result offset
1287 Returns:
1288 List of spans
1290 Raises:
1291 ValueError: If invalid parameters are provided
1293 Examples:
1294 >>> # Find slow database queries
1295 >>> spans = service.query_spans( # doctest: +SKIP
1296 ... db,
1297 ... resource_type="database",
1298 ... min_duration_ms=100.0,
1299 ... order_by="duration_desc",
1300 ... limit=50
1301 ... )
1302 >>> # Find tool invocation errors
1303 >>> spans = service.query_spans( # doctest: +SKIP
1304 ... db,
1305 ... resource_type="tool",
1306 ... status="error",
1307 ... name_contains="invoke"
1308 ... )
1309 """
1310 # Third-Party
1311 # pylint: disable=import-outside-toplevel
1312 from sqlalchemy import cast, String
1314 # pylint: enable=import-outside-toplevel
1315 # Validate limit
1316 if limit < 1 or limit > 1000:
1317 raise ValueError("limit must be between 1 and 1000")
1319 # Validate order_by
1320 valid_orders = ["start_time_desc", "start_time_asc", "duration_desc", "duration_asc"]
1321 if order_by not in valid_orders:
1322 raise ValueError(f"order_by must be one of: {', '.join(valid_orders)}")
1324 query = db.query(ObservabilitySpan)
1326 # Trace ID filters (with OR support)
1327 if trace_id:
1328 query = query.filter(ObservabilitySpan.trace_id == trace_id)
1329 if trace_id_in:
1330 query = query.filter(ObservabilitySpan.trace_id.in_(trace_id_in))
1332 # Resource type filters (with OR support)
1333 if resource_type:
1334 query = query.filter(ObservabilitySpan.resource_type == resource_type)
1335 if resource_type_in:
1336 query = query.filter(ObservabilitySpan.resource_type.in_(resource_type_in))
1338 # Resource name filters (with OR support)
1339 if resource_name:
1340 query = query.filter(ObservabilitySpan.resource_name == resource_name)
1341 if resource_name_in:
1342 query = query.filter(ObservabilitySpan.resource_name.in_(resource_name_in))
1344 # Name substring filter
1345 if name_contains:
1346 query = query.filter(ObservabilitySpan.name.ilike(f"%{name_contains}%"))
1348 # Kind filters (with OR support)
1349 if kind:
1350 query = query.filter(ObservabilitySpan.kind == kind)
1351 if kind_in:
1352 query = query.filter(ObservabilitySpan.kind.in_(kind_in))
1354 # Status filters (with OR and NOT support)
1355 if status:
1356 query = query.filter(ObservabilitySpan.status == status)
1357 if status_in:
1358 query = query.filter(ObservabilitySpan.status.in_(status_in))
1359 if status_not_in:
1360 query = query.filter(~ObservabilitySpan.status.in_(status_not_in))
1362 # Time range filters
1363 if start_time:
1364 query = query.filter(ObservabilitySpan.start_time >= start_time)
1365 if end_time:
1366 query = query.filter(ObservabilitySpan.start_time <= end_time)
1368 # Duration filters
1369 if min_duration_ms is not None:
1370 query = query.filter(ObservabilitySpan.duration_ms >= min_duration_ms)
1371 if max_duration_ms is not None:
1372 query = query.filter(ObservabilitySpan.duration_ms <= max_duration_ms)
1374 # Attribute-based filtering with AND logic
1375 if attribute_filters:
1376 for key, value in attribute_filters.items():
1377 query = query.filter(ObservabilitySpan.attributes[key].astext == str(value))
1379 # Free-text search across all attribute values
1380 if attribute_search:
1381 safe_search = attribute_search.replace("%", "\\%").replace("_", "\\_")
1382 query = query.filter(cast(ObservabilitySpan.attributes, String).ilike(f"%{safe_search}%"))
1384 # Apply ordering
1385 if order_by == "start_time_desc":
1386 query = query.order_by(desc(ObservabilitySpan.start_time))
1387 elif order_by == "start_time_asc":
1388 query = query.order_by(ObservabilitySpan.start_time)
1389 elif order_by == "duration_desc":
1390 query = query.order_by(desc(ObservabilitySpan.duration_ms))
1391 else: # duration_asc (validated above)
1392 query = query.order_by(ObservabilitySpan.duration_ms)
1394 # Apply pagination
1395 query = query.limit(limit).offset(offset)
1397 return query.all()
1399 def get_trace_with_spans(self, db: Session, trace_id: str) -> Optional[ObservabilityTrace]:
1400 """Get a complete trace with all spans and events.
1402 Args:
1403 db: Database session
1404 trace_id: Trace ID
1406 Returns:
1407 Trace with spans and events loaded
1409 Examples:
1410 >>> trace = service.get_trace_with_spans(db, trace_id) # doctest: +SKIP
1411 >>> if trace: # doctest: +SKIP
1412 ... for span in trace.spans: # doctest: +SKIP
1413 ... print(f"Span: {span.name}, Events: {len(span.events)}") # doctest: +SKIP
1414 """
1415 return db.query(ObservabilityTrace).filter_by(trace_id=trace_id).options(joinedload(ObservabilityTrace.spans).joinedload(ObservabilitySpan.events)).first()
1417 def delete_old_traces(self, db: Session, before_time: datetime) -> int:
1418 """Delete traces older than a given time.
1420 Args:
1421 db: Database session
1422 before_time: Delete traces before this time
1424 Returns:
1425 Number of traces deleted
1427 Examples:
1428 >>> from datetime import timedelta # doctest: +SKIP
1429 >>> cutoff = utc_now() - timedelta(days=30) # doctest: +SKIP
1430 >>> deleted = service.delete_old_traces(db, cutoff) # doctest: +SKIP
1431 >>> print(f"Deleted {deleted} old traces") # doctest: +SKIP
1432 """
1433 deleted = db.query(ObservabilityTrace).filter(ObservabilityTrace.start_time < before_time).delete()
1434 if not self._safe_commit(db, "delete_old_traces"):
1435 return 0
1436 logger.info(f"Deleted {deleted} traces older than {before_time}")
1437 return deleted